stream.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. )
  19. // Stream is used to represent a logical stream
  20. // within a session.
  21. type Stream struct {
  22. recvWindow uint32
  23. sendWindow uint32
  24. id uint32
  25. session *Session
  26. state streamState
  27. stateLock sync.Mutex
  28. recvBuf bytes.Buffer
  29. waitingBuffer *directBuffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlHdrLock sync.Mutex
  33. sendHdr header
  34. sendLock sync.Mutex
  35. recvNotifyCh chan struct{}
  36. sendNotifyCh chan struct{}
  37. readDeadline time.Time
  38. writeDeadline time.Time
  39. }
  40. type directBuffer struct {
  41. buf []byte
  42. bytes int
  43. }
  44. // newStream is used to construct a new stream within
  45. // a given session for an ID
  46. func newStream(session *Session, id uint32, state streamState) *Stream {
  47. s := &Stream{
  48. id: id,
  49. session: session,
  50. state: state,
  51. controlHdr: header(make([]byte, headerSize)),
  52. sendHdr: header(make([]byte, headerSize)),
  53. recvWindow: initialStreamWindow,
  54. sendWindow: initialStreamWindow,
  55. recvNotifyCh: make(chan struct{}, 1),
  56. sendNotifyCh: make(chan struct{}, 1),
  57. }
  58. return s
  59. }
  60. // Session returns the associated stream session
  61. func (s *Stream) Session() *Session {
  62. return s.session
  63. }
  64. // StreamID returns the ID of this stream
  65. func (s *Stream) StreamID() uint32 {
  66. return s.id
  67. }
  68. // Read is used to read from the stream
  69. func (s *Stream) Read(b []byte) (n int, err error) {
  70. defer asyncNotify(s.recvNotifyCh)
  71. var dBuf *directBuffer
  72. START:
  73. s.stateLock.Lock()
  74. switch s.state {
  75. case streamRemoteClose:
  76. fallthrough
  77. case streamClosed:
  78. if s.recvBuf.Len() == 0 {
  79. s.stateLock.Unlock()
  80. return 0, io.EOF
  81. }
  82. }
  83. s.stateLock.Unlock()
  84. // If there is no data available, block
  85. s.recvLock.Lock()
  86. if s.recvBuf.Len() == 0 {
  87. // Mark ourself as waiting potentially
  88. if s.waitingBuffer == nil {
  89. dBuf = &directBuffer{buf: b}
  90. s.waitingBuffer = dBuf
  91. }
  92. s.recvLock.Unlock()
  93. goto WAIT
  94. }
  95. // Read any bytes
  96. n, _ = s.recvBuf.Read(b)
  97. s.recvLock.Unlock()
  98. // Send a window update potentially
  99. err = s.sendWindowUpdate()
  100. return n, err
  101. WAIT:
  102. var timeout <-chan time.Time
  103. if !s.readDeadline.IsZero() {
  104. delay := s.readDeadline.Sub(time.Now())
  105. timeout = time.After(delay)
  106. }
  107. select {
  108. case <-s.recvNotifyCh:
  109. if dBuf != nil && dBuf.bytes > 0 {
  110. return dBuf.bytes, nil
  111. }
  112. goto START
  113. case <-timeout:
  114. if dBuf != nil && dBuf.bytes > 0 {
  115. return dBuf.bytes, nil
  116. }
  117. return 0, ErrTimeout
  118. }
  119. }
  120. // Write is used to write to the stream
  121. func (s *Stream) Write(b []byte) (n int, err error) {
  122. s.sendLock.Lock()
  123. defer s.sendLock.Unlock()
  124. total := 0
  125. for total < len(b) {
  126. n, err := s.write(b[total:])
  127. total += n
  128. if err != nil {
  129. return total, err
  130. }
  131. }
  132. return total, nil
  133. }
  134. // write is used to write to the stream, may return on
  135. // a short write.
  136. func (s *Stream) write(b []byte) (n int, err error) {
  137. var flags uint16
  138. var max uint32
  139. var body io.Reader
  140. START:
  141. s.stateLock.Lock()
  142. switch s.state {
  143. case streamLocalClose:
  144. fallthrough
  145. case streamClosed:
  146. s.stateLock.Unlock()
  147. return 0, ErrStreamClosed
  148. }
  149. s.stateLock.Unlock()
  150. // If there is no data available, block
  151. if atomic.LoadUint32(&s.sendWindow) == 0 {
  152. goto WAIT
  153. }
  154. // Determine the flags if any
  155. flags = s.sendFlags()
  156. // Send up to our send window
  157. max = min(s.sendWindow, uint32(len(b)))
  158. body = bytes.NewReader(b[:max])
  159. // Send the header
  160. s.sendHdr.encode(typeData, flags, s.id, max)
  161. if err := s.session.waitForSend(s.sendHdr, body); err != nil {
  162. return 0, err
  163. }
  164. // Reduce our send window
  165. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  166. // Unlock
  167. return int(max), err
  168. WAIT:
  169. var timeout <-chan time.Time
  170. if !s.writeDeadline.IsZero() {
  171. delay := s.writeDeadline.Sub(time.Now())
  172. timeout = time.After(delay)
  173. }
  174. select {
  175. case <-s.sendNotifyCh:
  176. goto START
  177. case <-timeout:
  178. return 0, ErrTimeout
  179. }
  180. return 0, nil
  181. }
  182. // sendFlags determines any flags that are appropriate
  183. // based on the current stream state
  184. func (s *Stream) sendFlags() uint16 {
  185. s.stateLock.Lock()
  186. defer s.stateLock.Unlock()
  187. var flags uint16
  188. switch s.state {
  189. case streamInit:
  190. flags |= flagSYN
  191. s.state = streamSYNSent
  192. case streamSYNReceived:
  193. flags |= flagACK
  194. s.state = streamEstablished
  195. }
  196. return flags
  197. }
  198. // sendWindowUpdate potentially sends a window update enabling
  199. // further writes to take place. Must be invoked with the lock.
  200. func (s *Stream) sendWindowUpdate() error {
  201. s.controlHdrLock.Lock()
  202. defer s.controlHdrLock.Unlock()
  203. // Determine the delta update
  204. max := s.session.config.MaxStreamWindowSize
  205. delta := max - s.recvWindow
  206. // Determine the flags if any
  207. flags := s.sendFlags()
  208. // Check if we can omit the update
  209. if delta < (max/2) && flags == 0 {
  210. return nil
  211. }
  212. // Send the header
  213. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  214. if err := s.session.waitForSend(s.controlHdr, nil); err != nil {
  215. return err
  216. }
  217. // Update our window
  218. s.recvWindow += delta
  219. return nil
  220. }
  221. // sendClose is used to send a FIN
  222. func (s *Stream) sendClose() error {
  223. s.controlHdrLock.Lock()
  224. defer s.controlHdrLock.Unlock()
  225. flags := s.sendFlags()
  226. flags |= flagFIN
  227. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  228. if err := s.session.sendNoWait(s.controlHdr); err != nil {
  229. return err
  230. }
  231. return nil
  232. }
  233. // Close is used to close the stream
  234. func (s *Stream) Close() error {
  235. s.stateLock.Lock()
  236. switch s.state {
  237. // Opened means we need to signal a close
  238. case streamSYNSent:
  239. fallthrough
  240. case streamSYNReceived:
  241. fallthrough
  242. case streamEstablished:
  243. s.state = streamLocalClose
  244. goto SEND_CLOSE
  245. case streamLocalClose:
  246. case streamRemoteClose:
  247. s.state = streamClosed
  248. s.session.closeStream(s.id, false)
  249. goto SEND_CLOSE
  250. case streamClosed:
  251. default:
  252. panic("unhandled state")
  253. }
  254. s.stateLock.Unlock()
  255. return nil
  256. SEND_CLOSE:
  257. s.stateLock.Unlock()
  258. s.sendClose()
  259. s.notifyWaiting()
  260. return nil
  261. }
  262. // forceClose is used for when the session is exiting
  263. func (s *Stream) forceClose() {
  264. s.stateLock.Lock()
  265. s.state = streamClosed
  266. s.stateLock.Unlock()
  267. s.notifyWaiting()
  268. }
  269. // processFlags is used to update the state of the stream
  270. // based on set flags, if any. Lock must be held
  271. func (s *Stream) processFlags(flags uint16) error {
  272. s.stateLock.Lock()
  273. defer s.stateLock.Unlock()
  274. if flags&flagACK == flagACK {
  275. if s.state == streamSYNSent {
  276. s.state = streamEstablished
  277. }
  278. } else if flags&flagFIN == flagFIN {
  279. switch s.state {
  280. case streamSYNSent:
  281. fallthrough
  282. case streamSYNReceived:
  283. fallthrough
  284. case streamEstablished:
  285. s.state = streamRemoteClose
  286. s.notifyWaiting()
  287. case streamLocalClose:
  288. s.state = streamClosed
  289. s.session.closeStream(s.id, true)
  290. s.notifyWaiting()
  291. default:
  292. return ErrUnexpectedFlag
  293. }
  294. } else if flags&flagRST == flagRST {
  295. s.state = streamClosed
  296. s.session.closeStream(s.id, true)
  297. s.notifyWaiting()
  298. }
  299. return nil
  300. }
  301. // notifyWaiting notifies all the waiting channels
  302. func (s *Stream) notifyWaiting() {
  303. asyncNotify(s.recvNotifyCh)
  304. asyncNotify(s.sendNotifyCh)
  305. }
  306. // incrSendWindow updates the size of our send window
  307. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  308. if err := s.processFlags(flags); err != nil {
  309. return err
  310. }
  311. // Increase window, unblock a sender
  312. atomic.AddUint32(&s.sendWindow, hdr.Length())
  313. asyncNotify(s.sendNotifyCh)
  314. return nil
  315. }
  316. // readData is used to handle a data frame
  317. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  318. if err := s.processFlags(flags); err != nil {
  319. return err
  320. }
  321. // Check that our recv window is not exceeded
  322. length := hdr.Length()
  323. if length == 0 {
  324. return nil
  325. }
  326. if length > atomic.LoadUint32(&s.recvWindow) {
  327. return ErrRecvWindowExceeded
  328. }
  329. // Decrement the receive window
  330. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  331. // Wrap in a limited reader
  332. conn = &io.LimitedReader{R: conn, N: int64(length)}
  333. // Copy into waiting buffer if any
  334. s.recvLock.Lock()
  335. if s.waitingBuffer != nil {
  336. n, err := conn.Read(s.waitingBuffer.buf)
  337. s.waitingBuffer.bytes = n
  338. s.waitingBuffer = nil
  339. if err != nil {
  340. s.recvLock.Unlock()
  341. return err
  342. }
  343. }
  344. // Copy to our buffer anything left
  345. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  346. s.recvLock.Unlock()
  347. return err
  348. }
  349. s.recvLock.Unlock()
  350. // Unblock any readers
  351. asyncNotify(s.recvNotifyCh)
  352. return nil
  353. }
  354. // SetDeadline sets the read and write deadlines
  355. func (s *Stream) SetDeadline(t time.Time) error {
  356. if err := s.SetReadDeadline(t); err != nil {
  357. return err
  358. }
  359. if err := s.SetWriteDeadline(t); err != nil {
  360. return err
  361. }
  362. return nil
  363. }
  364. // SetReadDeadline sets the deadline for future Read calls.
  365. func (s *Stream) SetReadDeadline(t time.Time) error {
  366. s.readDeadline = t
  367. return nil
  368. }
  369. // SetWriteDeadline sets the deadline for future Write calls
  370. func (s *Stream) SetWriteDeadline(t time.Time) error {
  371. s.writeDeadline = t
  372. return nil
  373. }