Browse Source

Protect reads of `sendReady.Body` with mutex and temp buffer (#102)

This commit fixes an edge case where the `body` passed to
`waitForSendErr` can be written to after returning from the function.
This happens when `sendReady` is buffered on the `sendCh` and the
session is shutdown or the write times out.

When this condition happens and `waitForSendErr` has not yet exited, the
`body` is safely copied into a temporary buffer in `send`. Otherwise
`waitForSendErr` safely created a copy of the `body` and exits, this
essentially results in double buffering for the edge case which seems
acceptable.
Mathias Fredriksson 2 years ago
parent
commit
3aa5700c94
1 changed files with 50 additions and 4 deletions
  1. 50 4
      session.go

+ 50 - 4
session.go

@@ -2,6 +2,7 @@ package yamux
 
 import (
 	"bufio"
+	"bytes"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -80,6 +81,7 @@ type Session struct {
 // or to directly send a header
 type sendReady struct {
 	Hdr  []byte
+	mu   *sync.Mutex // Protects Body from unsafe reads.
 	Body []byte
 	Err  chan error
 }
@@ -373,7 +375,7 @@ func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) erro
 		timerPool.Put(t)
 	}()
 
-	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
+	ready := sendReady{Hdr: hdr, mu: &sync.Mutex{}, Body: body, Err: errCh}
 	select {
 	case s.sendCh <- ready:
 	case <-s.shutdownCh:
@@ -382,12 +384,34 @@ func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) erro
 		return ErrConnectionWriteTimeout
 	}
 
+	bodyCopy := func() {
+		if body == nil {
+			return // A nil body is ignored.
+		}
+
+		// In the event of session shutdown or connection write timeout,
+		// we need to prevent `send` from reading the body buffer after
+		// returning from this function since the caller may re-use the
+		// underlying array.
+		ready.mu.Lock()
+		defer ready.mu.Unlock()
+
+		if ready.Body == nil {
+			return // Body was already copied in `send`.
+		}
+		newBody := make([]byte, len(body))
+		copy(newBody, body)
+		ready.Body = newBody
+	}
+
 	select {
 	case err := <-errCh:
 		return err
 	case <-s.shutdownCh:
+		bodyCopy()
 		return ErrSessionShutdown
 	case <-timer.C:
+		bodyCopy()
 		return ErrConnectionWriteTimeout
 	}
 }
@@ -420,7 +444,10 @@ func (s *Session) sendNoWait(hdr header) error {
 
 // send is a long running goroutine that sends data
 func (s *Session) send() {
+	var bodyBuf bytes.Buffer
 	for {
+		bodyBuf.Reset()
+
 		select {
 		case ready := <-s.sendCh:
 			// Send a header if ready
@@ -438,9 +465,28 @@ func (s *Session) send() {
 				}
 			}
 
-			// Send data from a body if given
-			if ready.Body != nil {
-				_, err := s.conn.Write(ready.Body)
+			if ready.mu != nil {
+				ready.mu.Lock()
+				if ready.Body != nil {
+					// Copy the body into the buffer to avoid
+					// holding a mutex lock during the write.
+					_, err := bodyBuf.Write(ready.Body)
+					if err != nil {
+						ready.Body = nil
+						ready.mu.Unlock()
+						s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
+						asyncSendErr(ready.Err, err)
+						s.exitErr(err)
+						return
+					}
+					ready.Body = nil
+				}
+				ready.mu.Unlock()
+			}
+
+			if bodyBuf.Len() > 0 {
+				// Send data from a body if given
+				_, err := s.conn.Write(bodyBuf.Bytes())
 				if err != nil {
 					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
 					asyncSendErr(ready.Err, err)