|
@@ -0,0 +1,374 @@
|
|
|
+package yamux
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ // ErrInvalidVersion means we received a frame with an
|
|
|
+ // invalid version
|
|
|
+ ErrInvalidVersion = fmt.Errorf("invalid protocol version")
|
|
|
+
|
|
|
+ // ErrInvalidMsgType means we received a frame with an
|
|
|
+ // invalid message type
|
|
|
+ ErrInvalidMsgType = fmt.Errorf("invalid msg type")
|
|
|
+
|
|
|
+ // ErrSessionShutdown is used if there is a shutdown during
|
|
|
+ // an operation
|
|
|
+ ErrSessionShutdown = fmt.Errorf("session shutdown")
|
|
|
+)
|
|
|
+
|
|
|
+// Session is used to wrap a reliable ordered connection and to
|
|
|
+// multiplex it into multiple streams.
|
|
|
+type Session struct {
|
|
|
+ // client is true if we are a client size connection
|
|
|
+ client bool
|
|
|
+
|
|
|
+ // config holds our configuration
|
|
|
+ config *Config
|
|
|
+
|
|
|
+ // conn is the underlying connection
|
|
|
+ conn io.ReadWriteCloser
|
|
|
+
|
|
|
+ // nextStreamID is the next stream we should
|
|
|
+ // send. This depends if we are a client/server.
|
|
|
+ nextStreamID uint32
|
|
|
+
|
|
|
+ // pings is used to track inflight pings
|
|
|
+ pings map[uint32]chan struct{}
|
|
|
+ pingID uint32
|
|
|
+ pingLock sync.Mutex
|
|
|
+
|
|
|
+ // streams maps a stream id to a stream
|
|
|
+ streams map[uint32]*Stream
|
|
|
+
|
|
|
+ // acceptCh is used to pass ready streams to the client
|
|
|
+ acceptCh chan *Stream
|
|
|
+
|
|
|
+ // sendCh is used to mark a stream as ready to send,
|
|
|
+ // or to send a header out directly.
|
|
|
+ sendCh chan sendReady
|
|
|
+
|
|
|
+ // shutdown is used to safely close a session
|
|
|
+ shutdown bool
|
|
|
+ shutdownErr error
|
|
|
+ shutdownCh chan struct{}
|
|
|
+ shutdownLock sync.Mutex
|
|
|
+}
|
|
|
+
|
|
|
+// hasAddr is used to get the address from the underlying connection
|
|
|
+type hasAddr interface {
|
|
|
+ LocalAddr() net.Addr
|
|
|
+ RemoteAddr() net.Addr
|
|
|
+}
|
|
|
+
|
|
|
+// yamuxAddr is used when we cannot get the underlying address
|
|
|
+type yamuxAddr struct {
|
|
|
+ Addr string
|
|
|
+}
|
|
|
+
|
|
|
+func (*yamuxAddr) Network() string {
|
|
|
+ return "yamux"
|
|
|
+}
|
|
|
+
|
|
|
+func (y *yamuxAddr) String() string {
|
|
|
+ return fmt.Sprintf("yamux:%s", y.Addr)
|
|
|
+}
|
|
|
+
|
|
|
+// sendReady is used to either mark a stream as ready
|
|
|
+// or to directly send a header
|
|
|
+type sendReady struct {
|
|
|
+ StreamID uint32
|
|
|
+ Hdr []byte
|
|
|
+ Err chan error
|
|
|
+}
|
|
|
+
|
|
|
+// newSession is used to construct a new session
|
|
|
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
|
+ s := &Session{
|
|
|
+ client: client,
|
|
|
+ config: config,
|
|
|
+ conn: conn,
|
|
|
+ pings: make(map[uint32]chan struct{}),
|
|
|
+ streams: make(map[uint32]*Stream),
|
|
|
+ acceptCh: make(chan *Stream, config.AcceptBacklog),
|
|
|
+ sendCh: make(chan sendReady, 64),
|
|
|
+ shutdownCh: make(chan struct{}),
|
|
|
+ }
|
|
|
+ if client {
|
|
|
+ s.nextStreamID = 1
|
|
|
+ } else {
|
|
|
+ s.nextStreamID = 2
|
|
|
+ }
|
|
|
+ go s.recv()
|
|
|
+ go s.send()
|
|
|
+ if config.EnableKeepAlive {
|
|
|
+ go s.keepalive()
|
|
|
+ }
|
|
|
+ return s
|
|
|
+}
|
|
|
+
|
|
|
+// Open is used to create a new stream
|
|
|
+func (s *Session) Open() (*Stream, error) {
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Accept is used to block until the next available stream
|
|
|
+// is ready to be accepted.
|
|
|
+func (s *Session) Accept() (net.Conn, error) {
|
|
|
+ return s.AcceptStream()
|
|
|
+}
|
|
|
+
|
|
|
+// AcceptStream is used to block until the next available stream
|
|
|
+// is ready to be accepted.
|
|
|
+func (s *Session) AcceptStream() (*Stream, error) {
|
|
|
+ select {
|
|
|
+ case stream := <-s.acceptCh:
|
|
|
+ return stream, nil
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return nil, s.shutdownErr
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Close is used to close the session and all streams.
|
|
|
+// Attempts to send a GoAway before closing the connection.
|
|
|
+func (s *Session) Close() error {
|
|
|
+ s.shutdownLock.Lock()
|
|
|
+ defer s.shutdownLock.Unlock()
|
|
|
+
|
|
|
+ if s.shutdown {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ s.shutdown = true
|
|
|
+ close(s.shutdownCh)
|
|
|
+ s.conn.Close()
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// Addr is used to get the address of the listener.
|
|
|
+func (s *Session) Addr() net.Addr {
|
|
|
+ return s.LocalAddr()
|
|
|
+}
|
|
|
+
|
|
|
+// LocalAddr is used to get the local address of the
|
|
|
+// underlying connection.
|
|
|
+func (s *Session) LocalAddr() net.Addr {
|
|
|
+ addr, ok := s.conn.(hasAddr)
|
|
|
+ if !ok {
|
|
|
+ return &yamuxAddr{"local"}
|
|
|
+ }
|
|
|
+ return addr.LocalAddr()
|
|
|
+}
|
|
|
+
|
|
|
+// RemoteAddr is used to get the address of remote end
|
|
|
+// of the underlying connection
|
|
|
+func (s *Session) RemoteAddr() net.Addr {
|
|
|
+ addr, ok := s.conn.(hasAddr)
|
|
|
+ if !ok {
|
|
|
+ return &yamuxAddr{"remote"}
|
|
|
+ }
|
|
|
+ return addr.RemoteAddr()
|
|
|
+}
|
|
|
+
|
|
|
+// Ping is used to measure the RTT response time
|
|
|
+func (s *Session) Ping() (time.Duration, error) {
|
|
|
+ // Get a channel for the ping
|
|
|
+ ch := make(chan struct{})
|
|
|
+
|
|
|
+ // Get a new ping id, mark as pending
|
|
|
+ s.pingLock.Lock()
|
|
|
+ id := s.pingID
|
|
|
+ s.pingID++
|
|
|
+ s.pings[id] = ch
|
|
|
+ s.pingLock.Unlock()
|
|
|
+
|
|
|
+ // Send the ping request
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
+ hdr.encode(typePing, flagSYN, 0, id)
|
|
|
+ if err := s.waitForSend(hdr); err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait for a response
|
|
|
+ start := time.Now()
|
|
|
+ select {
|
|
|
+ case <-ch:
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return 0, ErrSessionShutdown
|
|
|
+ }
|
|
|
+
|
|
|
+ // Compute the RTT
|
|
|
+ return time.Now().Sub(start), nil
|
|
|
+}
|
|
|
+
|
|
|
+// keepalive is a long running goroutine that periodically does
|
|
|
+// a ping to keep the connection alive.
|
|
|
+func (s *Session) keepalive() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-time.After(s.config.KeepAliveInterval):
|
|
|
+ s.Ping()
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// waitForSend waits to send a header, checking for a potential shutdown
|
|
|
+func (s *Session) waitForSend(hdr header) error {
|
|
|
+ errCh := make(chan error, 1)
|
|
|
+ ready := sendReady{Hdr: hdr, Err: errCh}
|
|
|
+ select {
|
|
|
+ case s.sendCh <- ready:
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return ErrSessionShutdown
|
|
|
+ }
|
|
|
+ select {
|
|
|
+ case err := <-errCh:
|
|
|
+ return err
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return ErrSessionShutdown
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// sendNoWait does a send without waiting
|
|
|
+func (s *Session) sendNoWait(hdr header) error {
|
|
|
+ select {
|
|
|
+ case s.sendCh <- sendReady{Hdr: hdr}:
|
|
|
+ return nil
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return ErrSessionShutdown
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// send is a long running goroutine that sends data
|
|
|
+func (s *Session) send() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case ready := <-s.sendCh:
|
|
|
+ // Send data from a stream if ready
|
|
|
+ if ready.StreamID != 0 {
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ // Send a header if ready
|
|
|
+ if ready.Hdr != nil {
|
|
|
+ sent := 0
|
|
|
+ for sent < len(ready.Hdr) {
|
|
|
+ n, err := s.conn.Write(ready.Hdr[sent:])
|
|
|
+ if err != nil {
|
|
|
+ s.exitErr(err)
|
|
|
+ asyncSendErr(ready.Err, err)
|
|
|
+ }
|
|
|
+ sent += n
|
|
|
+ }
|
|
|
+ }
|
|
|
+ asyncSendErr(ready.Err, nil)
|
|
|
+ case <-s.shutdownCh:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// recv is a long running goroutine that accepts new data
|
|
|
+func (s *Session) recv() {
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
+ for {
|
|
|
+ // Read the header
|
|
|
+ if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
|
|
+ s.exitErr(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Verify the version
|
|
|
+ if hdr.Version() != protoVersion {
|
|
|
+ s.exitErr(ErrInvalidVersion)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Switch on the type
|
|
|
+ msgType := hdr.MsgType()
|
|
|
+ switch msgType {
|
|
|
+ case typeData:
|
|
|
+ s.handleData(hdr)
|
|
|
+ case typeWindowUpdate:
|
|
|
+ s.handleWindowUpdate(hdr)
|
|
|
+ case typePing:
|
|
|
+ s.handlePing(hdr)
|
|
|
+ case typeGoAway:
|
|
|
+ s.handleGoAway(hdr)
|
|
|
+ default:
|
|
|
+ s.exitErr(ErrInvalidMsgType)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handleData is invokde for a typeData frame
|
|
|
+func (s *Session) handleData(hdr header) {
|
|
|
+ flags := hdr.Flags()
|
|
|
+
|
|
|
+ // Check for a new stream creation
|
|
|
+ if flags&flagSYN == flagSYN {
|
|
|
+ s.createStream(hdr.StreamID())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handleWindowUpdate is invokde for a typeWindowUpdate frame
|
|
|
+func (s *Session) handleWindowUpdate(hdr header) {
|
|
|
+ flags := hdr.Flags()
|
|
|
+
|
|
|
+ // Check for a new stream creation
|
|
|
+ if flags&flagSYN == flagSYN {
|
|
|
+ s.createStream(hdr.StreamID())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handlePing is invokde for a typePing frame
|
|
|
+func (s *Session) handlePing(hdr header) {
|
|
|
+ flags := hdr.Flags()
|
|
|
+ pingID := hdr.Length()
|
|
|
+
|
|
|
+ // Check if this is a query, respond back
|
|
|
+ if flags&flagSYN == flagSYN {
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
+ hdr.encode(typePing, flagACK, 0, pingID)
|
|
|
+ s.sendNoWait(hdr)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle a response
|
|
|
+ s.pingLock.Lock()
|
|
|
+ ch := s.pings[pingID]
|
|
|
+ if ch != nil {
|
|
|
+ delete(s.pings, pingID)
|
|
|
+ close(ch)
|
|
|
+ }
|
|
|
+ s.pingLock.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
+// handleGoAway is invokde for a typeGoAway frame
|
|
|
+func (s *Session) handleGoAway(hdr header) {
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+// exitErr is used to handle an error that is causing
|
|
|
+// the listener to exit.
|
|
|
+func (s *Session) exitErr(err error) {
|
|
|
+}
|
|
|
+
|
|
|
+// goAway is used to send a goAway message
|
|
|
+func (s *Session) goAway(reason uint32) {
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
+ hdr.encode(typeGoAway, 0, 0, reason)
|
|
|
+ s.sendNoWait(hdr)
|
|
|
+}
|
|
|
+
|
|
|
+// createStream is used to create a new stream
|
|
|
+func (s *Session) createStream(id uint32) {
|
|
|
+ // TODO
|
|
|
+}
|