Armon Dadgar преди 10 години
родител
ревизия
b5c8b56dbf
променени са 8 файла, в които са добавени 664 реда и са изтрити 11 реда
  1. 8 7
      README.md
  2. 54 4
      const.go
  3. 4 0
      const_test.go
  4. 63 0
      mux.go
  5. 374 0
      session.go
  6. 55 0
      session_test.go
  7. 94 0
      stream.go
  8. 12 0
      util.go

+ 8 - 7
README.md

@@ -17,14 +17,15 @@ Yamux uses a streaming connection underneath, but imposes a message
 framing so that it can be shared between many logical streams. Each
 frame contains a header like:
 
-* Version (4 bits)
-* Type (4 bits)
-* Flags (8 bits)
+* Version (8 bits)
+* Type (8 bits)
+* Flags (16 bits)
 * StreamID (32 bits)
 * Length (32 bits)
 
-This means that each header has a 10 byte overhead. Each field
-is described below:
+This means that each header has a 12 byte overhead.
+All fields are encoded in network order (big endian).
+Each field is described below:
 
 ## Version Field
 
@@ -53,10 +54,10 @@ The flags field is used to provide additional information related
 to the message type. The following flags are supported:
 
 * 0x1 SYN - Signals the start of a new stream. May be sent with a data or
-  window update message.
+  window update message. Also sent with a ping to indicate outbound.
 
 * 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data
-  or window update message.
+  or window update message. Also sent with a ping to indicate response.
 
 * 0x4 FIN - Performs a half-close of a new stream. May be sent with a data
   message or window update.

+ 54 - 4
const.go

@@ -1,8 +1,13 @@
 package yamux
 
+import (
+	"encoding/binary"
+	"fmt"
+)
+
 const (
 	// protoVersion is the only version we support
-	protoVersion = 0
+	protoVersion uint8 = 0
 )
 
 const (
@@ -28,7 +33,7 @@ const (
 const (
 	// SYN is sent to signal a new stream. May
 	// be sent with a data payload
-	flagSYN uint8 = 1 << iota
+	flagSYN uint16 = 1 << iota
 
 	// ACK is sent to acknowledge a new stream. May
 	// be sent with a data payload
@@ -48,10 +53,10 @@ const (
 
 const (
 	// initialSessionWindow is the initial session window size
-	initialSessionWindow = 2 * 1024 * 1024
+	initialSessionWindow uint32 = 2 * 1024 * 1024
 
 	// initialStreamWindow is the initial stream window size
-	initialStreamWindow = 256 * 1024
+	initialStreamWindow uint32 = 256 * 1024
 )
 
 const (
@@ -64,3 +69,48 @@ const (
 	// goAwayInternalErr sent on an internal error
 	goAwayInternalErr
 )
+
+const (
+	sizeOfVersion  = 1
+	sizeOfType     = 1
+	sizeOfFlags    = 2
+	sizeOfStreamID = 4
+	sizeOfLength   = 4
+	headerSize     = sizeOfVersion + sizeOfType + sizeOfFlags +
+		sizeOfStreamID + sizeOfLength
+)
+
+type header []byte
+
+func (h header) Version() uint8 {
+	return h[0]
+}
+
+func (h header) MsgType() uint8 {
+	return h[1]
+}
+
+func (h header) Flags() uint16 {
+	return binary.BigEndian.Uint16(h[2:4])
+}
+
+func (h header) StreamID() uint32 {
+	return binary.BigEndian.Uint32(h[4:8])
+}
+
+func (h header) Length() uint32 {
+	return binary.BigEndian.Uint32(h[8:12])
+}
+
+func (h header) String() string {
+	return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d",
+		h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
+}
+
+func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
+	h[0] = protoVersion
+	h[1] = msgType
+	binary.BigEndian.PutUint16(h[2:4], flags)
+	binary.BigEndian.PutUint32(h[4:8], streamID)
+	binary.BigEndian.PutUint32(h[8:12], length)
+}

+ 4 - 0
const_test.go

@@ -47,4 +47,8 @@ func TestConst(t *testing.T) {
 	if goAwayInternalErr != 2 {
 		t.Fatalf("bad: %v", goAwayInternalErr)
 	}
+
+	if headerSize != 12 {
+		t.Fatalf("bad header size")
+	}
 }

+ 63 - 0
mux.go

@@ -0,0 +1,63 @@
+package yamux
+
+import (
+	"io"
+	"time"
+)
+
+// Config is used to tune the Yamux session
+type Config struct {
+	// AcceptBacklog is used to limit how many streams may be
+	// waiting an accept.
+	AcceptBacklog int
+
+	// EnableCompression is used to control if we compress
+	// outgoing data. We have no control over incoming data.
+	EnableCompression bool
+
+	// EnableKeepalive is used to do a period keep alive
+	// messages using a ping.
+	EnableKeepAlive bool
+
+	// KeepAliveInterval is how often to perform the keep alive
+	KeepAliveInterval time.Duration
+
+	// MaxSessionWindowSize is used to control the maximum
+	// window size that we allow for a session.
+	MaxSessionWindowSize uint32
+
+	// MaxStreamWindowSize is used to control the maximum
+	// window size that we allow for a stream.
+	MaxStreamWindowSize uint32
+}
+
+// DefaultConfig is used to return a default configuration
+func DefaultConfig() *Config {
+	return &Config{
+		AcceptBacklog:        256,
+		EnableCompression:    true,
+		EnableKeepAlive:      true,
+		KeepAliveInterval:    30 * time.Second,
+		MaxSessionWindowSize: initialSessionWindow,
+		MaxStreamWindowSize:  initialStreamWindow,
+	}
+}
+
+// Server is used to initialize a new server-side connection.
+// There must be at most one server-side connection. If a nil config is
+// provided, the DefaultConfiguration will be used.
+func Server(conn io.ReadWriteCloser, config *Config) *Session {
+	if config == nil {
+		config = DefaultConfig()
+	}
+	return newSession(config, conn, false)
+}
+
+// Client is used to initialize a new client-side connection.
+// There must be at most one client-side connection.
+func Client(conn io.ReadWriteCloser, config *Config) *Session {
+	if config == nil {
+		config = DefaultConfig()
+	}
+	return newSession(config, conn, true)
+}

+ 374 - 0
session.go

@@ -0,0 +1,374 @@
+package yamux
+
+import (
+	"fmt"
+	"io"
+	"net"
+	"sync"
+	"time"
+)
+
+var (
+	// ErrInvalidVersion means we received a frame with an
+	// invalid version
+	ErrInvalidVersion = fmt.Errorf("invalid protocol version")
+
+	// ErrInvalidMsgType means we received a frame with an
+	// invalid message type
+	ErrInvalidMsgType = fmt.Errorf("invalid msg type")
+
+	// ErrSessionShutdown is used if there is a shutdown during
+	// an operation
+	ErrSessionShutdown = fmt.Errorf("session shutdown")
+)
+
+// Session is used to wrap a reliable ordered connection and to
+// multiplex it into multiple streams.
+type Session struct {
+	// client is true if we are a client size connection
+	client bool
+
+	// config holds our configuration
+	config *Config
+
+	// conn is the underlying connection
+	conn io.ReadWriteCloser
+
+	// nextStreamID is the next stream we should
+	// send. This depends if we are a client/server.
+	nextStreamID uint32
+
+	// pings is used to track inflight pings
+	pings    map[uint32]chan struct{}
+	pingID   uint32
+	pingLock sync.Mutex
+
+	// streams maps a stream id to a stream
+	streams map[uint32]*Stream
+
+	// acceptCh is used to pass ready streams to the client
+	acceptCh chan *Stream
+
+	// sendCh is used to mark a stream as ready to send,
+	// or to send a header out directly.
+	sendCh chan sendReady
+
+	// shutdown is used to safely close a session
+	shutdown     bool
+	shutdownErr  error
+	shutdownCh   chan struct{}
+	shutdownLock sync.Mutex
+}
+
+// hasAddr is used to get the address from the underlying connection
+type hasAddr interface {
+	LocalAddr() net.Addr
+	RemoteAddr() net.Addr
+}
+
+// yamuxAddr is used when we cannot get the underlying address
+type yamuxAddr struct {
+	Addr string
+}
+
+func (*yamuxAddr) Network() string {
+	return "yamux"
+}
+
+func (y *yamuxAddr) String() string {
+	return fmt.Sprintf("yamux:%s", y.Addr)
+}
+
+// sendReady is used to either mark a stream as ready
+// or to directly send a header
+type sendReady struct {
+	StreamID uint32
+	Hdr      []byte
+	Err      chan error
+}
+
+// newSession is used to construct a new session
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+	s := &Session{
+		client:     client,
+		config:     config,
+		conn:       conn,
+		pings:      make(map[uint32]chan struct{}),
+		streams:    make(map[uint32]*Stream),
+		acceptCh:   make(chan *Stream, config.AcceptBacklog),
+		sendCh:     make(chan sendReady, 64),
+		shutdownCh: make(chan struct{}),
+	}
+	if client {
+		s.nextStreamID = 1
+	} else {
+		s.nextStreamID = 2
+	}
+	go s.recv()
+	go s.send()
+	if config.EnableKeepAlive {
+		go s.keepalive()
+	}
+	return s
+}
+
+// Open is used to create a new stream
+func (s *Session) Open() (*Stream, error) {
+	return nil, nil
+}
+
+// Accept is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) Accept() (net.Conn, error) {
+	return s.AcceptStream()
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+	select {
+	case stream := <-s.acceptCh:
+		return stream, nil
+	case <-s.shutdownCh:
+		return nil, s.shutdownErr
+	}
+}
+
+// Close is used to close the session and all streams.
+// Attempts to send a GoAway before closing the connection.
+func (s *Session) Close() error {
+	s.shutdownLock.Lock()
+	defer s.shutdownLock.Unlock()
+
+	if s.shutdown {
+		return nil
+	}
+	s.shutdown = true
+	close(s.shutdownCh)
+	s.conn.Close()
+	return nil
+}
+
+// Addr is used to get the address of the listener.
+func (s *Session) Addr() net.Addr {
+	return s.LocalAddr()
+}
+
+// LocalAddr is used to get the local address of the
+// underlying connection.
+func (s *Session) LocalAddr() net.Addr {
+	addr, ok := s.conn.(hasAddr)
+	if !ok {
+		return &yamuxAddr{"local"}
+	}
+	return addr.LocalAddr()
+}
+
+// RemoteAddr is used to get the address of remote end
+// of the underlying connection
+func (s *Session) RemoteAddr() net.Addr {
+	addr, ok := s.conn.(hasAddr)
+	if !ok {
+		return &yamuxAddr{"remote"}
+	}
+	return addr.RemoteAddr()
+}
+
+// Ping is used to measure the RTT response time
+func (s *Session) Ping() (time.Duration, error) {
+	// Get a channel for the ping
+	ch := make(chan struct{})
+
+	// Get a new ping id, mark as pending
+	s.pingLock.Lock()
+	id := s.pingID
+	s.pingID++
+	s.pings[id] = ch
+	s.pingLock.Unlock()
+
+	// Send the ping request
+	hdr := header(make([]byte, headerSize))
+	hdr.encode(typePing, flagSYN, 0, id)
+	if err := s.waitForSend(hdr); err != nil {
+		return 0, err
+	}
+
+	// Wait for a response
+	start := time.Now()
+	select {
+	case <-ch:
+	case <-s.shutdownCh:
+		return 0, ErrSessionShutdown
+	}
+
+	// Compute the RTT
+	return time.Now().Sub(start), nil
+}
+
+// keepalive is a long running goroutine that periodically does
+// a ping to keep the connection alive.
+func (s *Session) keepalive() {
+	for {
+		select {
+		case <-time.After(s.config.KeepAliveInterval):
+			s.Ping()
+		case <-s.shutdownCh:
+			return
+		}
+	}
+}
+
+// waitForSend waits to send a header, checking for a potential shutdown
+func (s *Session) waitForSend(hdr header) error {
+	errCh := make(chan error, 1)
+	ready := sendReady{Hdr: hdr, Err: errCh}
+	select {
+	case s.sendCh <- ready:
+	case <-s.shutdownCh:
+		return ErrSessionShutdown
+	}
+	select {
+	case err := <-errCh:
+		return err
+	case <-s.shutdownCh:
+		return ErrSessionShutdown
+	}
+}
+
+// sendNoWait does a send without waiting
+func (s *Session) sendNoWait(hdr header) error {
+	select {
+	case s.sendCh <- sendReady{Hdr: hdr}:
+		return nil
+	case <-s.shutdownCh:
+		return ErrSessionShutdown
+	}
+}
+
+// send is a long running goroutine that sends data
+func (s *Session) send() {
+	for {
+		select {
+		case ready := <-s.sendCh:
+			// Send data from a stream if ready
+			if ready.StreamID != 0 {
+
+			}
+
+			// Send a header if ready
+			if ready.Hdr != nil {
+				sent := 0
+				for sent < len(ready.Hdr) {
+					n, err := s.conn.Write(ready.Hdr[sent:])
+					if err != nil {
+						s.exitErr(err)
+						asyncSendErr(ready.Err, err)
+					}
+					sent += n
+				}
+			}
+			asyncSendErr(ready.Err, nil)
+		case <-s.shutdownCh:
+			return
+		}
+	}
+}
+
+// recv is a long running goroutine that accepts new data
+func (s *Session) recv() {
+	hdr := header(make([]byte, headerSize))
+	for {
+		// Read the header
+		if _, err := io.ReadFull(s.conn, hdr); err != nil {
+			s.exitErr(err)
+			return
+		}
+
+		// Verify the version
+		if hdr.Version() != protoVersion {
+			s.exitErr(ErrInvalidVersion)
+			return
+		}
+
+		// Switch on the type
+		msgType := hdr.MsgType()
+		switch msgType {
+		case typeData:
+			s.handleData(hdr)
+		case typeWindowUpdate:
+			s.handleWindowUpdate(hdr)
+		case typePing:
+			s.handlePing(hdr)
+		case typeGoAway:
+			s.handleGoAway(hdr)
+		default:
+			s.exitErr(ErrInvalidMsgType)
+			return
+		}
+	}
+}
+
+// handleData is invokde for a typeData frame
+func (s *Session) handleData(hdr header) {
+	flags := hdr.Flags()
+
+	// Check for a new stream creation
+	if flags&flagSYN == flagSYN {
+		s.createStream(hdr.StreamID())
+	}
+}
+
+// handleWindowUpdate is invokde for a typeWindowUpdate frame
+func (s *Session) handleWindowUpdate(hdr header) {
+	flags := hdr.Flags()
+
+	// Check for a new stream creation
+	if flags&flagSYN == flagSYN {
+		s.createStream(hdr.StreamID())
+	}
+}
+
+// handlePing is invokde for a typePing frame
+func (s *Session) handlePing(hdr header) {
+	flags := hdr.Flags()
+	pingID := hdr.Length()
+
+	// Check if this is a query, respond back
+	if flags&flagSYN == flagSYN {
+		hdr := header(make([]byte, headerSize))
+		hdr.encode(typePing, flagACK, 0, pingID)
+		s.sendNoWait(hdr)
+		return
+	}
+
+	// Handle a response
+	s.pingLock.Lock()
+	ch := s.pings[pingID]
+	if ch != nil {
+		delete(s.pings, pingID)
+		close(ch)
+	}
+	s.pingLock.Unlock()
+}
+
+// handleGoAway is invokde for a typeGoAway frame
+func (s *Session) handleGoAway(hdr header) {
+
+}
+
+// exitErr is used to handle an error that is causing
+// the listener to exit.
+func (s *Session) exitErr(err error) {
+}
+
+// goAway is used to send a goAway message
+func (s *Session) goAway(reason uint32) {
+	hdr := header(make([]byte, headerSize))
+	hdr.encode(typeGoAway, 0, 0, reason)
+	s.sendNoWait(hdr)
+}
+
+// createStream is used to create a new stream
+func (s *Session) createStream(id uint32) {
+	// TODO
+}

+ 55 - 0
session_test.go

@@ -0,0 +1,55 @@
+package yamux
+
+import (
+	"io"
+	"testing"
+)
+
+type pipeConn struct {
+	reader *io.PipeReader
+	writer *io.PipeWriter
+}
+
+func (p *pipeConn) Read(b []byte) (int, error) {
+	return p.reader.Read(b)
+}
+
+func (p *pipeConn) Write(b []byte) (int, error) {
+	return p.writer.Write(b)
+}
+
+func (p *pipeConn) Close() error {
+	p.reader.Close()
+	return p.writer.Close()
+}
+
+func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
+	read1, write1 := io.Pipe()
+	read2, write2 := io.Pipe()
+	return &pipeConn{read1, write2}, &pipeConn{read2, write1}
+}
+
+func TestPing(t *testing.T) {
+	conn1, conn2 := testConn()
+	client := Client(conn1, nil)
+	defer client.Close()
+
+	server := Server(conn2, nil)
+	defer server.Close()
+
+	rtt, err := client.Ping()
+	if err != nil {
+		t.Fatalf("err: %v", err)
+	}
+	if rtt == 0 {
+		t.Fatalf("bad: %v", rtt)
+	}
+
+	rtt, err = server.Ping()
+	if err != nil {
+		t.Fatalf("err: %v", err)
+	}
+	if rtt == 0 {
+		t.Fatalf("bad: %v", rtt)
+	}
+}

+ 94 - 0
stream.go

@@ -0,0 +1,94 @@
+package yamux
+
+import (
+	"bytes"
+	"net"
+	"time"
+)
+
+type streamState int
+
+const (
+	streamSYNSent streamState = iota
+	streamSYNReceived
+	streamEstablished
+	streamLocalClose
+	streamRemoteClose
+	streamClosed
+)
+
+// Stream is used to represent a logical stream
+// within a session.
+type Stream struct {
+	id      uint32
+	session *Session
+
+	state streamState
+
+	recvBuf    bytes.Buffer
+	recvWindow uint32
+
+	sendBuf    bytes.Buffer
+	sendWindow uint32
+
+	readDeadline  time.Time
+	writeDeadline time.Time
+}
+
+// Session returns the associated stream session
+func (s *Stream) Session() *Session {
+	return s.session
+}
+
+// StreamID returns the ID of this stream
+func (s *Stream) StreamID() uint32 {
+	return s.id
+}
+
+// Read is used to read from the stream
+func (s *Stream) Read(b []byte) (int, error) {
+	return 0, nil
+}
+
+// Write is used to write to the stream
+func (s *Stream) Write(b []byte) (int, error) {
+	return 0, nil
+}
+
+// Close is used to close the stream
+func (s *Stream) Close() error {
+	return nil
+}
+
+// LocalAddr returns the local address
+func (s *Stream) LocalAddr() net.Addr {
+	return s.session.LocalAddr()
+}
+
+// LocalAddr returns the remote address
+func (s *Stream) RemoteAddr() net.Addr {
+	return s.session.RemoteAddr()
+}
+
+// SetDeadline sets the read and write deadlines
+func (s *Stream) SetDeadline(t time.Time) error {
+	if err := s.SetReadDeadline(t); err != nil {
+		return err
+	}
+	if err := s.SetWriteDeadline(t); err != nil {
+		return err
+	}
+	return nil
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+	s.readDeadline = t
+	return nil
+}
+
+// SetWriteDeadline sets the deadline for future Write calls
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+	s.writeDeadline = t
+	return nil
+}

+ 12 - 0
util.go

@@ -0,0 +1,12 @@
+package yamux
+
+// asyncSendErr is used to try an async send of an error
+func asyncSendErr(ch chan error, err error) {
+	if ch == nil {
+		return
+	}
+	select {
+	case ch <- err:
+	default:
+	}
+}