Browse Source

handle session

fancl 9 months ago
parent
commit
1441811ef0
2 changed files with 40 additions and 26 deletions
  1. 35 21
      session.go
  2. 5 5
      stream.go

+ 35 - 21
session.go

@@ -8,15 +8,21 @@ import (
 	"io"
 	"math"
 	"net"
+	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
 	"time"
 )
 
+var (
+	_sid uint64
+)
+
 // Session is used to wrap a reliable ordered connection and to
 // multiplex it into multiple streams.
 type Session struct {
+	id string
 	// remoteGoAway indicates the remote side does
 	// not want futher connections. Must be first for alignment.
 	remoteGoAway int32
@@ -93,7 +99,6 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 	if logger == nil {
 		logger = &discordLogger{}
 	}
-
 	s := &Session{
 		config:     config,
 		logger:     logger,
@@ -109,6 +114,11 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		sendDoneCh: make(chan struct{}),
 		shutdownCh: make(chan struct{}),
 	}
+	if v, ok := conn.(hasAddr); ok {
+		s.id = v.RemoteAddr().String()
+	} else {
+		s.id = strconv.FormatInt(int64(atomic.AddUint64(&_sid, 1)), 10)
+	}
 	if client {
 		s.nextStreamID = 1
 	} else {
@@ -122,6 +132,10 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 	return s
 }
 
+func (s *Session) ID() string {
+	return s.id
+}
+
 // IsClosed does a safe check to see if we have shutdown
 func (s *Session) IsClosed() bool {
 	select {
@@ -197,7 +211,7 @@ GET_ID:
 		select {
 		case <-s.synCh:
 		default:
-			s.logger.Errorf("aborted stream open without inflight syn semaphore")
+			s.logger.Warnf("session %s aborted stream open without inflight syn semaphore", s.ID())
 		}
 		return nil, err
 	}
@@ -220,7 +234,7 @@ func (s *Session) setOpenTimeout(stream *Stream) {
 	case <-timer.C:
 		// Timeout reached while waiting for ACK.
 		// Close the session to force connection re-establishment.
-		s.logger.Errorf("aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
+		s.logger.Warnf("session %s aborted stream open (destination=%s): %v", s.ID(), s.RemoteAddr().String(), ErrTimeout.err)
 		s.Close()
 	}
 }
@@ -366,7 +380,7 @@ func (s *Session) keepalive() {
 			_, err := s.Ping()
 			if err != nil {
 				if err != ErrSessionShutdown {
-					s.logger.Errorf("keepalive failed: %v", err)
+					s.logger.Warnf("session %s keepalive failed: %v", s.ID(), err)
 					s.exitErr(ErrKeepAliveTimeout)
 				}
 				return
@@ -558,14 +572,14 @@ func (s *Session) recvLoop() error {
 		// Read the header
 		if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
 			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
-				s.logger.Errorf("failed to read stream header: %v", err)
+				s.logger.Warnf("session %s unavailable to read stream header: %v", s.ID(), err)
 			}
 			return err
 		}
 
 		// Verify the version
 		if hdr.Version() != protoVersion {
-			s.logger.Errorf("invalid stream protocol version: %d", hdr.Version())
+			s.logger.Warnf("session %s read invalid stream protocol version: %d", s.ID(), hdr.Version())
 			return ErrInvalidVersion
 		}
 
@@ -600,13 +614,13 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	if stream == nil {
 		// Drain any data on the wire
 		if hdr.MsgType() == typeData && hdr.Length() > 0 {
-			s.logger.Warnf("discarding data for stream: %d", id)
+			s.logger.Warnf("session %s discarding data for stream: %d", s.ID(), id)
 			if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.Length())); err != nil {
-				s.logger.Errorf("failed to discard stream %d data: %v", id, err)
+				s.logger.Warnf("session %s failed to discard stream %d data: %v", s.ID(), id, err)
 				return nil
 			}
 		} else {
-			s.logger.Warnf("frame for missing stream: %v", hdr)
+			s.logger.Warnf("session %s frame for missing stream: %v", s.ID(), hdr)
 		}
 		return nil
 	}
@@ -615,7 +629,7 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	if hdr.MsgType() == typeWindowUpdate {
 		if err := stream.incrSendWindow(hdr, flags); err != nil {
 			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-				s.logger.Warnf("failed to send go away: %v", sendErr)
+				s.logger.Warnf("session %s failed to send go away: %v", s.ID(), sendErr)
 			}
 			return err
 		}
@@ -625,7 +639,7 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	// Read the new data
 	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
 		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-			s.logger.Warnf("failed to send go away: %v", sendErr)
+			s.logger.Warnf("session %s failed to send go away: %v", s.ID(), sendErr)
 		}
 		return err
 	}
@@ -644,7 +658,7 @@ func (s *Session) handlePing(hdr header) error {
 			hdr := header(make([]byte, headerSize))
 			hdr.encode(typePing, flagACK, 0, pingID)
 			if err := s.sendNoWait(hdr); err != nil {
-				s.logger.Warnf("stream %s failed to send ping reply: %v", hdr.StreamID(), err)
+				s.logger.Warnf("session %s stream %s failed to send ping reply: %v", s.ID(), hdr.StreamID(), err)
 			}
 		}()
 		return nil
@@ -668,13 +682,13 @@ func (s *Session) handleGoAway(hdr header) error {
 	case goAwayNormal:
 		atomic.SwapInt32(&s.remoteGoAway, 1)
 	case goAwayProtoErr:
-		s.logger.Errorf("received protocol error go away")
+		s.logger.Warnf("session %s received protocol error go away", s.ID())
 		return fmt.Errorf("yamux protocol error")
 	case goAwayInternalErr:
-		s.logger.Errorf("received internal error go away")
+		s.logger.Warnf("session %s  received internal error go away", s.ID())
 		return fmt.Errorf("remote yamux internal error")
 	default:
-		s.logger.Errorf("received unexpected go away")
+		s.logger.Warnf("session %s received unexpected go away", s.ID())
 		return fmt.Errorf("unexpected go away received")
 	}
 	return nil
@@ -697,9 +711,9 @@ func (s *Session) incomingStream(id uint32) error {
 
 	// Check if stream already exists
 	if _, ok := s.streams[id]; ok {
-		s.logger.Errorf("duplicate stream declared")
+		s.logger.Warnf("session %s duplicate stream declared", s.ID())
 		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-			s.logger.Warnf("failed to send go away: %v", sendErr)
+			s.logger.Debugf("session %s failed to send go away: %v", s.ID(), sendErr)
 		}
 		return ErrDuplicateStream
 	}
@@ -713,7 +727,7 @@ func (s *Session) incomingStream(id uint32) error {
 		return nil
 	default:
 		// Backlog exceeded! RST the stream
-		s.logger.Warnf("backlog exceeded, forcing connection reset")
+		s.logger.Warnf("session %s backlog exceeded, forcing connection reset", s.ID())
 		delete(s.streams, id)
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typeWindowUpdate, flagRST, id, 0)
@@ -730,7 +744,7 @@ func (s *Session) closeStream(id uint32) {
 		select {
 		case <-s.synCh:
 		default:
-			s.logger.Errorf("SYN tracking out of sync")
+			s.logger.Warnf("session %s SYN tracking out of sync", s.ID())
 		}
 	}
 	delete(s.streams, id)
@@ -744,12 +758,12 @@ func (s *Session) establishStream(id uint32) {
 	if _, ok := s.inflight[id]; ok {
 		delete(s.inflight, id)
 	} else {
-		s.logger.Errorf("established stream without inflight SYN (no tracking entry)")
+		s.logger.Warnf("session %s established stream without inflight SYN (no tracking entry)", s.ID())
 	}
 	select {
 	case <-s.synCh:
 	default:
-		s.logger.Errorf("established stream without inflight SYN (didn't have semaphore)")
+		s.logger.Warnf("session %s established stream without inflight SYN (didn't have semaphore)", s.ID())
 	}
 	s.streamLock.Unlock()
 }

+ 5 - 5
stream.go

@@ -431,7 +431,7 @@ func (s *Stream) processFlags(flags uint16) error {
 			closeStream = true
 			s.notifyWaiting()
 		default:
-			s.session.logger.Errorf("unexpected FIN flag in state %d", s.state)
+			s.session.logger.Warnf("unexpected FIN flag in state %d", s.state)
 			return ErrUnexpectedFlag
 		}
 	}
@@ -484,7 +484,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) (err error)
 	s.recvLock.Lock()
 
 	if length > s.recvWindow {
-		s.session.logger.Errorf("receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
+		s.session.logger.Warnf("receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
 		s.recvLock.Unlock()
 		return ErrRecvWindowExceeded
 	}
@@ -498,7 +498,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) (err error)
 	defer putBytes(buf)
 
 	if nr, err = io.ReadFull(conn, buf); err != nil {
-		s.session.logger.Errorf("failed to read stream %d data: %v", s.id, err)
+		s.session.logger.Warnf("failed to read stream %d data: %v", s.id, err)
 		return err
 	}
 	if uint32(nr) != length {
@@ -506,12 +506,12 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) (err error)
 	}
 	if s.session.config.Crypto != nil {
 		if buf, err = s.session.config.Crypto.Decrypt(buf); err != nil {
-			s.session.logger.Errorf("failed to decrypt stream %d data: %v", s.id, err)
+			s.session.logger.Warnf("failed to decrypt stream %d data: %v", s.id, err)
 			return err
 		}
 	}
 	if copiedLength, err = s.recvBuf.Write(buf); err != nil {
-		s.session.logger.Errorf("failed to read stream %d data: %v", s.id, err)
+		s.session.logger.Warnf("failed to read stream %d data: %v", s.id, err)
 		s.recvLock.Unlock()
 		return err
 	}