stream.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. )
  19. // Stream is used to represent a logical stream
  20. // within a session.
  21. type Stream struct {
  22. recvWindow uint32
  23. sendWindow uint32
  24. id uint32
  25. session *Session
  26. state streamState
  27. stateLock sync.Mutex
  28. recvBuf bytes.Buffer
  29. recvLock sync.Mutex
  30. controlHdr header
  31. controlHdrLock sync.Mutex
  32. sendHdr header
  33. sendLock sync.Mutex
  34. recvNotifyCh chan struct{}
  35. sendNotifyCh chan struct{}
  36. readDeadline time.Time
  37. writeDeadline time.Time
  38. }
  39. // newStream is used to construct a new stream within
  40. // a given session for an ID
  41. func newStream(session *Session, id uint32, state streamState) *Stream {
  42. s := &Stream{
  43. id: id,
  44. session: session,
  45. state: state,
  46. controlHdr: header(make([]byte, headerSize)),
  47. sendHdr: header(make([]byte, headerSize)),
  48. recvWindow: initialStreamWindow,
  49. sendWindow: initialStreamWindow,
  50. recvNotifyCh: make(chan struct{}, 1),
  51. sendNotifyCh: make(chan struct{}, 1),
  52. }
  53. return s
  54. }
  55. // Session returns the associated stream session
  56. func (s *Stream) Session() *Session {
  57. return s.session
  58. }
  59. // StreamID returns the ID of this stream
  60. func (s *Stream) StreamID() uint32 {
  61. return s.id
  62. }
  63. // Read is used to read from the stream
  64. func (s *Stream) Read(b []byte) (n int, err error) {
  65. defer asyncNotify(s.recvNotifyCh)
  66. START:
  67. s.stateLock.Lock()
  68. switch s.state {
  69. case streamRemoteClose:
  70. fallthrough
  71. case streamClosed:
  72. if s.recvBuf.Len() == 0 {
  73. s.stateLock.Unlock()
  74. return 0, io.EOF
  75. }
  76. }
  77. s.stateLock.Unlock()
  78. // If there is no data available, block
  79. s.recvLock.Lock()
  80. if s.recvBuf.Len() == 0 {
  81. s.recvLock.Unlock()
  82. goto WAIT
  83. }
  84. // Read any bytes
  85. n, _ = s.recvBuf.Read(b)
  86. s.recvLock.Unlock()
  87. // Send a window update potentially
  88. err = s.sendWindowUpdate()
  89. return n, err
  90. WAIT:
  91. var timeout <-chan time.Time
  92. if !s.readDeadline.IsZero() {
  93. delay := s.readDeadline.Sub(time.Now())
  94. timeout = time.After(delay)
  95. }
  96. select {
  97. case <-s.recvNotifyCh:
  98. goto START
  99. case <-timeout:
  100. return 0, ErrTimeout
  101. }
  102. }
  103. // Write is used to write to the stream
  104. func (s *Stream) Write(b []byte) (n int, err error) {
  105. s.sendLock.Lock()
  106. defer s.sendLock.Unlock()
  107. total := 0
  108. for total < len(b) {
  109. n, err := s.write(b[total:])
  110. total += n
  111. if err != nil {
  112. return total, err
  113. }
  114. }
  115. return total, nil
  116. }
  117. // write is used to write to the stream, may return on
  118. // a short write.
  119. func (s *Stream) write(b []byte) (n int, err error) {
  120. var flags uint16
  121. var max uint32
  122. var body io.Reader
  123. START:
  124. s.stateLock.Lock()
  125. switch s.state {
  126. case streamLocalClose:
  127. fallthrough
  128. case streamClosed:
  129. s.stateLock.Unlock()
  130. return 0, ErrStreamClosed
  131. }
  132. s.stateLock.Unlock()
  133. // If there is no data available, block
  134. if atomic.LoadUint32(&s.sendWindow) == 0 {
  135. goto WAIT
  136. }
  137. // Determine the flags if any
  138. flags = s.sendFlags()
  139. // Send up to our send window
  140. max = min(s.sendWindow, uint32(len(b)))
  141. body = bytes.NewReader(b[:max])
  142. // Send the header
  143. s.sendHdr.encode(typeData, flags, s.id, max)
  144. if err := s.session.waitForSend(s.sendHdr, body); err != nil {
  145. return 0, err
  146. }
  147. // Reduce our send window
  148. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  149. // Unlock
  150. return int(max), err
  151. WAIT:
  152. var timeout <-chan time.Time
  153. if !s.writeDeadline.IsZero() {
  154. delay := s.writeDeadline.Sub(time.Now())
  155. timeout = time.After(delay)
  156. }
  157. select {
  158. case <-s.sendNotifyCh:
  159. goto START
  160. case <-timeout:
  161. return 0, ErrTimeout
  162. }
  163. return 0, nil
  164. }
  165. // sendFlags determines any flags that are appropriate
  166. // based on the current stream state
  167. func (s *Stream) sendFlags() uint16 {
  168. s.stateLock.Lock()
  169. defer s.stateLock.Unlock()
  170. var flags uint16
  171. switch s.state {
  172. case streamInit:
  173. flags |= flagSYN
  174. s.state = streamSYNSent
  175. case streamSYNReceived:
  176. flags |= flagACK
  177. s.state = streamEstablished
  178. }
  179. return flags
  180. }
  181. // sendWindowUpdate potentially sends a window update enabling
  182. // further writes to take place. Must be invoked with the lock.
  183. func (s *Stream) sendWindowUpdate() error {
  184. s.controlHdrLock.Lock()
  185. defer s.controlHdrLock.Unlock()
  186. // Determine the delta update
  187. max := s.session.config.MaxStreamWindowSize
  188. delta := max - s.recvWindow
  189. // Determine the flags if any
  190. flags := s.sendFlags()
  191. // Check if we can omit the update
  192. if delta < (max/2) && flags == 0 {
  193. return nil
  194. }
  195. // Send the header
  196. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  197. if err := s.session.waitForSend(s.controlHdr, nil); err != nil {
  198. return err
  199. }
  200. // Update our window
  201. s.recvWindow += delta
  202. return nil
  203. }
  204. // sendClose is used to send a FIN
  205. func (s *Stream) sendClose() error {
  206. s.controlHdrLock.Lock()
  207. defer s.controlHdrLock.Unlock()
  208. flags := s.sendFlags()
  209. flags |= flagFIN
  210. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  211. if err := s.session.sendNoWait(s.controlHdr); err != nil {
  212. return err
  213. }
  214. return nil
  215. }
  216. // Close is used to close the stream
  217. func (s *Stream) Close() error {
  218. s.stateLock.Lock()
  219. switch s.state {
  220. // Opened means we need to signal a close
  221. case streamSYNSent:
  222. fallthrough
  223. case streamSYNReceived:
  224. fallthrough
  225. case streamEstablished:
  226. s.state = streamLocalClose
  227. goto SEND_CLOSE
  228. case streamLocalClose:
  229. case streamRemoteClose:
  230. s.state = streamClosed
  231. s.session.closeStream(s.id, false)
  232. goto SEND_CLOSE
  233. case streamClosed:
  234. default:
  235. panic("unhandled state")
  236. }
  237. s.stateLock.Unlock()
  238. return nil
  239. SEND_CLOSE:
  240. s.stateLock.Unlock()
  241. s.sendClose()
  242. s.notifyWaiting()
  243. return nil
  244. }
  245. // forceClose is used for when the session is exiting
  246. func (s *Stream) forceClose() {
  247. s.stateLock.Lock()
  248. s.state = streamClosed
  249. s.stateLock.Unlock()
  250. s.notifyWaiting()
  251. }
  252. // processFlags is used to update the state of the stream
  253. // based on set flags, if any. Lock must be held
  254. func (s *Stream) processFlags(flags uint16) error {
  255. s.stateLock.Lock()
  256. defer s.stateLock.Unlock()
  257. if flags&flagACK == flagACK {
  258. if s.state == streamSYNSent {
  259. s.state = streamEstablished
  260. }
  261. } else if flags&flagFIN == flagFIN {
  262. switch s.state {
  263. case streamSYNSent:
  264. fallthrough
  265. case streamSYNReceived:
  266. fallthrough
  267. case streamEstablished:
  268. s.state = streamRemoteClose
  269. s.notifyWaiting()
  270. case streamLocalClose:
  271. s.state = streamClosed
  272. s.session.closeStream(s.id, true)
  273. s.notifyWaiting()
  274. default:
  275. return ErrUnexpectedFlag
  276. }
  277. } else if flags&flagRST == flagRST {
  278. s.state = streamClosed
  279. s.session.closeStream(s.id, true)
  280. s.notifyWaiting()
  281. }
  282. return nil
  283. }
  284. // notifyWaiting notifies all the waiting channels
  285. func (s *Stream) notifyWaiting() {
  286. asyncNotify(s.recvNotifyCh)
  287. asyncNotify(s.sendNotifyCh)
  288. }
  289. // incrSendWindow updates the size of our send window
  290. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  291. if err := s.processFlags(flags); err != nil {
  292. return err
  293. }
  294. // Increase window, unblock a sender
  295. atomic.AddUint32(&s.sendWindow, hdr.Length())
  296. asyncNotify(s.sendNotifyCh)
  297. return nil
  298. }
  299. // readData is used to handle a data frame
  300. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  301. if err := s.processFlags(flags); err != nil {
  302. return err
  303. }
  304. // Check that our recv window is not exceeded
  305. length := hdr.Length()
  306. if length == 0 {
  307. return nil
  308. }
  309. if length > atomic.LoadUint32(&s.recvWindow) {
  310. return ErrRecvWindowExceeded
  311. }
  312. // Decrement the receive window
  313. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  314. // Wrap in a limited reader
  315. conn = &io.LimitedReader{R: conn, N: int64(length)}
  316. // Copy to our buffer
  317. s.recvLock.Lock()
  318. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  319. return err
  320. }
  321. s.recvLock.Unlock()
  322. // Unblock any readers
  323. asyncNotify(s.recvNotifyCh)
  324. return nil
  325. }
  326. // SetDeadline sets the read and write deadlines
  327. func (s *Stream) SetDeadline(t time.Time) error {
  328. if err := s.SetReadDeadline(t); err != nil {
  329. return err
  330. }
  331. if err := s.SetWriteDeadline(t); err != nil {
  332. return err
  333. }
  334. return nil
  335. }
  336. // SetReadDeadline sets the deadline for future Read calls.
  337. func (s *Stream) SetReadDeadline(t time.Time) error {
  338. s.readDeadline = t
  339. return nil
  340. }
  341. // SetWriteDeadline sets the deadline for future Write calls
  342. func (s *Stream) SetWriteDeadline(t time.Time) error {
  343. s.writeDeadline = t
  344. return nil
  345. }