Bläddra i källkod

Adding more tests

Armon Dadgar 10 år sedan
förälder
incheckning
bb02fb9068
3 ändrade filer med 753 tillägg och 50 borttagningar
  1. 219 42
      session.go
  2. 229 0
      session_test.go
  3. 305 8
      stream.go

+ 219 - 42
session.go

@@ -3,6 +3,7 @@ package yamux
 import (
 	"fmt"
 	"io"
+	"math"
 	"net"
 	"sync"
 	"time"
@@ -20,6 +21,33 @@ var (
 	// ErrSessionShutdown is used if there is a shutdown during
 	// an operation
 	ErrSessionShutdown = fmt.Errorf("session shutdown")
+
+	// ErrStreamsExhausted is returned if we have no more
+	// stream ids to issue
+	ErrStreamsExhausted = fmt.Errorf("streams exhausted")
+
+	// ErrDuplicateStream is used if a duplicate stream is
+	// opened inbound
+	ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
+
+	// ErrMissingStream indicates a stream was named which
+	// does not exist.
+	ErrMissingStream = fmt.Errorf("missing stream references")
+
+	// ErrReceiveWindowExceeded indicates the window was exceeded
+	ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
+
+	// ErrTimeout is used when we reach an IO deadline
+	ErrTimeout = fmt.Errorf("i/o deadline reached")
+
+	// ErrStreamClosed is returned when using a closed stream
+	ErrStreamClosed = fmt.Errorf("stream closed")
+
+	// ErrUnexpectedFlag is set when we get an unexpected flag
+	ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
+
+	// ErrRemoteGoAway is used when we get a go away from the other side
+	ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
 )
 
 // Session is used to wrap a reliable ordered connection and to
@@ -34,17 +62,26 @@ type Session struct {
 	// conn is the underlying connection
 	conn io.ReadWriteCloser
 
-	// nextStreamID is the next stream we should
-	// send. This depends if we are a client/server.
-	nextStreamID uint32
-
 	// pings is used to track inflight pings
 	pings    map[uint32]chan struct{}
 	pingID   uint32
 	pingLock sync.Mutex
 
+	// remoteGoAway indicates the remote side does
+	// not want futher connections
+	remoteGoAway bool
+
+	// localGoAway indicates that we should stop
+	// accepting futher connections
+	localGoAway bool
+
+	// nextStreamID is the next stream we should
+	// send. This depends if we are a client/server.
+	nextStreamID uint32
+
 	// streams maps a stream id to a stream
-	streams map[uint32]*Stream
+	streams    map[uint32]*Stream
+	streamLock sync.RWMutex
 
 	// acceptCh is used to pass ready streams to the client
 	acceptCh chan *Stream
@@ -82,9 +119,9 @@ func (y *yamuxAddr) String() string {
 // sendReady is used to either mark a stream as ready
 // or to directly send a header
 type sendReady struct {
-	StreamID uint32
-	Hdr      []byte
-	Err      chan error
+	Hdr  []byte
+	Body io.Reader
+	Err  chan error
 }
 
 // newSession is used to construct a new session
@@ -112,9 +149,41 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 	return s
 }
 
+// isShutdown does a safe check to see if we have shutdown
+func (s *Session) isShutdown() bool {
+	select {
+	case <-s.shutdownCh:
+		return true
+	default:
+		return false
+	}
+}
+
 // Open is used to create a new stream
 func (s *Session) Open() (*Stream, error) {
-	return nil, nil
+	if s.isShutdown() {
+		return nil, ErrSessionShutdown
+	}
+	if s.remoteGoAway {
+		return nil, ErrRemoteGoAway
+	}
+
+	s.streamLock.Lock()
+	defer s.streamLock.Unlock()
+
+	// Check if we've exhaused the streams
+	id := s.nextStreamID
+	if id >= math.MaxUint32-1 {
+		return nil, ErrStreamsExhausted
+	}
+	s.nextStreamID += 2
+
+	// Register the stream
+	stream := newStream(s, id, streamInit)
+	s.streams[id] = stream
+
+	// Send the window update to create
+	return stream, stream.sendWindowUpdate()
 }
 
 // Accept is used to block until the next available stream
@@ -144,8 +213,25 @@ func (s *Session) Close() error {
 		return nil
 	}
 	s.shutdown = true
+	if s.shutdownErr == nil {
+		s.shutdownErr = ErrSessionShutdown
+	}
 	close(s.shutdownCh)
 	s.conn.Close()
+
+	s.streamLock.Lock()
+	defer s.streamLock.Unlock()
+	for _, stream := range s.streams {
+		stream.forceClose()
+	}
+	return nil
+}
+
+// GoAway can be used to prevent accepting further
+// connections. It does not close the underlying conn.
+func (s *Session) GoAway() error {
+	s.localGoAway = true
+	s.goAway(goAwayNormal)
 	return nil
 }
 
@@ -189,7 +275,7 @@ func (s *Session) Ping() (time.Duration, error) {
 	// Send the ping request
 	hdr := header(make([]byte, headerSize))
 	hdr.encode(typePing, flagSYN, 0, id)
-	if err := s.waitForSend(hdr); err != nil {
+	if err := s.waitForSend(hdr, nil); err != nil {
 		return 0, err
 	}
 
@@ -219,9 +305,9 @@ func (s *Session) keepalive() {
 }
 
 // waitForSend waits to send a header, checking for a potential shutdown
-func (s *Session) waitForSend(hdr header) error {
+func (s *Session) waitForSend(hdr header, body io.Reader) error {
 	errCh := make(chan error, 1)
-	ready := sendReady{Hdr: hdr, Err: errCh}
+	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
 	select {
 	case s.sendCh <- ready:
 	case <-s.shutdownCh:
@@ -250,11 +336,6 @@ func (s *Session) send() {
 	for {
 		select {
 		case ready := <-s.sendCh:
-			// Send data from a stream if ready
-			if ready.StreamID != 0 {
-
-			}
-
 			// Send a header if ready
 			if ready.Hdr != nil {
 				sent := 0
@@ -263,10 +344,23 @@ func (s *Session) send() {
 					if err != nil {
 						s.exitErr(err)
 						asyncSendErr(ready.Err, err)
+						return
 					}
 					sent += n
 				}
 			}
+
+			// Send data from a body if given
+			if ready.Body != nil {
+				_, err := io.Copy(s.conn, ready.Body)
+				if err != nil {
+					s.exitErr(err)
+					asyncSendErr(ready.Err, err)
+					return
+				}
+			}
+
+			// No error, successful send
 			asyncSendErr(ready.Err, nil)
 		case <-s.shutdownCh:
 			return
@@ -277,7 +371,7 @@ func (s *Session) send() {
 // recv is a long running goroutine that accepts new data
 func (s *Session) recv() {
 	hdr := header(make([]byte, headerSize))
-	for {
+	for !s.isShutdown() {
 		// Read the header
 		if _, err := io.ReadFull(s.conn, hdr); err != nil {
 			s.exitErr(err)
@@ -294,13 +388,22 @@ func (s *Session) recv() {
 		msgType := hdr.MsgType()
 		switch msgType {
 		case typeData:
-			s.handleData(hdr)
+			fallthrough
 		case typeWindowUpdate:
-			s.handleWindowUpdate(hdr)
-		case typePing:
-			s.handlePing(hdr)
+			if err := s.handleStreamMessage(hdr); err != nil {
+				s.exitErr(err)
+				return
+			}
 		case typeGoAway:
-			s.handleGoAway(hdr)
+			if err := s.handleGoAway(hdr); err != nil {
+				s.exitErr(err)
+				return
+			}
+		case typePing:
+			if err := s.handlePing(hdr); err != nil {
+				s.exitErr(err)
+				return
+			}
 		default:
 			s.exitErr(ErrInvalidMsgType)
 			return
@@ -308,28 +411,46 @@ func (s *Session) recv() {
 	}
 }
 
-// handleData is invokde for a typeData frame
-func (s *Session) handleData(hdr header) {
-	flags := hdr.Flags()
-
+// handleStreamMessage handles either a data or window update frame
+func (s *Session) handleStreamMessage(hdr header) error {
 	// Check for a new stream creation
+	id := hdr.StreamID()
+	flags := hdr.Flags()
 	if flags&flagSYN == flagSYN {
-		s.createStream(hdr.StreamID())
+		if err := s.incomingStream(id); err != nil {
+			return err
+		}
 	}
-}
 
-// handleWindowUpdate is invokde for a typeWindowUpdate frame
-func (s *Session) handleWindowUpdate(hdr header) {
-	flags := hdr.Flags()
+	// Get the stream
+	s.streamLock.RLock()
+	stream := s.streams[id]
+	s.streamLock.RUnlock()
 
-	// Check for a new stream creation
-	if flags&flagSYN == flagSYN {
-		s.createStream(hdr.StreamID())
+	// Make sure we have a stream
+	if stream == nil {
+		s.goAway(goAwayProtoErr)
+		return ErrMissingStream
+	}
+
+	// Check if this is a window update
+	if hdr.MsgType() == typeWindowUpdate {
+		if err := stream.incrSendWindow(hdr, flags); err != nil {
+			s.goAway(goAwayProtoErr)
+			return err
+		}
 	}
+
+	// Read the new data
+	if err := stream.readData(hdr, flags, s.conn); err != nil {
+		s.goAway(goAwayProtoErr)
+		return err
+	}
+	return nil
 }
 
 // handlePing is invokde for a typePing frame
-func (s *Session) handlePing(hdr header) {
+func (s *Session) handlePing(hdr header) error {
 	flags := hdr.Flags()
 	pingID := hdr.Length()
 
@@ -338,7 +459,7 @@ func (s *Session) handlePing(hdr header) {
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typePing, flagACK, 0, pingID)
 		s.sendNoWait(hdr)
-		return
+		return nil
 	}
 
 	// Handle a response
@@ -349,16 +470,30 @@ func (s *Session) handlePing(hdr header) {
 		close(ch)
 	}
 	s.pingLock.Unlock()
+	return nil
 }
 
 // handleGoAway is invokde for a typeGoAway frame
-func (s *Session) handleGoAway(hdr header) {
-
+func (s *Session) handleGoAway(hdr header) error {
+	code := hdr.Length()
+	switch code {
+	case goAwayNormal:
+		s.remoteGoAway = true
+	case goAwayProtoErr:
+		return fmt.Errorf("yamux protocol error")
+	case goAwayInternalErr:
+		return fmt.Errorf("remote yamux internal error")
+	default:
+		return fmt.Errorf("unexpected go away received")
+	}
+	return nil
 }
 
 // exitErr is used to handle an error that is causing
 // the listener to exit.
 func (s *Session) exitErr(err error) {
+	s.shutdownErr = err
+	s.Close()
 }
 
 // goAway is used to send a goAway message
@@ -368,7 +503,49 @@ func (s *Session) goAway(reason uint32) {
 	s.sendNoWait(hdr)
 }
 
-// createStream is used to create a new stream
-func (s *Session) createStream(id uint32) {
-	// TODO
+// incomingStream is used to create a new incoming stream
+func (s *Session) incomingStream(id uint32) error {
+	// Reject immediately if we are doing a go away
+	if s.localGoAway {
+		hdr := header(make([]byte, headerSize))
+		hdr.encode(typeWindowUpdate, flagRST, id, 0)
+		s.sendNoWait(hdr)
+		return nil
+	}
+
+	s.streamLock.Lock()
+	defer s.streamLock.Unlock()
+
+	// Check if stream already exists
+	if _, ok := s.streams[id]; ok {
+		s.goAway(goAwayProtoErr)
+		s.exitErr(ErrDuplicateStream)
+		return nil
+	}
+
+	// Register the stream
+	stream := newStream(s, id, streamSYNReceived)
+	s.streams[id] = stream
+
+	// Check if we've exceeded the backlog
+	select {
+	case s.acceptCh <- stream:
+		return nil
+	default:
+		// Backlog exceeded! RST the stream
+		delete(s.streams, id)
+		stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
+		s.sendNoWait(stream.sendHdr)
+	}
+	return nil
+}
+
+// closeStream is used to close a stream once both sides have
+// issued a close.
+func (s *Session) closeStream(id uint32, withLock bool) {
+	if !withLock {
+		s.streamLock.Lock()
+		defer s.streamLock.Unlock()
+	}
+	delete(s.streams, id)
 }

+ 229 - 0
session_test.go

@@ -2,7 +2,9 @@ package yamux
 
 import (
 	"io"
+	"sync"
 	"testing"
+	"time"
 )
 
 type pipeConn struct {
@@ -53,3 +55,230 @@ func TestPing(t *testing.T) {
 		t.Fatalf("bad: %v", rtt)
 	}
 }
+
+func TestAccept(t *testing.T) {
+	conn1, conn2 := testConn()
+	client := Client(conn1, nil)
+	defer client.Close()
+
+	server := Server(conn2, nil)
+	defer server.Close()
+
+	wg := &sync.WaitGroup{}
+	wg.Add(4)
+
+	go func() {
+		defer wg.Done()
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if id := stream.StreamID(); id != 1 {
+			t.Fatalf("bad: %v", id)
+		}
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := client.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if id := stream.StreamID(); id != 2 {
+			t.Fatalf("bad: %v", id)
+		}
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := server.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if id := stream.StreamID(); id != 2 {
+			t.Fatalf("bad: %v", id)
+		}
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if id := stream.StreamID(); id != 1 {
+			t.Fatalf("bad: %v", id)
+		}
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	doneCh := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(doneCh)
+	}()
+
+	select {
+	case <-doneCh:
+	case <-time.After(time.Second):
+		panic("timeout")
+	}
+}
+
+func TestSendData_Small(t *testing.T) {
+	conn1, conn2 := testConn()
+	client := Client(conn1, nil)
+	defer client.Close()
+
+	server := Server(conn2, nil)
+	defer server.Close()
+
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	go func() {
+		defer wg.Done()
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+
+		buf := make([]byte, 4)
+		for i := 0; i < 1000; i++ {
+			n, err := stream.Read(buf)
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			if n != 4 {
+				t.Fatalf("short read: %d", n)
+			}
+			if string(buf) != "test" {
+				t.Fatalf("bad: %s", buf)
+			}
+		}
+
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+
+		for i := 0; i < 1000; i++ {
+			n, err := stream.Write([]byte("test"))
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			if n != 4 {
+				t.Fatalf("short write %d", n)
+			}
+		}
+
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	doneCh := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(doneCh)
+	}()
+	select {
+	case <-doneCh:
+	case <-time.After(time.Second):
+		panic("timeout")
+	}
+}
+
+func TestSendData_Large(t *testing.T) {
+	conn1, conn2 := testConn()
+	client := Client(conn1, nil)
+	defer client.Close()
+
+	server := Server(conn2, nil)
+	defer server.Close()
+
+	data := make([]byte, 512*1024)
+	for idx := range data {
+		data[idx] = byte(idx % 256)
+	}
+
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	go func() {
+		defer wg.Done()
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+
+		buf := make([]byte, 4*1024)
+		for i := 0; i < 128; i++ {
+			n, err := stream.Read(buf)
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			if n != 4*1024 {
+				t.Fatalf("short read: %d", n)
+			}
+			for idx := range buf {
+				if buf[idx] != byte(idx%256) {
+					t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
+				}
+			}
+		}
+
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+
+		n, err := stream.Write(data)
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if n != len(data) {
+			t.Fatalf("short write %d", n)
+		}
+
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	doneCh := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(doneCh)
+	}()
+	select {
+	case <-doneCh:
+	case <-time.After(time.Second):
+		panic("timeout")
+	}
+}

+ 305 - 8
stream.go

@@ -2,14 +2,19 @@ package yamux
 
 import (
 	"bytes"
+	"compress/lzw"
+	"io"
+	"log"
 	"net"
+	"sync"
 	"time"
 )
 
 type streamState int
 
 const (
-	streamSYNSent streamState = iota
+	streamInit streamState = iota
+	streamSYNSent
 	streamSYNReceived
 	streamEstablished
 	streamLocalClose
@@ -24,17 +29,35 @@ type Stream struct {
 	session *Session
 
 	state streamState
+	lock  sync.Mutex
 
-	recvBuf    bytes.Buffer
-	recvWindow uint32
+	recvBuf bytes.Buffer
+	sendHdr header
 
-	sendBuf    bytes.Buffer
+	recvWindow uint32
 	sendWindow uint32
 
+	notifyCh chan struct{}
+
 	readDeadline  time.Time
 	writeDeadline time.Time
 }
 
+// newStream is used to construct a new stream within
+// a given session for an ID
+func newStream(session *Session, id uint32, state streamState) *Stream {
+	s := &Stream{
+		id:         id,
+		session:    session,
+		state:      state,
+		recvWindow: initialStreamWindow,
+		sendWindow: initialStreamWindow,
+		notifyCh:   make(chan struct{}, 1),
+		sendHdr:    header(make([]byte, headerSize)),
+	}
+	return s
+}
+
 // Session returns the associated stream session
 func (s *Stream) Session() *Session {
 	return s.session
@@ -46,18 +69,212 @@ func (s *Stream) StreamID() uint32 {
 }
 
 // Read is used to read from the stream
-func (s *Stream) Read(b []byte) (int, error) {
-	return 0, nil
+func (s *Stream) Read(b []byte) (n int, err error) {
+START:
+	s.lock.Lock()
+	switch s.state {
+	case streamRemoteClose:
+		fallthrough
+	case streamClosed:
+		if s.recvBuf.Len() == 0 {
+			s.lock.Unlock()
+			return 0, io.EOF
+		}
+	}
+
+	// If there is no data available, block
+	if s.recvBuf.Len() == 0 {
+		s.lock.Unlock()
+		goto WAIT
+	}
+
+	// Read any bytes
+	n, _ = s.recvBuf.Read(b)
+
+	// Send a window update potentially
+	err = s.sendWindowUpdate()
+	s.lock.Unlock()
+	return n, err
+
+WAIT:
+	var timeout <-chan time.Time
+	if !s.readDeadline.IsZero() {
+		delay := s.readDeadline.Sub(time.Now())
+		timeout = time.After(delay)
+	}
+	select {
+	case <-s.notifyCh:
+		goto START
+	case <-timeout:
+		return 0, ErrTimeout
+	}
 }
 
 // Write is used to write to the stream
-func (s *Stream) Write(b []byte) (int, error) {
+func (s *Stream) Write(b []byte) (n int, err error) {
+	total := 0
+	for total < len(b) {
+		n, err := s.write(b[total:])
+		total += n
+		if err != nil {
+			return total, err
+		}
+	}
+	return total, nil
+}
+
+// write is used to write to the stream, may return on
+// a short write.
+func (s *Stream) write(b []byte) (n int, err error) {
+	var flags uint16
+	var max uint32
+	var body io.Reader
+START:
+	s.lock.Lock()
+	switch s.state {
+	case streamLocalClose:
+		fallthrough
+	case streamClosed:
+		s.lock.Unlock()
+		return 0, ErrStreamClosed
+	}
+
+	// If there is no data available, block
+	if s.sendWindow == 0 {
+		s.lock.Unlock()
+		goto WAIT
+	}
+
+	// Determine the flags if any
+	flags = s.sendFlags()
+
+	// Send up to our send window
+	max = min(s.sendWindow, uint32(len(b)))
+	body = bytes.NewReader(b[:max])
+
+	// TODO: Compress
+
+	// Send the header
+	s.sendHdr.encode(typeData, flags, s.id, max)
+	if err := s.session.waitForSend(s.sendHdr, body); err != nil {
+		s.lock.Unlock()
+		return 0, err
+	}
+
+	// Reduce our send window
+	s.sendWindow -= max
+
+	// Unlock
+	s.lock.Unlock()
+	return int(max), err
+
+WAIT:
+	var timeout <-chan time.Time
+	if !s.writeDeadline.IsZero() {
+		delay := s.writeDeadline.Sub(time.Now())
+		timeout = time.After(delay)
+	}
+	select {
+	case <-s.notifyCh:
+		goto START
+	case <-timeout:
+		return 0, ErrTimeout
+	}
 	return 0, nil
 }
 
+// sendFlags determines any flags that are appropriate
+// based on the current stream state
+func (s *Stream) sendFlags() uint16 {
+	// Determine the flags if any
+	var flags uint16
+	switch s.state {
+	case streamInit:
+		flags |= flagSYN
+		s.state = streamSYNSent
+	case streamSYNReceived:
+		flags |= flagACK
+		s.state = streamEstablished
+	}
+	return flags
+}
+
+// sendWindowUpdate potentially sends a window update enabling
+// further writes to take place. Must be invoked with the lock.
+func (s *Stream) sendWindowUpdate() error {
+	// Determine the delta update
+	max := s.session.config.MaxStreamWindowSize
+	delta := max - s.recvWindow
+
+	// Determine the flags if any
+	flags := s.sendFlags()
+
+	// Check if we can omit the update
+	if delta < (max/2) && flags == 0 {
+		return nil
+	}
+
+	// Send the header
+	s.sendHdr.encode(typeWindowUpdate, flags, s.id, delta)
+	if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
+		return err
+	}
+	log.Printf("Window Update %d +%d", s.id, delta)
+
+	// Update our window
+	s.recvWindow += delta
+	return nil
+}
+
+// sendClose is used to send a FIN
+func (s *Stream) sendClose() error {
+	flags := s.sendFlags()
+	flags |= flagFIN
+	s.sendHdr.encode(typeWindowUpdate, flags, s.id, 0)
+	if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
+		return err
+	}
+	return nil
+}
+
 // Close is used to close the stream
 func (s *Stream) Close() error {
-	return nil
+	s.lock.Lock()
+	defer s.lock.Unlock()
+
+	switch s.state {
+	// Local or full close means nothing to do
+	case streamLocalClose:
+		fallthrough
+	case streamClosed:
+		return nil
+
+	// Remote close, weneed to send FIN and we are done
+	case streamRemoteClose:
+		s.state = streamClosed
+		s.session.closeStream(s.id, false)
+		s.sendClose()
+		return nil
+
+	// Opened means we need to signal a close
+	case streamSYNSent:
+		fallthrough
+	case streamSYNReceived:
+		fallthrough
+	case streamEstablished:
+		s.state = streamLocalClose
+		s.sendClose()
+		return nil
+	}
+	panic("unhandled state")
+}
+
+// forceClose is used for when the session is exiting
+func (s *Stream) forceClose() {
+	s.lock.Lock()
+	defer s.lock.Unlock()
+	s.state = streamClosed
+	asyncNotify(s.notifyCh)
 }
 
 // LocalAddr returns the local address
@@ -92,3 +309,83 @@ func (s *Stream) SetWriteDeadline(t time.Time) error {
 	s.writeDeadline = t
 	return nil
 }
+
+// processFlags is used to update the state of the stream
+// based on set flags, if any. Lock must be held
+func (s *Stream) processFlags(flags uint16) error {
+	if flags&flagACK == flagACK {
+		if s.state == streamSYNSent {
+			s.state = streamEstablished
+		}
+
+	} else if flags&flagFIN == flagFIN {
+		switch s.state {
+		case streamSYNSent:
+			fallthrough
+		case streamSYNReceived:
+			fallthrough
+		case streamEstablished:
+			s.state = streamRemoteClose
+		case streamLocalClose:
+			s.state = streamClosed
+			s.session.closeStream(s.id, true)
+		default:
+			return ErrUnexpectedFlag
+		}
+	} else if flags&flagRST == flagRST {
+		s.state = streamClosed
+		s.session.closeStream(s.id, true)
+	}
+	return nil
+}
+
+// incrSendWindow updates the size of our send window
+func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
+	s.lock.Lock()
+	defer s.lock.Unlock()
+	if err := s.processFlags(flags); err != nil {
+		return err
+	}
+
+	// Increase window, unblock a sender
+	s.sendWindow += hdr.Length()
+	asyncNotify(s.notifyCh)
+	return nil
+}
+
+// readData is used to handle a data frame
+func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
+	s.lock.Lock()
+	defer s.lock.Unlock()
+	if err := s.processFlags(flags); err != nil {
+		return err
+	}
+
+	// Check that our recv window is not exceeded
+	length := hdr.Length()
+	if length > s.recvWindow {
+		return ErrRecvWindowExceeded
+	}
+
+	// Decrement the receive window
+	s.recvWindow -= length
+
+	// Wrap in a limited reader
+	conn = &io.LimitedReader{R: conn, N: int64(length)}
+
+	// Handle potential data compression
+	if flags&flagLZW == flagLZW {
+		cr := lzw.NewReader(conn, lzw.MSB, 8)
+		defer cr.Close()
+		conn = cr
+	}
+
+	// Copy to our buffer
+	if _, err := io.Copy(&s.recvBuf, conn); err != nil {
+		return err
+	}
+
+	// Unblock any readers
+	asyncNotify(s.notifyCh)
+	return nil
+}