stream.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. streamReset
  19. )
  20. // Stream is used to represent a logical stream
  21. // within a session.
  22. type Stream struct {
  23. recvWindow uint32
  24. sendWindow uint32
  25. id uint32
  26. session *Session
  27. state streamState
  28. stateLock sync.Mutex
  29. recvBuf bytes.Buffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlErr chan error
  33. controlHdrLock sync.Mutex
  34. sendHdr header
  35. sendErr chan error
  36. sendLock sync.Mutex
  37. recvNotifyCh chan struct{}
  38. sendNotifyCh chan struct{}
  39. readDeadline time.Time
  40. writeDeadline time.Time
  41. }
  42. // newStream is used to construct a new stream within
  43. // a given session for an ID
  44. func newStream(session *Session, id uint32, state streamState) *Stream {
  45. s := &Stream{
  46. id: id,
  47. session: session,
  48. state: state,
  49. controlHdr: header(make([]byte, headerSize)),
  50. controlErr: make(chan error, 1),
  51. sendHdr: header(make([]byte, headerSize)),
  52. sendErr: make(chan error, 1),
  53. recvWindow: initialStreamWindow,
  54. sendWindow: initialStreamWindow,
  55. recvNotifyCh: make(chan struct{}, 1),
  56. sendNotifyCh: make(chan struct{}, 1),
  57. }
  58. return s
  59. }
  60. // Session returns the associated stream session
  61. func (s *Stream) Session() *Session {
  62. return s.session
  63. }
  64. // StreamID returns the ID of this stream
  65. func (s *Stream) StreamID() uint32 {
  66. return s.id
  67. }
  68. // Read is used to read from the stream
  69. func (s *Stream) Read(b []byte) (n int, err error) {
  70. defer asyncNotify(s.recvNotifyCh)
  71. START:
  72. s.stateLock.Lock()
  73. switch s.state {
  74. case streamRemoteClose:
  75. fallthrough
  76. case streamClosed:
  77. if s.recvBuf.Len() == 0 {
  78. s.stateLock.Unlock()
  79. return 0, io.EOF
  80. }
  81. case streamReset:
  82. s.stateLock.Unlock()
  83. return 0, ErrConnectionReset
  84. }
  85. s.stateLock.Unlock()
  86. // If there is no data available, block
  87. s.recvLock.Lock()
  88. if s.recvBuf.Len() == 0 {
  89. s.recvLock.Unlock()
  90. goto WAIT
  91. }
  92. // Read any bytes
  93. n, _ = s.recvBuf.Read(b)
  94. s.recvLock.Unlock()
  95. // Send a window update potentially
  96. err = s.sendWindowUpdate()
  97. return n, err
  98. WAIT:
  99. var timeout <-chan time.Time
  100. if !s.readDeadline.IsZero() {
  101. delay := s.readDeadline.Sub(time.Now())
  102. timeout = time.After(delay)
  103. }
  104. select {
  105. case <-s.recvNotifyCh:
  106. goto START
  107. case <-timeout:
  108. return 0, ErrTimeout
  109. }
  110. }
  111. // Write is used to write to the stream
  112. func (s *Stream) Write(b []byte) (n int, err error) {
  113. s.sendLock.Lock()
  114. defer s.sendLock.Unlock()
  115. total := 0
  116. for total < len(b) {
  117. n, err := s.write(b[total:])
  118. total += n
  119. if err != nil {
  120. return total, err
  121. }
  122. }
  123. return total, nil
  124. }
  125. // write is used to write to the stream, may return on
  126. // a short write.
  127. func (s *Stream) write(b []byte) (n int, err error) {
  128. var flags uint16
  129. var max uint32
  130. var body io.Reader
  131. START:
  132. s.stateLock.Lock()
  133. switch s.state {
  134. case streamLocalClose:
  135. fallthrough
  136. case streamClosed:
  137. s.stateLock.Unlock()
  138. return 0, ErrStreamClosed
  139. case streamReset:
  140. s.stateLock.Unlock()
  141. return 0, ErrConnectionReset
  142. }
  143. s.stateLock.Unlock()
  144. // If there is no data available, block
  145. window := atomic.LoadUint32(&s.sendWindow)
  146. if window == 0 {
  147. goto WAIT
  148. }
  149. // Determine the flags if any
  150. flags = s.sendFlags()
  151. // Send up to our send window
  152. max = min(window, uint32(len(b)))
  153. body = bytes.NewReader(b[:max])
  154. // Send the header
  155. s.sendHdr.encode(typeData, flags, s.id, max)
  156. if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
  157. return 0, err
  158. }
  159. // Reduce our send window
  160. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  161. // Unlock
  162. return int(max), err
  163. WAIT:
  164. var timeout <-chan time.Time
  165. if !s.writeDeadline.IsZero() {
  166. delay := s.writeDeadline.Sub(time.Now())
  167. timeout = time.After(delay)
  168. }
  169. select {
  170. case <-s.sendNotifyCh:
  171. goto START
  172. case <-timeout:
  173. return 0, ErrTimeout
  174. }
  175. return 0, nil
  176. }
  177. // sendFlags determines any flags that are appropriate
  178. // based on the current stream state
  179. func (s *Stream) sendFlags() uint16 {
  180. s.stateLock.Lock()
  181. defer s.stateLock.Unlock()
  182. var flags uint16
  183. switch s.state {
  184. case streamInit:
  185. flags |= flagSYN
  186. s.state = streamSYNSent
  187. case streamSYNReceived:
  188. flags |= flagACK
  189. s.state = streamEstablished
  190. }
  191. return flags
  192. }
  193. // sendWindowUpdate potentially sends a window update enabling
  194. // further writes to take place. Must be invoked with the lock.
  195. func (s *Stream) sendWindowUpdate() error {
  196. s.controlHdrLock.Lock()
  197. defer s.controlHdrLock.Unlock()
  198. // Determine the delta update
  199. max := s.session.config.MaxStreamWindowSize
  200. delta := max - atomic.LoadUint32(&s.recvWindow)
  201. // Determine the flags if any
  202. flags := s.sendFlags()
  203. // Check if we can omit the update
  204. if delta < (max/2) && flags == 0 {
  205. return nil
  206. }
  207. // Update our window
  208. atomic.AddUint32(&s.recvWindow, delta)
  209. // Send the header
  210. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  211. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  212. return err
  213. }
  214. return nil
  215. }
  216. // sendClose is used to send a FIN
  217. func (s *Stream) sendClose() error {
  218. s.controlHdrLock.Lock()
  219. defer s.controlHdrLock.Unlock()
  220. flags := s.sendFlags()
  221. flags |= flagFIN
  222. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  223. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  224. return err
  225. }
  226. return nil
  227. }
  228. // Close is used to close the stream
  229. func (s *Stream) Close() error {
  230. s.stateLock.Lock()
  231. switch s.state {
  232. // Opened means we need to signal a close
  233. case streamSYNSent:
  234. fallthrough
  235. case streamSYNReceived:
  236. fallthrough
  237. case streamEstablished:
  238. s.state = streamLocalClose
  239. goto SEND_CLOSE
  240. case streamLocalClose:
  241. case streamRemoteClose:
  242. s.state = streamClosed
  243. s.session.closeStream(s.id)
  244. goto SEND_CLOSE
  245. case streamClosed:
  246. case streamReset:
  247. default:
  248. panic("unhandled state")
  249. }
  250. s.stateLock.Unlock()
  251. return nil
  252. SEND_CLOSE:
  253. s.stateLock.Unlock()
  254. s.sendClose()
  255. s.notifyWaiting()
  256. return nil
  257. }
  258. // forceClose is used for when the session is exiting
  259. func (s *Stream) forceClose() {
  260. s.stateLock.Lock()
  261. s.state = streamClosed
  262. s.stateLock.Unlock()
  263. s.notifyWaiting()
  264. }
  265. // processFlags is used to update the state of the stream
  266. // based on set flags, if any. Lock must be held
  267. func (s *Stream) processFlags(flags uint16) error {
  268. s.stateLock.Lock()
  269. defer s.stateLock.Unlock()
  270. if flags&flagACK == flagACK {
  271. if s.state == streamSYNSent {
  272. s.state = streamEstablished
  273. }
  274. } else if flags&flagFIN == flagFIN {
  275. switch s.state {
  276. case streamSYNSent:
  277. fallthrough
  278. case streamSYNReceived:
  279. fallthrough
  280. case streamEstablished:
  281. s.state = streamRemoteClose
  282. s.notifyWaiting()
  283. case streamLocalClose:
  284. s.state = streamClosed
  285. s.session.closeStream(s.id)
  286. s.notifyWaiting()
  287. default:
  288. s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
  289. return ErrUnexpectedFlag
  290. }
  291. } else if flags&flagRST == flagRST {
  292. s.state = streamReset
  293. s.session.closeStream(s.id)
  294. s.notifyWaiting()
  295. }
  296. return nil
  297. }
  298. // notifyWaiting notifies all the waiting channels
  299. func (s *Stream) notifyWaiting() {
  300. asyncNotify(s.recvNotifyCh)
  301. asyncNotify(s.sendNotifyCh)
  302. }
  303. // incrSendWindow updates the size of our send window
  304. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  305. if err := s.processFlags(flags); err != nil {
  306. return err
  307. }
  308. // Increase window, unblock a sender
  309. atomic.AddUint32(&s.sendWindow, hdr.Length())
  310. asyncNotify(s.sendNotifyCh)
  311. return nil
  312. }
  313. // readData is used to handle a data frame
  314. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  315. if err := s.processFlags(flags); err != nil {
  316. return err
  317. }
  318. // Check that our recv window is not exceeded
  319. length := hdr.Length()
  320. if length == 0 {
  321. return nil
  322. }
  323. if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
  324. s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
  325. return ErrRecvWindowExceeded
  326. }
  327. // Wrap in a limited reader
  328. conn = &io.LimitedReader{R: conn, N: int64(length)}
  329. // Copy into buffer
  330. s.recvLock.Lock()
  331. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  332. s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
  333. s.recvLock.Unlock()
  334. return err
  335. }
  336. // Decrement the receive window
  337. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  338. s.recvLock.Unlock()
  339. // Unblock any readers
  340. asyncNotify(s.recvNotifyCh)
  341. return nil
  342. }
  343. // SetDeadline sets the read and write deadlines
  344. func (s *Stream) SetDeadline(t time.Time) error {
  345. if err := s.SetReadDeadline(t); err != nil {
  346. return err
  347. }
  348. if err := s.SetWriteDeadline(t); err != nil {
  349. return err
  350. }
  351. return nil
  352. }
  353. // SetReadDeadline sets the deadline for future Read calls.
  354. func (s *Stream) SetReadDeadline(t time.Time) error {
  355. s.readDeadline = t
  356. return nil
  357. }
  358. // SetWriteDeadline sets the deadline for future Write calls
  359. func (s *Stream) SetWriteDeadline(t time.Time) error {
  360. s.writeDeadline = t
  361. return nil
  362. }