|
@@ -69,12 +69,14 @@ type Session struct {
|
|
// recvDoneCh is closed when recv() exits to avoid a race
|
|
// recvDoneCh is closed when recv() exits to avoid a race
|
|
// between stream registration and stream shutdown
|
|
// between stream registration and stream shutdown
|
|
recvDoneCh chan struct{}
|
|
recvDoneCh chan struct{}
|
|
|
|
+ sendDoneCh chan struct{}
|
|
|
|
|
|
// shutdown is used to safely close a session
|
|
// shutdown is used to safely close a session
|
|
- shutdown bool
|
|
|
|
- shutdownErr error
|
|
|
|
- shutdownCh chan struct{}
|
|
|
|
- shutdownLock sync.Mutex
|
|
|
|
|
|
+ shutdown bool
|
|
|
|
+ shutdownErr error
|
|
|
|
+ shutdownCh chan struct{}
|
|
|
|
+ shutdownLock sync.Mutex
|
|
|
|
+ shutdownErrLock sync.Mutex
|
|
}
|
|
}
|
|
|
|
|
|
// sendReady is used to either mark a stream as ready
|
|
// sendReady is used to either mark a stream as ready
|
|
@@ -105,6 +107,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
acceptCh: make(chan *Stream, config.AcceptBacklog),
|
|
acceptCh: make(chan *Stream, config.AcceptBacklog),
|
|
sendCh: make(chan *sendReady, 64),
|
|
sendCh: make(chan *sendReady, 64),
|
|
recvDoneCh: make(chan struct{}),
|
|
recvDoneCh: make(chan struct{}),
|
|
|
|
+ sendDoneCh: make(chan struct{}),
|
|
shutdownCh: make(chan struct{}),
|
|
shutdownCh: make(chan struct{}),
|
|
}
|
|
}
|
|
if client {
|
|
if client {
|
|
@@ -257,10 +260,15 @@ func (s *Session) Close() error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
s.shutdown = true
|
|
s.shutdown = true
|
|
|
|
+
|
|
|
|
+ s.shutdownErrLock.Lock()
|
|
if s.shutdownErr == nil {
|
|
if s.shutdownErr == nil {
|
|
s.shutdownErr = ErrSessionShutdown
|
|
s.shutdownErr = ErrSessionShutdown
|
|
}
|
|
}
|
|
|
|
+ s.shutdownErrLock.Unlock()
|
|
|
|
+
|
|
close(s.shutdownCh)
|
|
close(s.shutdownCh)
|
|
|
|
+
|
|
s.conn.Close()
|
|
s.conn.Close()
|
|
<-s.recvDoneCh
|
|
<-s.recvDoneCh
|
|
|
|
|
|
@@ -269,17 +277,18 @@ func (s *Session) Close() error {
|
|
for _, stream := range s.streams {
|
|
for _, stream := range s.streams {
|
|
stream.forceClose()
|
|
stream.forceClose()
|
|
}
|
|
}
|
|
|
|
+ <-s.sendDoneCh
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
// exitErr is used to handle an error that is causing the
|
|
// exitErr is used to handle an error that is causing the
|
|
// session to terminate.
|
|
// session to terminate.
|
|
func (s *Session) exitErr(err error) {
|
|
func (s *Session) exitErr(err error) {
|
|
- s.shutdownLock.Lock()
|
|
|
|
|
|
+ s.shutdownErrLock.Lock()
|
|
if s.shutdownErr == nil {
|
|
if s.shutdownErr == nil {
|
|
s.shutdownErr = err
|
|
s.shutdownErr = err
|
|
}
|
|
}
|
|
- s.shutdownLock.Unlock()
|
|
|
|
|
|
+ s.shutdownErrLock.Unlock()
|
|
s.Close()
|
|
s.Close()
|
|
}
|
|
}
|
|
|
|
|
|
@@ -444,6 +453,13 @@ func (s *Session) sendNoWait(hdr header) error {
|
|
|
|
|
|
// send is a long running goroutine that sends data
|
|
// send is a long running goroutine that sends data
|
|
func (s *Session) send() {
|
|
func (s *Session) send() {
|
|
|
|
+ if err := s.sendLoop(); err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (s *Session) sendLoop() error {
|
|
|
|
+ defer close(s.sendDoneCh)
|
|
var bodyBuf bytes.Buffer
|
|
var bodyBuf bytes.Buffer
|
|
for {
|
|
for {
|
|
bodyBuf.Reset()
|
|
bodyBuf.Reset()
|
|
@@ -456,8 +472,7 @@ func (s *Session) send() {
|
|
if err != nil {
|
|
if err != nil {
|
|
s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
|
|
s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
|
|
asyncSendErr(ready.Err, err)
|
|
asyncSendErr(ready.Err, err)
|
|
- s.exitErr(err)
|
|
|
|
- return
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -471,8 +486,7 @@ func (s *Session) send() {
|
|
ready.mu.Unlock()
|
|
ready.mu.Unlock()
|
|
s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
|
|
s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
|
|
asyncSendErr(ready.Err, err)
|
|
asyncSendErr(ready.Err, err)
|
|
- s.exitErr(err)
|
|
|
|
- return
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
ready.Body = nil
|
|
ready.Body = nil
|
|
}
|
|
}
|
|
@@ -484,15 +498,14 @@ func (s *Session) send() {
|
|
if err != nil {
|
|
if err != nil {
|
|
s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
|
|
s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
|
|
asyncSendErr(ready.Err, err)
|
|
asyncSendErr(ready.Err, err)
|
|
- s.exitErr(err)
|
|
|
|
- return
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// No error, successful send
|
|
// No error, successful send
|
|
asyncSendErr(ready.Err, nil)
|
|
asyncSendErr(ready.Err, nil)
|
|
case <-s.shutdownCh:
|
|
case <-s.shutdownCh:
|
|
- return
|
|
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|