session.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "math"
  6. "net"
  7. "sync"
  8. "time"
  9. )
  10. var (
  11. // ErrInvalidVersion means we received a frame with an
  12. // invalid version
  13. ErrInvalidVersion = fmt.Errorf("invalid protocol version")
  14. // ErrInvalidMsgType means we received a frame with an
  15. // invalid message type
  16. ErrInvalidMsgType = fmt.Errorf("invalid msg type")
  17. // ErrSessionShutdown is used if there is a shutdown during
  18. // an operation
  19. ErrSessionShutdown = fmt.Errorf("session shutdown")
  20. // ErrStreamsExhausted is returned if we have no more
  21. // stream ids to issue
  22. ErrStreamsExhausted = fmt.Errorf("streams exhausted")
  23. // ErrDuplicateStream is used if a duplicate stream is
  24. // opened inbound
  25. ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
  26. // ErrMissingStream indicates a stream was named which
  27. // does not exist.
  28. ErrMissingStream = fmt.Errorf("missing stream references")
  29. // ErrReceiveWindowExceeded indicates the window was exceeded
  30. ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
  31. // ErrTimeout is used when we reach an IO deadline
  32. ErrTimeout = fmt.Errorf("i/o deadline reached")
  33. // ErrStreamClosed is returned when using a closed stream
  34. ErrStreamClosed = fmt.Errorf("stream closed")
  35. // ErrUnexpectedFlag is set when we get an unexpected flag
  36. ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
  37. // ErrRemoteGoAway is used when we get a go away from the other side
  38. ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
  39. )
  40. // Session is used to wrap a reliable ordered connection and to
  41. // multiplex it into multiple streams.
  42. type Session struct {
  43. // client is true if we are a client size connection
  44. client bool
  45. // config holds our configuration
  46. config *Config
  47. // conn is the underlying connection
  48. conn io.ReadWriteCloser
  49. // pings is used to track inflight pings
  50. pings map[uint32]chan struct{}
  51. pingID uint32
  52. pingLock sync.Mutex
  53. // remoteGoAway indicates the remote side does
  54. // not want futher connections
  55. remoteGoAway bool
  56. // localGoAway indicates that we should stop
  57. // accepting futher connections
  58. localGoAway bool
  59. // nextStreamID is the next stream we should
  60. // send. This depends if we are a client/server.
  61. nextStreamID uint32
  62. // streams maps a stream id to a stream
  63. streams map[uint32]*Stream
  64. streamLock sync.RWMutex
  65. // acceptCh is used to pass ready streams to the client
  66. acceptCh chan *Stream
  67. // sendCh is used to mark a stream as ready to send,
  68. // or to send a header out directly.
  69. sendCh chan sendReady
  70. // shutdown is used to safely close a session
  71. shutdown bool
  72. shutdownErr error
  73. shutdownCh chan struct{}
  74. shutdownLock sync.Mutex
  75. }
  76. // hasAddr is used to get the address from the underlying connection
  77. type hasAddr interface {
  78. LocalAddr() net.Addr
  79. RemoteAddr() net.Addr
  80. }
  81. // yamuxAddr is used when we cannot get the underlying address
  82. type yamuxAddr struct {
  83. Addr string
  84. }
  85. func (*yamuxAddr) Network() string {
  86. return "yamux"
  87. }
  88. func (y *yamuxAddr) String() string {
  89. return fmt.Sprintf("yamux:%s", y.Addr)
  90. }
  91. // sendReady is used to either mark a stream as ready
  92. // or to directly send a header
  93. type sendReady struct {
  94. Hdr []byte
  95. Body io.Reader
  96. Err chan error
  97. }
  98. // newSession is used to construct a new session
  99. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  100. s := &Session{
  101. client: client,
  102. config: config,
  103. conn: conn,
  104. pings: make(map[uint32]chan struct{}),
  105. streams: make(map[uint32]*Stream),
  106. acceptCh: make(chan *Stream, config.AcceptBacklog),
  107. sendCh: make(chan sendReady, 64),
  108. shutdownCh: make(chan struct{}),
  109. }
  110. if client {
  111. s.nextStreamID = 1
  112. } else {
  113. s.nextStreamID = 2
  114. }
  115. go s.recv()
  116. go s.send()
  117. if config.EnableKeepAlive {
  118. go s.keepalive()
  119. }
  120. return s
  121. }
  122. // isShutdown does a safe check to see if we have shutdown
  123. func (s *Session) isShutdown() bool {
  124. select {
  125. case <-s.shutdownCh:
  126. return true
  127. default:
  128. return false
  129. }
  130. }
  131. // Open is used to create a new stream
  132. func (s *Session) Open() (*Stream, error) {
  133. if s.isShutdown() {
  134. return nil, ErrSessionShutdown
  135. }
  136. if s.remoteGoAway {
  137. return nil, ErrRemoteGoAway
  138. }
  139. s.streamLock.Lock()
  140. defer s.streamLock.Unlock()
  141. // Check if we've exhaused the streams
  142. id := s.nextStreamID
  143. if id >= math.MaxUint32-1 {
  144. return nil, ErrStreamsExhausted
  145. }
  146. s.nextStreamID += 2
  147. // Register the stream
  148. stream := newStream(s, id, streamInit)
  149. s.streams[id] = stream
  150. // Send the window update to create
  151. return stream, stream.sendWindowUpdate()
  152. }
  153. // Accept is used to block until the next available stream
  154. // is ready to be accepted.
  155. func (s *Session) Accept() (net.Conn, error) {
  156. return s.AcceptStream()
  157. }
  158. // AcceptStream is used to block until the next available stream
  159. // is ready to be accepted.
  160. func (s *Session) AcceptStream() (*Stream, error) {
  161. select {
  162. case stream := <-s.acceptCh:
  163. return stream, nil
  164. case <-s.shutdownCh:
  165. return nil, s.shutdownErr
  166. }
  167. }
  168. // Close is used to close the session and all streams.
  169. // Attempts to send a GoAway before closing the connection.
  170. func (s *Session) Close() error {
  171. s.shutdownLock.Lock()
  172. defer s.shutdownLock.Unlock()
  173. if s.shutdown {
  174. return nil
  175. }
  176. s.shutdown = true
  177. if s.shutdownErr == nil {
  178. s.shutdownErr = ErrSessionShutdown
  179. }
  180. close(s.shutdownCh)
  181. s.conn.Close()
  182. s.streamLock.Lock()
  183. defer s.streamLock.Unlock()
  184. for _, stream := range s.streams {
  185. stream.forceClose()
  186. }
  187. return nil
  188. }
  189. // GoAway can be used to prevent accepting further
  190. // connections. It does not close the underlying conn.
  191. func (s *Session) GoAway() error {
  192. s.localGoAway = true
  193. s.goAway(goAwayNormal)
  194. return nil
  195. }
  196. // Addr is used to get the address of the listener.
  197. func (s *Session) Addr() net.Addr {
  198. return s.LocalAddr()
  199. }
  200. // LocalAddr is used to get the local address of the
  201. // underlying connection.
  202. func (s *Session) LocalAddr() net.Addr {
  203. addr, ok := s.conn.(hasAddr)
  204. if !ok {
  205. return &yamuxAddr{"local"}
  206. }
  207. return addr.LocalAddr()
  208. }
  209. // RemoteAddr is used to get the address of remote end
  210. // of the underlying connection
  211. func (s *Session) RemoteAddr() net.Addr {
  212. addr, ok := s.conn.(hasAddr)
  213. if !ok {
  214. return &yamuxAddr{"remote"}
  215. }
  216. return addr.RemoteAddr()
  217. }
  218. // Ping is used to measure the RTT response time
  219. func (s *Session) Ping() (time.Duration, error) {
  220. // Get a channel for the ping
  221. ch := make(chan struct{})
  222. // Get a new ping id, mark as pending
  223. s.pingLock.Lock()
  224. id := s.pingID
  225. s.pingID++
  226. s.pings[id] = ch
  227. s.pingLock.Unlock()
  228. // Send the ping request
  229. hdr := header(make([]byte, headerSize))
  230. hdr.encode(typePing, flagSYN, 0, id)
  231. if err := s.waitForSend(hdr, nil); err != nil {
  232. return 0, err
  233. }
  234. // Wait for a response
  235. start := time.Now()
  236. select {
  237. case <-ch:
  238. case <-s.shutdownCh:
  239. return 0, ErrSessionShutdown
  240. }
  241. // Compute the RTT
  242. return time.Now().Sub(start), nil
  243. }
  244. // keepalive is a long running goroutine that periodically does
  245. // a ping to keep the connection alive.
  246. func (s *Session) keepalive() {
  247. for {
  248. select {
  249. case <-time.After(s.config.KeepAliveInterval):
  250. s.Ping()
  251. case <-s.shutdownCh:
  252. return
  253. }
  254. }
  255. }
  256. // waitForSend waits to send a header, checking for a potential shutdown
  257. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  258. errCh := make(chan error, 1)
  259. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  260. select {
  261. case s.sendCh <- ready:
  262. case <-s.shutdownCh:
  263. return ErrSessionShutdown
  264. }
  265. select {
  266. case err := <-errCh:
  267. return err
  268. case <-s.shutdownCh:
  269. return ErrSessionShutdown
  270. }
  271. }
  272. // sendNoWait does a send without waiting
  273. func (s *Session) sendNoWait(hdr header) error {
  274. select {
  275. case s.sendCh <- sendReady{Hdr: hdr}:
  276. return nil
  277. case <-s.shutdownCh:
  278. return ErrSessionShutdown
  279. }
  280. }
  281. // send is a long running goroutine that sends data
  282. func (s *Session) send() {
  283. for {
  284. select {
  285. case ready := <-s.sendCh:
  286. // Send a header if ready
  287. if ready.Hdr != nil {
  288. sent := 0
  289. for sent < len(ready.Hdr) {
  290. n, err := s.conn.Write(ready.Hdr[sent:])
  291. if err != nil {
  292. s.exitErr(err)
  293. asyncSendErr(ready.Err, err)
  294. return
  295. }
  296. sent += n
  297. }
  298. }
  299. // Send data from a body if given
  300. if ready.Body != nil {
  301. _, err := io.Copy(s.conn, ready.Body)
  302. if err != nil {
  303. s.exitErr(err)
  304. asyncSendErr(ready.Err, err)
  305. return
  306. }
  307. }
  308. // No error, successful send
  309. asyncSendErr(ready.Err, nil)
  310. case <-s.shutdownCh:
  311. return
  312. }
  313. }
  314. }
  315. // recv is a long running goroutine that accepts new data
  316. func (s *Session) recv() {
  317. hdr := header(make([]byte, headerSize))
  318. for !s.isShutdown() {
  319. // Read the header
  320. if _, err := io.ReadFull(s.conn, hdr); err != nil {
  321. s.exitErr(err)
  322. return
  323. }
  324. // Verify the version
  325. if hdr.Version() != protoVersion {
  326. s.exitErr(ErrInvalidVersion)
  327. return
  328. }
  329. // Switch on the type
  330. msgType := hdr.MsgType()
  331. switch msgType {
  332. case typeData:
  333. fallthrough
  334. case typeWindowUpdate:
  335. if err := s.handleStreamMessage(hdr); err != nil {
  336. s.exitErr(err)
  337. return
  338. }
  339. case typeGoAway:
  340. if err := s.handleGoAway(hdr); err != nil {
  341. s.exitErr(err)
  342. return
  343. }
  344. case typePing:
  345. if err := s.handlePing(hdr); err != nil {
  346. s.exitErr(err)
  347. return
  348. }
  349. default:
  350. s.exitErr(ErrInvalidMsgType)
  351. return
  352. }
  353. }
  354. }
  355. // handleStreamMessage handles either a data or window update frame
  356. func (s *Session) handleStreamMessage(hdr header) error {
  357. // Check for a new stream creation
  358. id := hdr.StreamID()
  359. flags := hdr.Flags()
  360. if flags&flagSYN == flagSYN {
  361. if err := s.incomingStream(id); err != nil {
  362. return err
  363. }
  364. }
  365. // Get the stream
  366. s.streamLock.RLock()
  367. stream := s.streams[id]
  368. s.streamLock.RUnlock()
  369. // Make sure we have a stream
  370. if stream == nil {
  371. s.goAway(goAwayProtoErr)
  372. return ErrMissingStream
  373. }
  374. // Check if this is a window update
  375. if hdr.MsgType() == typeWindowUpdate {
  376. if err := stream.incrSendWindow(hdr, flags); err != nil {
  377. s.goAway(goAwayProtoErr)
  378. return err
  379. }
  380. }
  381. // Read the new data
  382. if err := stream.readData(hdr, flags, s.conn); err != nil {
  383. s.goAway(goAwayProtoErr)
  384. return err
  385. }
  386. return nil
  387. }
  388. // handlePing is invokde for a typePing frame
  389. func (s *Session) handlePing(hdr header) error {
  390. flags := hdr.Flags()
  391. pingID := hdr.Length()
  392. // Check if this is a query, respond back
  393. if flags&flagSYN == flagSYN {
  394. hdr := header(make([]byte, headerSize))
  395. hdr.encode(typePing, flagACK, 0, pingID)
  396. s.sendNoWait(hdr)
  397. return nil
  398. }
  399. // Handle a response
  400. s.pingLock.Lock()
  401. ch := s.pings[pingID]
  402. if ch != nil {
  403. delete(s.pings, pingID)
  404. close(ch)
  405. }
  406. s.pingLock.Unlock()
  407. return nil
  408. }
  409. // handleGoAway is invokde for a typeGoAway frame
  410. func (s *Session) handleGoAway(hdr header) error {
  411. code := hdr.Length()
  412. switch code {
  413. case goAwayNormal:
  414. s.remoteGoAway = true
  415. case goAwayProtoErr:
  416. return fmt.Errorf("yamux protocol error")
  417. case goAwayInternalErr:
  418. return fmt.Errorf("remote yamux internal error")
  419. default:
  420. return fmt.Errorf("unexpected go away received")
  421. }
  422. return nil
  423. }
  424. // exitErr is used to handle an error that is causing
  425. // the listener to exit.
  426. func (s *Session) exitErr(err error) {
  427. s.shutdownErr = err
  428. s.Close()
  429. }
  430. // goAway is used to send a goAway message
  431. func (s *Session) goAway(reason uint32) {
  432. hdr := header(make([]byte, headerSize))
  433. hdr.encode(typeGoAway, 0, 0, reason)
  434. s.sendNoWait(hdr)
  435. }
  436. // incomingStream is used to create a new incoming stream
  437. func (s *Session) incomingStream(id uint32) error {
  438. // Reject immediately if we are doing a go away
  439. if s.localGoAway {
  440. hdr := header(make([]byte, headerSize))
  441. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  442. s.sendNoWait(hdr)
  443. return nil
  444. }
  445. s.streamLock.Lock()
  446. defer s.streamLock.Unlock()
  447. // Check if stream already exists
  448. if _, ok := s.streams[id]; ok {
  449. s.goAway(goAwayProtoErr)
  450. s.exitErr(ErrDuplicateStream)
  451. return nil
  452. }
  453. // Register the stream
  454. stream := newStream(s, id, streamSYNReceived)
  455. s.streams[id] = stream
  456. // Check if we've exceeded the backlog
  457. select {
  458. case s.acceptCh <- stream:
  459. return nil
  460. default:
  461. // Backlog exceeded! RST the stream
  462. delete(s.streams, id)
  463. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  464. s.sendNoWait(stream.sendHdr)
  465. }
  466. return nil
  467. }
  468. // closeStream is used to close a stream once both sides have
  469. // issued a close.
  470. func (s *Session) closeStream(id uint32, withLock bool) {
  471. if !withLock {
  472. s.streamLock.Lock()
  473. defer s.streamLock.Unlock()
  474. }
  475. delete(s.streams, id)
  476. }