|
@@ -3,6 +3,7 @@ package yamux
|
|
import (
|
|
import (
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
|
|
+ "math"
|
|
"net"
|
|
"net"
|
|
"sync"
|
|
"sync"
|
|
"time"
|
|
"time"
|
|
@@ -20,6 +21,33 @@ var (
|
|
// ErrSessionShutdown is used if there is a shutdown during
|
|
// ErrSessionShutdown is used if there is a shutdown during
|
|
// an operation
|
|
// an operation
|
|
ErrSessionShutdown = fmt.Errorf("session shutdown")
|
|
ErrSessionShutdown = fmt.Errorf("session shutdown")
|
|
|
|
+
|
|
|
|
+ // ErrStreamsExhausted is returned if we have no more
|
|
|
|
+ // stream ids to issue
|
|
|
|
+ ErrStreamsExhausted = fmt.Errorf("streams exhausted")
|
|
|
|
+
|
|
|
|
+ // ErrDuplicateStream is used if a duplicate stream is
|
|
|
|
+ // opened inbound
|
|
|
|
+ ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
|
|
|
|
+
|
|
|
|
+ // ErrMissingStream indicates a stream was named which
|
|
|
|
+ // does not exist.
|
|
|
|
+ ErrMissingStream = fmt.Errorf("missing stream references")
|
|
|
|
+
|
|
|
|
+ // ErrReceiveWindowExceeded indicates the window was exceeded
|
|
|
|
+ ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
|
|
|
|
+
|
|
|
|
+ // ErrTimeout is used when we reach an IO deadline
|
|
|
|
+ ErrTimeout = fmt.Errorf("i/o deadline reached")
|
|
|
|
+
|
|
|
|
+ // ErrStreamClosed is returned when using a closed stream
|
|
|
|
+ ErrStreamClosed = fmt.Errorf("stream closed")
|
|
|
|
+
|
|
|
|
+ // ErrUnexpectedFlag is set when we get an unexpected flag
|
|
|
|
+ ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
|
|
|
|
+
|
|
|
|
+ // ErrRemoteGoAway is used when we get a go away from the other side
|
|
|
|
+ ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
|
|
)
|
|
)
|
|
|
|
|
|
// Session is used to wrap a reliable ordered connection and to
|
|
// Session is used to wrap a reliable ordered connection and to
|
|
@@ -34,17 +62,26 @@ type Session struct {
|
|
// conn is the underlying connection
|
|
// conn is the underlying connection
|
|
conn io.ReadWriteCloser
|
|
conn io.ReadWriteCloser
|
|
|
|
|
|
- // nextStreamID is the next stream we should
|
|
|
|
- // send. This depends if we are a client/server.
|
|
|
|
- nextStreamID uint32
|
|
|
|
-
|
|
|
|
// pings is used to track inflight pings
|
|
// pings is used to track inflight pings
|
|
pings map[uint32]chan struct{}
|
|
pings map[uint32]chan struct{}
|
|
pingID uint32
|
|
pingID uint32
|
|
pingLock sync.Mutex
|
|
pingLock sync.Mutex
|
|
|
|
|
|
|
|
+ // remoteGoAway indicates the remote side does
|
|
|
|
+ // not want futher connections
|
|
|
|
+ remoteGoAway bool
|
|
|
|
+
|
|
|
|
+ // localGoAway indicates that we should stop
|
|
|
|
+ // accepting futher connections
|
|
|
|
+ localGoAway bool
|
|
|
|
+
|
|
|
|
+ // nextStreamID is the next stream we should
|
|
|
|
+ // send. This depends if we are a client/server.
|
|
|
|
+ nextStreamID uint32
|
|
|
|
+
|
|
// streams maps a stream id to a stream
|
|
// streams maps a stream id to a stream
|
|
- streams map[uint32]*Stream
|
|
|
|
|
|
+ streams map[uint32]*Stream
|
|
|
|
+ streamLock sync.RWMutex
|
|
|
|
|
|
// acceptCh is used to pass ready streams to the client
|
|
// acceptCh is used to pass ready streams to the client
|
|
acceptCh chan *Stream
|
|
acceptCh chan *Stream
|
|
@@ -82,9 +119,9 @@ func (y *yamuxAddr) String() string {
|
|
// sendReady is used to either mark a stream as ready
|
|
// sendReady is used to either mark a stream as ready
|
|
// or to directly send a header
|
|
// or to directly send a header
|
|
type sendReady struct {
|
|
type sendReady struct {
|
|
- StreamID uint32
|
|
|
|
- Hdr []byte
|
|
|
|
- Err chan error
|
|
|
|
|
|
+ Hdr []byte
|
|
|
|
+ Body io.Reader
|
|
|
|
+ Err chan error
|
|
}
|
|
}
|
|
|
|
|
|
// newSession is used to construct a new session
|
|
// newSession is used to construct a new session
|
|
@@ -112,9 +149,41 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
return s
|
|
return s
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// isShutdown does a safe check to see if we have shutdown
|
|
|
|
+func (s *Session) isShutdown() bool {
|
|
|
|
+ select {
|
|
|
|
+ case <-s.shutdownCh:
|
|
|
|
+ return true
|
|
|
|
+ default:
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
// Open is used to create a new stream
|
|
// Open is used to create a new stream
|
|
func (s *Session) Open() (*Stream, error) {
|
|
func (s *Session) Open() (*Stream, error) {
|
|
- return nil, nil
|
|
|
|
|
|
+ if s.isShutdown() {
|
|
|
|
+ return nil, ErrSessionShutdown
|
|
|
|
+ }
|
|
|
|
+ if s.remoteGoAway {
|
|
|
|
+ return nil, ErrRemoteGoAway
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.streamLock.Lock()
|
|
|
|
+ defer s.streamLock.Unlock()
|
|
|
|
+
|
|
|
|
+ // Check if we've exhaused the streams
|
|
|
|
+ id := s.nextStreamID
|
|
|
|
+ if id >= math.MaxUint32-1 {
|
|
|
|
+ return nil, ErrStreamsExhausted
|
|
|
|
+ }
|
|
|
|
+ s.nextStreamID += 2
|
|
|
|
+
|
|
|
|
+ // Register the stream
|
|
|
|
+ stream := newStream(s, id, streamInit)
|
|
|
|
+ s.streams[id] = stream
|
|
|
|
+
|
|
|
|
+ // Send the window update to create
|
|
|
|
+ return stream, stream.sendWindowUpdate()
|
|
}
|
|
}
|
|
|
|
|
|
// Accept is used to block until the next available stream
|
|
// Accept is used to block until the next available stream
|
|
@@ -144,8 +213,25 @@ func (s *Session) Close() error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
s.shutdown = true
|
|
s.shutdown = true
|
|
|
|
+ if s.shutdownErr == nil {
|
|
|
|
+ s.shutdownErr = ErrSessionShutdown
|
|
|
|
+ }
|
|
close(s.shutdownCh)
|
|
close(s.shutdownCh)
|
|
s.conn.Close()
|
|
s.conn.Close()
|
|
|
|
+
|
|
|
|
+ s.streamLock.Lock()
|
|
|
|
+ defer s.streamLock.Unlock()
|
|
|
|
+ for _, stream := range s.streams {
|
|
|
|
+ stream.forceClose()
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// GoAway can be used to prevent accepting further
|
|
|
|
+// connections. It does not close the underlying conn.
|
|
|
|
+func (s *Session) GoAway() error {
|
|
|
|
+ s.localGoAway = true
|
|
|
|
+ s.goAway(goAwayNormal)
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -189,7 +275,7 @@ func (s *Session) Ping() (time.Duration, error) {
|
|
// Send the ping request
|
|
// Send the ping request
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr.encode(typePing, flagSYN, 0, id)
|
|
hdr.encode(typePing, flagSYN, 0, id)
|
|
- if err := s.waitForSend(hdr); err != nil {
|
|
|
|
|
|
+ if err := s.waitForSend(hdr, nil); err != nil {
|
|
return 0, err
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
|
|
@@ -219,9 +305,9 @@ func (s *Session) keepalive() {
|
|
}
|
|
}
|
|
|
|
|
|
// waitForSend waits to send a header, checking for a potential shutdown
|
|
// waitForSend waits to send a header, checking for a potential shutdown
|
|
-func (s *Session) waitForSend(hdr header) error {
|
|
|
|
|
|
+func (s *Session) waitForSend(hdr header, body io.Reader) error {
|
|
errCh := make(chan error, 1)
|
|
errCh := make(chan error, 1)
|
|
- ready := sendReady{Hdr: hdr, Err: errCh}
|
|
|
|
|
|
+ ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
|
|
select {
|
|
select {
|
|
case s.sendCh <- ready:
|
|
case s.sendCh <- ready:
|
|
case <-s.shutdownCh:
|
|
case <-s.shutdownCh:
|
|
@@ -250,11 +336,6 @@ func (s *Session) send() {
|
|
for {
|
|
for {
|
|
select {
|
|
select {
|
|
case ready := <-s.sendCh:
|
|
case ready := <-s.sendCh:
|
|
- // Send data from a stream if ready
|
|
|
|
- if ready.StreamID != 0 {
|
|
|
|
-
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
// Send a header if ready
|
|
// Send a header if ready
|
|
if ready.Hdr != nil {
|
|
if ready.Hdr != nil {
|
|
sent := 0
|
|
sent := 0
|
|
@@ -263,10 +344,23 @@ func (s *Session) send() {
|
|
if err != nil {
|
|
if err != nil {
|
|
s.exitErr(err)
|
|
s.exitErr(err)
|
|
asyncSendErr(ready.Err, err)
|
|
asyncSendErr(ready.Err, err)
|
|
|
|
+ return
|
|
}
|
|
}
|
|
sent += n
|
|
sent += n
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ // Send data from a body if given
|
|
|
|
+ if ready.Body != nil {
|
|
|
|
+ _, err := io.Copy(s.conn, ready.Body)
|
|
|
|
+ if err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ asyncSendErr(ready.Err, err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // No error, successful send
|
|
asyncSendErr(ready.Err, nil)
|
|
asyncSendErr(ready.Err, nil)
|
|
case <-s.shutdownCh:
|
|
case <-s.shutdownCh:
|
|
return
|
|
return
|
|
@@ -277,7 +371,7 @@ func (s *Session) send() {
|
|
// recv is a long running goroutine that accepts new data
|
|
// recv is a long running goroutine that accepts new data
|
|
func (s *Session) recv() {
|
|
func (s *Session) recv() {
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
- for {
|
|
|
|
|
|
+ for !s.isShutdown() {
|
|
// Read the header
|
|
// Read the header
|
|
if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
|
if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
|
s.exitErr(err)
|
|
s.exitErr(err)
|
|
@@ -294,13 +388,22 @@ func (s *Session) recv() {
|
|
msgType := hdr.MsgType()
|
|
msgType := hdr.MsgType()
|
|
switch msgType {
|
|
switch msgType {
|
|
case typeData:
|
|
case typeData:
|
|
- s.handleData(hdr)
|
|
|
|
|
|
+ fallthrough
|
|
case typeWindowUpdate:
|
|
case typeWindowUpdate:
|
|
- s.handleWindowUpdate(hdr)
|
|
|
|
- case typePing:
|
|
|
|
- s.handlePing(hdr)
|
|
|
|
|
|
+ if err := s.handleStreamMessage(hdr); err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
case typeGoAway:
|
|
case typeGoAway:
|
|
- s.handleGoAway(hdr)
|
|
|
|
|
|
+ if err := s.handleGoAway(hdr); err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ case typePing:
|
|
|
|
+ if err := s.handlePing(hdr); err != nil {
|
|
|
|
+ s.exitErr(err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
default:
|
|
default:
|
|
s.exitErr(ErrInvalidMsgType)
|
|
s.exitErr(ErrInvalidMsgType)
|
|
return
|
|
return
|
|
@@ -308,28 +411,46 @@ func (s *Session) recv() {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-// handleData is invokde for a typeData frame
|
|
|
|
-func (s *Session) handleData(hdr header) {
|
|
|
|
- flags := hdr.Flags()
|
|
|
|
-
|
|
|
|
|
|
+// handleStreamMessage handles either a data or window update frame
|
|
|
|
+func (s *Session) handleStreamMessage(hdr header) error {
|
|
// Check for a new stream creation
|
|
// Check for a new stream creation
|
|
|
|
+ id := hdr.StreamID()
|
|
|
|
+ flags := hdr.Flags()
|
|
if flags&flagSYN == flagSYN {
|
|
if flags&flagSYN == flagSYN {
|
|
- s.createStream(hdr.StreamID())
|
|
|
|
|
|
+ if err := s.incomingStream(id); err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
}
|
|
}
|
|
-}
|
|
|
|
|
|
|
|
-// handleWindowUpdate is invokde for a typeWindowUpdate frame
|
|
|
|
-func (s *Session) handleWindowUpdate(hdr header) {
|
|
|
|
- flags := hdr.Flags()
|
|
|
|
|
|
+ // Get the stream
|
|
|
|
+ s.streamLock.RLock()
|
|
|
|
+ stream := s.streams[id]
|
|
|
|
+ s.streamLock.RUnlock()
|
|
|
|
|
|
- // Check for a new stream creation
|
|
|
|
- if flags&flagSYN == flagSYN {
|
|
|
|
- s.createStream(hdr.StreamID())
|
|
|
|
|
|
+ // Make sure we have a stream
|
|
|
|
+ if stream == nil {
|
|
|
|
+ s.goAway(goAwayProtoErr)
|
|
|
|
+ return ErrMissingStream
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Check if this is a window update
|
|
|
|
+ if hdr.MsgType() == typeWindowUpdate {
|
|
|
|
+ if err := stream.incrSendWindow(hdr, flags); err != nil {
|
|
|
|
+ s.goAway(goAwayProtoErr)
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ // Read the new data
|
|
|
|
+ if err := stream.readData(hdr, flags, s.conn); err != nil {
|
|
|
|
+ s.goAway(goAwayProtoErr)
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
// handlePing is invokde for a typePing frame
|
|
// handlePing is invokde for a typePing frame
|
|
-func (s *Session) handlePing(hdr header) {
|
|
|
|
|
|
+func (s *Session) handlePing(hdr header) error {
|
|
flags := hdr.Flags()
|
|
flags := hdr.Flags()
|
|
pingID := hdr.Length()
|
|
pingID := hdr.Length()
|
|
|
|
|
|
@@ -338,7 +459,7 @@ func (s *Session) handlePing(hdr header) {
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr.encode(typePing, flagACK, 0, pingID)
|
|
hdr.encode(typePing, flagACK, 0, pingID)
|
|
s.sendNoWait(hdr)
|
|
s.sendNoWait(hdr)
|
|
- return
|
|
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
// Handle a response
|
|
// Handle a response
|
|
@@ -349,16 +470,30 @@ func (s *Session) handlePing(hdr header) {
|
|
close(ch)
|
|
close(ch)
|
|
}
|
|
}
|
|
s.pingLock.Unlock()
|
|
s.pingLock.Unlock()
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
// handleGoAway is invokde for a typeGoAway frame
|
|
// handleGoAway is invokde for a typeGoAway frame
|
|
-func (s *Session) handleGoAway(hdr header) {
|
|
|
|
-
|
|
|
|
|
|
+func (s *Session) handleGoAway(hdr header) error {
|
|
|
|
+ code := hdr.Length()
|
|
|
|
+ switch code {
|
|
|
|
+ case goAwayNormal:
|
|
|
|
+ s.remoteGoAway = true
|
|
|
|
+ case goAwayProtoErr:
|
|
|
|
+ return fmt.Errorf("yamux protocol error")
|
|
|
|
+ case goAwayInternalErr:
|
|
|
|
+ return fmt.Errorf("remote yamux internal error")
|
|
|
|
+ default:
|
|
|
|
+ return fmt.Errorf("unexpected go away received")
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
|
|
|
|
// exitErr is used to handle an error that is causing
|
|
// exitErr is used to handle an error that is causing
|
|
// the listener to exit.
|
|
// the listener to exit.
|
|
func (s *Session) exitErr(err error) {
|
|
func (s *Session) exitErr(err error) {
|
|
|
|
+ s.shutdownErr = err
|
|
|
|
+ s.Close()
|
|
}
|
|
}
|
|
|
|
|
|
// goAway is used to send a goAway message
|
|
// goAway is used to send a goAway message
|
|
@@ -368,7 +503,49 @@ func (s *Session) goAway(reason uint32) {
|
|
s.sendNoWait(hdr)
|
|
s.sendNoWait(hdr)
|
|
}
|
|
}
|
|
|
|
|
|
-// createStream is used to create a new stream
|
|
|
|
-func (s *Session) createStream(id uint32) {
|
|
|
|
- // TODO
|
|
|
|
|
|
+// incomingStream is used to create a new incoming stream
|
|
|
|
+func (s *Session) incomingStream(id uint32) error {
|
|
|
|
+ // Reject immediately if we are doing a go away
|
|
|
|
+ if s.localGoAway {
|
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
|
+ hdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
|
|
+ s.sendNoWait(hdr)
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.streamLock.Lock()
|
|
|
|
+ defer s.streamLock.Unlock()
|
|
|
|
+
|
|
|
|
+ // Check if stream already exists
|
|
|
|
+ if _, ok := s.streams[id]; ok {
|
|
|
|
+ s.goAway(goAwayProtoErr)
|
|
|
|
+ s.exitErr(ErrDuplicateStream)
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Register the stream
|
|
|
|
+ stream := newStream(s, id, streamSYNReceived)
|
|
|
|
+ s.streams[id] = stream
|
|
|
|
+
|
|
|
|
+ // Check if we've exceeded the backlog
|
|
|
|
+ select {
|
|
|
|
+ case s.acceptCh <- stream:
|
|
|
|
+ return nil
|
|
|
|
+ default:
|
|
|
|
+ // Backlog exceeded! RST the stream
|
|
|
|
+ delete(s.streams, id)
|
|
|
|
+ stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
|
|
+ s.sendNoWait(stream.sendHdr)
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// closeStream is used to close a stream once both sides have
|
|
|
|
+// issued a close.
|
|
|
|
+func (s *Session) closeStream(id uint32, withLock bool) {
|
|
|
|
+ if !withLock {
|
|
|
|
+ s.streamLock.Lock()
|
|
|
|
+ defer s.streamLock.Unlock()
|
|
|
|
+ }
|
|
|
|
+ delete(s.streams, id)
|
|
}
|
|
}
|