stream.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. )
  19. // Stream is used to represent a logical stream
  20. // within a session.
  21. type Stream struct {
  22. recvWindow uint32
  23. sendWindow uint32
  24. id uint32
  25. session *Session
  26. state streamState
  27. stateLock sync.Mutex
  28. recvBuf bytes.Buffer
  29. waitingBuffer *directBuffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlErr chan error
  33. controlHdrLock sync.Mutex
  34. sendHdr header
  35. sendErr chan error
  36. sendLock sync.Mutex
  37. recvNotifyCh chan struct{}
  38. sendNotifyCh chan struct{}
  39. readDeadline time.Time
  40. writeDeadline time.Time
  41. }
  42. type directBuffer struct {
  43. buf []byte
  44. bytes int
  45. }
  46. // newStream is used to construct a new stream within
  47. // a given session for an ID
  48. func newStream(session *Session, id uint32, state streamState) *Stream {
  49. s := &Stream{
  50. id: id,
  51. session: session,
  52. state: state,
  53. controlHdr: header(make([]byte, headerSize)),
  54. controlErr: make(chan error, 1),
  55. sendHdr: header(make([]byte, headerSize)),
  56. sendErr: make(chan error, 1),
  57. recvWindow: initialStreamWindow,
  58. sendWindow: initialStreamWindow,
  59. recvNotifyCh: make(chan struct{}, 1),
  60. sendNotifyCh: make(chan struct{}, 1),
  61. }
  62. return s
  63. }
  64. // Session returns the associated stream session
  65. func (s *Stream) Session() *Session {
  66. return s.session
  67. }
  68. // StreamID returns the ID of this stream
  69. func (s *Stream) StreamID() uint32 {
  70. return s.id
  71. }
  72. // Read is used to read from the stream
  73. func (s *Stream) Read(b []byte) (n int, err error) {
  74. defer asyncNotify(s.recvNotifyCh)
  75. var dBuf *directBuffer
  76. START:
  77. s.stateLock.Lock()
  78. switch s.state {
  79. case streamRemoteClose:
  80. fallthrough
  81. case streamClosed:
  82. if s.recvBuf.Len() == 0 {
  83. s.stateLock.Unlock()
  84. return 0, io.EOF
  85. }
  86. }
  87. s.stateLock.Unlock()
  88. // If there is no data available, block
  89. s.recvLock.Lock()
  90. if s.recvBuf.Len() == 0 {
  91. // Mark ourself as waiting potentially
  92. if s.waitingBuffer == nil {
  93. dBuf = &directBuffer{buf: b}
  94. s.waitingBuffer = dBuf
  95. }
  96. s.recvLock.Unlock()
  97. goto WAIT
  98. }
  99. // Read any bytes
  100. n, _ = s.recvBuf.Read(b)
  101. s.recvLock.Unlock()
  102. // Send a window update potentially
  103. err = s.sendWindowUpdate()
  104. return n, err
  105. WAIT:
  106. var timeout <-chan time.Time
  107. if !s.readDeadline.IsZero() {
  108. delay := s.readDeadline.Sub(time.Now())
  109. timeout = time.After(delay)
  110. }
  111. select {
  112. case <-s.recvNotifyCh:
  113. if dBuf != nil && dBuf.bytes > 0 {
  114. return dBuf.bytes, nil
  115. }
  116. goto START
  117. case <-timeout:
  118. if dBuf != nil && dBuf.bytes > 0 {
  119. return dBuf.bytes, nil
  120. }
  121. return 0, ErrTimeout
  122. }
  123. }
  124. // Write is used to write to the stream
  125. func (s *Stream) Write(b []byte) (n int, err error) {
  126. s.sendLock.Lock()
  127. defer s.sendLock.Unlock()
  128. total := 0
  129. for total < len(b) {
  130. n, err := s.write(b[total:])
  131. total += n
  132. if err != nil {
  133. return total, err
  134. }
  135. }
  136. return total, nil
  137. }
  138. // write is used to write to the stream, may return on
  139. // a short write.
  140. func (s *Stream) write(b []byte) (n int, err error) {
  141. var flags uint16
  142. var max uint32
  143. var body io.Reader
  144. START:
  145. s.stateLock.Lock()
  146. switch s.state {
  147. case streamLocalClose:
  148. fallthrough
  149. case streamClosed:
  150. s.stateLock.Unlock()
  151. return 0, ErrStreamClosed
  152. }
  153. s.stateLock.Unlock()
  154. // If there is no data available, block
  155. if atomic.LoadUint32(&s.sendWindow) == 0 {
  156. goto WAIT
  157. }
  158. // Determine the flags if any
  159. flags = s.sendFlags()
  160. // Send up to our send window
  161. max = min(s.sendWindow, uint32(len(b)))
  162. body = bytes.NewReader(b[:max])
  163. // Send the header
  164. s.sendHdr.encode(typeData, flags, s.id, max)
  165. if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
  166. return 0, err
  167. }
  168. // Reduce our send window
  169. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  170. // Unlock
  171. return int(max), err
  172. WAIT:
  173. var timeout <-chan time.Time
  174. if !s.writeDeadline.IsZero() {
  175. delay := s.writeDeadline.Sub(time.Now())
  176. timeout = time.After(delay)
  177. }
  178. select {
  179. case <-s.sendNotifyCh:
  180. goto START
  181. case <-timeout:
  182. return 0, ErrTimeout
  183. }
  184. return 0, nil
  185. }
  186. // sendFlags determines any flags that are appropriate
  187. // based on the current stream state
  188. func (s *Stream) sendFlags() uint16 {
  189. s.stateLock.Lock()
  190. defer s.stateLock.Unlock()
  191. var flags uint16
  192. switch s.state {
  193. case streamInit:
  194. flags |= flagSYN
  195. s.state = streamSYNSent
  196. case streamSYNReceived:
  197. flags |= flagACK
  198. s.state = streamEstablished
  199. }
  200. return flags
  201. }
  202. // sendWindowUpdate potentially sends a window update enabling
  203. // further writes to take place. Must be invoked with the lock.
  204. func (s *Stream) sendWindowUpdate() error {
  205. s.controlHdrLock.Lock()
  206. defer s.controlHdrLock.Unlock()
  207. // Determine the delta update
  208. max := s.session.config.MaxStreamWindowSize
  209. delta := max - s.recvWindow
  210. // Determine the flags if any
  211. flags := s.sendFlags()
  212. // Check if we can omit the update
  213. if delta < (max/2) && flags == 0 {
  214. return nil
  215. }
  216. // Send the header
  217. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  218. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  219. return err
  220. }
  221. // Update our window
  222. s.recvWindow += delta
  223. return nil
  224. }
  225. // sendClose is used to send a FIN
  226. func (s *Stream) sendClose() error {
  227. s.controlHdrLock.Lock()
  228. defer s.controlHdrLock.Unlock()
  229. flags := s.sendFlags()
  230. flags |= flagFIN
  231. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  232. if err := s.session.sendNoWait(s.controlHdr); err != nil {
  233. return err
  234. }
  235. return nil
  236. }
  237. // Close is used to close the stream
  238. func (s *Stream) Close() error {
  239. s.stateLock.Lock()
  240. switch s.state {
  241. // Opened means we need to signal a close
  242. case streamSYNSent:
  243. fallthrough
  244. case streamSYNReceived:
  245. fallthrough
  246. case streamEstablished:
  247. s.state = streamLocalClose
  248. goto SEND_CLOSE
  249. case streamLocalClose:
  250. case streamRemoteClose:
  251. s.state = streamClosed
  252. s.session.closeStream(s.id, false)
  253. goto SEND_CLOSE
  254. case streamClosed:
  255. default:
  256. panic("unhandled state")
  257. }
  258. s.stateLock.Unlock()
  259. return nil
  260. SEND_CLOSE:
  261. s.stateLock.Unlock()
  262. s.sendClose()
  263. s.notifyWaiting()
  264. return nil
  265. }
  266. // forceClose is used for when the session is exiting
  267. func (s *Stream) forceClose() {
  268. s.stateLock.Lock()
  269. s.state = streamClosed
  270. s.stateLock.Unlock()
  271. s.notifyWaiting()
  272. }
  273. // processFlags is used to update the state of the stream
  274. // based on set flags, if any. Lock must be held
  275. func (s *Stream) processFlags(flags uint16) error {
  276. s.stateLock.Lock()
  277. defer s.stateLock.Unlock()
  278. if flags&flagACK == flagACK {
  279. if s.state == streamSYNSent {
  280. s.state = streamEstablished
  281. }
  282. } else if flags&flagFIN == flagFIN {
  283. switch s.state {
  284. case streamSYNSent:
  285. fallthrough
  286. case streamSYNReceived:
  287. fallthrough
  288. case streamEstablished:
  289. s.state = streamRemoteClose
  290. s.notifyWaiting()
  291. case streamLocalClose:
  292. s.state = streamClosed
  293. s.session.closeStream(s.id, true)
  294. s.notifyWaiting()
  295. default:
  296. return ErrUnexpectedFlag
  297. }
  298. } else if flags&flagRST == flagRST {
  299. s.state = streamClosed
  300. s.session.closeStream(s.id, true)
  301. s.notifyWaiting()
  302. }
  303. return nil
  304. }
  305. // notifyWaiting notifies all the waiting channels
  306. func (s *Stream) notifyWaiting() {
  307. asyncNotify(s.recvNotifyCh)
  308. asyncNotify(s.sendNotifyCh)
  309. }
  310. // incrSendWindow updates the size of our send window
  311. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  312. if err := s.processFlags(flags); err != nil {
  313. return err
  314. }
  315. // Increase window, unblock a sender
  316. atomic.AddUint32(&s.sendWindow, hdr.Length())
  317. asyncNotify(s.sendNotifyCh)
  318. return nil
  319. }
  320. // readData is used to handle a data frame
  321. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  322. if err := s.processFlags(flags); err != nil {
  323. return err
  324. }
  325. // Check that our recv window is not exceeded
  326. length := hdr.Length()
  327. if length == 0 {
  328. return nil
  329. }
  330. if length > atomic.LoadUint32(&s.recvWindow) {
  331. return ErrRecvWindowExceeded
  332. }
  333. // Decrement the receive window
  334. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  335. // Wrap in a limited reader
  336. conn = &io.LimitedReader{R: conn, N: int64(length)}
  337. // Copy into waiting buffer if any
  338. s.recvLock.Lock()
  339. if s.waitingBuffer != nil {
  340. n, err := conn.Read(s.waitingBuffer.buf)
  341. s.waitingBuffer.bytes = n
  342. s.waitingBuffer = nil
  343. if err != nil {
  344. s.recvLock.Unlock()
  345. return err
  346. }
  347. }
  348. // Copy to our buffer anything left
  349. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  350. s.recvLock.Unlock()
  351. return err
  352. }
  353. s.recvLock.Unlock()
  354. // Unblock any readers
  355. asyncNotify(s.recvNotifyCh)
  356. return nil
  357. }
  358. // SetDeadline sets the read and write deadlines
  359. func (s *Stream) SetDeadline(t time.Time) error {
  360. if err := s.SetReadDeadline(t); err != nil {
  361. return err
  362. }
  363. if err := s.SetWriteDeadline(t); err != nil {
  364. return err
  365. }
  366. return nil
  367. }
  368. // SetReadDeadline sets the deadline for future Read calls.
  369. func (s *Stream) SetReadDeadline(t time.Time) error {
  370. s.readDeadline = t
  371. return nil
  372. }
  373. // SetWriteDeadline sets the deadline for future Write calls
  374. func (s *Stream) SetWriteDeadline(t time.Time) error {
  375. s.writeDeadline = t
  376. return nil
  377. }