stream.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. streamReset
  19. )
  20. // Stream is used to represent a logical stream
  21. // within a session.
  22. type Stream struct {
  23. recvWindow uint32
  24. sendWindow uint32
  25. id uint32
  26. session *Session
  27. state streamState
  28. stateLock sync.Mutex
  29. recvBuf bytes.Buffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlErr chan error
  33. controlHdrLock sync.Mutex
  34. sendHdr header
  35. sendErr chan error
  36. sendLock sync.Mutex
  37. recvNotifyCh chan struct{}
  38. sendNotifyCh chan struct{}
  39. readDeadline time.Time
  40. writeDeadline time.Time
  41. }
  42. // newStream is used to construct a new stream within
  43. // a given session for an ID
  44. func newStream(session *Session, id uint32, state streamState) *Stream {
  45. s := &Stream{
  46. id: id,
  47. session: session,
  48. state: state,
  49. controlHdr: header(make([]byte, headerSize)),
  50. controlErr: make(chan error, 1),
  51. sendHdr: header(make([]byte, headerSize)),
  52. sendErr: make(chan error, 1),
  53. recvWindow: initialStreamWindow,
  54. sendWindow: initialStreamWindow,
  55. recvNotifyCh: make(chan struct{}, 1),
  56. sendNotifyCh: make(chan struct{}, 1),
  57. }
  58. return s
  59. }
  60. // Session returns the associated stream session
  61. func (s *Stream) Session() *Session {
  62. return s.session
  63. }
  64. // StreamID returns the ID of this stream
  65. func (s *Stream) StreamID() uint32 {
  66. return s.id
  67. }
  68. // Read is used to read from the stream
  69. func (s *Stream) Read(b []byte) (n int, err error) {
  70. defer asyncNotify(s.recvNotifyCh)
  71. START:
  72. s.stateLock.Lock()
  73. switch s.state {
  74. case streamLocalClose:
  75. fallthrough
  76. case streamRemoteClose:
  77. fallthrough
  78. case streamClosed:
  79. if s.recvBuf.Len() == 0 {
  80. s.stateLock.Unlock()
  81. return 0, io.EOF
  82. }
  83. case streamReset:
  84. s.stateLock.Unlock()
  85. return 0, ErrConnectionReset
  86. }
  87. s.stateLock.Unlock()
  88. // If there is no data available, block
  89. s.recvLock.Lock()
  90. if s.recvBuf.Len() == 0 {
  91. s.recvLock.Unlock()
  92. goto WAIT
  93. }
  94. // Read any bytes
  95. n, _ = s.recvBuf.Read(b)
  96. s.recvLock.Unlock()
  97. // Send a window update potentially
  98. err = s.sendWindowUpdate()
  99. return n, err
  100. WAIT:
  101. var timeout <-chan time.Time
  102. if !s.readDeadline.IsZero() {
  103. delay := s.readDeadline.Sub(time.Now())
  104. timeout = time.After(delay)
  105. }
  106. select {
  107. case <-s.recvNotifyCh:
  108. goto START
  109. case <-timeout:
  110. return 0, ErrTimeout
  111. }
  112. }
  113. // Write is used to write to the stream
  114. func (s *Stream) Write(b []byte) (n int, err error) {
  115. s.sendLock.Lock()
  116. defer s.sendLock.Unlock()
  117. total := 0
  118. for total < len(b) {
  119. n, err := s.write(b[total:])
  120. total += n
  121. if err != nil {
  122. return total, err
  123. }
  124. }
  125. return total, nil
  126. }
  127. // write is used to write to the stream, may return on
  128. // a short write.
  129. func (s *Stream) write(b []byte) (n int, err error) {
  130. var flags uint16
  131. var max uint32
  132. var body io.Reader
  133. START:
  134. s.stateLock.Lock()
  135. switch s.state {
  136. case streamLocalClose:
  137. fallthrough
  138. case streamClosed:
  139. s.stateLock.Unlock()
  140. return 0, ErrStreamClosed
  141. case streamReset:
  142. s.stateLock.Unlock()
  143. return 0, ErrConnectionReset
  144. }
  145. s.stateLock.Unlock()
  146. // If there is no data available, block
  147. window := atomic.LoadUint32(&s.sendWindow)
  148. if window == 0 {
  149. goto WAIT
  150. }
  151. // Determine the flags if any
  152. flags = s.sendFlags()
  153. // Send up to our send window
  154. max = min(window, uint32(len(b)))
  155. body = bytes.NewReader(b[:max])
  156. // Send the header
  157. s.sendHdr.encode(typeData, flags, s.id, max)
  158. if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
  159. return 0, err
  160. }
  161. // Reduce our send window
  162. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  163. // Unlock
  164. return int(max), err
  165. WAIT:
  166. var timeout <-chan time.Time
  167. if !s.writeDeadline.IsZero() {
  168. delay := s.writeDeadline.Sub(time.Now())
  169. timeout = time.After(delay)
  170. }
  171. select {
  172. case <-s.sendNotifyCh:
  173. goto START
  174. case <-timeout:
  175. return 0, ErrTimeout
  176. }
  177. return 0, nil
  178. }
  179. // sendFlags determines any flags that are appropriate
  180. // based on the current stream state
  181. func (s *Stream) sendFlags() uint16 {
  182. s.stateLock.Lock()
  183. defer s.stateLock.Unlock()
  184. var flags uint16
  185. switch s.state {
  186. case streamInit:
  187. flags |= flagSYN
  188. s.state = streamSYNSent
  189. case streamSYNReceived:
  190. flags |= flagACK
  191. s.state = streamEstablished
  192. }
  193. return flags
  194. }
  195. // sendWindowUpdate potentially sends a window update enabling
  196. // further writes to take place. Must be invoked with the lock.
  197. func (s *Stream) sendWindowUpdate() error {
  198. s.controlHdrLock.Lock()
  199. defer s.controlHdrLock.Unlock()
  200. // Determine the delta update
  201. max := s.session.config.MaxStreamWindowSize
  202. delta := max - atomic.LoadUint32(&s.recvWindow)
  203. // Determine the flags if any
  204. flags := s.sendFlags()
  205. // Check if we can omit the update
  206. if delta < (max/2) && flags == 0 {
  207. return nil
  208. }
  209. // Update our window
  210. atomic.AddUint32(&s.recvWindow, delta)
  211. // Send the header
  212. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  213. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  214. return err
  215. }
  216. return nil
  217. }
  218. // sendClose is used to send a FIN
  219. func (s *Stream) sendClose() error {
  220. s.controlHdrLock.Lock()
  221. defer s.controlHdrLock.Unlock()
  222. flags := s.sendFlags()
  223. flags |= flagFIN
  224. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  225. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  226. return err
  227. }
  228. return nil
  229. }
  230. // Close is used to close the stream
  231. func (s *Stream) Close() error {
  232. s.stateLock.Lock()
  233. switch s.state {
  234. // Opened means we need to signal a close
  235. case streamSYNSent:
  236. fallthrough
  237. case streamSYNReceived:
  238. fallthrough
  239. case streamEstablished:
  240. s.state = streamLocalClose
  241. goto SEND_CLOSE
  242. case streamLocalClose:
  243. case streamRemoteClose:
  244. s.state = streamClosed
  245. s.session.closeStream(s.id)
  246. goto SEND_CLOSE
  247. case streamClosed:
  248. case streamReset:
  249. default:
  250. panic("unhandled state")
  251. }
  252. s.stateLock.Unlock()
  253. return nil
  254. SEND_CLOSE:
  255. s.stateLock.Unlock()
  256. s.sendClose()
  257. s.notifyWaiting()
  258. return nil
  259. }
  260. // forceClose is used for when the session is exiting
  261. func (s *Stream) forceClose() {
  262. s.stateLock.Lock()
  263. s.state = streamClosed
  264. s.stateLock.Unlock()
  265. s.notifyWaiting()
  266. }
  267. // processFlags is used to update the state of the stream
  268. // based on set flags, if any. Lock must be held
  269. func (s *Stream) processFlags(flags uint16) error {
  270. s.stateLock.Lock()
  271. defer s.stateLock.Unlock()
  272. if flags&flagACK == flagACK {
  273. if s.state == streamSYNSent {
  274. s.state = streamEstablished
  275. }
  276. } else if flags&flagFIN == flagFIN {
  277. switch s.state {
  278. case streamSYNSent:
  279. fallthrough
  280. case streamSYNReceived:
  281. fallthrough
  282. case streamEstablished:
  283. s.state = streamRemoteClose
  284. s.notifyWaiting()
  285. case streamLocalClose:
  286. s.state = streamClosed
  287. s.session.closeStream(s.id)
  288. s.notifyWaiting()
  289. default:
  290. s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
  291. return ErrUnexpectedFlag
  292. }
  293. } else if flags&flagRST == flagRST {
  294. s.state = streamReset
  295. s.session.closeStream(s.id)
  296. s.notifyWaiting()
  297. }
  298. return nil
  299. }
  300. // notifyWaiting notifies all the waiting channels
  301. func (s *Stream) notifyWaiting() {
  302. asyncNotify(s.recvNotifyCh)
  303. asyncNotify(s.sendNotifyCh)
  304. }
  305. // incrSendWindow updates the size of our send window
  306. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  307. if err := s.processFlags(flags); err != nil {
  308. return err
  309. }
  310. // Increase window, unblock a sender
  311. atomic.AddUint32(&s.sendWindow, hdr.Length())
  312. asyncNotify(s.sendNotifyCh)
  313. return nil
  314. }
  315. // readData is used to handle a data frame
  316. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  317. if err := s.processFlags(flags); err != nil {
  318. return err
  319. }
  320. // Check that our recv window is not exceeded
  321. length := hdr.Length()
  322. if length == 0 {
  323. return nil
  324. }
  325. if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
  326. s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
  327. return ErrRecvWindowExceeded
  328. }
  329. // Wrap in a limited reader
  330. conn = &io.LimitedReader{R: conn, N: int64(length)}
  331. // Copy into buffer
  332. s.recvLock.Lock()
  333. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  334. s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
  335. s.recvLock.Unlock()
  336. return err
  337. }
  338. // Decrement the receive window
  339. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  340. s.recvLock.Unlock()
  341. // Unblock any readers
  342. asyncNotify(s.recvNotifyCh)
  343. return nil
  344. }
  345. // SetDeadline sets the read and write deadlines
  346. func (s *Stream) SetDeadline(t time.Time) error {
  347. if err := s.SetReadDeadline(t); err != nil {
  348. return err
  349. }
  350. if err := s.SetWriteDeadline(t); err != nil {
  351. return err
  352. }
  353. return nil
  354. }
  355. // SetReadDeadline sets the deadline for future Read calls.
  356. func (s *Stream) SetReadDeadline(t time.Time) error {
  357. s.readDeadline = t
  358. return nil
  359. }
  360. // SetWriteDeadline sets the deadline for future Write calls
  361. func (s *Stream) SetWriteDeadline(t time.Time) error {
  362. s.writeDeadline = t
  363. return nil
  364. }