lxg 4 rokov pred
rodič
commit
0abaf47dc5
2 zmenil súbory, kde vykonal 27 pridanie a 5 odobranie
  1. 13 3
      gateway/rpc/client.go
  2. 14 2
      gateway/rpc/server.go

+ 13 - 3
gateway/rpc/client.go

@@ -90,7 +90,10 @@ func (c *Client) rdyLoop() {
 	defer atomic.StoreInt32(&c.isConnected, 0)
 	for {
 		if frame, err := readFrame(c.conn); err == nil {
-			if frame.Func == FuncResponse {
+			switch frame.Func {
+			case FuncPing:
+				c.pintAt = time.Now()
+			case FuncResponse:
 				c.transactionLocker.RLock()
 				ch, ok := c.transaction[frame.Sequence]
 				c.transactionLocker.RUnlock()
@@ -100,9 +103,9 @@ func (c *Client) rdyLoop() {
 					} else {
 						ch.Cancel()
 					}
+				} else {
+					log.Warnf("RPC: connection %s response %d dropped", c.conn.LocalAddr(), frame.Sequence)
 				}
-			} else if frame.Func == FuncPing {
-				c.pintAt = time.Now()
 			}
 		} else {
 			log.Infof("RPC: connection %s closed", c.conn.LocalAddr())
@@ -182,6 +185,8 @@ func (c *Client) Do(ctx context.Context, req *Request) (res *Response, err error
 			//canceled
 			err = io.ErrClosedPipe
 		}
+	case <-c.exitChan:
+		err = io.ErrClosedPipe
 	case <-ctx.Done():
 		trans.Cancel()
 		err = errors.New("Client.Timeout exceeded while awaiting response")
@@ -191,6 +196,11 @@ func (c *Client) Do(ctx context.Context, req *Request) (res *Response, err error
 
 func (c *Client) Close() (err error) {
 	if atomic.CompareAndSwapInt32(&c.exitFlag, 0, 1) {
+		c.transactionLocker.Lock()
+		for _, t := range c.transaction {
+			t.Cancel()
+		}
+		c.transactionLocker.Unlock()
 		if c.conn != nil {
 			err = c.conn.Close()
 		}

+ 14 - 2
gateway/rpc/server.go

@@ -1,10 +1,11 @@
 package rpc
 
 import (
-	"git.nspix.com/golang/micro/log"
 	"net"
 	"sync"
 	"sync/atomic"
+
+	"git.nspix.com/golang/micro/log"
 )
 
 type HandleFunc func(ctx *Context) error
@@ -14,6 +15,7 @@ type Server struct {
 	ch         chan *Request
 	ctxPool    sync.Pool
 	serviceMap sync.Map // map[string]HandleFunc
+	sessions   sync.Map
 	exitFlag   int32
 	exitChan   chan struct{}
 }
@@ -85,7 +87,11 @@ func (svr *Server) process(conn net.Conn) {
 		err   error
 		frame *Frame
 	)
+	log.Infof("RPC: connection %s connecting", conn.RemoteAddr())
+	svr.sessions.Store(conn.LocalAddr().String(), conn)
 	defer func() {
+		log.Infof("RPC: connection %s closed", conn.RemoteAddr())
+		svr.sessions.Delete(conn.LocalAddr().String())
 		_ = conn.Close()
 	}()
 	for {
@@ -98,7 +104,7 @@ func (svr *Server) process(conn net.Conn) {
 				log.Warnf("RPC: write ping frame error: %s", err.Error())
 			}
 		case FuncRequest:
-			//读取一个请求
+			//read request
 			if req, err2 := ReadRequest(frame.Data); err2 == nil {
 				req.reset(frame.Sequence, conn)
 				select {
@@ -135,6 +141,12 @@ func (svr *Server) Serve(l net.Listener) (err error) {
 
 func (svr *Server) Close() (err error) {
 	if atomic.CompareAndSwapInt32(&svr.exitFlag, 0, 1) {
+		//clear sessions
+		svr.sessions.Range(func(key, value interface{}) bool {
+			c := value.(net.Conn)
+			c.Close()
+			return true
+		})
 		err = svr.listener.Close()
 		close(svr.exitChan)
 	}