session.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. package yamux
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // Session is used to wrap a reliable ordered connection and to
  16. // multiplex it into multiple streams.
  17. type Session struct {
  18. // remoteGoAway indicates the remote side does
  19. // not want futher connections. Must be first for alignment.
  20. remoteGoAway int32
  21. // localGoAway indicates that we should stop
  22. // accepting futher connections. Must be first for alignment.
  23. localGoAway int32
  24. // nextStreamID is the next stream we should
  25. // send. This depends if we are a client/server.
  26. nextStreamID uint32
  27. // config holds our configuration
  28. config *Config
  29. // logger is used for our logs
  30. logger Logger
  31. // conn is the underlying connection
  32. conn io.ReadWriteCloser
  33. // bufRead is a buffered reader
  34. bufRead *bufio.Reader
  35. // pings is used to track inflight pings
  36. pings map[uint32]chan struct{}
  37. pingID uint32
  38. pingLock sync.Mutex
  39. // streams maps a stream id to a stream, and inflight has an entry
  40. // for any outgoing stream that has not yet been established. Both are
  41. // protected by streamLock.
  42. streams map[uint32]*Stream
  43. inflight map[uint32]struct{}
  44. streamLock sync.Mutex
  45. // synCh acts like a semaphore. It is sized to the AcceptBacklog which
  46. // is assumed to be symmetric between the client and server. This allows
  47. // the client to avoid exceeding the backlog and instead blocks the open.
  48. synCh chan struct{}
  49. // acceptCh is used to pass ready streams to the client
  50. acceptCh chan *Stream
  51. // sendCh is used to mark a stream as ready to send,
  52. // or to send a header out directly.
  53. sendCh chan *sendReady
  54. // recvDoneCh is closed when recv() exits to avoid a race
  55. // between stream registration and stream shutdown
  56. recvDoneCh chan struct{}
  57. sendDoneCh chan struct{}
  58. // shutdown is used to safely close a session
  59. shutdown bool
  60. shutdownErr error
  61. shutdownCh chan struct{}
  62. shutdownLock sync.Mutex
  63. shutdownErrLock sync.Mutex
  64. }
  65. // sendReady is used to either mark a stream as ready
  66. // or to directly send a header
  67. type sendReady struct {
  68. Hdr []byte
  69. mu sync.Mutex // Protects Body from unsafe reads.
  70. Body []byte
  71. Err chan error
  72. }
  73. // newSession is used to construct a new session
  74. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  75. logger := config.Logger
  76. if logger == nil {
  77. logger = &discordLogger{}
  78. }
  79. s := &Session{
  80. config: config,
  81. logger: logger,
  82. conn: conn,
  83. bufRead: bufio.NewReader(conn),
  84. pings: make(map[uint32]chan struct{}),
  85. streams: make(map[uint32]*Stream),
  86. inflight: make(map[uint32]struct{}),
  87. synCh: make(chan struct{}, config.AcceptBacklog),
  88. acceptCh: make(chan *Stream, config.AcceptBacklog),
  89. sendCh: make(chan *sendReady, 64),
  90. recvDoneCh: make(chan struct{}),
  91. sendDoneCh: make(chan struct{}),
  92. shutdownCh: make(chan struct{}),
  93. }
  94. if client {
  95. s.nextStreamID = 1
  96. } else {
  97. s.nextStreamID = 2
  98. }
  99. go s.recv()
  100. go s.send()
  101. if config.EnableKeepAlive {
  102. go s.keepalive()
  103. }
  104. return s
  105. }
  106. // IsClosed does a safe check to see if we have shutdown
  107. func (s *Session) IsClosed() bool {
  108. select {
  109. case <-s.shutdownCh:
  110. return true
  111. default:
  112. return false
  113. }
  114. }
  115. // CloseChan returns a read-only channel which is closed as
  116. // soon as the session is closed.
  117. func (s *Session) CloseChan() <-chan struct{} {
  118. return s.shutdownCh
  119. }
  120. // NumStreams returns the number of currently open streams
  121. func (s *Session) NumStreams() int {
  122. s.streamLock.Lock()
  123. num := len(s.streams)
  124. s.streamLock.Unlock()
  125. return num
  126. }
  127. // Open is used to create a new stream as a net.Conn
  128. func (s *Session) Open() (net.Conn, error) {
  129. conn, err := s.OpenStream()
  130. if err != nil {
  131. return nil, err
  132. }
  133. return conn, nil
  134. }
  135. // OpenStream is used to create a new stream
  136. func (s *Session) OpenStream() (*Stream, error) {
  137. if s.IsClosed() {
  138. return nil, ErrSessionShutdown
  139. }
  140. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  141. return nil, ErrRemoteGoAway
  142. }
  143. // Block if we have too many inflight SYNs
  144. select {
  145. case s.synCh <- struct{}{}:
  146. case <-s.shutdownCh:
  147. return nil, ErrSessionShutdown
  148. }
  149. GET_ID:
  150. // Get an ID, and check for stream exhaustion
  151. id := atomic.LoadUint32(&s.nextStreamID)
  152. if id >= math.MaxUint32-1 {
  153. return nil, ErrStreamsExhausted
  154. }
  155. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  156. goto GET_ID
  157. }
  158. // Register the stream
  159. stream := newStream(s, id, streamInit)
  160. s.streamLock.Lock()
  161. s.streams[id] = stream
  162. s.inflight[id] = struct{}{}
  163. s.streamLock.Unlock()
  164. if s.config.StreamOpenTimeout > 0 {
  165. go s.setOpenTimeout(stream)
  166. }
  167. // Send the window update to create
  168. if err := stream.sendWindowUpdate(); err != nil {
  169. select {
  170. case <-s.synCh:
  171. default:
  172. s.logger.Errorf("aborted stream open without inflight syn semaphore")
  173. }
  174. return nil, err
  175. }
  176. return stream, nil
  177. }
  178. // setOpenTimeout implements a timeout for streams that are opened but not established.
  179. // If the StreamOpenTimeout is exceeded we assume the peer is unable to ACK,
  180. // and close the session.
  181. // The number of running timers is bounded by the capacity of the synCh.
  182. func (s *Session) setOpenTimeout(stream *Stream) {
  183. timer := time.NewTimer(s.config.StreamOpenTimeout)
  184. defer timer.Stop()
  185. select {
  186. case <-stream.establishCh:
  187. return
  188. case <-s.shutdownCh:
  189. return
  190. case <-timer.C:
  191. // Timeout reached while waiting for ACK.
  192. // Close the session to force connection re-establishment.
  193. s.logger.Errorf("aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err)
  194. s.Close()
  195. }
  196. }
  197. // Accept is used to block until the next available stream
  198. // is ready to be accepted.
  199. func (s *Session) Accept() (net.Conn, error) {
  200. conn, err := s.AcceptStream()
  201. if err != nil {
  202. return nil, err
  203. }
  204. return conn, err
  205. }
  206. // AcceptStream is used to block until the next available stream
  207. // is ready to be accepted.
  208. func (s *Session) AcceptStream() (*Stream, error) {
  209. select {
  210. case stream := <-s.acceptCh:
  211. if err := stream.sendWindowUpdate(); err != nil {
  212. return nil, err
  213. }
  214. return stream, nil
  215. case <-s.shutdownCh:
  216. return nil, s.shutdownErr
  217. }
  218. }
  219. // AcceptStream is used to block until the next available stream
  220. // is ready to be accepted.
  221. func (s *Session) AcceptStreamWithContext(ctx context.Context) (*Stream, error) {
  222. select {
  223. case <-ctx.Done():
  224. return nil, ctx.Err()
  225. case stream := <-s.acceptCh:
  226. if err := stream.sendWindowUpdate(); err != nil {
  227. return nil, err
  228. }
  229. return stream, nil
  230. case <-s.shutdownCh:
  231. return nil, s.shutdownErr
  232. }
  233. }
  234. // Close is used to close the session and all streams.
  235. // Attempts to send a GoAway before closing the connection.
  236. func (s *Session) Close() error {
  237. s.shutdownLock.Lock()
  238. defer s.shutdownLock.Unlock()
  239. if s.shutdown {
  240. return nil
  241. }
  242. s.shutdown = true
  243. s.shutdownErrLock.Lock()
  244. if s.shutdownErr == nil {
  245. s.shutdownErr = ErrSessionShutdown
  246. }
  247. s.shutdownErrLock.Unlock()
  248. close(s.shutdownCh)
  249. s.conn.Close()
  250. <-s.recvDoneCh
  251. s.streamLock.Lock()
  252. defer s.streamLock.Unlock()
  253. for _, stream := range s.streams {
  254. stream.forceClose()
  255. }
  256. <-s.sendDoneCh
  257. return nil
  258. }
  259. // exitErr is used to handle an error that is causing the
  260. // session to terminate.
  261. func (s *Session) exitErr(err error) {
  262. s.shutdownErrLock.Lock()
  263. if s.shutdownErr == nil {
  264. s.shutdownErr = err
  265. }
  266. s.shutdownErrLock.Unlock()
  267. s.Close()
  268. }
  269. // GoAway can be used to prevent accepting further
  270. // connections. It does not close the underlying conn.
  271. func (s *Session) GoAway() error {
  272. return s.waitForSend(s.goAway(goAwayNormal), nil)
  273. }
  274. // goAway is used to send a goAway message
  275. func (s *Session) goAway(reason uint32) header {
  276. atomic.SwapInt32(&s.localGoAway, 1)
  277. hdr := header(make([]byte, headerSize))
  278. hdr.encode(typeGoAway, 0, 0, reason)
  279. return hdr
  280. }
  281. // Ping is used to measure the RTT response time
  282. func (s *Session) Ping() (time.Duration, error) {
  283. // Get a channel for the ping
  284. ch := make(chan struct{})
  285. // Get a new ping id, mark as pending
  286. s.pingLock.Lock()
  287. id := s.pingID
  288. s.pingID++
  289. s.pings[id] = ch
  290. s.pingLock.Unlock()
  291. // Send the ping request
  292. hdr := header(make([]byte, headerSize))
  293. hdr.encode(typePing, flagSYN, 0, id)
  294. if err := s.waitForSend(hdr, nil); err != nil {
  295. return 0, err
  296. }
  297. // Wait for a response
  298. start := time.Now()
  299. select {
  300. case <-ch:
  301. case <-time.After(s.config.ConnectionWriteTimeout):
  302. s.pingLock.Lock()
  303. delete(s.pings, id) // Ignore it if a response comes later.
  304. s.pingLock.Unlock()
  305. return 0, ErrTimeout
  306. case <-s.shutdownCh:
  307. return 0, ErrSessionShutdown
  308. }
  309. // Compute the RTT
  310. return time.Now().Sub(start), nil
  311. }
  312. // keepalive is a long running goroutine that periodically does
  313. // a ping to keep the connection alive.
  314. func (s *Session) keepalive() {
  315. for {
  316. select {
  317. case <-time.After(s.config.KeepAliveInterval):
  318. _, err := s.Ping()
  319. if err != nil {
  320. if err != ErrSessionShutdown {
  321. s.logger.Errorf("keepalive failed: %v", err)
  322. s.exitErr(ErrKeepAliveTimeout)
  323. }
  324. return
  325. }
  326. case <-s.shutdownCh:
  327. return
  328. }
  329. }
  330. }
  331. func (s *Session) bodyClone(buf []byte, p *sendReady) {
  332. if buf == nil {
  333. return // A nil body is ignored.
  334. }
  335. // In the event of session shutdown or connection write timeout,
  336. // we need to prevent `send` from reading the body buffer after
  337. // returning from this function since the caller may re-use the
  338. // underlying array.
  339. p.mu.Lock()
  340. defer p.mu.Unlock()
  341. if p.Body == nil {
  342. return // Body was already copied in `send`.
  343. }
  344. newBody := make([]byte, len(buf))
  345. copy(newBody, buf)
  346. p.Body = newBody
  347. }
  348. // waitForSendErr waits to send a header, checking for a potential shutdown
  349. func (s *Session) waitForSend(hdr header, body []byte) (err error) {
  350. errCh := make(chan error, 1)
  351. err = s.waitForSendErr(hdr, body, errCh)
  352. return
  353. }
  354. // waitForSendErr waits to send a header with optional data, checking for a
  355. // potential shutdown. Since there's the expectation that sends can happen
  356. // in a timely manner, we enforce the connection write timeout here.
  357. func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) error {
  358. t := timerPool.Get()
  359. timer := t.(*time.Timer)
  360. timer.Reset(s.config.ConnectionWriteTimeout)
  361. defer func() {
  362. timer.Stop()
  363. select {
  364. case <-timer.C:
  365. default:
  366. }
  367. timerPool.Put(t)
  368. }()
  369. ready := &sendReady{Hdr: hdr, Body: body, Err: errCh}
  370. select {
  371. case s.sendCh <- ready:
  372. case <-s.shutdownCh:
  373. return ErrSessionShutdown
  374. case <-timer.C:
  375. return ErrConnectionWriteTimeout
  376. }
  377. select {
  378. case err := <-errCh:
  379. return err
  380. case <-s.shutdownCh:
  381. s.bodyClone(body, ready)
  382. return ErrSessionShutdown
  383. case <-timer.C:
  384. s.bodyClone(body, ready)
  385. return ErrConnectionWriteTimeout
  386. }
  387. }
  388. // sendNoWait does a send without waiting. Since there's the expectation that
  389. // the send happens right here, we enforce the connection write timeout if we
  390. // can't queue the header to be sent.
  391. func (s *Session) sendNoWait(hdr header) error {
  392. t := timerPool.Get()
  393. timer := t.(*time.Timer)
  394. timer.Reset(s.config.ConnectionWriteTimeout)
  395. defer func() {
  396. timer.Stop()
  397. select {
  398. case <-timer.C:
  399. default:
  400. }
  401. timerPool.Put(t)
  402. }()
  403. select {
  404. case s.sendCh <- &sendReady{Hdr: hdr}:
  405. return nil
  406. case <-s.shutdownCh:
  407. return ErrSessionShutdown
  408. case <-timer.C:
  409. return ErrConnectionWriteTimeout
  410. }
  411. }
  412. // send is a long running goroutine that sends data
  413. func (s *Session) send() {
  414. if err := s.sendLoop(); err != nil {
  415. s.exitErr(err)
  416. }
  417. }
  418. func (s *Session) sendPacket(packet *sendReady) (err error) {
  419. var (
  420. n int
  421. nw int
  422. buf []byte
  423. buffer *bytes.Buffer
  424. )
  425. buffer = getBuffer()
  426. defer func() {
  427. putBuffer(buffer)
  428. }()
  429. packet.mu.Lock()
  430. // Send a header if ready
  431. if packet.Hdr != nil {
  432. n, _ = buffer.Write(packet.Hdr)
  433. nw += n
  434. }
  435. if packet.Body != nil {
  436. if s.config.Crypto != nil {
  437. if buf, err = s.config.Crypto.Encrypt(packet.Body); err != nil {
  438. return
  439. }
  440. } else {
  441. buf = packet.Body
  442. }
  443. n, _ = buffer.Write(buf)
  444. nw += n
  445. packet.Body = nil
  446. }
  447. packet.mu.Unlock()
  448. if buffer.Len() > 0 {
  449. // Send data from a body if given
  450. if n, err = s.conn.Write(buffer.Bytes()); err != nil {
  451. asyncSendErr(packet.Err, err)
  452. return err
  453. }
  454. if n != nw {
  455. asyncSendErr(packet.Err, io.ErrShortWrite)
  456. return io.ErrShortWrite
  457. }
  458. }
  459. // No error, successful send
  460. asyncSendErr(packet.Err, nil)
  461. return
  462. }
  463. func (s *Session) sendLoop() (err error) {
  464. defer close(s.sendDoneCh)
  465. for {
  466. select {
  467. case ready := <-s.sendCh:
  468. if err = s.sendPacket(ready); err != nil {
  469. return err
  470. }
  471. case <-s.shutdownCh:
  472. return
  473. }
  474. }
  475. }
  476. // recv is a long running goroutine that accepts new data
  477. func (s *Session) recv() {
  478. if err := s.recvLoop(); err != nil {
  479. s.exitErr(err)
  480. }
  481. }
  482. // Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
  483. var (
  484. handlers = []func(*Session, header) error{
  485. typeData: (*Session).handleStreamMessage,
  486. typeWindowUpdate: (*Session).handleStreamMessage,
  487. typePing: (*Session).handlePing,
  488. typeGoAway: (*Session).handleGoAway,
  489. }
  490. )
  491. // recvLoop continues to receive data until a fatal error is encountered
  492. func (s *Session) recvLoop() error {
  493. defer close(s.recvDoneCh)
  494. hdr := header(make([]byte, headerSize))
  495. for {
  496. // Read the header
  497. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  498. if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
  499. s.logger.Errorf("failed to read stream header: %v", err)
  500. }
  501. return err
  502. }
  503. // Verify the version
  504. if hdr.Version() != protoVersion {
  505. s.logger.Errorf("invalid stream protocol version: %d", hdr.Version())
  506. return ErrInvalidVersion
  507. }
  508. mt := hdr.MsgType()
  509. if mt < typeData || mt > typeGoAway {
  510. return ErrInvalidMsgType
  511. }
  512. if err := handlers[mt](s, hdr); err != nil {
  513. return err
  514. }
  515. }
  516. }
  517. // handleStreamMessage handles either a data or window update frame
  518. func (s *Session) handleStreamMessage(hdr header) error {
  519. // Check for a new stream creation
  520. id := hdr.StreamID()
  521. flags := hdr.Flags()
  522. if flags&flagSYN == flagSYN {
  523. if err := s.incomingStream(id); err != nil {
  524. return err
  525. }
  526. }
  527. // Get the stream
  528. s.streamLock.Lock()
  529. stream := s.streams[id]
  530. s.streamLock.Unlock()
  531. // If we do not have a stream, likely we sent a RST
  532. if stream == nil {
  533. // Drain any data on the wire
  534. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  535. s.logger.Warnf("discarding data for stream: %d", id)
  536. if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  537. s.logger.Errorf("failed to discard stream %d data: %v", id, err)
  538. return nil
  539. }
  540. } else {
  541. s.logger.Warnf("frame for missing stream: %v", hdr)
  542. }
  543. return nil
  544. }
  545. // Check if this is a window update
  546. if hdr.MsgType() == typeWindowUpdate {
  547. if err := stream.incrSendWindow(hdr, flags); err != nil {
  548. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  549. s.logger.Warnf("failed to send go away: %v", sendErr)
  550. }
  551. return err
  552. }
  553. return nil
  554. }
  555. // Read the new data
  556. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  557. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  558. s.logger.Warnf("failed to send go away: %v", sendErr)
  559. }
  560. return err
  561. }
  562. return nil
  563. }
  564. // handlePing is invoked for a typePing frame
  565. func (s *Session) handlePing(hdr header) error {
  566. flags := hdr.Flags()
  567. pingID := hdr.Length()
  568. // Check if this is a query, respond back in a separate context so we
  569. // don't interfere with the receiving thread blocking for the write.
  570. if flags&flagSYN == flagSYN {
  571. go func() {
  572. hdr := header(make([]byte, headerSize))
  573. hdr.encode(typePing, flagACK, 0, pingID)
  574. if err := s.sendNoWait(hdr); err != nil {
  575. s.logger.Warnf("stream %s failed to send ping reply: %v", hdr.StreamID(), err)
  576. }
  577. }()
  578. return nil
  579. }
  580. // Handle a response
  581. s.pingLock.Lock()
  582. ch := s.pings[pingID]
  583. if ch != nil {
  584. delete(s.pings, pingID)
  585. close(ch)
  586. }
  587. s.pingLock.Unlock()
  588. return nil
  589. }
  590. // handleGoAway is invoked for a typeGoAway frame
  591. func (s *Session) handleGoAway(hdr header) error {
  592. code := hdr.Length()
  593. switch code {
  594. case goAwayNormal:
  595. atomic.SwapInt32(&s.remoteGoAway, 1)
  596. case goAwayProtoErr:
  597. s.logger.Errorf("received protocol error go away")
  598. return fmt.Errorf("yamux protocol error")
  599. case goAwayInternalErr:
  600. s.logger.Errorf("received internal error go away")
  601. return fmt.Errorf("remote yamux internal error")
  602. default:
  603. s.logger.Errorf("received unexpected go away")
  604. return fmt.Errorf("unexpected go away received")
  605. }
  606. return nil
  607. }
  608. // incomingStream is used to create a new incoming stream
  609. func (s *Session) incomingStream(id uint32) error {
  610. // Reject immediately if we are doing a go away
  611. if atomic.LoadInt32(&s.localGoAway) == 1 {
  612. hdr := header(make([]byte, headerSize))
  613. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  614. return s.sendNoWait(hdr)
  615. }
  616. // Allocate a new stream
  617. stream := newStream(s, id, streamSYNReceived)
  618. s.streamLock.Lock()
  619. defer s.streamLock.Unlock()
  620. // Check if stream already exists
  621. if _, ok := s.streams[id]; ok {
  622. s.logger.Errorf("duplicate stream declared")
  623. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  624. s.logger.Warnf("failed to send go away: %v", sendErr)
  625. }
  626. return ErrDuplicateStream
  627. }
  628. // Register the stream
  629. s.streams[id] = stream
  630. // Check if we've exceeded the backlog
  631. select {
  632. case s.acceptCh <- stream:
  633. return nil
  634. default:
  635. // Backlog exceeded! RST the stream
  636. s.logger.Warnf("backlog exceeded, forcing connection reset")
  637. delete(s.streams, id)
  638. hdr := header(make([]byte, headerSize))
  639. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  640. return s.sendNoWait(hdr)
  641. }
  642. }
  643. // closeStream is used to close a stream once both sides have
  644. // issued a close. If there was an in-flight SYN and the stream
  645. // was not yet established, then this will give the credit back.
  646. func (s *Session) closeStream(id uint32) {
  647. s.streamLock.Lock()
  648. if _, ok := s.inflight[id]; ok {
  649. select {
  650. case <-s.synCh:
  651. default:
  652. s.logger.Errorf("SYN tracking out of sync")
  653. }
  654. }
  655. delete(s.streams, id)
  656. s.streamLock.Unlock()
  657. }
  658. // establishStream is used to mark a stream that was in the
  659. // SYN Sent state as established.
  660. func (s *Session) establishStream(id uint32) {
  661. s.streamLock.Lock()
  662. if _, ok := s.inflight[id]; ok {
  663. delete(s.inflight, id)
  664. } else {
  665. s.logger.Errorf("established stream without inflight SYN (no tracking entry)")
  666. }
  667. select {
  668. case <-s.synCh:
  669. default:
  670. s.logger.Errorf("established stream without inflight SYN (didn't have semaphore)")
  671. }
  672. s.streamLock.Unlock()
  673. }