瀏覽代碼

Split notify channels

Armon Dadgar 10 年之前
父節點
當前提交
c41b80d852
共有 1 個文件被更改,包括 27 次插入14 次删除
  1. 27 14
      stream.go

+ 27 - 14
stream.go

@@ -41,7 +41,8 @@ type Stream struct {
 	sendHdr  header
 	sendLock sync.Mutex
 
-	notifyCh chan struct{}
+	recvNotifyCh chan struct{}
+	sendNotifyCh chan struct{}
 
 	readDeadline  time.Time
 	writeDeadline time.Time
@@ -51,14 +52,15 @@ type Stream struct {
 // a given session for an ID
 func newStream(session *Session, id uint32, state streamState) *Stream {
 	s := &Stream{
-		id:         id,
-		session:    session,
-		state:      state,
-		controlHdr: header(make([]byte, headerSize)),
-		sendHdr:    header(make([]byte, headerSize)),
-		recvWindow: initialStreamWindow,
-		sendWindow: initialStreamWindow,
-		notifyCh:   make(chan struct{}, 1),
+		id:           id,
+		session:      session,
+		state:        state,
+		controlHdr:   header(make([]byte, headerSize)),
+		sendHdr:      header(make([]byte, headerSize)),
+		recvWindow:   initialStreamWindow,
+		sendWindow:   initialStreamWindow,
+		recvNotifyCh: make(chan struct{}, 1),
+		sendNotifyCh: make(chan struct{}, 1),
 	}
 	return s
 }
@@ -75,6 +77,7 @@ func (s *Stream) StreamID() uint32 {
 
 // Read is used to read from the stream
 func (s *Stream) Read(b []byte) (n int, err error) {
+	defer asyncNotify(s.recvNotifyCh)
 START:
 	s.stateLock.Lock()
 	switch s.state {
@@ -110,7 +113,7 @@ WAIT:
 		timeout = time.After(delay)
 	}
 	select {
-	case <-s.notifyCh:
+	case <-s.recvNotifyCh:
 		goto START
 	case <-timeout:
 		return 0, ErrTimeout
@@ -180,7 +183,7 @@ WAIT:
 		timeout = time.After(delay)
 	}
 	select {
-	case <-s.notifyCh:
+	case <-s.sendNotifyCh:
 		goto START
 	case <-timeout:
 		return 0, ErrTimeout
@@ -276,6 +279,7 @@ func (s *Stream) Close() error {
 SEND_CLOSE:
 	s.stateLock.Unlock()
 	s.sendClose()
+	s.notifyWaiting()
 	return nil
 }
 
@@ -284,7 +288,7 @@ func (s *Stream) forceClose() {
 	s.stateLock.Lock()
 	s.state = streamClosed
 	s.stateLock.Unlock()
-	asyncNotify(s.notifyCh)
+	s.notifyWaiting()
 }
 
 // processFlags is used to update the state of the stream
@@ -305,19 +309,28 @@ func (s *Stream) processFlags(flags uint16) error {
 			fallthrough
 		case streamEstablished:
 			s.state = streamRemoteClose
+			s.notifyWaiting()
 		case streamLocalClose:
 			s.state = streamClosed
 			s.session.closeStream(s.id, true)
+			s.notifyWaiting()
 		default:
 			return ErrUnexpectedFlag
 		}
 	} else if flags&flagRST == flagRST {
 		s.state = streamClosed
 		s.session.closeStream(s.id, true)
+		s.notifyWaiting()
 	}
 	return nil
 }
 
+// notifyWaiting notifies all the waiting channels
+func (s *Stream) notifyWaiting() {
+	asyncNotify(s.recvNotifyCh)
+	asyncNotify(s.sendNotifyCh)
+}
+
 // incrSendWindow updates the size of our send window
 func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
 	if err := s.processFlags(flags); err != nil {
@@ -326,7 +339,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
 
 	// Increase window, unblock a sender
 	atomic.AddUint32(&s.sendWindow, hdr.Length())
-	asyncNotify(s.notifyCh)
+	asyncNotify(s.sendNotifyCh)
 	return nil
 }
 
@@ -359,7 +372,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	s.recvLock.Unlock()
 
 	// Unblock any readers
-	asyncNotify(s.notifyCh)
+	asyncNotify(s.recvNotifyCh)
 	return nil
 }