Bladeren bron

Merge pull request #88 from hashicorp/feature/close-timeout

add StreamCloseTimeout
Mitchell Hashimoto 4 jaren geleden
bovenliggende
commit
a95892c5f8
4 gewijzigde bestanden met toevoegingen van 129 en 2 verwijderingen
  1. 2 0
      go.mod
  2. 8 0
      mux.go
  3. 70 0
      session_test.go
  4. 49 2
      stream.go

+ 2 - 0
go.mod

@@ -1 +1,3 @@
 module github.com/hashicorp/yamux
+
+go 1.15

+ 8 - 0
mux.go

@@ -31,6 +31,13 @@ type Config struct {
 	// window size that we allow for a stream.
 	MaxStreamWindowSize uint32
 
+	// StreamCloseTimeout is the maximum time that a stream will allowed to
+	// be in a half-closed state when `Close` is called before forcibly
+	// closing the connection. Forcibly closed connections will empty the
+	// receive buffer, drop any future packets received for that stream,
+	// and send a RST to the remote side.
+	StreamCloseTimeout time.Duration
+
 	// LogOutput is used to control the log destination. Either Logger or
 	// LogOutput can be set, not both.
 	LogOutput io.Writer
@@ -48,6 +55,7 @@ func DefaultConfig() *Config {
 		KeepAliveInterval:      30 * time.Second,
 		ConnectionWriteTimeout: 10 * time.Second,
 		MaxStreamWindowSize:    initialStreamWindow,
+		StreamCloseTimeout:     5 * time.Minute,
 		LogOutput:              os.Stderr,
 	}
 }

+ 70 - 0
session_test.go

@@ -273,6 +273,76 @@ func TestAccept(t *testing.T) {
 	}
 }
 
+func TestClose_closeTimeout(t *testing.T) {
+	conf := testConf()
+	conf.StreamCloseTimeout = 10 * time.Millisecond
+	client, server := testClientServerConfig(conf)
+	defer client.Close()
+	defer server.Close()
+
+	if client.NumStreams() != 0 {
+		t.Fatalf("bad")
+	}
+	if server.NumStreams() != 0 {
+		t.Fatalf("bad")
+	}
+
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	// Open a stream on the client but only close it on the server.
+	// We want to see if the stream ever gets cleaned up on the client.
+
+	var clientStream *Stream
+	go func() {
+		defer wg.Done()
+		var err error
+		clientStream, err = client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	go func() {
+		defer wg.Done()
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if err := stream.Close(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	doneCh := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(doneCh)
+	}()
+
+	select {
+	case <-doneCh:
+	case <-time.After(time.Second):
+		panic("timeout")
+	}
+
+	// We should have zero streams after our timeout period
+	time.Sleep(100 * time.Millisecond)
+
+	if v := server.NumStreams(); v > 0 {
+		t.Fatalf("should have zero streams: %d", v)
+	}
+	if v := client.NumStreams(); v > 0 {
+		t.Fatalf("should have zero streams: %d", v)
+	}
+
+	if _, err := clientStream.Write([]byte("hello")); err == nil {
+		t.Fatal("should error on write")
+	} else if err.Error() != "connection reset" {
+		t.Fatalf("expected connection reset, got %q", err)
+	}
+}
+
 func TestNonNilInterface(t *testing.T) {
 	_, server := testClientServer()
 	server.Close()

+ 49 - 2
stream.go

@@ -49,6 +49,10 @@ type Stream struct {
 
 	readDeadline  atomic.Value // time.Time
 	writeDeadline atomic.Value // time.Time
+
+	// closeTimer is set with stateLock held to honor the StreamCloseTimeout
+	// setting on Session.
+	closeTimer *time.Timer
 }
 
 // newStream is used to construct a new stream within
@@ -312,6 +316,27 @@ func (s *Stream) Close() error {
 	s.stateLock.Unlock()
 	return nil
 SEND_CLOSE:
+	// This shouldn't happen (the more realistic scenario to cancel the
+	// timer is via processFlags) but just in case this ever happens, we
+	// cancel the timer to prevent dangling timers.
+	if s.closeTimer != nil {
+		s.closeTimer.Stop()
+		s.closeTimer = nil
+	}
+
+	// If we have a StreamCloseTimeout set we start the timeout timer.
+	// We do this only if we're not already closing the stream since that
+	// means this was a graceful close.
+	//
+	// This prevents memory leaks if one side (this side) closes and the
+	// remote side poorly behaves and never responds with a FIN to complete
+	// the close. After the specified timeout, we clean our resources up no
+	// matter what.
+	if !closeStream && s.session.config.StreamCloseTimeout > 0 {
+		s.closeTimer = time.AfterFunc(
+			s.session.config.StreamCloseTimeout, s.closeTimeout)
+	}
+
 	s.stateLock.Unlock()
 	s.sendClose()
 	s.notifyWaiting()
@@ -321,6 +346,22 @@ SEND_CLOSE:
 	return nil
 }
 
+// closeTimeout is called after StreamCloseTimeout during a close to
+// close this stream.
+func (s *Stream) closeTimeout() {
+	// Close our side forcibly
+	s.forceClose()
+
+	// Free the stream from the session map
+	s.session.closeStream(s.id)
+
+	// Send a RST so the remote side closes too.
+	s.sendLock.Lock()
+	defer s.sendLock.Unlock()
+	s.sendHdr.encode(typeWindowUpdate, flagRST, s.id, 0)
+	s.session.sendNoWait(s.sendHdr)
+}
+
 // forceClose is used for when the session is exiting
 func (s *Stream) forceClose() {
 	s.stateLock.Lock()
@@ -332,16 +373,22 @@ func (s *Stream) forceClose() {
 // processFlags is used to update the state of the stream
 // based on set flags, if any. Lock must be held
 func (s *Stream) processFlags(flags uint16) error {
+	s.stateLock.Lock()
+	defer s.stateLock.Unlock()
+
 	// Close the stream without holding the state lock
 	closeStream := false
 	defer func() {
 		if closeStream {
+			if s.closeTimer != nil {
+				// Stop our close timeout timer since we gracefully closed
+				s.closeTimer.Stop()
+			}
+
 			s.session.closeStream(s.id)
 		}
 	}()
 
-	s.stateLock.Lock()
-	defer s.stateLock.Unlock()
 	if flags&flagACK == flagACK {
 		if s.state == streamSYNSent {
 			s.state = streamEstablished