瀏覽代碼

wait for both recv and send routines to complete on close (#105)

this change ensure that there is no write operations after close is
called. This was not the case before and this led to race conditions
in the caller code because there was no way to ensure there would not be
any write anymore after the close.

Co-authored-by: Jeff Mitchell <jeffrey.mitchell@gmail.com>
Clément Michaud 2 年之前
父節點
當前提交
aad893ec06
共有 2 個文件被更改,包括 38 次插入20 次删除
  1. 26 13
      session.go
  2. 12 7
      session_test.go

+ 26 - 13
session.go

@@ -69,12 +69,14 @@ type Session struct {
 	// recvDoneCh is closed when recv() exits to avoid a race
 	// recvDoneCh is closed when recv() exits to avoid a race
 	// between stream registration and stream shutdown
 	// between stream registration and stream shutdown
 	recvDoneCh chan struct{}
 	recvDoneCh chan struct{}
+	sendDoneCh chan struct{}
 
 
 	// shutdown is used to safely close a session
 	// shutdown is used to safely close a session
-	shutdown     bool
-	shutdownErr  error
-	shutdownCh   chan struct{}
-	shutdownLock sync.Mutex
+	shutdown        bool
+	shutdownErr     error
+	shutdownCh      chan struct{}
+	shutdownLock    sync.Mutex
+	shutdownErrLock sync.Mutex
 }
 }
 
 
 // sendReady is used to either mark a stream as ready
 // sendReady is used to either mark a stream as ready
@@ -105,6 +107,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan *sendReady, 64),
 		sendCh:     make(chan *sendReady, 64),
 		recvDoneCh: make(chan struct{}),
 		recvDoneCh: make(chan struct{}),
+		sendDoneCh: make(chan struct{}),
 		shutdownCh: make(chan struct{}),
 		shutdownCh: make(chan struct{}),
 	}
 	}
 	if client {
 	if client {
@@ -257,10 +260,15 @@ func (s *Session) Close() error {
 		return nil
 		return nil
 	}
 	}
 	s.shutdown = true
 	s.shutdown = true
+
+	s.shutdownErrLock.Lock()
 	if s.shutdownErr == nil {
 	if s.shutdownErr == nil {
 		s.shutdownErr = ErrSessionShutdown
 		s.shutdownErr = ErrSessionShutdown
 	}
 	}
+	s.shutdownErrLock.Unlock()
+
 	close(s.shutdownCh)
 	close(s.shutdownCh)
+
 	s.conn.Close()
 	s.conn.Close()
 	<-s.recvDoneCh
 	<-s.recvDoneCh
 
 
@@ -269,17 +277,18 @@ func (s *Session) Close() error {
 	for _, stream := range s.streams {
 	for _, stream := range s.streams {
 		stream.forceClose()
 		stream.forceClose()
 	}
 	}
+	<-s.sendDoneCh
 	return nil
 	return nil
 }
 }
 
 
 // exitErr is used to handle an error that is causing the
 // exitErr is used to handle an error that is causing the
 // session to terminate.
 // session to terminate.
 func (s *Session) exitErr(err error) {
 func (s *Session) exitErr(err error) {
-	s.shutdownLock.Lock()
+	s.shutdownErrLock.Lock()
 	if s.shutdownErr == nil {
 	if s.shutdownErr == nil {
 		s.shutdownErr = err
 		s.shutdownErr = err
 	}
 	}
-	s.shutdownLock.Unlock()
+	s.shutdownErrLock.Unlock()
 	s.Close()
 	s.Close()
 }
 }
 
 
@@ -444,6 +453,13 @@ func (s *Session) sendNoWait(hdr header) error {
 
 
 // send is a long running goroutine that sends data
 // send is a long running goroutine that sends data
 func (s *Session) send() {
 func (s *Session) send() {
+	if err := s.sendLoop(); err != nil {
+		s.exitErr(err)
+	}
+}
+
+func (s *Session) sendLoop() error {
+	defer close(s.sendDoneCh)
 	var bodyBuf bytes.Buffer
 	var bodyBuf bytes.Buffer
 	for {
 	for {
 		bodyBuf.Reset()
 		bodyBuf.Reset()
@@ -456,8 +472,7 @@ func (s *Session) send() {
 				if err != nil {
 				if err != nil {
 					s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
 					s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
 					asyncSendErr(ready.Err, err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 				}
 			}
 			}
 
 
@@ -471,8 +486,7 @@ func (s *Session) send() {
 					ready.mu.Unlock()
 					ready.mu.Unlock()
 					s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
 					s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
 					asyncSendErr(ready.Err, err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 				}
 				ready.Body = nil
 				ready.Body = nil
 			}
 			}
@@ -484,15 +498,14 @@ func (s *Session) send() {
 				if err != nil {
 				if err != nil {
 					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
 					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
 					asyncSendErr(ready.Err, err)
 					asyncSendErr(ready.Err, err)
-					s.exitErr(err)
-					return
+					return err
 				}
 				}
 			}
 			}
 
 
 			// No error, successful send
 			// No error, successful send
 			asyncSendErr(ready.Err, nil)
 			asyncSendErr(ready.Err, nil)
 		case <-s.shutdownCh:
 		case <-s.shutdownCh:
-			return
+			return nil
 		}
 		}
 	}
 	}
 }
 }

+ 12 - 7
session_test.go

@@ -463,16 +463,16 @@ func TestSendData_Small(t *testing.T) {
 	}()
 	}()
 	select {
 	select {
 	case <-doneCh:
 	case <-doneCh:
+		if client.NumStreams() != 0 {
+			t.Fatalf("bad")
+		}
+		if server.NumStreams() != 0 {
+			t.Fatalf("bad")
+		}
+		return
 	case <-time.After(time.Second):
 	case <-time.After(time.Second):
 		panic("timeout")
 		panic("timeout")
 	}
 	}
-
-	if client.NumStreams() != 0 {
-		t.Fatalf("bad")
-	}
-	if server.NumStreams() != 0 {
-		t.Fatalf("bad")
-	}
 }
 }
 
 
 func TestSendData_Large(t *testing.T) {
 func TestSendData_Large(t *testing.T) {
@@ -1047,6 +1047,8 @@ func TestKeepAlive_Timeout(t *testing.T) {
 		t.Fatalf("timeout waiting for timeout")
 		t.Fatalf("timeout waiting for timeout")
 	}
 	}
 
 
+	clientConn.writeBlocker.Unlock()
+
 	if !server.IsClosed() {
 	if !server.IsClosed() {
 		t.Fatalf("server should have closed")
 		t.Fatalf("server should have closed")
 	}
 	}
@@ -1243,6 +1245,7 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
 
 
 		conn := client.conn.(*pipeConn)
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 
 		_, err = stream.Read(make([]byte, flood))
 		_, err = stream.Read(make([]byte, flood))
 		if err != ErrConnectionWriteTimeout {
 		if err != ErrConnectionWriteTimeout {
@@ -1338,6 +1341,7 @@ func TestSession_sendNoWait_Timeout(t *testing.T) {
 
 
 		conn := client.conn.(*pipeConn)
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 
 		hdr := header(make([]byte, headerSize))
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typePing, flagACK, 0, 0)
 		hdr.encode(typePing, flagACK, 0, 0)
@@ -1458,6 +1462,7 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) {
 
 
 		conn := client.conn.(*pipeConn)
 		conn := client.conn.(*pipeConn)
 		conn.writeBlocker.Lock()
 		conn.writeBlocker.Lock()
+		defer conn.writeBlocker.Unlock()
 
 
 		// Since the write goroutine is blocked then this will return a
 		// Since the write goroutine is blocked then this will return a
 		// timeout since it can't get feedback about whether the write
 		// timeout since it can't get feedback about whether the write