Browse Source

Merge pull request #9 from hashicorp/b-deadlock

Fixes read/write loop and ping deadlocks
James Phillips 9 years ago
parent
commit
4036a347b6
4 changed files with 234 additions and 10 deletions
  1. 4 0
      const.go
  2. 7 0
      mux.go
  3. 42 9
      session.go
  4. 181 1
      session_test.go

+ 4 - 0
const.go

@@ -29,6 +29,10 @@ var (
 	// ErrReceiveWindowExceeded indicates the window was exceeded
 	ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
 
+	// ErrHeaderWriteTimeout indicates that we hit an IO deadline waiting
+	// for a header to be written.
+	ErrHeaderWriteTimeout = fmt.Errorf("header write timeout")
+
 	// ErrTimeout is used when we reach an IO deadline
 	ErrTimeout = fmt.Errorf("i/o deadline reached")
 

+ 7 - 0
mux.go

@@ -20,6 +20,12 @@ type Config struct {
 	// KeepAliveInterval is how often to perform the keep alive
 	KeepAliveInterval time.Duration
 
+	// HeaderWriteTimeout is how long we will wait to perform a blocking
+	// operation writing a header, after which we will throw an error and
+	// close the stream. Headers are small, so this should be set to a value
+	// after which you suspect there is something wrong with the connection.
+	HeaderWriteTimeout time.Duration
+
 	// MaxStreamWindowSize is used to control the maximum
 	// window size that we allow for a stream.
 	MaxStreamWindowSize uint32
@@ -34,6 +40,7 @@ func DefaultConfig() *Config {
 		AcceptBacklog:       256,
 		EnableKeepAlive:     true,
 		KeepAliveInterval:   30 * time.Second,
+		HeaderWriteTimeout:  10 * time.Second,
 		MaxStreamWindowSize: initialStreamWindow,
 		LogOutput:           os.Stderr,
 	}

+ 42 - 9
session.go

@@ -299,29 +299,51 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
 	return s.waitForSendErr(hdr, body, errCh)
 }
 
-// waitForSendErr waits to send a header, checking for a potential shutdown
+// waitForSendErr waits to send a header with optional data, checking for a
+// potential shutdown. If the body is not supplied then we will enforce the
+// configured HeaderWriteTimeout, since this is a small control header.
 func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
+	var timeout <- chan time.Time
+	if body == nil {
+		timer := time.NewTimer(s.config.HeaderWriteTimeout)
+		defer timer.Stop()
+
+		timeout = timer.C
+	}
+
 	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
 	select {
 	case s.sendCh <- ready:
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timeout:
+		return ErrHeaderWriteTimeout
 	}
+
 	select {
 	case err := <-errCh:
 		return err
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timeout:
+		return ErrHeaderWriteTimeout
 	}
 }
 
-// sendNoWait does a send without waiting
+// sendNoWait does a send without waiting. Since there's still a case where
+// sendCh itself can be full, we will enforce the configured HeaderWriteTimeout,
+// since this is a small control header.
 func (s *Session) sendNoWait(hdr header) error {
+	timer := time.NewTimer(s.config.HeaderWriteTimeout)
+	defer timer.Stop()
+
 	select {
 	case s.sendCh <- sendReady{Hdr: hdr}:
 		return nil
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timer.C:
+		return ErrHeaderWriteTimeout
 	}
 }
 
@@ -446,7 +468,9 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	// Check if this is a window update
 	if hdr.MsgType() == typeWindowUpdate {
 		if err := stream.incrSendWindow(hdr, flags); err != nil {
-			s.sendNoWait(s.goAway(goAwayProtoErr))
+			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+			}
 			return err
 		}
 		return nil
@@ -454,7 +478,9 @@ func (s *Session) handleStreamMessage(hdr header) error {
 
 	// Read the new data
 	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
-		s.sendNoWait(s.goAway(goAwayProtoErr))
+		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+		}
 		return err
 	}
 	return nil
@@ -465,11 +491,16 @@ func (s *Session) handlePing(hdr header) error {
 	flags := hdr.Flags()
 	pingID := hdr.Length()
 
-	// Check if this is a query, respond back
+	// Check if this is a query, respond back in a separate context so we
+	// don't interfere with the receiving thread blocking for the write.
 	if flags&flagSYN == flagSYN {
-		hdr := header(make([]byte, headerSize))
-		hdr.encode(typePing, flagACK, 0, pingID)
-		s.sendNoWait(hdr)
+		go func() {
+			hdr := header(make([]byte, headerSize))
+			hdr.encode(typePing, flagACK, 0, pingID)
+			if err := s.sendNoWait(hdr); err != nil {
+				s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
+			}
+		}()
 		return nil
 	}
 
@@ -521,7 +552,9 @@ func (s *Session) incomingStream(id uint32) error {
 	// Check if stream already exists
 	if _, ok := s.streams[id]; ok {
 		s.logger.Printf("[ERR] yamux: duplicate stream declared")
-		s.sendNoWait(s.goAway(goAwayProtoErr))
+		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+		}
 		return ErrDuplicateStream
 	}
 

+ 181 - 1
session_test.go

@@ -14,6 +14,7 @@ import (
 type pipeConn struct {
 	reader *io.PipeReader
 	writer *io.PipeWriter
+	writeBlocker sync.Mutex
 }
 
 func (p *pipeConn) Read(b []byte) (int, error) {
@@ -21,6 +22,8 @@ func (p *pipeConn) Read(b []byte) (int, error) {
 }
 
 func (p *pipeConn) Write(b []byte) (int, error) {
+	p.writeBlocker.Lock()
+	defer p.writeBlocker.Unlock()
 	return p.writer.Write(b)
 }
 
@@ -32,13 +35,16 @@ func (p *pipeConn) Close() error {
 func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
 	read1, write1 := io.Pipe()
 	read2, write2 := io.Pipe()
-	return &pipeConn{read1, write2}, &pipeConn{read2, write1}
+	conn1 := &pipeConn{reader: read1, writer: write2}
+	conn2 := &pipeConn{reader: read2, writer: write1}
+	return conn1, conn2
 }
 
 func testClientServer() (*Session, *Session) {
 	conf := DefaultConfig()
 	conf.AcceptBacklog = 64
 	conf.KeepAliveInterval = 100 * time.Millisecond
+	conf.HeaderWriteTimeout = 250 * time.Millisecond
 	return testClientServerConfig(conf)
 }
 
@@ -799,3 +805,177 @@ func TestBacklogExceeded_Accept(t *testing.T) {
 		}
 	}
 }
+
+func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	// Choose a huge flood size that we know will result in a window update.
+	flood := int64(client.config.MaxStreamWindowSize) - 1
+
+	// The server will accept a new stream and then flood data to it.
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		n, err := stream.Write(make([]byte, flood))
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if int64(n) != flood {
+			t.Fatalf("short write: %d", n)
+		}
+	}()
+
+	// The client will open a stream, block outbound writes, and then
+	// listen to the flood from the server, which should time out since
+	// it won't be able to send the window update.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn := client.conn.(*pipeConn)
+		conn.writeBlocker.Lock()
+
+		_, err = stream.Read(make([]byte, flood))
+		if err != ErrHeaderWriteTimeout {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	wg.Wait()
+}
+
+func TestSession_sendNoWait_Timeout(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+	}()
+
+	// The client will open the stream and then block outbound writes, we'll
+	// probe sendNoWait once it gets into that state.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn := client.conn.(*pipeConn)
+		conn.writeBlocker.Lock()
+
+		hdr := header(make([]byte, headerSize))
+		hdr.encode(typePing, flagACK, 0, 0)
+		for {
+			err = client.sendNoWait(hdr)
+			if err == nil {
+				continue
+			} else if err == ErrHeaderWriteTimeout {
+				break
+			} else {
+				t.Fatalf("err: %v", err)
+			}
+		}
+	}()
+
+	wg.Wait()
+}
+
+func TestSession_PingOfDeath(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	var doPingOfDeath sync.Mutex
+	doPingOfDeath.Lock()
+
+	// This is used later to block outbound writes.
+	conn := server.conn.(*pipeConn)
+
+	// The server will accept a stream, block outbound writes, and then
+	// flood its send channel so that no more headers can be queued.
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn.writeBlocker.Lock()
+		for {
+			hdr := header(make([]byte, headerSize))
+			hdr.encode(typePing, 0, 0, 0)
+			err = server.sendNoWait(hdr)
+			if err == nil {
+				continue
+			} else if err == ErrHeaderWriteTimeout {
+				break
+			} else {
+				t.Fatalf("err: %v", err)
+			}
+		}
+
+		doPingOfDeath.Unlock()
+	}()
+
+	// The client will open a stream and then send the server a ping once it
+	// can no longer write. This makes sure the server doesn't deadlock reads
+	// while trying to reply to the ping with no ability to write.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		// This ping will never unblock because the ping id will never
+		// show up in a response.
+		doPingOfDeath.Lock()
+		go func() { client.Ping() }()
+
+		// Wait for a while to make sure the previous ping times out,
+		// then turn writes back on and make sure a ping works again.
+		time.Sleep(2 * server.config.HeaderWriteTimeout)
+		conn.writeBlocker.Unlock()
+		if _, err = client.Ping(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	wg.Wait()
+}