|
@@ -6,8 +6,6 @@ import (
|
|
"context"
|
|
"context"
|
|
"fmt"
|
|
"fmt"
|
|
"io"
|
|
"io"
|
|
- "io/ioutil"
|
|
|
|
- "log"
|
|
|
|
"math"
|
|
"math"
|
|
"net"
|
|
"net"
|
|
"strings"
|
|
"strings"
|
|
@@ -93,7 +91,7 @@ type sendReady struct {
|
|
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
logger := config.Logger
|
|
logger := config.Logger
|
|
if logger == nil {
|
|
if logger == nil {
|
|
- logger = log.New(config.LogOutput, "", log.LstdFlags)
|
|
|
|
|
|
+ logger = &discordLogger{}
|
|
}
|
|
}
|
|
|
|
|
|
s := &Session{
|
|
s := &Session{
|
|
@@ -199,7 +197,7 @@ GET_ID:
|
|
select {
|
|
select {
|
|
case <-s.synCh:
|
|
case <-s.synCh:
|
|
default:
|
|
default:
|
|
- s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
|
|
|
|
|
|
+ s.logger.Errorf("aborted stream open without inflight syn semaphore")
|
|
}
|
|
}
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
@@ -222,7 +220,7 @@ func (s *Session) setOpenTimeout(stream *Stream) {
|
|
case <-timer.C:
|
|
case <-timer.C:
|
|
// Timeout reached while waiting for ACK.
|
|
// Timeout reached while waiting for ACK.
|
|
// Close the session to force connection re-establishment.
|
|
// Close the session to force connection re-establishment.
|
|
- s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
|
|
|
|
|
|
+ s.logger.Errorf("aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
|
|
s.Close()
|
|
s.Close()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -368,7 +366,7 @@ func (s *Session) keepalive() {
|
|
_, err := s.Ping()
|
|
_, err := s.Ping()
|
|
if err != nil {
|
|
if err != nil {
|
|
if err != ErrSessionShutdown {
|
|
if err != ErrSessionShutdown {
|
|
- s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
|
|
|
|
|
|
+ s.logger.Errorf("keepalive failed: %v", err)
|
|
s.exitErr(ErrKeepAliveTimeout)
|
|
s.exitErr(ErrKeepAliveTimeout)
|
|
}
|
|
}
|
|
return
|
|
return
|
|
@@ -475,54 +473,63 @@ func (s *Session) send() {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *Session) sendLoop() error {
|
|
|
|
|
|
+func (s *Session) sendPacket(packet *sendReady) (err error) {
|
|
|
|
+ var (
|
|
|
|
+ n int
|
|
|
|
+ nw int
|
|
|
|
+ buf []byte
|
|
|
|
+ buffer *bytes.Buffer
|
|
|
|
+ )
|
|
|
|
+ buffer = getBuffer()
|
|
|
|
+ defer func() {
|
|
|
|
+ putBuffer(buffer)
|
|
|
|
+ }()
|
|
|
|
+ packet.mu.Lock()
|
|
|
|
+ // Send a header if ready
|
|
|
|
+ if packet.Hdr != nil {
|
|
|
|
+ n, _ = buffer.Write(packet.Hdr)
|
|
|
|
+ nw += n
|
|
|
|
+ }
|
|
|
|
+ if packet.Body != nil {
|
|
|
|
+ if s.config.Crypto != nil {
|
|
|
|
+ if buf, err = s.config.Crypto.Encrypt(packet.Body); err != nil {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ buf = packet.Body
|
|
|
|
+ }
|
|
|
|
+ n, _ = buffer.Write(buf)
|
|
|
|
+ nw += n
|
|
|
|
+ packet.Body = nil
|
|
|
|
+ }
|
|
|
|
+ packet.mu.Unlock()
|
|
|
|
+
|
|
|
|
+ if buffer.Len() > 0 {
|
|
|
|
+ // Send data from a body if given
|
|
|
|
+ if n, err = s.conn.Write(buffer.Bytes()); err != nil {
|
|
|
|
+ asyncSendErr(packet.Err, err)
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ if n != nw {
|
|
|
|
+ asyncSendErr(packet.Err, io.ErrShortWrite)
|
|
|
|
+ return io.ErrShortWrite
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // No error, successful send
|
|
|
|
+ asyncSendErr(packet.Err, nil)
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (s *Session) sendLoop() (err error) {
|
|
defer close(s.sendDoneCh)
|
|
defer close(s.sendDoneCh)
|
|
- var bodyBuf bytes.Buffer
|
|
|
|
for {
|
|
for {
|
|
- bodyBuf.Reset()
|
|
|
|
-
|
|
|
|
select {
|
|
select {
|
|
case ready := <-s.sendCh:
|
|
case ready := <-s.sendCh:
|
|
- // Send a header if ready
|
|
|
|
- if ready.Hdr != nil {
|
|
|
|
- _, err := s.conn.Write(ready.Hdr)
|
|
|
|
- if err != nil {
|
|
|
|
- s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
|
|
|
|
- asyncSendErr(ready.Err, err)
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
|
|
+ if err = s.sendPacket(ready); err != nil {
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
-
|
|
|
|
- ready.mu.Lock()
|
|
|
|
- if ready.Body != nil {
|
|
|
|
- // Copy the body into the buffer to avoid
|
|
|
|
- // holding a mutex lock during the write.
|
|
|
|
- _, err := bodyBuf.Write(ready.Body)
|
|
|
|
- if err != nil {
|
|
|
|
- ready.Body = nil
|
|
|
|
- ready.mu.Unlock()
|
|
|
|
- s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err)
|
|
|
|
- asyncSendErr(ready.Err, err)
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- ready.Body = nil
|
|
|
|
- }
|
|
|
|
- ready.mu.Unlock()
|
|
|
|
-
|
|
|
|
- if bodyBuf.Len() > 0 {
|
|
|
|
- // Send data from a body if given
|
|
|
|
- _, err := s.conn.Write(bodyBuf.Bytes())
|
|
|
|
- if err != nil {
|
|
|
|
- s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
|
|
|
|
- asyncSendErr(ready.Err, err)
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // No error, successful send
|
|
|
|
- asyncSendErr(ready.Err, nil)
|
|
|
|
case <-s.shutdownCh:
|
|
case <-s.shutdownCh:
|
|
- return nil
|
|
|
|
|
|
+ return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -552,14 +559,14 @@ func (s *Session) recvLoop() error {
|
|
// Read the header
|
|
// Read the header
|
|
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
|
|
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
|
|
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
|
|
- s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
|
|
|
|
|
|
+ s.logger.Errorf("failed to read stream header: %v", err)
|
|
}
|
|
}
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|
|
// Verify the version
|
|
// Verify the version
|
|
if hdr.Version() != protoVersion {
|
|
if hdr.Version() != protoVersion {
|
|
- s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
|
|
|
|
|
|
+ s.logger.Errorf("invalid stream protocol version: %d", hdr.Version())
|
|
return ErrInvalidVersion
|
|
return ErrInvalidVersion
|
|
}
|
|
}
|
|
|
|
|
|
@@ -594,13 +601,13 @@ func (s *Session) handleStreamMessage(hdr header) error {
|
|
if stream == nil {
|
|
if stream == nil {
|
|
// Drain any data on the wire
|
|
// Drain any data on the wire
|
|
if hdr.MsgType() == typeData && hdr.Length() > 0 {
|
|
if hdr.MsgType() == typeData && hdr.Length() > 0 {
|
|
- s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
|
|
|
|
- if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
|
|
|
|
- s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
|
|
|
|
|
|
+ s.logger.Warnf("discarding data for stream: %d", id)
|
|
|
|
+ if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.Length())); err != nil {
|
|
|
|
+ s.logger.Errorf("failed to discard stream %d data: %v", id, err)
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
- s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
|
|
|
|
|
|
+ s.logger.Warnf("frame for missing stream: %v", hdr)
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
@@ -609,7 +616,7 @@ func (s *Session) handleStreamMessage(hdr header) error {
|
|
if hdr.MsgType() == typeWindowUpdate {
|
|
if hdr.MsgType() == typeWindowUpdate {
|
|
if err := stream.incrSendWindow(hdr, flags); err != nil {
|
|
if err := stream.incrSendWindow(hdr, flags); err != nil {
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
- s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
|
|
|
|
|
|
+ s.logger.Warnf("failed to send go away: %v", sendErr)
|
|
}
|
|
}
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
@@ -619,14 +626,14 @@ func (s *Session) handleStreamMessage(hdr header) error {
|
|
// Read the new data
|
|
// Read the new data
|
|
if err := stream.readData(hdr, flags, s.bufRead); err != nil {
|
|
if err := stream.readData(hdr, flags, s.bufRead); err != nil {
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
- s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
|
|
|
|
|
|
+ s.logger.Warnf("failed to send go away: %v", sendErr)
|
|
}
|
|
}
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-// handlePing is invokde for a typePing frame
|
|
|
|
|
|
+// handlePing is invoked for a typePing frame
|
|
func (s *Session) handlePing(hdr header) error {
|
|
func (s *Session) handlePing(hdr header) error {
|
|
flags := hdr.Flags()
|
|
flags := hdr.Flags()
|
|
pingID := hdr.Length()
|
|
pingID := hdr.Length()
|
|
@@ -638,7 +645,7 @@ func (s *Session) handlePing(hdr header) error {
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr.encode(typePing, flagACK, 0, pingID)
|
|
hdr.encode(typePing, flagACK, 0, pingID)
|
|
if err := s.sendNoWait(hdr); err != nil {
|
|
if err := s.sendNoWait(hdr); err != nil {
|
|
- s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
|
|
|
|
|
|
+ s.logger.Warnf("stream %s failed to send ping reply: %v", hdr.StreamID(), err)
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
return nil
|
|
return nil
|
|
@@ -655,20 +662,20 @@ func (s *Session) handlePing(hdr header) error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-// handleGoAway is invokde for a typeGoAway frame
|
|
|
|
|
|
+// handleGoAway is invoked for a typeGoAway frame
|
|
func (s *Session) handleGoAway(hdr header) error {
|
|
func (s *Session) handleGoAway(hdr header) error {
|
|
code := hdr.Length()
|
|
code := hdr.Length()
|
|
switch code {
|
|
switch code {
|
|
case goAwayNormal:
|
|
case goAwayNormal:
|
|
atomic.SwapInt32(&s.remoteGoAway, 1)
|
|
atomic.SwapInt32(&s.remoteGoAway, 1)
|
|
case goAwayProtoErr:
|
|
case goAwayProtoErr:
|
|
- s.logger.Printf("[ERR] yamux: received protocol error go away")
|
|
|
|
|
|
+ s.logger.Errorf("received protocol error go away")
|
|
return fmt.Errorf("yamux protocol error")
|
|
return fmt.Errorf("yamux protocol error")
|
|
case goAwayInternalErr:
|
|
case goAwayInternalErr:
|
|
- s.logger.Printf("[ERR] yamux: received internal error go away")
|
|
|
|
|
|
+ s.logger.Errorf("received internal error go away")
|
|
return fmt.Errorf("remote yamux internal error")
|
|
return fmt.Errorf("remote yamux internal error")
|
|
default:
|
|
default:
|
|
- s.logger.Printf("[ERR] yamux: received unexpected go away")
|
|
|
|
|
|
+ s.logger.Errorf("received unexpected go away")
|
|
return fmt.Errorf("unexpected go away received")
|
|
return fmt.Errorf("unexpected go away received")
|
|
}
|
|
}
|
|
return nil
|
|
return nil
|
|
@@ -691,9 +698,9 @@ func (s *Session) incomingStream(id uint32) error {
|
|
|
|
|
|
// Check if stream already exists
|
|
// Check if stream already exists
|
|
if _, ok := s.streams[id]; ok {
|
|
if _, ok := s.streams[id]; ok {
|
|
- s.logger.Printf("[ERR] yamux: duplicate stream declared")
|
|
|
|
|
|
+ s.logger.Errorf("duplicate stream declared")
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
|
|
- s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
|
|
|
|
|
|
+ s.logger.Warnf("failed to send go away: %v", sendErr)
|
|
}
|
|
}
|
|
return ErrDuplicateStream
|
|
return ErrDuplicateStream
|
|
}
|
|
}
|
|
@@ -707,7 +714,7 @@ func (s *Session) incomingStream(id uint32) error {
|
|
return nil
|
|
return nil
|
|
default:
|
|
default:
|
|
// Backlog exceeded! RST the stream
|
|
// Backlog exceeded! RST the stream
|
|
- s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
|
|
|
|
|
|
+ s.logger.Warnf("backlog exceeded, forcing connection reset")
|
|
delete(s.streams, id)
|
|
delete(s.streams, id)
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr := header(make([]byte, headerSize))
|
|
hdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
hdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
@@ -724,7 +731,7 @@ func (s *Session) closeStream(id uint32) {
|
|
select {
|
|
select {
|
|
case <-s.synCh:
|
|
case <-s.synCh:
|
|
default:
|
|
default:
|
|
- s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
|
|
|
|
|
|
+ s.logger.Errorf("SYN tracking out of sync")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
delete(s.streams, id)
|
|
delete(s.streams, id)
|
|
@@ -738,12 +745,12 @@ func (s *Session) establishStream(id uint32) {
|
|
if _, ok := s.inflight[id]; ok {
|
|
if _, ok := s.inflight[id]; ok {
|
|
delete(s.inflight, id)
|
|
delete(s.inflight, id)
|
|
} else {
|
|
} else {
|
|
- s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
|
|
|
|
|
|
+ s.logger.Errorf("established stream without inflight SYN (no tracking entry)")
|
|
}
|
|
}
|
|
select {
|
|
select {
|
|
case <-s.synCh:
|
|
case <-s.synCh:
|
|
default:
|
|
default:
|
|
- s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
|
|
|
|
|
|
+ s.logger.Errorf("established stream without inflight SYN (didn't have semaphore)")
|
|
}
|
|
}
|
|
s.streamLock.Unlock()
|
|
s.streamLock.Unlock()
|
|
}
|
|
}
|