fancl преди 9 месеца
родител
ревизия
3f731acf92
променени са 12 файла, в които са добавени 230 реда и са изтрити 94 реда
  1. 8 0
      .idea/.gitignore
  2. 8 0
      .idea/modules.xml
  3. 6 0
      .idea/vcs.xml
  4. 9 0
      .idea/yamux.iml
  5. 6 0
      crypto.go
  6. 1 1
      go.mod
  7. 3 0
      mux.go
  8. 66 0
      pool.go
  9. 75 68
      session.go
  10. 1 12
      session_test.go
  11. 28 10
      stream.go
  12. 19 3
      util.go

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/yamux.iml" filepath="$PROJECT_DIR$/.idea/yamux.iml" />
+    </modules>
+  </component>
+</project>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="" vcs="Git" />
+  </component>
+</project>

+ 9 - 0
.idea/yamux.iml

@@ -0,0 +1,9 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="WEB_MODULE" version="4">
+  <component name="Go" enabled="true" />
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 6 - 0
crypto.go

@@ -0,0 +1,6 @@
+package yamux
+
+type Crypto interface {
+	Encrypt(src []byte) (dst []byte, err error)
+	Decrypt(src []byte) (dst []byte, err error)
+}

+ 1 - 1
go.mod

@@ -1,3 +1,3 @@
-module github.com/hashicorp/yamux
+module git.nspix.com/golang/yamux
 
 go 1.15

+ 3 - 0
mux.go

@@ -51,6 +51,9 @@ type Config struct {
 	// Logger is used to pass in the logger to be used. Either Logger or
 	// LogOutput can be set, not both.
 	Logger Logger
+
+	//Crypto is used to encrypt data
+	Crypto Crypto
 }
 
 // DefaultConfig is used to return a default configuration

+ 66 - 0
pool.go

@@ -0,0 +1,66 @@
+package yamux
+
+import (
+	"bytes"
+	"sync"
+)
+
+var (
+	bufferPool sync.Pool
+	bufPool5k  sync.Pool
+	bufPool2k  sync.Pool
+	bufPool1k  sync.Pool
+	bufPool    sync.Pool
+)
+
+func getBuffer() *bytes.Buffer {
+	if v := bufferPool.Get(); v != nil {
+		return v.(*bytes.Buffer)
+	}
+	return bytes.NewBuffer([]byte{})
+}
+
+func putBuffer(b *bytes.Buffer) {
+	b.Reset()
+	bufferPool.Put(b)
+}
+
+func getBytes(size int) []byte {
+	if size <= 0 {
+		return nil
+	}
+	var x interface{}
+	if size >= 5*1024 {
+		x = bufPool5k.Get()
+	} else if size >= 2*1024 {
+		x = bufPool2k.Get()
+	} else if size >= 1*1024 {
+		x = bufPool1k.Get()
+	} else {
+		x = bufPool.Get()
+	}
+	if x == nil {
+		return make([]byte, size)
+	}
+	buf := x.([]byte)
+	if cap(buf) < size {
+		return make([]byte, size)
+	}
+	return buf[:size]
+}
+
+func putBytes(buf []byte) {
+	size := cap(buf)
+	if size <= 0 {
+		return
+	}
+	if size >= 5*1024 {
+		bufPool5k.Put(buf)
+	} else if size >= 2*1024 {
+		bufPool2k.Put(buf)
+	} else if size >= 1*1024 {
+		bufPool1k.Put(buf)
+	} else {
+		bufPool.Put(buf)
+	}
+}

+ 75 - 68
session.go

@@ -6,8 +6,6 @@ import (
 	"context"
 	"fmt"
 	"io"
-	"io/ioutil"
-	"log"
 	"math"
 	"net"
 	"strings"
@@ -93,7 +91,7 @@ type sendReady struct {
 func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 	logger := config.Logger
 	if logger == nil {
-		logger = log.New(config.LogOutput, "", log.LstdFlags)
+		logger = &discordLogger{}
 	}
 
 	s := &Session{
@@ -199,7 +197,7 @@ GET_ID:
 		select {
 		case <-s.synCh:
 		default:
-			s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
+			s.logger.Errorf("aborted stream open without inflight syn semaphore")
 		}
 		return nil, err
 	}
@@ -222,7 +220,7 @@ func (s *Session) setOpenTimeout(stream *Stream) {
 	case <-timer.C:
 		// Timeout reached while waiting for ACK.
 		// Close the session to force connection re-establishment.
-		s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
+		s.logger.Errorf("aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
 		s.Close()
 	}
 }
@@ -368,7 +366,7 @@ func (s *Session) keepalive() {
 			_, err := s.Ping()
 			if err != nil {
 				if err != ErrSessionShutdown {
-					s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
+					s.logger.Errorf("keepalive failed: %v", err)
 					s.exitErr(ErrKeepAliveTimeout)
 				}
 				return
@@ -475,54 +473,63 @@ func (s *Session) send() {
 	}
 }
 
-func (s *Session) sendLoop() error {
+func (s *Session) sendPacket(packet *sendReady) (err error) {
+	var (
+		n      int
+		nw     int
+		buf    []byte
+		buffer *bytes.Buffer
+	)
+	buffer = getBuffer()
+	defer func() {
+		putBuffer(buffer)
+	}()
+	packet.mu.Lock()
+	// Send a header if ready
+	if packet.Hdr != nil {
+		n, _ = buffer.Write(packet.Hdr)
+		nw += n
+	}
+	if packet.Body != nil {
+		if s.config.Crypto != nil {
+			if buf, err = s.config.Crypto.Encrypt(packet.Body); err != nil {
+				return
+			}
+		} else {
+			buf = packet.Body
+		}
+		n, _ = buffer.Write(buf)
+		nw += n
+		packet.Body = nil
+	}
+	packet.mu.Unlock()
+
+	if buffer.Len() > 0 {
+		// Send data from a body if given
+		if n, err = s.conn.Write(buffer.Bytes()); err != nil {
+			asyncSendErr(packet.Err, err)
+			return err
+		}
+		if n != nw {
+			asyncSendErr(packet.Err, io.ErrShortWrite)
+			return io.ErrShortWrite
+		}
+	}
+	// No error, successful send
+	asyncSendErr(packet.Err, nil)
+	return
+}
+
+func (s *Session) sendLoop() (err error) {
 	defer close(s.sendDoneCh)
-	var bodyBuf bytes.Buffer
 	for {
-		bodyBuf.Reset()
-
 		select {
 		case ready := <-s.sendCh:
-			// Send a header if ready
-			if ready.Hdr != nil {
-				_, err := s.conn.Write(ready.Hdr)
-				if err != nil {
-					s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
-					asyncSendErr(ready.Err, err)
-					return err
-				}
+			if err = s.sendPacket(ready); err != nil {
+				return err
 			}
-
-			ready.mu.Lock()
-			if ready.Body != nil {
-				// Copy the body into the buffer to avoid
-				// holding a mutex lock during the write.
-				_, err := bodyBuf.Write(ready.Body)
-				if err != nil {
-					ready.Body = nil
-					ready.mu.Unlock()
-					s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
-					asyncSendErr(ready.Err, err)
-					return err
-				}
-				ready.Body = nil
-			}
-			ready.mu.Unlock()
-
-			if bodyBuf.Len() > 0 {
-				// Send data from a body if given
-				_, err := s.conn.Write(bodyBuf.Bytes())
-				if err != nil {
-					s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
-					asyncSendErr(ready.Err, err)
-					return err
-				}
-			}
-
-			// No error, successful send
-			asyncSendErr(ready.Err, nil)
 		case <-s.shutdownCh:
-			return nil
+			return
 		}
 	}
 }
@@ -552,14 +559,14 @@ func (s *Session) recvLoop() error {
 		// Read the header
 		if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
 			if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
-				s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
+				s.logger.Errorf("failed to read stream header: %v", err)
 			}
 			return err
 		}
 
 		// Verify the version
 		if hdr.Version() != protoVersion {
-			s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
+			s.logger.Errorf("invalid stream protocol version: %d", hdr.Version())
 			return ErrInvalidVersion
 		}
 
@@ -594,13 +601,13 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	if stream == nil {
 		// Drain any data on the wire
 		if hdr.MsgType() == typeData && hdr.Length() > 0 {
-			s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
-			if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
-				s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
+			s.logger.Warnf("discarding data for stream: %d", id)
+			if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.Length())); err != nil {
+				s.logger.Errorf("failed to discard stream %d data: %v", id, err)
 				return nil
 			}
 		} else {
-			s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
+			s.logger.Warnf("frame for missing stream: %v", hdr)
 		}
 		return nil
 	}
@@ -609,7 +616,7 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	if hdr.MsgType() == typeWindowUpdate {
 		if err := stream.incrSendWindow(hdr, flags); err != nil {
 			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+				s.logger.Warnf("failed to send go away: %v", sendErr)
 			}
 			return err
 		}
@@ -619,14 +626,14 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	// Read the new data
 	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
 		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+			s.logger.Warnf("failed to send go away: %v", sendErr)
 		}
 		return err
 	}
 	return nil
 }
 
-// handlePing is invokde for a typePing frame
+// handlePing is invoked for a typePing frame
 func (s *Session) handlePing(hdr header) error {
 	flags := hdr.Flags()
 	pingID := hdr.Length()
@@ -638,7 +645,7 @@ func (s *Session) handlePing(hdr header) error {
 			hdr := header(make([]byte, headerSize))
 			hdr.encode(typePing, flagACK, 0, pingID)
 			if err := s.sendNoWait(hdr); err != nil {
-				s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
+				s.logger.Warnf("stream %s failed to send ping reply: %v", hdr.StreamID(), err)
 			}
 		}()
 		return nil
@@ -655,20 +662,20 @@ func (s *Session) handlePing(hdr header) error {
 	return nil
 }
 
-// handleGoAway is invokde for a typeGoAway frame
+// handleGoAway is invoked for a typeGoAway frame
 func (s *Session) handleGoAway(hdr header) error {
 	code := hdr.Length()
 	switch code {
 	case goAwayNormal:
 		atomic.SwapInt32(&s.remoteGoAway, 1)
 	case goAwayProtoErr:
-		s.logger.Printf("[ERR] yamux: received protocol error go away")
+		s.logger.Errorf("received protocol error go away")
 		return fmt.Errorf("yamux protocol error")
 	case goAwayInternalErr:
-		s.logger.Printf("[ERR] yamux: received internal error go away")
+		s.logger.Errorf("received internal error go away")
 		return fmt.Errorf("remote yamux internal error")
 	default:
-		s.logger.Printf("[ERR] yamux: received unexpected go away")
+		s.logger.Errorf("received unexpected go away")
 		return fmt.Errorf("unexpected go away received")
 	}
 	return nil
@@ -691,9 +698,9 @@ func (s *Session) incomingStream(id uint32) error {
 
 	// Check if stream already exists
 	if _, ok := s.streams[id]; ok {
-		s.logger.Printf("[ERR] yamux: duplicate stream declared")
+		s.logger.Errorf("duplicate stream declared")
 		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
-			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+			s.logger.Warnf("failed to send go away: %v", sendErr)
 		}
 		return ErrDuplicateStream
 	}
@@ -707,7 +714,7 @@ func (s *Session) incomingStream(id uint32) error {
 		return nil
 	default:
 		// Backlog exceeded! RST the stream
-		s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
+		s.logger.Warnf("backlog exceeded, forcing connection reset")
 		delete(s.streams, id)
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typeWindowUpdate, flagRST, id, 0)
@@ -724,7 +731,7 @@ func (s *Session) closeStream(id uint32) {
 		select {
 		case <-s.synCh:
 		default:
-			s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
+			s.logger.Errorf("SYN tracking out of sync")
 		}
 	}
 	delete(s.streams, id)
@@ -738,12 +745,12 @@ func (s *Session) establishStream(id uint32) {
 	if _, ok := s.inflight[id]; ok {
 		delete(s.inflight, id)
 	} else {
-		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
+		s.logger.Errorf("established stream without inflight SYN (no tracking entry)")
 	}
 	select {
 	case <-s.synCh:
 	default:
-		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
+		s.logger.Errorf("established stream without inflight SYN (didn't have semaphore)")
 	}
 	s.streamLock.Unlock()
 }

+ 1 - 12
session_test.go

@@ -6,7 +6,6 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
-	"log"
 	"net"
 	"reflect"
 	"runtime"
@@ -28,7 +27,7 @@ func (l *logCapture) match(expect []string) bool {
 
 func captureLogs(s *Session) *logCapture {
 	buf := new(logCapture)
-	s.logger = log.New(buf, "", 0)
+	s.logger = &discordLogger{}
 	return buf
 }
 
@@ -285,8 +284,6 @@ func TestOpenStreamTimeout(t *testing.T) {
 	defer client.Close()
 	defer server.Close()
 
-	clientLogs := captureLogs(client)
-
 	// Open a single stream without a server to acknowledge it.
 	s, err := client.OpenStream()
 	if err != nil {
@@ -297,9 +294,6 @@ func TestOpenStreamTimeout(t *testing.T) {
 	// Since no ACKs are received, the stream and session should be closed.
 	time.Sleep(timeout * 5)
 
-	if !clientLogs.match([]string{"[ERR] yamux: aborted stream open (destination=yamux:remote): i/o deadline reached"}) {
-		t.Fatalf("server log incorect: %v", clientLogs.logs())
-	}
 	if s.state != streamClosed {
 		t.Fatalf("stream should have been closed")
 	}
@@ -1083,7 +1077,6 @@ func TestKeepAlive_Timeout(t *testing.T) {
 	defer server.Close()
 
 	_ = captureLogs(client) // Client logs aren't part of the test
-	serverLogs := captureLogs(server)
 
 	errCh := make(chan error, 1)
 	go func() {
@@ -1109,10 +1102,6 @@ func TestKeepAlive_Timeout(t *testing.T) {
 	if !server.IsClosed() {
 		t.Fatalf("server should have closed")
 	}
-
-	if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
-		t.Fatalf("server log incorect: %v", serverLogs.logs())
-	}
 }
 
 func TestLargeWindow(t *testing.T) {

+ 28 - 10
stream.go

@@ -431,7 +431,7 @@ func (s *Stream) processFlags(flags uint16) error {
 			closeStream = true
 			s.notifyWaiting()
 		default:
-			s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
+			s.session.logger.Errorf("unexpected FIN flag in state %d", s.state)
 			return ErrUnexpectedFlag
 		}
 	}
@@ -463,17 +463,20 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
 }
 
 // readData is used to handle a data frame
-func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
-	if err := s.processFlags(flags); err != nil {
-		return err
+func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) (err error) {
+	var (
+		nr           int
+		buf          []byte
+		copiedLength int
+	)
+	if err = s.processFlags(flags); err != nil {
+		return
 	}
-
 	// Check that our recv window is not exceeded
 	length := hdr.Length()
 	if length == 0 {
 		return nil
 	}
-
 	// Wrap in a limited reader
 	conn = &io.LimitedReader{R: conn, N: int64(length)}
 
@@ -481,7 +484,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	s.recvLock.Lock()
 
 	if length > s.recvWindow {
-		s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
+		s.session.logger.Errorf("receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
 		s.recvLock.Unlock()
 		return ErrRecvWindowExceeded
 	}
@@ -491,9 +494,24 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 		// This way we can read in the whole packet without further allocations.
 		s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
 	}
-	copiedLength, err := io.Copy(s.recvBuf, conn)
-	if err != nil {
-		s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
+	buf = getBytes(int(length))
+	defer putBytes(buf)
+
+	if nr, err = io.ReadFull(conn, buf); err != nil {
+		s.session.logger.Errorf("failed to read stream %d data: %v", s.id, err)
+		return err
+	}
+	if uint32(nr) != length {
+		return io.ErrShortBuffer
+	}
+	if s.session.config.Crypto != nil {
+		if buf, err = s.session.config.Crypto.Decrypt(buf); err != nil {
+			s.session.logger.Errorf("failed to decrypt stream %d data: %v", s.id, err)
+			return err
+		}
+	}
+	if copiedLength, err = s.recvBuf.Write(buf); err != nil {
+		s.session.logger.Errorf("failed to read stream %d data: %v", s.id, err)
 		s.recvLock.Unlock()
 		return err
 	}

+ 19 - 3
util.go

@@ -7,9 +7,25 @@ import (
 
 // Logger is a abstract of *log.Logger
 type Logger interface {
-	Print(v ...interface{})
-	Printf(format string, v ...interface{})
-	Println(v ...interface{})
+	Debugf(format string, args ...interface{})
+	Infof(format string, args ...interface{})
+	Warnf(format string, args ...interface{})
+	Errorf(format string, args ...interface{})
+}
+
+type discordLogger struct {
+}
+
+func (lg *discordLogger) Debugf(format string, args ...interface{}) {
+}
+
+func (lg *discordLogger) Infof(format string, args ...interface{}) {
+}
+
+func (lg *discordLogger) Warnf(format string, args ...interface{}) {
+}
+
+func (lg *discordLogger) Errorf(format string, args ...interface{}) {
 }
 
 var (