session.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "math"
  9. "net"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // Session is used to wrap a reliable ordered connection and to
  16. // multiplex it into multiple streams.
  17. type Session struct {
  18. // remoteGoAway indicates the remote side does
  19. // not want futher connections. Must be first for alignment.
  20. remoteGoAway int32
  21. // localGoAway indicates that we should stop
  22. // accepting futher connections. Must be first for alignment.
  23. localGoAway int32
  24. // nextStreamID is the next stream we should
  25. // send. This depends if we are a client/server.
  26. nextStreamID uint32
  27. // config holds our configuration
  28. config *Config
  29. // logger is used for our logs
  30. logger *log.Logger
  31. // conn is the underlying connection
  32. conn io.ReadWriteCloser
  33. // bufRead is a buffered reader
  34. bufRead *bufio.Reader
  35. // pings is used to track inflight pings
  36. pings map[uint32]chan struct{}
  37. pingID uint32
  38. pingLock sync.Mutex
  39. // streams maps a stream id to a stream, and inflight has an entry
  40. // for any outgoing stream that has not yet been established. Both are
  41. // protected by streamLock.
  42. streams map[uint32]*Stream
  43. inflight map[uint32]struct{}
  44. streamLock sync.Mutex
  45. // synCh acts like a semaphore. It is sized to the AcceptBacklog which
  46. // is assumed to be symmetric between the client and server. This allows
  47. // the client to avoid exceeding the backlog and instead blocks the open.
  48. synCh chan struct{}
  49. // acceptCh is used to pass ready streams to the client
  50. acceptCh chan *Stream
  51. // sendCh is used to mark a stream as ready to send,
  52. // or to send a header out directly.
  53. sendCh chan sendReady
  54. // recvDoneCh is closed when recv() exits to avoid a race
  55. // between stream registration and stream shutdown
  56. recvDoneCh chan struct{}
  57. // shutdown is used to safely close a session
  58. shutdown bool
  59. shutdownErr error
  60. shutdownCh chan struct{}
  61. shutdownLock sync.Mutex
  62. }
  63. // sendReady is used to either mark a stream as ready
  64. // or to directly send a header
  65. type sendReady struct {
  66. Hdr []byte
  67. Body []byte
  68. Err chan error
  69. }
  70. // newSession is used to construct a new session
  71. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  72. logger := config.Logger
  73. if logger == nil {
  74. logger = log.New(config.LogOutput, "", log.LstdFlags)
  75. }
  76. s := &Session{
  77. config: config,
  78. logger: logger,
  79. conn: conn,
  80. bufRead: bufio.NewReader(conn),
  81. pings: make(map[uint32]chan struct{}),
  82. streams: make(map[uint32]*Stream),
  83. inflight: make(map[uint32]struct{}),
  84. synCh: make(chan struct{}, config.AcceptBacklog),
  85. acceptCh: make(chan *Stream, config.AcceptBacklog),
  86. sendCh: make(chan sendReady, 64),
  87. recvDoneCh: make(chan struct{}),
  88. shutdownCh: make(chan struct{}),
  89. }
  90. if client {
  91. s.nextStreamID = 1
  92. } else {
  93. s.nextStreamID = 2
  94. }
  95. go s.recv()
  96. go s.send()
  97. if config.EnableKeepAlive {
  98. go s.keepalive()
  99. }
  100. return s
  101. }
  102. // IsClosed does a safe check to see if we have shutdown
  103. func (s *Session) IsClosed() bool {
  104. select {
  105. case <-s.shutdownCh:
  106. return true
  107. default:
  108. return false
  109. }
  110. }
  111. // CloseChan returns a read-only channel which is closed as
  112. // soon as the session is closed.
  113. func (s *Session) CloseChan() <-chan struct{} {
  114. return s.shutdownCh
  115. }
  116. // NumStreams returns the number of currently open streams
  117. func (s *Session) NumStreams() int {
  118. s.streamLock.Lock()
  119. num := len(s.streams)
  120. s.streamLock.Unlock()
  121. return num
  122. }
  123. // Open is used to create a new stream as a net.Conn
  124. func (s *Session) Open() (net.Conn, error) {
  125. conn, err := s.OpenStream()
  126. if err != nil {
  127. return nil, err
  128. }
  129. return conn, nil
  130. }
  131. // OpenStream is used to create a new stream
  132. func (s *Session) OpenStream() (*Stream, error) {
  133. if s.IsClosed() {
  134. return nil, ErrSessionShutdown
  135. }
  136. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  137. return nil, ErrRemoteGoAway
  138. }
  139. // Block if we have too many inflight SYNs
  140. select {
  141. case s.synCh <- struct{}{}:
  142. case <-s.shutdownCh:
  143. return nil, ErrSessionShutdown
  144. }
  145. GET_ID:
  146. // Get an ID, and check for stream exhaustion
  147. id := atomic.LoadUint32(&s.nextStreamID)
  148. if id >= math.MaxUint32-1 {
  149. return nil, ErrStreamsExhausted
  150. }
  151. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  152. goto GET_ID
  153. }
  154. // Register the stream
  155. stream := newStream(s, id, streamInit)
  156. s.streamLock.Lock()
  157. s.streams[id] = stream
  158. s.inflight[id] = struct{}{}
  159. s.streamLock.Unlock()
  160. if s.config.StreamOpenTimeout > 0 {
  161. go s.setOpenTimeout(stream)
  162. }
  163. // Send the window update to create
  164. if err := stream.sendWindowUpdate(); err != nil {
  165. select {
  166. case <-s.synCh:
  167. default:
  168. s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
  169. }
  170. return nil, err
  171. }
  172. return stream, nil
  173. }
  174. // setOpenTimeout implements a timeout for streams that are opened but not established.
  175. // If the StreamOpenTimeout is exceeded we assume the peer is unable to ACK,
  176. // and close the session.
  177. // The number of running timers is bounded by the capacity of the synCh.
  178. func (s *Session) setOpenTimeout(stream *Stream) {
  179. timer := time.NewTimer(s.config.StreamOpenTimeout)
  180. defer timer.Stop()
  181. select {
  182. case <-stream.establishCh:
  183. return
  184. case <-s.shutdownCh:
  185. return
  186. case <-timer.C:
  187. // Timeout reached while waiting for ACK.
  188. // Close the session to force connection re-establishment.
  189. s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
  190. s.Close()
  191. }
  192. }
  193. // Accept is used to block until the next available stream
  194. // is ready to be accepted.
  195. func (s *Session) Accept() (net.Conn, error) {
  196. conn, err := s.AcceptStream()
  197. if err != nil {
  198. return nil, err
  199. }
  200. return conn, err
  201. }
  202. // AcceptStream is used to block until the next available stream
  203. // is ready to be accepted.
  204. func (s *Session) AcceptStream() (*Stream, error) {
  205. select {
  206. case stream := <-s.acceptCh:
  207. if err := stream.sendWindowUpdate(); err != nil {
  208. return nil, err
  209. }
  210. return stream, nil
  211. case <-s.shutdownCh:
  212. return nil, s.shutdownErr
  213. }
  214. }
  215. // Close is used to close the session and all streams.
  216. // Attempts to send a GoAway before closing the connection.
  217. func (s *Session) Close() error {
  218. s.shutdownLock.Lock()
  219. defer s.shutdownLock.Unlock()
  220. if s.shutdown {
  221. return nil
  222. }
  223. s.shutdown = true
  224. if s.shutdownErr == nil {
  225. s.shutdownErr = ErrSessionShutdown
  226. }
  227. close(s.shutdownCh)
  228. s.conn.Close()
  229. <-s.recvDoneCh
  230. s.streamLock.Lock()
  231. defer s.streamLock.Unlock()
  232. for _, stream := range s.streams {
  233. stream.forceClose()
  234. }
  235. return nil
  236. }
  237. // exitErr is used to handle an error that is causing the
  238. // session to terminate.
  239. func (s *Session) exitErr(err error) {
  240. s.shutdownLock.Lock()
  241. if s.shutdownErr == nil {
  242. s.shutdownErr = err
  243. }
  244. s.shutdownLock.Unlock()
  245. s.Close()
  246. }
  247. // GoAway can be used to prevent accepting further
  248. // connections. It does not close the underlying conn.
  249. func (s *Session) GoAway() error {
  250. return s.waitForSend(s.goAway(goAwayNormal), nil)
  251. }
  252. // goAway is used to send a goAway message
  253. func (s *Session) goAway(reason uint32) header {
  254. atomic.SwapInt32(&s.localGoAway, 1)
  255. hdr := header(make([]byte, headerSize))
  256. hdr.encode(typeGoAway, 0, 0, reason)
  257. return hdr
  258. }
  259. // Ping is used to measure the RTT response time
  260. func (s *Session) Ping() (time.Duration, error) {
  261. // Get a channel for the ping
  262. ch := make(chan struct{})
  263. // Get a new ping id, mark as pending
  264. s.pingLock.Lock()
  265. id := s.pingID
  266. s.pingID++
  267. s.pings[id] = ch
  268. s.pingLock.Unlock()
  269. // Send the ping request
  270. hdr := header(make([]byte, headerSize))
  271. hdr.encode(typePing, flagSYN, 0, id)
  272. if err := s.waitForSend(hdr, nil); err != nil {
  273. return 0, err
  274. }
  275. // Wait for a response
  276. start := time.Now()
  277. select {
  278. case <-ch:
  279. case <-time.After(s.config.ConnectionWriteTimeout):
  280. s.pingLock.Lock()
  281. delete(s.pings, id) // Ignore it if a response comes later.
  282. s.pingLock.Unlock()
  283. return 0, ErrTimeout
  284. case <-s.shutdownCh:
  285. return 0, ErrSessionShutdown
  286. }
  287. // Compute the RTT
  288. return time.Now().Sub(start), nil
  289. }
  290. // keepalive is a long running goroutine that periodically does
  291. // a ping to keep the connection alive.
  292. func (s *Session) keepalive() {
  293. for {
  294. select {
  295. case <-time.After(s.config.KeepAliveInterval):
  296. _, err := s.Ping()
  297. if err != nil {
  298. if err != ErrSessionShutdown {
  299. s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
  300. s.exitErr(ErrKeepAliveTimeout)
  301. }
  302. return
  303. }
  304. case <-s.shutdownCh:
  305. return
  306. }
  307. }
  308. }
  309. // waitForSendErr waits to send a header, checking for a potential shutdown
  310. func (s *Session) waitForSend(hdr header, body []byte) error {
  311. errCh := make(chan error, 1)
  312. return s.waitForSendErr(hdr, body, errCh)
  313. }
  314. // waitForSendErr waits to send a header with optional data, checking for a
  315. // potential shutdown. Since there's the expectation that sends can happen
  316. // in a timely manner, we enforce the connection write timeout here.
  317. func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) error {
  318. t := timerPool.Get()
  319. timer := t.(*time.Timer)
  320. timer.Reset(s.config.ConnectionWriteTimeout)
  321. defer func() {
  322. timer.Stop()
  323. select {
  324. case <-timer.C:
  325. default:
  326. }
  327. timerPool.Put(t)
  328. }()
  329. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  330. select {
  331. case s.sendCh <- ready:
  332. case <-s.shutdownCh:
  333. return ErrSessionShutdown
  334. case <-timer.C:
  335. return ErrConnectionWriteTimeout
  336. }
  337. select {
  338. case err := <-errCh:
  339. return err
  340. case <-s.shutdownCh:
  341. return ErrSessionShutdown
  342. case <-timer.C:
  343. return ErrConnectionWriteTimeout
  344. }
  345. }
  346. // sendNoWait does a send without waiting. Since there's the expectation that
  347. // the send happens right here, we enforce the connection write timeout if we
  348. // can't queue the header to be sent.
  349. func (s *Session) sendNoWait(hdr header) error {
  350. t := timerPool.Get()
  351. timer := t.(*time.Timer)
  352. timer.Reset(s.config.ConnectionWriteTimeout)
  353. defer func() {
  354. timer.Stop()
  355. select {
  356. case <-timer.C:
  357. default:
  358. }
  359. timerPool.Put(t)
  360. }()
  361. select {
  362. case s.sendCh <- sendReady{Hdr: hdr}:
  363. return nil
  364. case <-s.shutdownCh:
  365. return ErrSessionShutdown
  366. case <-timer.C:
  367. return ErrConnectionWriteTimeout
  368. }
  369. }
  370. // send is a long running goroutine that sends data
  371. func (s *Session) send() {
  372. for {
  373. select {
  374. case ready := <-s.sendCh:
  375. // Send a header if ready
  376. if ready.Hdr != nil {
  377. sent := 0
  378. for sent < len(ready.Hdr) {
  379. n, err := s.conn.Write(ready.Hdr[sent:])
  380. if err != nil {
  381. s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
  382. asyncSendErr(ready.Err, err)
  383. s.exitErr(err)
  384. return
  385. }
  386. sent += n
  387. }
  388. }
  389. // Send data from a body if given
  390. if ready.Body != nil {
  391. _, err := s.conn.Write(ready.Body)
  392. if err != nil {
  393. s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
  394. asyncSendErr(ready.Err, err)
  395. s.exitErr(err)
  396. return
  397. }
  398. }
  399. // No error, successful send
  400. asyncSendErr(ready.Err, nil)
  401. case <-s.shutdownCh:
  402. return
  403. }
  404. }
  405. }
  406. // recv is a long running goroutine that accepts new data
  407. func (s *Session) recv() {
  408. if err := s.recvLoop(); err != nil {
  409. s.exitErr(err)
  410. }
  411. }
  412. // Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
  413. var (
  414. handlers = []func(*Session, header) error{
  415. typeData: (*Session).handleStreamMessage,
  416. typeWindowUpdate: (*Session).handleStreamMessage,
  417. typePing: (*Session).handlePing,
  418. typeGoAway: (*Session).handleGoAway,
  419. }
  420. )
  421. // recvLoop continues to receive data until a fatal error is encountered
  422. func (s *Session) recvLoop() error {
  423. defer close(s.recvDoneCh)
  424. hdr := header(make([]byte, headerSize))
  425. for {
  426. // Read the header
  427. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  428. if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
  429. s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
  430. }
  431. return err
  432. }
  433. // Verify the version
  434. if hdr.Version() != protoVersion {
  435. s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
  436. return ErrInvalidVersion
  437. }
  438. mt := hdr.MsgType()
  439. if mt < typeData || mt > typeGoAway {
  440. return ErrInvalidMsgType
  441. }
  442. if err := handlers[mt](s, hdr); err != nil {
  443. return err
  444. }
  445. }
  446. }
  447. // handleStreamMessage handles either a data or window update frame
  448. func (s *Session) handleStreamMessage(hdr header) error {
  449. // Check for a new stream creation
  450. id := hdr.StreamID()
  451. flags := hdr.Flags()
  452. if flags&flagSYN == flagSYN {
  453. if err := s.incomingStream(id); err != nil {
  454. return err
  455. }
  456. }
  457. // Get the stream
  458. s.streamLock.Lock()
  459. stream := s.streams[id]
  460. s.streamLock.Unlock()
  461. // If we do not have a stream, likely we sent a RST
  462. if stream == nil {
  463. // Drain any data on the wire
  464. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  465. s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
  466. if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  467. s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
  468. return nil
  469. }
  470. } else {
  471. s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
  472. }
  473. return nil
  474. }
  475. // Check if this is a window update
  476. if hdr.MsgType() == typeWindowUpdate {
  477. if err := stream.incrSendWindow(hdr, flags); err != nil {
  478. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  479. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  480. }
  481. return err
  482. }
  483. return nil
  484. }
  485. // Read the new data
  486. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  487. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  488. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  489. }
  490. return err
  491. }
  492. return nil
  493. }
  494. // handlePing is invokde for a typePing frame
  495. func (s *Session) handlePing(hdr header) error {
  496. flags := hdr.Flags()
  497. pingID := hdr.Length()
  498. // Check if this is a query, respond back in a separate context so we
  499. // don't interfere with the receiving thread blocking for the write.
  500. if flags&flagSYN == flagSYN {
  501. go func() {
  502. hdr := header(make([]byte, headerSize))
  503. hdr.encode(typePing, flagACK, 0, pingID)
  504. if err := s.sendNoWait(hdr); err != nil {
  505. s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
  506. }
  507. }()
  508. return nil
  509. }
  510. // Handle a response
  511. s.pingLock.Lock()
  512. ch := s.pings[pingID]
  513. if ch != nil {
  514. delete(s.pings, pingID)
  515. close(ch)
  516. }
  517. s.pingLock.Unlock()
  518. return nil
  519. }
  520. // handleGoAway is invokde for a typeGoAway frame
  521. func (s *Session) handleGoAway(hdr header) error {
  522. code := hdr.Length()
  523. switch code {
  524. case goAwayNormal:
  525. atomic.SwapInt32(&s.remoteGoAway, 1)
  526. case goAwayProtoErr:
  527. s.logger.Printf("[ERR] yamux: received protocol error go away")
  528. return fmt.Errorf("yamux protocol error")
  529. case goAwayInternalErr:
  530. s.logger.Printf("[ERR] yamux: received internal error go away")
  531. return fmt.Errorf("remote yamux internal error")
  532. default:
  533. s.logger.Printf("[ERR] yamux: received unexpected go away")
  534. return fmt.Errorf("unexpected go away received")
  535. }
  536. return nil
  537. }
  538. // incomingStream is used to create a new incoming stream
  539. func (s *Session) incomingStream(id uint32) error {
  540. // Reject immediately if we are doing a go away
  541. if atomic.LoadInt32(&s.localGoAway) == 1 {
  542. hdr := header(make([]byte, headerSize))
  543. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  544. return s.sendNoWait(hdr)
  545. }
  546. // Allocate a new stream
  547. stream := newStream(s, id, streamSYNReceived)
  548. s.streamLock.Lock()
  549. defer s.streamLock.Unlock()
  550. // Check if stream already exists
  551. if _, ok := s.streams[id]; ok {
  552. s.logger.Printf("[ERR] yamux: duplicate stream declared")
  553. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  554. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  555. }
  556. return ErrDuplicateStream
  557. }
  558. // Register the stream
  559. s.streams[id] = stream
  560. // Check if we've exceeded the backlog
  561. select {
  562. case s.acceptCh <- stream:
  563. return nil
  564. default:
  565. // Backlog exceeded! RST the stream
  566. s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
  567. delete(s.streams, id)
  568. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  569. return s.sendNoWait(stream.sendHdr)
  570. }
  571. }
  572. // closeStream is used to close a stream once both sides have
  573. // issued a close. If there was an in-flight SYN and the stream
  574. // was not yet established, then this will give the credit back.
  575. func (s *Session) closeStream(id uint32) {
  576. s.streamLock.Lock()
  577. if _, ok := s.inflight[id]; ok {
  578. select {
  579. case <-s.synCh:
  580. default:
  581. s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
  582. }
  583. }
  584. delete(s.streams, id)
  585. s.streamLock.Unlock()
  586. }
  587. // establishStream is used to mark a stream that was in the
  588. // SYN Sent state as established.
  589. func (s *Session) establishStream(id uint32) {
  590. s.streamLock.Lock()
  591. if _, ok := s.inflight[id]; ok {
  592. delete(s.inflight, id)
  593. } else {
  594. s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
  595. }
  596. select {
  597. case <-s.synCh:
  598. default:
  599. s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
  600. }
  601. s.streamLock.Unlock()
  602. }