瀏覽代碼

server添加上下文的支持

lxg 3 年之前
父節點
當前提交
d148437165
共有 8 個文件被更改,包括 69 次插入33 次删除
  1. 10 0
      gateway/cli/context.go
  2. 9 3
      gateway/cli/server.go
  3. 3 3
      gateway/http/server.go
  4. 0 1
      gateway/rpc/client.go
  5. 5 3
      gateway/rpc/server.go
  6. 19 1
      helper/utils/utils.go
  7. 20 19
      options.go
  8. 3 3
      service.go

+ 10 - 0
gateway/cli/context.go

@@ -1,6 +1,7 @@
 package cli
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -26,6 +27,15 @@ type Context struct {
 	locker   sync.RWMutex
 	Values   map[string]interface{}
 	response *Response
+	ctx      context.Context
+}
+
+func (ctx *Context) Context() context.Context {
+	return ctx.ctx
+}
+
+func (ctx *Context) WithContext(c context.Context) {
+	ctx.ctx = c
 }
 
 //HasArgument 是否有指定的参数

+ 9 - 3
gateway/cli/server.go

@@ -1,6 +1,7 @@
 package cli
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	byte2 "git.nspix.com/golang/micro/helper/pool/byte"
@@ -25,6 +26,7 @@ type HandleFunc func(ctx *Context) (err error)
 
 type Server struct {
 	seq        int32
+	ctx        context.Context
 	locker     sync.RWMutex
 	executor   *Executor
 	contextMap map[int32]*Context
@@ -75,9 +77,11 @@ func (svr *Server) process(id int32, conn net.Conn) (err error) {
 				err = svr.writePack(conn, &Frame{Type: PacketTypeEcho, Data: buf})
 			}
 		case PacketTypeData:
+			var (
+				ok  bool
+				ctx *Context
+			)
 			tokens := strings.Fields(strings.TrimSpace(string(reqPacket.Data)))
-			var ok bool
-			var ctx *Context
 			svr.locker.Lock()
 			if ctx, ok = svr.contextMap[id]; !ok {
 				ctx = &Context{
@@ -87,6 +91,7 @@ func (svr *Server) process(id int32, conn net.Conn) (err error) {
 				svr.contextMap[id] = ctx
 			}
 			svr.locker.Unlock()
+			ctx.WithContext(svr.ctx)
 			ctx.reset(strings.TrimSpace(string(reqPacket.Data)))
 			if res, err = svr.executor.Do(ctx, tokens...); err == nil {
 				if res.Code == 0 {
@@ -198,9 +203,10 @@ func (svr *Server) Append(child *Executor) {
 	svr.executor.Append(child)
 }
 
-func New() *Server {
+func New(ctx context.Context) *Server {
 	return &Server{
 		seq:        0,
+		ctx:        ctx,
 		executor:   NewExecutor("ROOT"),
 		contextMap: make(map[int32]*Context),
 	}

+ 3 - 3
gateway/http/server.go

@@ -26,6 +26,7 @@ type Node struct {
 }
 
 type Server struct {
+	ctx            context.Context
 	svr            *http.Server
 	middleware     []Middleware
 	keepAlive      bool
@@ -79,8 +80,6 @@ func (r *Server) Websocket(path string, h websocket.Handler) {
 	r.Handler("GET", path, r.createWebsocket(h))
 }
 
-
-
 func (r *Server) Handle(method string, path string, h HandleFunc, middleware ...Middleware) {
 	r.router.Handle(method, path, func(writer http.ResponseWriter, request *http.Request, params router.Params) {
 		ctx := getContext()
@@ -178,8 +177,9 @@ func (r *Server) Shutdown(ctx context.Context) (err error) {
 	return
 }
 
-func New() *Server {
+func New(ctx context.Context) *Server {
 	return &Server{
+		ctx:    ctx,
 		router: router.New(),
 	}
 }

+ 0 - 1
gateway/rpc/client.go

@@ -86,7 +86,6 @@ func (c *Client) eventLoop() {
 }
 
 func (c *Client) rdyLoop() {
-	log.Infof("RPC: connection %s connected", c.conn.LocalAddr())
 	defer atomic.StoreInt32(&c.isConnected, 0)
 	for {
 		if frame, err := readFrame(c.conn); err == nil {

+ 5 - 3
gateway/rpc/server.go

@@ -1,6 +1,7 @@
 package rpc
 
 import (
+	"context"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -11,6 +12,7 @@ import (
 type HandleFunc func(ctx *Context) error
 
 type Server struct {
+	ctx        context.Context
 	listener   net.Listener
 	ch         chan *Request
 	ctxPool    sync.Pool
@@ -87,10 +89,8 @@ func (svr *Server) process(conn net.Conn) {
 		err   error
 		frame *Frame
 	)
-	log.Infof("RPC: connection %s connecting", conn.RemoteAddr())
 	svr.sessions.Store(conn.LocalAddr().String(), conn)
 	defer func() {
-		log.Infof("RPC: connection %s closed", conn.RemoteAddr())
 		svr.sessions.Delete(conn.LocalAddr().String())
 		_ = conn.Close()
 	}()
@@ -102,6 +102,7 @@ func (svr *Server) process(conn net.Conn) {
 		case FuncPing:
 			if err = writeFrame(conn, &Frame{Func: FuncPing}); err != nil {
 				log.Warnf("RPC: write ping frame error: %s", err.Error())
+				return
 			}
 		case FuncRequest:
 			//read request
@@ -153,8 +154,9 @@ func (svr *Server) Close() (err error) {
 	return
 }
 
-func NewServer() *Server {
+func New(ctx context.Context) *Server {
 	return &Server{
+		ctx:      ctx,
 		ch:       make(chan *Request, 10),
 		exitChan: make(chan struct{}),
 	}

+ 19 - 1
helper/utils/utils.go

@@ -1,6 +1,7 @@
 package utils
 
 import (
+	"crypto/md5"
 	"encoding/binary"
 	"errors"
 	"math/rand"
@@ -12,6 +13,24 @@ import (
 	"time"
 )
 
+//MD5 get byte md5
+func MD5(b []byte) []byte {
+	hash := md5.New()
+	hash.Write(b)
+	return hash.Sum(nil)
+}
+
+//MD5File get file md5
+func MD5File(filename string) (b []byte, err error) {
+	var (
+		buf []byte
+	)
+	if buf, err = os.ReadFile(filename); err != nil {
+		return
+	}
+	return MD5(buf), nil
+}
+
 //LowerFirst Make a string's first character lowercase
 func LowerFirst(s string) string {
 	isFirst := true
@@ -55,7 +74,6 @@ func InArray(needle interface{}, haystack interface{}) bool {
 	default:
 		panic("haystack: haystack type must be slice, array or map")
 	}
-
 	return false
 }
 

+ 20 - 19
options.go

@@ -9,25 +9,25 @@ import (
 
 type (
 	Options struct {
-		Zone                    string            //注册域
-		Name                    string            //名称
-		Version                 string            //版本号
-		EnableHttp              bool              //启用HTTP功能
-		EnableRPC               bool              //启用RPC功能
-		EnableInternalListener  bool              //启用内置网络监听服务
-		DisableRegister         bool              //禁用注册
-		registry                registry.Registry //注册仓库
-		Server                  Server            //加载的服务
-		Port                    int               //绑定端口
-		Address                 string            //绑定地址
-		EnableHttpPProf         bool              //启用HTTP调试工具
-		EnableStats             bool              //启用数据统计
-		EnableLogPrefix         bool              //启用日志前缀
-		EnableCli               bool              //启用cli模式
-		EnableReport            bool              //启用数据上报
-		RegistryArguments       map[string]string //注册参数
-		Context                 context.Context
-		shortName               string
+		Zone                   string            //注册域
+		Name                   string            //名称
+		Version                string            //版本号
+		EnableHttp             bool              //启用HTTP功能
+		EnableRPC              bool              //启用RPC功能
+		EnableInternalListener bool              //启用内置网络监听服务
+		DisableRegister        bool              //禁用注册
+		registry               registry.Registry //注册仓库
+		Server                 Server            //加载的服务
+		Port                   int               //绑定端口
+		Address                string            //绑定地址
+		EnableHttpPProf        bool              //启用HTTP调试工具
+		EnableStats            bool              //启用数据统计
+		EnableLogPrefix        bool              //启用日志前缀
+		EnableCli              bool              //启用cli模式
+		EnableReport           bool              //启用数据上报
+		RegistryArguments      map[string]string //注册参数
+		Context                context.Context
+		shortName              string
 	}
 
 	Option func(o *Options)
@@ -121,6 +121,7 @@ func NewOptions() *Options {
 		EnableInternalListener: true,
 		EnableLogPrefix:        true,
 		EnableReport:           true,
+		EnableHttpPProf:        true,
 		Context:                context.Background(),
 		registry:               registry.DefaultRegistry,
 	}

+ 3 - 3
service.go

@@ -296,6 +296,7 @@ func (svr *Service) instance() *registry.ServiceNode {
 }
 
 func (svr *Service) startHTTPServe() (err error) {
+	svr.httpSvr = http.New(svr.ctx)
 	l := gateway.NewListener(svr.listener.Addr())
 	if err = svr.gateway.Attaches([][]byte{[]byte("GET"), []byte("POST"), []byte("PUT"), []byte("DELETE"), []byte("OPTIONS")}, l); err == nil {
 		svr.async(func() {
@@ -336,6 +337,7 @@ func (svr *Service) startHTTPServe() (err error) {
 }
 
 func (svr *Service) startRPCServe() (err error) {
+	svr.rpcSvr = rpc.New(svr.ctx)
 	l := gateway.NewListener(svr.listener.Addr())
 	if err = svr.gateway.Attach([]byte("RPC"), l); err == nil {
 		svr.async(func() {
@@ -352,6 +354,7 @@ func (svr *Service) startRPCServe() (err error) {
 }
 
 func (svr *Service) startCliServe() (err error) {
+	svr.cliSvr = cli.New(svr.ctx)
 	l := gateway.NewListener(svr.listener.Addr())
 	if err = svr.gateway.Attach([]byte("CLI"), l); err == nil {
 		svr.async(func() {
@@ -524,9 +527,6 @@ func New(opts ...Option) *Service {
 	svr := &Service{
 		opts:        o,
 		upTime:      time.Now(),
-		httpSvr:     http.New(),
-		cliSvr:      cli.New(),
-		rpcSvr:      rpc.NewServer(),
 		registry:    o.registry,
 		tickTimer:   time.NewTimer(math.MaxInt64),
 		tickTree:    btree.New(64),