session.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "math"
  9. "net"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // Session is used to wrap a reliable ordered connection and to
  16. // multiplex it into multiple streams.
  17. type Session struct {
  18. // remoteGoAway indicates the remote side does
  19. // not want futher connections. Must be first for alignment.
  20. remoteGoAway int32
  21. // localGoAway indicates that we should stop
  22. // accepting futher connections. Must be first for alignment.
  23. localGoAway int32
  24. // nextStreamID is the next stream we should
  25. // send. This depends if we are a client/server.
  26. nextStreamID uint32
  27. // config holds our configuration
  28. config *Config
  29. // logger is used for our logs
  30. logger *log.Logger
  31. // conn is the underlying connection
  32. conn io.ReadWriteCloser
  33. // bufRead is a buffered reader
  34. bufRead *bufio.Reader
  35. // pings is used to track inflight pings
  36. pings map[uint32]chan struct{}
  37. pingID uint32
  38. pingLock sync.Mutex
  39. // streams maps a stream id to a stream
  40. streams map[uint32]*Stream
  41. streamLock sync.Mutex
  42. // acceptCh is used to pass ready streams to the client
  43. acceptCh chan *Stream
  44. // sendCh is used to mark a stream as ready to send,
  45. // or to send a header out directly.
  46. sendCh chan sendReady
  47. // recvDoneCh is closed when recv() exits to avoid a race
  48. // between stream registration and stream shutdown
  49. recvDoneCh chan struct{}
  50. // shutdown is used to safely close a session
  51. shutdown bool
  52. shutdownErr error
  53. shutdownCh chan struct{}
  54. shutdownLock sync.Mutex
  55. }
  56. // sendReady is used to either mark a stream as ready
  57. // or to directly send a header
  58. type sendReady struct {
  59. Hdr []byte
  60. Body io.Reader
  61. Err chan error
  62. }
  63. // newSession is used to construct a new session
  64. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  65. s := &Session{
  66. config: config,
  67. logger: log.New(config.LogOutput, "", log.LstdFlags),
  68. conn: conn,
  69. bufRead: bufio.NewReader(conn),
  70. pings: make(map[uint32]chan struct{}),
  71. streams: make(map[uint32]*Stream),
  72. acceptCh: make(chan *Stream, config.AcceptBacklog),
  73. sendCh: make(chan sendReady, 64),
  74. recvDoneCh: make(chan struct{}),
  75. shutdownCh: make(chan struct{}),
  76. }
  77. if client {
  78. s.nextStreamID = 1
  79. } else {
  80. s.nextStreamID = 2
  81. }
  82. go s.recv()
  83. go s.send()
  84. if config.EnableKeepAlive {
  85. go s.keepalive()
  86. }
  87. return s
  88. }
  89. // IsClosed does a safe check to see if we have shutdown
  90. func (s *Session) IsClosed() bool {
  91. select {
  92. case <-s.shutdownCh:
  93. return true
  94. default:
  95. return false
  96. }
  97. }
  98. // NumStreams returns the number of currently open streams
  99. func (s *Session) NumStreams() int {
  100. s.streamLock.Lock()
  101. num := len(s.streams)
  102. s.streamLock.Unlock()
  103. return num
  104. }
  105. // Open is used to create a new stream as a net.Conn
  106. func (s *Session) Open() (net.Conn, error) {
  107. return s.OpenStream()
  108. }
  109. // OpenStream is used to create a new stream
  110. func (s *Session) OpenStream() (*Stream, error) {
  111. if s.IsClosed() {
  112. return nil, ErrSessionShutdown
  113. }
  114. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  115. return nil, ErrRemoteGoAway
  116. }
  117. GET_ID:
  118. // Get and ID, and check for stream exhaustion
  119. id := atomic.LoadUint32(&s.nextStreamID)
  120. if id >= math.MaxUint32-1 {
  121. return nil, ErrStreamsExhausted
  122. }
  123. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  124. goto GET_ID
  125. }
  126. // Register the stream
  127. stream := newStream(s, id, streamInit)
  128. s.streamLock.Lock()
  129. s.streams[id] = stream
  130. s.streamLock.Unlock()
  131. // Send the window update to create
  132. return stream, stream.sendWindowUpdate()
  133. }
  134. // Accept is used to block until the next available stream
  135. // is ready to be accepted.
  136. func (s *Session) Accept() (net.Conn, error) {
  137. return s.AcceptStream()
  138. }
  139. // AcceptStream is used to block until the next available stream
  140. // is ready to be accepted.
  141. func (s *Session) AcceptStream() (*Stream, error) {
  142. select {
  143. case stream := <-s.acceptCh:
  144. return stream, stream.sendWindowUpdate()
  145. case <-s.shutdownCh:
  146. return nil, s.shutdownErr
  147. }
  148. }
  149. // Close is used to close the session and all streams.
  150. // Attempts to send a GoAway before closing the connection.
  151. func (s *Session) Close() error {
  152. s.shutdownLock.Lock()
  153. defer s.shutdownLock.Unlock()
  154. if s.shutdown {
  155. return nil
  156. }
  157. s.shutdown = true
  158. if s.shutdownErr == nil {
  159. s.shutdownErr = ErrSessionShutdown
  160. }
  161. close(s.shutdownCh)
  162. s.conn.Close()
  163. <-s.recvDoneCh
  164. s.streamLock.Lock()
  165. defer s.streamLock.Unlock()
  166. for _, stream := range s.streams {
  167. stream.forceClose()
  168. }
  169. return nil
  170. }
  171. // exitErr is used to handle an error that is causing the
  172. // session to terminate.
  173. func (s *Session) exitErr(err error) {
  174. s.shutdownLock.Lock()
  175. if s.shutdownErr == nil {
  176. s.shutdownErr = err
  177. }
  178. s.shutdownLock.Unlock()
  179. s.Close()
  180. }
  181. // GoAway can be used to prevent accepting further
  182. // connections. It does not close the underlying conn.
  183. func (s *Session) GoAway() error {
  184. return s.waitForSend(s.goAway(goAwayNormal), nil)
  185. }
  186. // goAway is used to send a goAway message
  187. func (s *Session) goAway(reason uint32) header {
  188. atomic.SwapInt32(&s.localGoAway, 1)
  189. hdr := header(make([]byte, headerSize))
  190. hdr.encode(typeGoAway, 0, 0, reason)
  191. return hdr
  192. }
  193. // Ping is used to measure the RTT response time
  194. func (s *Session) Ping() (time.Duration, error) {
  195. // Get a channel for the ping
  196. ch := make(chan struct{})
  197. // Get a new ping id, mark as pending
  198. s.pingLock.Lock()
  199. id := s.pingID
  200. s.pingID++
  201. s.pings[id] = ch
  202. s.pingLock.Unlock()
  203. // Send the ping request
  204. hdr := header(make([]byte, headerSize))
  205. hdr.encode(typePing, flagSYN, 0, id)
  206. if err := s.waitForSend(hdr, nil); err != nil {
  207. return 0, err
  208. }
  209. // Wait for a response
  210. start := time.Now()
  211. select {
  212. case <-ch:
  213. case <-s.shutdownCh:
  214. return 0, ErrSessionShutdown
  215. }
  216. // Compute the RTT
  217. return time.Now().Sub(start), nil
  218. }
  219. // keepalive is a long running goroutine that periodically does
  220. // a ping to keep the connection alive.
  221. func (s *Session) keepalive() {
  222. for {
  223. select {
  224. case <-time.After(s.config.KeepAliveInterval):
  225. s.Ping()
  226. case <-s.shutdownCh:
  227. return
  228. }
  229. }
  230. }
  231. // waitForSendErr waits to send a header, checking for a potential shutdown
  232. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  233. errCh := make(chan error, 1)
  234. return s.waitForSendErr(hdr, body, errCh)
  235. }
  236. // waitForSendErr waits to send a header, checking for a potential shutdown
  237. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
  238. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  239. select {
  240. case s.sendCh <- ready:
  241. case <-s.shutdownCh:
  242. return ErrSessionShutdown
  243. }
  244. select {
  245. case err := <-errCh:
  246. return err
  247. case <-s.shutdownCh:
  248. return ErrSessionShutdown
  249. }
  250. }
  251. // sendNoWait does a send without waiting
  252. func (s *Session) sendNoWait(hdr header) error {
  253. select {
  254. case s.sendCh <- sendReady{Hdr: hdr}:
  255. return nil
  256. case <-s.shutdownCh:
  257. return ErrSessionShutdown
  258. }
  259. }
  260. // send is a long running goroutine that sends data
  261. func (s *Session) send() {
  262. for {
  263. select {
  264. case ready := <-s.sendCh:
  265. // Send a header if ready
  266. if ready.Hdr != nil {
  267. sent := 0
  268. for sent < len(ready.Hdr) {
  269. n, err := s.conn.Write(ready.Hdr[sent:])
  270. if err != nil {
  271. s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
  272. asyncSendErr(ready.Err, err)
  273. s.exitErr(err)
  274. return
  275. }
  276. sent += n
  277. }
  278. }
  279. // Send data from a body if given
  280. if ready.Body != nil {
  281. _, err := io.Copy(s.conn, ready.Body)
  282. if err != nil {
  283. s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
  284. asyncSendErr(ready.Err, err)
  285. s.exitErr(err)
  286. return
  287. }
  288. }
  289. // No error, successful send
  290. asyncSendErr(ready.Err, nil)
  291. case <-s.shutdownCh:
  292. return
  293. }
  294. }
  295. }
  296. // recv is a long running goroutine that accepts new data
  297. func (s *Session) recv() {
  298. if err := s.recvLoop(); err != nil {
  299. s.exitErr(err)
  300. }
  301. }
  302. // recvLoop continues to receive data until a fatal error is encountered
  303. func (s *Session) recvLoop() error {
  304. defer close(s.recvDoneCh)
  305. hdr := header(make([]byte, headerSize))
  306. var handler func(header) error
  307. for {
  308. // Read the header
  309. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  310. if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
  311. s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
  312. }
  313. return err
  314. }
  315. // Verify the version
  316. if hdr.Version() != protoVersion {
  317. s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
  318. return ErrInvalidVersion
  319. }
  320. // Switch on the type
  321. switch hdr.MsgType() {
  322. case typeData:
  323. handler = s.handleStreamMessage
  324. case typeWindowUpdate:
  325. handler = s.handleStreamMessage
  326. case typeGoAway:
  327. handler = s.handleGoAway
  328. case typePing:
  329. handler = s.handlePing
  330. default:
  331. return ErrInvalidMsgType
  332. }
  333. // Invoke the handler
  334. if err := handler(hdr); err != nil {
  335. return err
  336. }
  337. }
  338. }
  339. // handleStreamMessage handles either a data or window update frame
  340. func (s *Session) handleStreamMessage(hdr header) error {
  341. // Check for a new stream creation
  342. id := hdr.StreamID()
  343. flags := hdr.Flags()
  344. if flags&flagSYN == flagSYN {
  345. if err := s.incomingStream(id); err != nil {
  346. return err
  347. }
  348. }
  349. // Get the stream
  350. s.streamLock.Lock()
  351. stream := s.streams[id]
  352. s.streamLock.Unlock()
  353. // If we do not have a stream, likely we sent a RST
  354. if stream == nil {
  355. // Drain any data on the wire
  356. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  357. s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
  358. if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  359. s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
  360. return nil
  361. }
  362. } else {
  363. s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
  364. }
  365. return nil
  366. }
  367. // Check if this is a window update
  368. if hdr.MsgType() == typeWindowUpdate {
  369. if err := stream.incrSendWindow(hdr, flags); err != nil {
  370. s.sendNoWait(s.goAway(goAwayProtoErr))
  371. return err
  372. }
  373. return nil
  374. }
  375. // Read the new data
  376. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  377. s.sendNoWait(s.goAway(goAwayProtoErr))
  378. return err
  379. }
  380. return nil
  381. }
  382. // handlePing is invokde for a typePing frame
  383. func (s *Session) handlePing(hdr header) error {
  384. flags := hdr.Flags()
  385. pingID := hdr.Length()
  386. // Check if this is a query, respond back
  387. if flags&flagSYN == flagSYN {
  388. hdr := header(make([]byte, headerSize))
  389. hdr.encode(typePing, flagACK, 0, pingID)
  390. s.sendNoWait(hdr)
  391. return nil
  392. }
  393. // Handle a response
  394. s.pingLock.Lock()
  395. ch := s.pings[pingID]
  396. if ch != nil {
  397. delete(s.pings, pingID)
  398. close(ch)
  399. }
  400. s.pingLock.Unlock()
  401. return nil
  402. }
  403. // handleGoAway is invokde for a typeGoAway frame
  404. func (s *Session) handleGoAway(hdr header) error {
  405. code := hdr.Length()
  406. switch code {
  407. case goAwayNormal:
  408. atomic.SwapInt32(&s.remoteGoAway, 1)
  409. case goAwayProtoErr:
  410. s.logger.Printf("[ERR] yamux: received protocol error go away")
  411. return fmt.Errorf("yamux protocol error")
  412. case goAwayInternalErr:
  413. s.logger.Printf("[ERR] yamux: received internal error go away")
  414. return fmt.Errorf("remote yamux internal error")
  415. default:
  416. s.logger.Printf("[ERR] yamux: received unexpected go away")
  417. return fmt.Errorf("unexpected go away received")
  418. }
  419. return nil
  420. }
  421. // incomingStream is used to create a new incoming stream
  422. func (s *Session) incomingStream(id uint32) error {
  423. // Reject immediately if we are doing a go away
  424. if atomic.LoadInt32(&s.localGoAway) == 1 {
  425. hdr := header(make([]byte, headerSize))
  426. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  427. return s.sendNoWait(hdr)
  428. }
  429. // Allocate a new stream
  430. stream := newStream(s, id, streamSYNReceived)
  431. s.streamLock.Lock()
  432. defer s.streamLock.Unlock()
  433. // Check if stream already exists
  434. if _, ok := s.streams[id]; ok {
  435. s.logger.Printf("[ERR] yamux: duplicate stream declared")
  436. s.sendNoWait(s.goAway(goAwayProtoErr))
  437. return ErrDuplicateStream
  438. }
  439. // Register the stream
  440. s.streams[id] = stream
  441. // Check if we've exceeded the backlog
  442. select {
  443. case s.acceptCh <- stream:
  444. return nil
  445. default:
  446. // Backlog exceeded! RST the stream
  447. s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
  448. delete(s.streams, id)
  449. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  450. return s.sendNoWait(stream.sendHdr)
  451. }
  452. }
  453. // closeStream is used to close a stream once both sides have
  454. // issued a close.
  455. func (s *Session) closeStream(id uint32) {
  456. s.streamLock.Lock()
  457. delete(s.streams, id)
  458. s.streamLock.Unlock()
  459. }