Bladeren bron

wait for both recv and send routines to complete on close (#105)

this change ensure that there is no write operations after close is
called. This was not the case before and this led to race conditions
in the caller code because there was no way to ensure there would not be
any write anymore after the close.

Co-authored-by: Jeff Mitchell <jeffrey.mitchell@gmail.com>
Clément Michaud 2 jaren geleden
bovenliggende
commit
aad893ec06
2 gewijzigde bestanden met toevoegingen van 38 en 20 verwijderingen
  1. 26 13
      session.go
  2. 12 7
      session_test.go

+ 26 - 13
session.go

@@ -69,12 +69,14 @@ type Session struct {
 	// recvDoneCh is closed when recv() exits to avoid a race
 	// between stream registration and stream shutdown
 	recvDoneCh chan struct{}
+	sendDoneCh chan struct{}
 
 	// shutdown is used to safely close a session
-	shutdown     bool
-	shutdownErr  error
-	shutdownCh   chan struct{}
-	shutdownLock sync.Mutex
+	shutdown        bool
+	shutdownErr     error
+	shutdownCh      chan struct{}
+	shutdownLock    sync.Mutex
+	shutdownErrLock sync.Mutex
 }
 
 // sendReady is used to either mark a stream as ready
@@ -105,6 +107,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan *sendReady, 64),
 		recvDoneCh: make(chan struct{}),
+		sendDoneCh: make(chan struct{}),
 		shutdownCh: make(chan struct{}),
 	}
 	if client {
@@ -257,10 +260,15 @@ func (s *Session) Close() error {
 		return nil
 	}
 	s.shutdown = true
+
+	s.shutdownErrLock.Lock()
 	if s.shutdownErr == nil {
 		s.shutdownErr = ErrSessionShutdown
 	}
+	s.shutdownErrLock.Unlock()
+
 	close(s.shutdownCh)
+
 	s.conn.Close()
 	<-s.recvDoneCh
 
@@ -269,17 +277,18 @@ func (s *Session) Close() error {
 	for _, stream := range s.streams {
 		stream.forceClose()
 	}
+	<-s.sendDoneCh
 	return nil
 }
 
 // exitErr is used to handle an error that is causing the
 // session to terminate.
 func (s *Session) exitErr(err error) {
-	s.shutdownLock.Lock()
+	s.shutdownErrLock.Lock()
 	if s.shutdownErr == nil {
 		s.shutdownErr = err
 	}
-	s.shutdownLock.Unlock()
+	s.shutdownErrLock.Unlock()
 	s.Close()
 }
 
@@ -444,6 +453,13 @@ func (s *Session) sendNoWait(hdr header) error {
 
 // send is a long running goroutine that sends data
 func (s *Session) send() {
+	if err := s.sendLoop(); err != nil {
+		s.exitErr(err)
+	}
+}
+
+func (s *Session) sendLoop() error {
+	defer close(s.sendDoneCh)
 	var bodyBuf bytes.Buffer
 	for {
 		bodyBuf.Reset()
@@ -456,8 +472,7 @@ func (s *Session) send() {
 				if err != nil {
 					s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 			}
 
@@ -471,8 +486,7 @@ func (s *Session) send() {
 					ready.mu.Unlock()
 					s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 				ready.Body = nil
 			}
@@ -484,15 +498,14 @@ func (s *Session) send() {
 				if err != nil {
 					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 			}
 
 			// No error, successful send
 			asyncSendErr(ready.Err, nil)
 		case <-s.shutdownCh:
-			return
+			return nil
 		}
 	}
 }

+ 12 - 7
session_test.go

@@ -463,16 +463,16 @@ func TestSendData_Small(t *testing.T) {
 	}()
 	select {
 	case <-doneCh:
+		if client.NumStreams() != 0 {
+			t.Fatalf("bad")
+		}
+		if server.NumStreams() != 0 {
+			t.Fatalf("bad")
+		}
+		return
 	case <-time.After(time.Second):
 		panic("timeout")
 	}
-
-	if client.NumStreams() != 0 {
-		t.Fatalf("bad")
-	}
-	if server.NumStreams() != 0 {
-		t.Fatalf("bad")
-	}
 }
 
 func TestSendData_Large(t *testing.T) {
@@ -1047,6 +1047,8 @@ func TestKeepAlive_Timeout(t *testing.T) {
 		t.Fatalf("timeout waiting for timeout")
 	}
 
+	clientConn.writeBlocker.Unlock()
+
 	if !server.IsClosed() {
 		t.Fatalf("server should have closed")
 	}
@@ -1243,6 +1245,7 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
 
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 		_, err = stream.Read(make([]byte, flood))
 		if err != ErrConnectionWriteTimeout {
@@ -1338,6 +1341,7 @@ func TestSession_sendNoWait_Timeout(t *testing.T) {
 
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typePing, flagACK, 0, 0)
@@ -1458,6 +1462,7 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) {
 
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 		// Since the write goroutine is blocked then this will return a
 		// timeout since it can't get feedback about whether the write