Explorar o código

Reducing lock granularity

Armon Dadgar %!s(int64=10) %!d(string=hai) anos
pai
achega
6bd753dec7
Modificáronse 2 ficheiros con 120 adicións e 46 borrados
  1. 61 0
      session_test.go
  2. 59 46
      stream.go

+ 61 - 0
session_test.go

@@ -1,6 +1,7 @@
 package yamux
 
 import (
+	"fmt"
 	"io"
 	"sync"
 	"testing"
@@ -292,3 +293,63 @@ func TestGoAway(t *testing.T) {
 		t.Fatalf("err: %v", err)
 	}
 }
+
+func TestManyStreams(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	wg := &sync.WaitGroup{}
+
+	acceptor := func(i int) {
+		defer wg.Done()
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		buf := make([]byte, 512)
+		for {
+			n, err := stream.Read(buf)
+			println("read")
+			if err == io.EOF {
+				return
+			}
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			if n == 0 {
+				t.Fatalf("err: %v", err)
+			}
+		}
+	}
+	sender := func(i int) {
+		defer wg.Done()
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		msg := fmt.Sprintf("%08d", i)
+		for i := 0; i < 1000; i++ {
+			n, err := stream.Write([]byte(msg))
+			println("write")
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			if n != len(msg) {
+				t.Fatalf("short write %d", n)
+			}
+		}
+	}
+
+	for i := 0; i < 50; i++ {
+		wg.Add(2)
+		go acceptor(i)
+		go sender(i)
+	}
+
+	wg.Wait()
+}

+ 59 - 46
stream.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"io"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -22,17 +23,20 @@ const (
 // Stream is used to represent a logical stream
 // within a session.
 type Stream struct {
+	recvWindow uint32
+	sendWindow uint32
+
 	id      uint32
 	session *Session
 
-	state streamState
-	lock  sync.Mutex
+	state     streamState
+	stateLock sync.Mutex
 
-	recvBuf bytes.Buffer
-	sendHdr header
+	recvBuf  bytes.Buffer
+	recvLock sync.Mutex
 
-	recvWindow uint32
-	sendWindow uint32
+	sendHdr  header
+	sendLock sync.Mutex
 
 	notifyCh chan struct{}
 
@@ -68,29 +72,31 @@ func (s *Stream) StreamID() uint32 {
 // Read is used to read from the stream
 func (s *Stream) Read(b []byte) (n int, err error) {
 START:
-	s.lock.Lock()
+	s.stateLock.Lock()
 	switch s.state {
 	case streamRemoteClose:
 		fallthrough
 	case streamClosed:
 		if s.recvBuf.Len() == 0 {
-			s.lock.Unlock()
+			s.stateLock.Unlock()
 			return 0, io.EOF
 		}
 	}
+	s.stateLock.Unlock()
 
 	// If there is no data available, block
+	s.recvLock.Lock()
 	if s.recvBuf.Len() == 0 {
-		s.lock.Unlock()
+		s.recvLock.Unlock()
 		goto WAIT
 	}
 
 	// Read any bytes
 	n, _ = s.recvBuf.Read(b)
+	s.recvLock.Unlock()
 
 	// Send a window update potentially
 	err = s.sendWindowUpdate()
-	s.lock.Unlock()
 	return n, err
 
 WAIT:
@@ -127,18 +133,22 @@ func (s *Stream) write(b []byte) (n int, err error) {
 	var max uint32
 	var body io.Reader
 START:
-	s.lock.Lock()
+	s.stateLock.Lock()
 	switch s.state {
 	case streamLocalClose:
 		fallthrough
 	case streamClosed:
-		s.lock.Unlock()
+		s.stateLock.Unlock()
 		return 0, ErrStreamClosed
 	}
+	s.stateLock.Unlock()
+
+	// Lock the send
+	s.sendLock.Lock()
 
 	// If there is no data available, block
-	if s.sendWindow == 0 {
-		s.lock.Unlock()
+	if atomic.LoadUint32(&s.sendWindow) == 0 {
+		s.sendLock.Unlock()
 		goto WAIT
 	}
 
@@ -152,15 +162,15 @@ START:
 	// Send the header
 	s.sendHdr.encode(typeData, flags, s.id, max)
 	if err := s.session.waitForSend(s.sendHdr, body); err != nil {
-		s.lock.Unlock()
+		s.sendLock.Unlock()
 		return 0, err
 	}
 
 	// Reduce our send window
-	s.sendWindow -= max
+	atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
+	s.sendLock.Unlock()
 
 	// Unlock
-	s.lock.Unlock()
 	return int(max), err
 
 WAIT:
@@ -181,7 +191,8 @@ WAIT:
 // sendFlags determines any flags that are appropriate
 // based on the current stream state
 func (s *Stream) sendFlags() uint16 {
-	// Determine the flags if any
+	s.stateLock.Lock()
+	defer s.stateLock.Unlock()
 	var flags uint16
 	switch s.state {
 	case streamInit:
@@ -233,23 +244,8 @@ func (s *Stream) sendClose() error {
 
 // Close is used to close the stream
 func (s *Stream) Close() error {
-	s.lock.Lock()
-	defer s.lock.Unlock()
-
+	s.stateLock.Lock()
 	switch s.state {
-	// Local or full close means nothing to do
-	case streamLocalClose:
-		fallthrough
-	case streamClosed:
-		return nil
-
-	// Remote close, weneed to send FIN and we are done
-	case streamRemoteClose:
-		s.state = streamClosed
-		s.session.closeStream(s.id, false)
-		s.sendClose()
-		return nil
-
 	// Opened means we need to signal a close
 	case streamSYNSent:
 		fallthrough
@@ -257,23 +253,39 @@ func (s *Stream) Close() error {
 		fallthrough
 	case streamEstablished:
 		s.state = streamLocalClose
-		s.sendClose()
-		return nil
+		goto SEND_CLOSE
+
+	case streamLocalClose:
+	case streamRemoteClose:
+		s.state = streamClosed
+		s.session.closeStream(s.id, false)
+		goto SEND_CLOSE
+
+	case streamClosed:
+	default:
+		panic("unhandled state")
 	}
-	panic("unhandled state")
+	s.stateLock.Unlock()
+	return nil
+SEND_CLOSE:
+	s.stateLock.Unlock()
+	s.sendClose()
+	return nil
 }
 
 // forceClose is used for when the session is exiting
 func (s *Stream) forceClose() {
-	s.lock.Lock()
-	defer s.lock.Unlock()
+	s.stateLock.Lock()
 	s.state = streamClosed
+	s.stateLock.Unlock()
 	asyncNotify(s.notifyCh)
 }
 
 // processFlags is used to update the state of the stream
 // based on set flags, if any. Lock must be held
 func (s *Stream) processFlags(flags uint16) error {
+	s.stateLock.Lock()
+	defer s.stateLock.Unlock()
 	if flags&flagACK == flagACK {
 		if s.state == streamSYNSent {
 			s.state = streamEstablished
@@ -302,42 +314,43 @@ func (s *Stream) processFlags(flags uint16) error {
 
 // incrSendWindow updates the size of our send window
 func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
-	s.lock.Lock()
-	defer s.lock.Unlock()
 	if err := s.processFlags(flags); err != nil {
 		return err
 	}
 
 	// Increase window, unblock a sender
-	s.sendWindow += hdr.Length()
+	atomic.AddUint32(&s.sendWindow, hdr.Length())
 	asyncNotify(s.notifyCh)
 	return nil
 }
 
 // readData is used to handle a data frame
 func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
-	s.lock.Lock()
-	defer s.lock.Unlock()
 	if err := s.processFlags(flags); err != nil {
 		return err
 	}
 
 	// Check that our recv window is not exceeded
 	length := hdr.Length()
-	if length > s.recvWindow {
+	if length == 0 {
+		return nil
+	}
+	if length > atomic.LoadUint32(&s.recvWindow) {
 		return ErrRecvWindowExceeded
 	}
 
 	// Decrement the receive window
-	s.recvWindow -= length
+	atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
 
 	// Wrap in a limited reader
 	conn = &io.LimitedReader{R: conn, N: int64(length)}
 
 	// Copy to our buffer
+	s.recvLock.Lock()
 	if _, err := io.Copy(&s.recvBuf, conn); err != nil {
 		return err
 	}
+	s.recvLock.Unlock()
 
 	// Unblock any readers
 	asyncNotify(s.notifyCh)