Forráskód Böngészése

Session close waits for receive loop to terminate

Armon Dadgar 10 éve
szülő
commit
24e7d75fd7
1 módosított fájl, 18 hozzáadás és 8 törlés
  1. 18 8
      session.go

+ 18 - 8
session.go

@@ -57,6 +57,10 @@ type Session struct {
 	// or to send a header out directly.
 	sendCh chan sendReady
 
+	// recvDoneCh is closed when recv() exits to avoid a race
+	// between stream registration and stream shutdown
+	recvDoneCh chan struct{}
+
 	// shutdown is used to safely close a session
 	shutdown     bool
 	shutdownErr  error
@@ -83,6 +87,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		streams:    make(map[uint32]*Stream),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan sendReady, 64),
+		recvDoneCh: make(chan struct{}),
 		shutdownCh: make(chan struct{}),
 	}
 	if client {
@@ -182,6 +187,7 @@ func (s *Session) Close() error {
 	}
 	close(s.shutdownCh)
 	s.conn.Close()
+	<-s.recvDoneCh
 
 	s.streamLock.Lock()
 	defer s.streamLock.Unlock()
@@ -333,6 +339,14 @@ func (s *Session) send() {
 
 // recv is a long running goroutine that accepts new data
 func (s *Session) recv() {
+	if err := s.recvLoop(); err != nil {
+		s.exitErr(err)
+	}
+}
+
+// recvLoop continues to receive data until a fatal error is encountered
+func (s *Session) recvLoop() error {
+	defer close(s.recvDoneCh)
 	hdr := header(make([]byte, headerSize))
 	var handler func(header) error
 	for {
@@ -341,15 +355,13 @@ func (s *Session) recv() {
 			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
 				s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
 			}
-			s.exitErr(err)
-			return
+			return err
 		}
 
 		// Verify the version
 		if hdr.Version() != protoVersion {
 			s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
-			s.exitErr(ErrInvalidVersion)
-			return
+			return ErrInvalidVersion
 		}
 
 		// Switch on the type
@@ -363,14 +375,12 @@ func (s *Session) recv() {
 		case typePing:
 			handler = s.handlePing
 		default:
-			s.exitErr(ErrInvalidMsgType)
-			return
+			return ErrInvalidMsgType
 		}
 
 		// Invoke the handler
 		if err := handler(hdr); err != nil {
-			s.exitErr(err)
-			return
+			return err
 		}
 	}
 }