|
@@ -57,6 +57,10 @@ type Session struct {
|
|
// or to send a header out directly.
|
|
// or to send a header out directly.
|
|
sendCh chan sendReady
|
|
sendCh chan sendReady
|
|
|
|
|
|
|
|
+ // recvDoneCh is closed when recv() exits to avoid a race
|
|
|
|
+ // between stream registration and stream shutdown
|
|
|
|
+ recvDoneCh chan struct{}
|
|
|
|
+
|
|
// shutdown is used to safely close a session
|
|
// shutdown is used to safely close a session
|
|
shutdown bool
|
|
shutdown bool
|
|
shutdownErr error
|
|
shutdownErr error
|
|
@@ -83,6 +87,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
streams: make(map[uint32]*Stream),
|
|
streams: make(map[uint32]*Stream),
|
|
acceptCh: make(chan *Stream, config.AcceptBacklog),
|
|
acceptCh: make(chan *Stream, config.AcceptBacklog),
|
|
sendCh: make(chan sendReady, 64),
|
|
sendCh: make(chan sendReady, 64),
|
|
|
|
+ recvDoneCh: make(chan struct{}),
|
|
shutdownCh: make(chan struct{}),
|
|
shutdownCh: make(chan struct{}),
|
|
}
|
|
}
|
|
if client {
|
|
if client {
|
|
@@ -182,6 +187,7 @@ func (s *Session) Close() error {
|
|
}
|
|
}
|
|
close(s.shutdownCh)
|
|
close(s.shutdownCh)
|
|
s.conn.Close()
|
|
s.conn.Close()
|
|
|
|
+ <-s.recvDoneCh
|
|
|
|
|
|
s.streamLock.Lock()
|
|
s.streamLock.Lock()
|
|
defer s.streamLock.Unlock()
|
|
defer s.streamLock.Unlock()
|
|
@@ -333,6 +339,14 @@ func (s *Session) send() {
|
|
|
|
|
|
// recv is a long running goroutine that accepts new data
|
|
// recv is a long running goroutine that accepts new data
|
|
func (s *Session) recv() {
|
|
func (s *Session) recv() {
|
|
|
|
+ if err := s.recvLoop(); err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// recvLoop continues to receive data until a fatal error is encountered
|
|
|
|
+func (s *Session) recvLoop() error {
|
|
|
|
+ defer close(s.recvDoneCh)
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
var handler func(header) error
|
|
var handler func(header) error
|
|
for {
|
|
for {
|
|
@@ -341,15 +355,13 @@ func (s *Session) recv() {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
|
|
s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
|
|
s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
|
|
}
|
|
}
|
|
- s.exitErr(err)
|
|
|
|
- return
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
|
|
|
|
// Verify the version
|
|
// Verify the version
|
|
if hdr.Version() != protoVersion {
|
|
if hdr.Version() != protoVersion {
|
|
s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
|
|
s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
|
|
- s.exitErr(ErrInvalidVersion)
|
|
|
|
- return
|
|
|
|
|
|
+ return ErrInvalidVersion
|
|
}
|
|
}
|
|
|
|
|
|
// Switch on the type
|
|
// Switch on the type
|
|
@@ -363,14 +375,12 @@ func (s *Session) recv() {
|
|
case typePing:
|
|
case typePing:
|
|
handler = s.handlePing
|
|
handler = s.handlePing
|
|
default:
|
|
default:
|
|
- s.exitErr(ErrInvalidMsgType)
|
|
|
|
- return
|
|
|
|
|
|
+ return ErrInvalidMsgType
|
|
}
|
|
}
|
|
|
|
|
|
// Invoke the handler
|
|
// Invoke the handler
|
|
if err := handler(hdr); err != nil {
|
|
if err := handler(hdr); err != nil {
|
|
- s.exitErr(err)
|
|
|
|
- return
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|