stream.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package yamux
  2. import (
  3. "bytes"
  4. "compress/lzw"
  5. "io"
  6. "log"
  7. "sync"
  8. "time"
  9. )
  10. type streamState int
  11. const (
  12. streamInit streamState = iota
  13. streamSYNSent
  14. streamSYNReceived
  15. streamEstablished
  16. streamLocalClose
  17. streamRemoteClose
  18. streamClosed
  19. )
  20. // Stream is used to represent a logical stream
  21. // within a session.
  22. type Stream struct {
  23. id uint32
  24. session *Session
  25. state streamState
  26. lock sync.Mutex
  27. recvBuf bytes.Buffer
  28. sendHdr header
  29. recvWindow uint32
  30. sendWindow uint32
  31. notifyCh chan struct{}
  32. readDeadline time.Time
  33. writeDeadline time.Time
  34. }
  35. // newStream is used to construct a new stream within
  36. // a given session for an ID
  37. func newStream(session *Session, id uint32, state streamState) *Stream {
  38. s := &Stream{
  39. id: id,
  40. session: session,
  41. state: state,
  42. recvWindow: initialStreamWindow,
  43. sendWindow: initialStreamWindow,
  44. notifyCh: make(chan struct{}, 1),
  45. sendHdr: header(make([]byte, headerSize)),
  46. }
  47. return s
  48. }
  49. // Session returns the associated stream session
  50. func (s *Stream) Session() *Session {
  51. return s.session
  52. }
  53. // StreamID returns the ID of this stream
  54. func (s *Stream) StreamID() uint32 {
  55. return s.id
  56. }
  57. // Read is used to read from the stream
  58. func (s *Stream) Read(b []byte) (n int, err error) {
  59. START:
  60. s.lock.Lock()
  61. switch s.state {
  62. case streamRemoteClose:
  63. fallthrough
  64. case streamClosed:
  65. if s.recvBuf.Len() == 0 {
  66. s.lock.Unlock()
  67. return 0, io.EOF
  68. }
  69. }
  70. // If there is no data available, block
  71. if s.recvBuf.Len() == 0 {
  72. s.lock.Unlock()
  73. goto WAIT
  74. }
  75. // Read any bytes
  76. n, _ = s.recvBuf.Read(b)
  77. // Send a window update potentially
  78. err = s.sendWindowUpdate()
  79. s.lock.Unlock()
  80. return n, err
  81. WAIT:
  82. var timeout <-chan time.Time
  83. if !s.readDeadline.IsZero() {
  84. delay := s.readDeadline.Sub(time.Now())
  85. timeout = time.After(delay)
  86. }
  87. select {
  88. case <-s.notifyCh:
  89. goto START
  90. case <-timeout:
  91. return 0, ErrTimeout
  92. }
  93. }
  94. // Write is used to write to the stream
  95. func (s *Stream) Write(b []byte) (n int, err error) {
  96. total := 0
  97. for total < len(b) {
  98. n, err := s.write(b[total:])
  99. total += n
  100. if err != nil {
  101. return total, err
  102. }
  103. }
  104. return total, nil
  105. }
  106. // write is used to write to the stream, may return on
  107. // a short write.
  108. func (s *Stream) write(b []byte) (n int, err error) {
  109. var flags uint16
  110. var max uint32
  111. var body io.Reader
  112. START:
  113. s.lock.Lock()
  114. switch s.state {
  115. case streamLocalClose:
  116. fallthrough
  117. case streamClosed:
  118. s.lock.Unlock()
  119. return 0, ErrStreamClosed
  120. }
  121. // If there is no data available, block
  122. if s.sendWindow == 0 {
  123. s.lock.Unlock()
  124. goto WAIT
  125. }
  126. // Determine the flags if any
  127. flags = s.sendFlags()
  128. // Send up to our send window
  129. max = min(s.sendWindow, uint32(len(b)))
  130. body = bytes.NewReader(b[:max])
  131. // TODO: Compress
  132. // Send the header
  133. s.sendHdr.encode(typeData, flags, s.id, max)
  134. if err := s.session.waitForSend(s.sendHdr, body); err != nil {
  135. s.lock.Unlock()
  136. return 0, err
  137. }
  138. // Reduce our send window
  139. s.sendWindow -= max
  140. // Unlock
  141. s.lock.Unlock()
  142. return int(max), err
  143. WAIT:
  144. var timeout <-chan time.Time
  145. if !s.writeDeadline.IsZero() {
  146. delay := s.writeDeadline.Sub(time.Now())
  147. timeout = time.After(delay)
  148. }
  149. select {
  150. case <-s.notifyCh:
  151. goto START
  152. case <-timeout:
  153. return 0, ErrTimeout
  154. }
  155. return 0, nil
  156. }
  157. // sendFlags determines any flags that are appropriate
  158. // based on the current stream state
  159. func (s *Stream) sendFlags() uint16 {
  160. // Determine the flags if any
  161. var flags uint16
  162. switch s.state {
  163. case streamInit:
  164. flags |= flagSYN
  165. s.state = streamSYNSent
  166. case streamSYNReceived:
  167. flags |= flagACK
  168. s.state = streamEstablished
  169. }
  170. return flags
  171. }
  172. // sendWindowUpdate potentially sends a window update enabling
  173. // further writes to take place. Must be invoked with the lock.
  174. func (s *Stream) sendWindowUpdate() error {
  175. // Determine the delta update
  176. max := s.session.config.MaxStreamWindowSize
  177. delta := max - s.recvWindow
  178. // Determine the flags if any
  179. flags := s.sendFlags()
  180. // Check if we can omit the update
  181. if delta < (max/2) && flags == 0 {
  182. return nil
  183. }
  184. // Send the header
  185. s.sendHdr.encode(typeWindowUpdate, flags, s.id, delta)
  186. if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
  187. return err
  188. }
  189. log.Printf("Window Update %d +%d", s.id, delta)
  190. // Update our window
  191. s.recvWindow += delta
  192. return nil
  193. }
  194. // sendClose is used to send a FIN
  195. func (s *Stream) sendClose() error {
  196. flags := s.sendFlags()
  197. flags |= flagFIN
  198. s.sendHdr.encode(typeWindowUpdate, flags, s.id, 0)
  199. if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
  200. return err
  201. }
  202. return nil
  203. }
  204. // Close is used to close the stream
  205. func (s *Stream) Close() error {
  206. s.lock.Lock()
  207. defer s.lock.Unlock()
  208. switch s.state {
  209. // Local or full close means nothing to do
  210. case streamLocalClose:
  211. fallthrough
  212. case streamClosed:
  213. return nil
  214. // Remote close, weneed to send FIN and we are done
  215. case streamRemoteClose:
  216. s.state = streamClosed
  217. s.session.closeStream(s.id, false)
  218. s.sendClose()
  219. return nil
  220. // Opened means we need to signal a close
  221. case streamSYNSent:
  222. fallthrough
  223. case streamSYNReceived:
  224. fallthrough
  225. case streamEstablished:
  226. s.state = streamLocalClose
  227. s.sendClose()
  228. return nil
  229. }
  230. panic("unhandled state")
  231. }
  232. // forceClose is used for when the session is exiting
  233. func (s *Stream) forceClose() {
  234. s.lock.Lock()
  235. defer s.lock.Unlock()
  236. s.state = streamClosed
  237. asyncNotify(s.notifyCh)
  238. }
  239. // SetDeadline sets the read and write deadlines
  240. func (s *Stream) SetDeadline(t time.Time) error {
  241. if err := s.SetReadDeadline(t); err != nil {
  242. return err
  243. }
  244. if err := s.SetWriteDeadline(t); err != nil {
  245. return err
  246. }
  247. return nil
  248. }
  249. // SetReadDeadline sets the deadline for future Read calls.
  250. func (s *Stream) SetReadDeadline(t time.Time) error {
  251. s.readDeadline = t
  252. return nil
  253. }
  254. // SetWriteDeadline sets the deadline for future Write calls
  255. func (s *Stream) SetWriteDeadline(t time.Time) error {
  256. s.writeDeadline = t
  257. return nil
  258. }
  259. // processFlags is used to update the state of the stream
  260. // based on set flags, if any. Lock must be held
  261. func (s *Stream) processFlags(flags uint16) error {
  262. if flags&flagACK == flagACK {
  263. if s.state == streamSYNSent {
  264. s.state = streamEstablished
  265. }
  266. } else if flags&flagFIN == flagFIN {
  267. switch s.state {
  268. case streamSYNSent:
  269. fallthrough
  270. case streamSYNReceived:
  271. fallthrough
  272. case streamEstablished:
  273. s.state = streamRemoteClose
  274. case streamLocalClose:
  275. s.state = streamClosed
  276. s.session.closeStream(s.id, true)
  277. default:
  278. return ErrUnexpectedFlag
  279. }
  280. } else if flags&flagRST == flagRST {
  281. s.state = streamClosed
  282. s.session.closeStream(s.id, true)
  283. }
  284. return nil
  285. }
  286. // incrSendWindow updates the size of our send window
  287. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  288. s.lock.Lock()
  289. defer s.lock.Unlock()
  290. if err := s.processFlags(flags); err != nil {
  291. return err
  292. }
  293. // Increase window, unblock a sender
  294. s.sendWindow += hdr.Length()
  295. asyncNotify(s.notifyCh)
  296. return nil
  297. }
  298. // readData is used to handle a data frame
  299. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  300. s.lock.Lock()
  301. defer s.lock.Unlock()
  302. if err := s.processFlags(flags); err != nil {
  303. return err
  304. }
  305. // Check that our recv window is not exceeded
  306. length := hdr.Length()
  307. if length > s.recvWindow {
  308. return ErrRecvWindowExceeded
  309. }
  310. // Decrement the receive window
  311. s.recvWindow -= length
  312. // Wrap in a limited reader
  313. conn = &io.LimitedReader{R: conn, N: int64(length)}
  314. // Handle potential data compression
  315. if flags&flagLZW == flagLZW {
  316. cr := lzw.NewReader(conn, lzw.MSB, 8)
  317. defer cr.Close()
  318. conn = cr
  319. }
  320. // Copy to our buffer
  321. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  322. return err
  323. }
  324. // Unblock any readers
  325. asyncNotify(s.notifyCh)
  326. return nil
  327. }