session_test.go 30 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "reflect"
  10. "runtime"
  11. "strings"
  12. "sync"
  13. "testing"
  14. "time"
  15. )
  16. type logCapture struct{ bytes.Buffer }
  17. func (l *logCapture) logs() []string {
  18. return strings.Split(strings.TrimSpace(l.String()), "\n")
  19. }
  20. func (l *logCapture) match(expect []string) bool {
  21. return reflect.DeepEqual(l.logs(), expect)
  22. }
  23. func captureLogs(s *Session) *logCapture {
  24. buf := new(logCapture)
  25. s.logger = &discordLogger{}
  26. return buf
  27. }
  28. type pipeConn struct {
  29. reader *io.PipeReader
  30. writer *io.PipeWriter
  31. writeBlocker sync.Mutex
  32. }
  33. func (p *pipeConn) Read(b []byte) (int, error) {
  34. return p.reader.Read(b)
  35. }
  36. func (p *pipeConn) Write(b []byte) (int, error) {
  37. p.writeBlocker.Lock()
  38. defer p.writeBlocker.Unlock()
  39. return p.writer.Write(b)
  40. }
  41. func (p *pipeConn) Close() error {
  42. p.reader.Close()
  43. return p.writer.Close()
  44. }
  45. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  46. read1, write1 := io.Pipe()
  47. read2, write2 := io.Pipe()
  48. conn1 := &pipeConn{reader: read1, writer: write2}
  49. conn2 := &pipeConn{reader: read2, writer: write1}
  50. return conn1, conn2
  51. }
  52. func testConf() *Config {
  53. conf := DefaultConfig()
  54. conf.AcceptBacklog = 64
  55. conf.KeepAliveInterval = 100 * time.Millisecond
  56. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  57. return conf
  58. }
  59. func testConfNoKeepAlive() *Config {
  60. conf := testConf()
  61. conf.EnableKeepAlive = false
  62. return conf
  63. }
  64. func testClientServer() (*Session, *Session) {
  65. return testClientServerConfig(testConf())
  66. }
  67. func testClientServerConfig(conf *Config) (*Session, *Session) {
  68. conn1, conn2 := testConn()
  69. client, _ := Client(conn1, conf)
  70. server, _ := Server(conn2, conf)
  71. return client, server
  72. }
  73. func TestPing(t *testing.T) {
  74. client, server := testClientServer()
  75. defer client.Close()
  76. defer server.Close()
  77. rtt, err := client.Ping()
  78. if err != nil {
  79. t.Fatalf("err: %v", err)
  80. }
  81. if rtt == 0 {
  82. t.Fatalf("bad: %v", rtt)
  83. }
  84. rtt, err = server.Ping()
  85. if err != nil {
  86. t.Fatalf("err: %v", err)
  87. }
  88. if rtt == 0 {
  89. t.Fatalf("bad: %v", rtt)
  90. }
  91. }
  92. func TestPing_Timeout(t *testing.T) {
  93. client, server := testClientServerConfig(testConfNoKeepAlive())
  94. defer client.Close()
  95. defer server.Close()
  96. // Prevent the client from responding
  97. clientConn := client.conn.(*pipeConn)
  98. clientConn.writeBlocker.Lock()
  99. errCh := make(chan error, 1)
  100. go func() {
  101. _, err := server.Ping() // Ping via the server session
  102. errCh <- err
  103. }()
  104. select {
  105. case err := <-errCh:
  106. if err != ErrTimeout {
  107. t.Fatalf("err: %v", err)
  108. }
  109. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  110. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  111. }
  112. // Verify that we recover, even if we gave up
  113. clientConn.writeBlocker.Unlock()
  114. go func() {
  115. _, err := server.Ping() // Ping via the server session
  116. errCh <- err
  117. }()
  118. select {
  119. case err := <-errCh:
  120. if err != nil {
  121. t.Fatalf("err: %v", err)
  122. }
  123. case <-time.After(client.config.ConnectionWriteTimeout):
  124. t.Fatalf("timeout")
  125. }
  126. }
  127. func TestCloseBeforeAck(t *testing.T) {
  128. cfg := testConf()
  129. cfg.AcceptBacklog = 8
  130. client, server := testClientServerConfig(cfg)
  131. defer client.Close()
  132. defer server.Close()
  133. for i := 0; i < 8; i++ {
  134. s, err := client.OpenStream()
  135. if err != nil {
  136. t.Fatal(err)
  137. }
  138. s.Close()
  139. }
  140. for i := 0; i < 8; i++ {
  141. s, err := server.AcceptStream()
  142. if err != nil {
  143. t.Fatal(err)
  144. }
  145. s.Close()
  146. }
  147. done := make(chan struct{})
  148. go func() {
  149. defer close(done)
  150. s, err := client.OpenStream()
  151. if err != nil {
  152. t.Fatal(err)
  153. }
  154. s.Close()
  155. }()
  156. select {
  157. case <-done:
  158. case <-time.After(time.Second * 5):
  159. t.Fatal("timed out trying to open stream")
  160. }
  161. }
  162. func TestAccept(t *testing.T) {
  163. client, server := testClientServer()
  164. defer client.Close()
  165. defer server.Close()
  166. if client.NumStreams() != 0 {
  167. t.Fatalf("bad")
  168. }
  169. if server.NumStreams() != 0 {
  170. t.Fatalf("bad")
  171. }
  172. wg := &sync.WaitGroup{}
  173. wg.Add(4)
  174. go func() {
  175. defer wg.Done()
  176. stream, err := server.AcceptStream()
  177. if err != nil {
  178. t.Fatalf("err: %v", err)
  179. }
  180. if id := stream.StreamID(); id != 1 {
  181. t.Fatalf("bad: %v", id)
  182. }
  183. if err := stream.Close(); err != nil {
  184. t.Fatalf("err: %v", err)
  185. }
  186. }()
  187. go func() {
  188. defer wg.Done()
  189. stream, err := client.AcceptStream()
  190. if err != nil {
  191. t.Fatalf("err: %v", err)
  192. }
  193. if id := stream.StreamID(); id != 2 {
  194. t.Fatalf("bad: %v", id)
  195. }
  196. if err := stream.Close(); err != nil {
  197. t.Fatalf("err: %v", err)
  198. }
  199. }()
  200. go func() {
  201. defer wg.Done()
  202. stream, err := server.OpenStream()
  203. if err != nil {
  204. t.Fatalf("err: %v", err)
  205. }
  206. if id := stream.StreamID(); id != 2 {
  207. t.Fatalf("bad: %v", id)
  208. }
  209. if err := stream.Close(); err != nil {
  210. t.Fatalf("err: %v", err)
  211. }
  212. }()
  213. go func() {
  214. defer wg.Done()
  215. stream, err := client.OpenStream()
  216. if err != nil {
  217. t.Fatalf("err: %v", err)
  218. }
  219. if id := stream.StreamID(); id != 1 {
  220. t.Fatalf("bad: %v", id)
  221. }
  222. if err := stream.Close(); err != nil {
  223. t.Fatalf("err: %v", err)
  224. }
  225. }()
  226. doneCh := make(chan struct{})
  227. go func() {
  228. wg.Wait()
  229. close(doneCh)
  230. }()
  231. select {
  232. case <-doneCh:
  233. case <-time.After(time.Second):
  234. panic("timeout")
  235. }
  236. }
  237. func TestOpenStreamTimeout(t *testing.T) {
  238. const timeout = 25 * time.Millisecond
  239. cfg := testConf()
  240. cfg.StreamOpenTimeout = timeout
  241. client, server := testClientServerConfig(cfg)
  242. defer client.Close()
  243. defer server.Close()
  244. // Open a single stream without a server to acknowledge it.
  245. s, err := client.OpenStream()
  246. if err != nil {
  247. t.Fatal(err)
  248. }
  249. // Sleep for longer than the stream open timeout.
  250. // Since no ACKs are received, the stream and session should be closed.
  251. time.Sleep(timeout * 5)
  252. if s.state != streamClosed {
  253. t.Fatalf("stream should have been closed")
  254. }
  255. if !client.IsClosed() {
  256. t.Fatalf("session should have been closed")
  257. }
  258. }
  259. func TestClose_closeTimeout(t *testing.T) {
  260. conf := testConf()
  261. conf.StreamCloseTimeout = 10 * time.Millisecond
  262. client, server := testClientServerConfig(conf)
  263. defer client.Close()
  264. defer server.Close()
  265. if client.NumStreams() != 0 {
  266. t.Fatalf("bad")
  267. }
  268. if server.NumStreams() != 0 {
  269. t.Fatalf("bad")
  270. }
  271. wg := &sync.WaitGroup{}
  272. wg.Add(2)
  273. // Open a stream on the client but only close it on the server.
  274. // We want to see if the stream ever gets cleaned up on the client.
  275. var clientStream *Stream
  276. go func() {
  277. defer wg.Done()
  278. var err error
  279. clientStream, err = client.OpenStream()
  280. if err != nil {
  281. t.Fatalf("err: %v", err)
  282. }
  283. }()
  284. go func() {
  285. defer wg.Done()
  286. stream, err := server.AcceptStream()
  287. if err != nil {
  288. t.Fatalf("err: %v", err)
  289. }
  290. if err := stream.Close(); err != nil {
  291. t.Fatalf("err: %v", err)
  292. }
  293. }()
  294. doneCh := make(chan struct{})
  295. go func() {
  296. wg.Wait()
  297. close(doneCh)
  298. }()
  299. select {
  300. case <-doneCh:
  301. case <-time.After(time.Second):
  302. panic("timeout")
  303. }
  304. // We should have zero streams after our timeout period
  305. time.Sleep(100 * time.Millisecond)
  306. if v := server.NumStreams(); v > 0 {
  307. t.Fatalf("should have zero streams: %d", v)
  308. }
  309. if v := client.NumStreams(); v > 0 {
  310. t.Fatalf("should have zero streams: %d", v)
  311. }
  312. if _, err := clientStream.Write([]byte("hello")); err == nil {
  313. t.Fatal("should error on write")
  314. } else if err.Error() != "connection reset" {
  315. t.Fatalf("expected connection reset, got %q", err)
  316. }
  317. }
  318. func TestNonNilInterface(t *testing.T) {
  319. _, server := testClientServer()
  320. server.Close()
  321. conn, err := server.Accept()
  322. if err != nil && conn != nil {
  323. t.Error("bad: accept should return a connection of nil value")
  324. }
  325. conn, err = server.Open()
  326. if err != nil && conn != nil {
  327. t.Error("bad: open should return a connection of nil value")
  328. }
  329. }
  330. func TestSendData_Small(t *testing.T) {
  331. client, server := testClientServer()
  332. defer client.Close()
  333. defer server.Close()
  334. wg := &sync.WaitGroup{}
  335. wg.Add(2)
  336. go func() {
  337. defer wg.Done()
  338. stream, err := server.AcceptStream()
  339. if err != nil {
  340. t.Fatalf("err: %v", err)
  341. }
  342. if server.NumStreams() != 1 {
  343. t.Fatalf("bad")
  344. }
  345. buf := make([]byte, 4)
  346. for i := 0; i < 1000; i++ {
  347. n, err := stream.Read(buf)
  348. if err != nil {
  349. t.Fatalf("err: %v", err)
  350. }
  351. if n != 4 {
  352. t.Fatalf("short read: %d", n)
  353. }
  354. if string(buf) != "test" {
  355. t.Fatalf("bad: %s", buf)
  356. }
  357. }
  358. if err := stream.Close(); err != nil {
  359. t.Fatalf("err: %v", err)
  360. }
  361. }()
  362. go func() {
  363. defer wg.Done()
  364. stream, err := client.Open()
  365. if err != nil {
  366. t.Fatalf("err: %v", err)
  367. }
  368. if client.NumStreams() != 1 {
  369. t.Fatalf("bad")
  370. }
  371. for i := 0; i < 1000; i++ {
  372. n, err := stream.Write([]byte("test"))
  373. if err != nil {
  374. t.Fatalf("err: %v", err)
  375. }
  376. if n != 4 {
  377. t.Fatalf("short write %d", n)
  378. }
  379. }
  380. if err := stream.Close(); err != nil {
  381. t.Fatalf("err: %v", err)
  382. }
  383. }()
  384. doneCh := make(chan struct{})
  385. go func() {
  386. wg.Wait()
  387. close(doneCh)
  388. }()
  389. select {
  390. case <-doneCh:
  391. if client.NumStreams() != 0 {
  392. t.Fatalf("bad")
  393. }
  394. if server.NumStreams() != 0 {
  395. t.Fatalf("bad")
  396. }
  397. return
  398. case <-time.After(time.Second):
  399. panic("timeout")
  400. }
  401. }
  402. func TestSendData_Large(t *testing.T) {
  403. client, server := testClientServer()
  404. defer client.Close()
  405. defer server.Close()
  406. const (
  407. sendSize = 250 * 1024 * 1024
  408. recvSize = 4 * 1024
  409. )
  410. data := make([]byte, sendSize)
  411. for idx := range data {
  412. data[idx] = byte(idx % 256)
  413. }
  414. wg := &sync.WaitGroup{}
  415. wg.Add(2)
  416. go func() {
  417. defer wg.Done()
  418. stream, err := server.AcceptStream()
  419. if err != nil {
  420. t.Fatalf("err: %v", err)
  421. }
  422. var sz int
  423. buf := make([]byte, recvSize)
  424. for i := 0; i < sendSize/recvSize; i++ {
  425. n, err := stream.Read(buf)
  426. if err != nil {
  427. t.Fatalf("err: %v", err)
  428. }
  429. if n != recvSize {
  430. t.Fatalf("short read: %d", n)
  431. }
  432. sz += n
  433. for idx := range buf {
  434. if buf[idx] != byte(idx%256) {
  435. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  436. }
  437. }
  438. }
  439. if err := stream.Close(); err != nil {
  440. t.Fatalf("err: %v", err)
  441. }
  442. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  443. }()
  444. go func() {
  445. defer wg.Done()
  446. stream, err := client.Open()
  447. if err != nil {
  448. t.Fatalf("err: %v", err)
  449. }
  450. n, err := stream.Write(data)
  451. if err != nil {
  452. t.Fatalf("err: %v", err)
  453. }
  454. if n != len(data) {
  455. t.Fatalf("short write %d", n)
  456. }
  457. if err := stream.Close(); err != nil {
  458. t.Fatalf("err: %v", err)
  459. }
  460. }()
  461. doneCh := make(chan struct{})
  462. go func() {
  463. wg.Wait()
  464. close(doneCh)
  465. }()
  466. select {
  467. case <-doneCh:
  468. return
  469. case <-time.After(5 * time.Second):
  470. panic("timeout")
  471. }
  472. }
  473. func TestGoAway(t *testing.T) {
  474. client, server := testClientServer()
  475. defer client.Close()
  476. defer server.Close()
  477. if err := server.GoAway(); err != nil {
  478. t.Fatalf("err: %v", err)
  479. }
  480. _, err := client.Open()
  481. if err != ErrRemoteGoAway {
  482. t.Fatalf("err: %v", err)
  483. }
  484. }
  485. func TestManyStreams(t *testing.T) {
  486. client, server := testClientServer()
  487. defer client.Close()
  488. defer server.Close()
  489. wg := &sync.WaitGroup{}
  490. acceptor := func(i int) {
  491. defer wg.Done()
  492. stream, err := server.AcceptStream()
  493. if err != nil {
  494. t.Fatalf("err: %v", err)
  495. }
  496. defer stream.Close()
  497. buf := make([]byte, 512)
  498. for {
  499. n, err := stream.Read(buf)
  500. if err == io.EOF {
  501. return
  502. }
  503. if err != nil {
  504. t.Fatalf("err: %v", err)
  505. }
  506. if n == 0 {
  507. t.Fatalf("err: %v", err)
  508. }
  509. }
  510. }
  511. sender := func(i int) {
  512. defer wg.Done()
  513. stream, err := client.Open()
  514. if err != nil {
  515. t.Fatalf("err: %v", err)
  516. }
  517. defer stream.Close()
  518. msg := fmt.Sprintf("%08d", i)
  519. for i := 0; i < 1000; i++ {
  520. n, err := stream.Write([]byte(msg))
  521. if err != nil {
  522. t.Fatalf("err: %v", err)
  523. }
  524. if n != len(msg) {
  525. t.Fatalf("short write %d", n)
  526. }
  527. }
  528. }
  529. for i := 0; i < 50; i++ {
  530. wg.Add(2)
  531. go acceptor(i)
  532. go sender(i)
  533. }
  534. wg.Wait()
  535. }
  536. func TestManyStreams_PingPong(t *testing.T) {
  537. client, server := testClientServer()
  538. defer client.Close()
  539. defer server.Close()
  540. wg := &sync.WaitGroup{}
  541. ping := []byte("ping")
  542. pong := []byte("pong")
  543. acceptor := func(i int) {
  544. defer wg.Done()
  545. stream, err := server.AcceptStream()
  546. if err != nil {
  547. t.Fatalf("err: %v", err)
  548. }
  549. defer stream.Close()
  550. buf := make([]byte, 4)
  551. for {
  552. // Read the 'ping'
  553. n, err := stream.Read(buf)
  554. if err == io.EOF {
  555. return
  556. }
  557. if err != nil {
  558. t.Fatalf("err: %v", err)
  559. }
  560. if n != 4 {
  561. t.Fatalf("err: %v", err)
  562. }
  563. if !bytes.Equal(buf, ping) {
  564. t.Fatalf("bad: %s", buf)
  565. }
  566. // Shrink the internal buffer!
  567. stream.Shrink()
  568. // Write out the 'pong'
  569. n, err = stream.Write(pong)
  570. if err != nil {
  571. t.Fatalf("err: %v", err)
  572. }
  573. if n != 4 {
  574. t.Fatalf("err: %v", err)
  575. }
  576. }
  577. }
  578. sender := func(i int) {
  579. defer wg.Done()
  580. stream, err := client.OpenStream()
  581. if err != nil {
  582. t.Fatalf("err: %v", err)
  583. }
  584. defer stream.Close()
  585. buf := make([]byte, 4)
  586. for i := 0; i < 1000; i++ {
  587. // Send the 'ping'
  588. n, err := stream.Write(ping)
  589. if err != nil {
  590. t.Fatalf("err: %v", err)
  591. }
  592. if n != 4 {
  593. t.Fatalf("short write %d", n)
  594. }
  595. // Read the 'pong'
  596. n, err = stream.Read(buf)
  597. if err != nil {
  598. t.Fatalf("err: %v", err)
  599. }
  600. if n != 4 {
  601. t.Fatalf("err: %v", err)
  602. }
  603. if !bytes.Equal(buf, pong) {
  604. t.Fatalf("bad: %s", buf)
  605. }
  606. // Shrink the buffer
  607. stream.Shrink()
  608. }
  609. }
  610. for i := 0; i < 50; i++ {
  611. wg.Add(2)
  612. go acceptor(i)
  613. go sender(i)
  614. }
  615. wg.Wait()
  616. }
  617. func TestHalfClose(t *testing.T) {
  618. client, server := testClientServer()
  619. defer client.Close()
  620. defer server.Close()
  621. stream, err := client.Open()
  622. if err != nil {
  623. t.Fatalf("err: %v", err)
  624. }
  625. if _, err = stream.Write([]byte("a")); err != nil {
  626. t.Fatalf("err: %v", err)
  627. }
  628. stream2, err := server.Accept()
  629. if err != nil {
  630. t.Fatalf("err: %v", err)
  631. }
  632. stream2.Close() // Half close
  633. buf := make([]byte, 4)
  634. n, err := stream2.Read(buf)
  635. if err != nil {
  636. t.Fatalf("err: %v", err)
  637. }
  638. if n != 1 {
  639. t.Fatalf("bad: %v", n)
  640. }
  641. // Send more
  642. if _, err = stream.Write([]byte("bcd")); err != nil {
  643. t.Fatalf("err: %v", err)
  644. }
  645. stream.Close()
  646. // Read after close
  647. n, err = stream2.Read(buf)
  648. if err != nil {
  649. t.Fatalf("err: %v", err)
  650. }
  651. if n != 3 {
  652. t.Fatalf("bad: %v", n)
  653. }
  654. // EOF after close
  655. n, err = stream2.Read(buf)
  656. if err != io.EOF {
  657. t.Fatalf("err: %v", err)
  658. }
  659. if n != 0 {
  660. t.Fatalf("bad: %v", n)
  661. }
  662. }
  663. func TestHalfCloseSessionShutdown(t *testing.T) {
  664. client, server := testClientServer()
  665. defer client.Close()
  666. defer server.Close()
  667. // dataSize must be large enough to ensure the server will send a window
  668. // update
  669. dataSize := int64(server.config.MaxStreamWindowSize)
  670. data := make([]byte, dataSize)
  671. for idx := range data {
  672. data[idx] = byte(idx % 256)
  673. }
  674. stream, err := client.Open()
  675. if err != nil {
  676. t.Fatalf("err: %v", err)
  677. }
  678. if _, err = stream.Write(data); err != nil {
  679. t.Fatalf("err: %v", err)
  680. }
  681. stream2, err := server.Accept()
  682. if err != nil {
  683. t.Fatalf("err: %v", err)
  684. }
  685. if err := stream.Close(); err != nil {
  686. t.Fatalf("err: %v", err)
  687. }
  688. // Shut down the session of the sending side. This should not cause reads
  689. // to fail on the receiving side.
  690. if err := client.Close(); err != nil {
  691. t.Fatalf("err: %v", err)
  692. }
  693. buf := make([]byte, dataSize)
  694. n, err := stream2.Read(buf)
  695. if err != nil {
  696. t.Fatalf("err: %v", err)
  697. }
  698. if int64(n) != dataSize {
  699. t.Fatalf("bad: %v", n)
  700. }
  701. // EOF after close
  702. n, err = stream2.Read(buf)
  703. if err != io.EOF {
  704. t.Fatalf("err: %v", err)
  705. }
  706. if n != 0 {
  707. t.Fatalf("bad: %v", n)
  708. }
  709. }
  710. func TestReadDeadline(t *testing.T) {
  711. client, server := testClientServer()
  712. defer client.Close()
  713. defer server.Close()
  714. stream, err := client.Open()
  715. if err != nil {
  716. t.Fatalf("err: %v", err)
  717. }
  718. defer stream.Close()
  719. stream2, err := server.Accept()
  720. if err != nil {
  721. t.Fatalf("err: %v", err)
  722. }
  723. defer stream2.Close()
  724. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  725. t.Fatalf("err: %v", err)
  726. }
  727. buf := make([]byte, 4)
  728. _, err = stream.Read(buf)
  729. if err != ErrTimeout {
  730. t.Fatalf("err: %v", err)
  731. }
  732. // See https://github.com/hashicorp/yamux/issues/90
  733. // The standard library's http server package will read from connections in
  734. // the background to detect if they are alive.
  735. //
  736. // It sets a read deadline on connections and detect if the returned error
  737. // is a network timeout error which implements net.Error.
  738. //
  739. // The HTTP server will cancel all server requests if it isn't timeout error
  740. // from the connection.
  741. //
  742. // We assert that we return an error meeting the interface to avoid
  743. // accidently breaking yamux session compatability with the standard
  744. // library's http server implementation.
  745. if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
  746. t.Fatalf("reading timeout error is expected to implement net.Error and return true when calling Timeout()")
  747. }
  748. }
  749. func TestReadDeadline_BlockedRead(t *testing.T) {
  750. client, server := testClientServer()
  751. defer client.Close()
  752. defer server.Close()
  753. stream, err := client.Open()
  754. if err != nil {
  755. t.Fatalf("err: %v", err)
  756. }
  757. defer stream.Close()
  758. stream2, err := server.Accept()
  759. if err != nil {
  760. t.Fatalf("err: %v", err)
  761. }
  762. defer stream2.Close()
  763. // Start a read that will block
  764. errCh := make(chan error, 1)
  765. go func() {
  766. buf := make([]byte, 4)
  767. _, err := stream.Read(buf)
  768. errCh <- err
  769. close(errCh)
  770. }()
  771. // Wait to ensure the read has started.
  772. time.Sleep(5 * time.Millisecond)
  773. // Update the read deadline
  774. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  775. t.Fatalf("err: %v", err)
  776. }
  777. select {
  778. case <-time.After(100 * time.Millisecond):
  779. t.Fatal("expected read timeout")
  780. case err := <-errCh:
  781. if err != ErrTimeout {
  782. t.Fatalf("expected ErrTimeout; got %v", err)
  783. }
  784. }
  785. }
  786. func TestWriteDeadline(t *testing.T) {
  787. client, server := testClientServer()
  788. defer client.Close()
  789. defer server.Close()
  790. stream, err := client.Open()
  791. if err != nil {
  792. t.Fatalf("err: %v", err)
  793. }
  794. defer stream.Close()
  795. stream2, err := server.Accept()
  796. if err != nil {
  797. t.Fatalf("err: %v", err)
  798. }
  799. defer stream2.Close()
  800. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  801. t.Fatalf("err: %v", err)
  802. }
  803. buf := make([]byte, 512)
  804. for i := 0; i < int(initialStreamWindow); i++ {
  805. _, err := stream.Write(buf)
  806. if err != nil && err == ErrTimeout {
  807. return
  808. } else if err != nil {
  809. t.Fatalf("err: %v", err)
  810. }
  811. }
  812. t.Fatalf("Expected timeout")
  813. }
  814. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  815. client, server := testClientServer()
  816. defer client.Close()
  817. defer server.Close()
  818. stream, err := client.Open()
  819. if err != nil {
  820. t.Fatalf("err: %v", err)
  821. }
  822. defer stream.Close()
  823. stream2, err := server.Accept()
  824. if err != nil {
  825. t.Fatalf("err: %v", err)
  826. }
  827. defer stream2.Close()
  828. // Start a goroutine making writes that will block
  829. errCh := make(chan error, 1)
  830. go func() {
  831. buf := make([]byte, 512)
  832. for i := 0; i < int(initialStreamWindow); i++ {
  833. _, err := stream.Write(buf)
  834. if err == nil {
  835. continue
  836. }
  837. errCh <- err
  838. close(errCh)
  839. return
  840. }
  841. close(errCh)
  842. }()
  843. // Wait to ensure the write has started.
  844. time.Sleep(5 * time.Millisecond)
  845. // Update the write deadline
  846. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  847. t.Fatalf("err: %v", err)
  848. }
  849. select {
  850. case <-time.After(1 * time.Second):
  851. t.Fatal("expected write timeout")
  852. case err := <-errCh:
  853. if err != ErrTimeout {
  854. t.Fatalf("expected ErrTimeout; got %v", err)
  855. }
  856. }
  857. }
  858. func TestBacklogExceeded(t *testing.T) {
  859. client, server := testClientServer()
  860. defer client.Close()
  861. defer server.Close()
  862. // Fill the backlog
  863. max := client.config.AcceptBacklog
  864. for i := 0; i < max; i++ {
  865. stream, err := client.Open()
  866. if err != nil {
  867. t.Fatalf("err: %v", err)
  868. }
  869. defer stream.Close()
  870. if _, err := stream.Write([]byte("foo")); err != nil {
  871. t.Fatalf("err: %v", err)
  872. }
  873. }
  874. // Attempt to open a new stream
  875. errCh := make(chan error, 1)
  876. go func() {
  877. _, err := client.Open()
  878. errCh <- err
  879. }()
  880. // Shutdown the server
  881. go func() {
  882. time.Sleep(10 * time.Millisecond)
  883. server.Close()
  884. }()
  885. select {
  886. case err := <-errCh:
  887. if err == nil {
  888. t.Fatalf("open should fail")
  889. }
  890. case <-time.After(time.Second):
  891. t.Fatalf("timeout")
  892. }
  893. }
  894. func TestKeepAlive(t *testing.T) {
  895. client, server := testClientServer()
  896. defer client.Close()
  897. defer server.Close()
  898. time.Sleep(200 * time.Millisecond)
  899. // Ping value should increase
  900. client.pingLock.Lock()
  901. defer client.pingLock.Unlock()
  902. if client.pingID == 0 {
  903. t.Fatalf("should ping")
  904. }
  905. server.pingLock.Lock()
  906. defer server.pingLock.Unlock()
  907. if server.pingID == 0 {
  908. t.Fatalf("should ping")
  909. }
  910. }
  911. func TestKeepAlive_Timeout(t *testing.T) {
  912. conn1, conn2 := testConn()
  913. clientConf := testConf()
  914. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  915. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  916. client, _ := Client(conn1, clientConf)
  917. defer client.Close()
  918. server, _ := Server(conn2, testConf())
  919. defer server.Close()
  920. _ = captureLogs(client) // Client logs aren't part of the test
  921. errCh := make(chan error, 1)
  922. go func() {
  923. _, err := server.Accept() // Wait until server closes
  924. errCh <- err
  925. }()
  926. // Prevent the client from responding
  927. clientConn := client.conn.(*pipeConn)
  928. clientConn.writeBlocker.Lock()
  929. select {
  930. case err := <-errCh:
  931. if err != ErrKeepAliveTimeout {
  932. t.Fatalf("unexpected error: %v", err)
  933. }
  934. case <-time.After(1 * time.Second):
  935. t.Fatalf("timeout waiting for timeout")
  936. }
  937. clientConn.writeBlocker.Unlock()
  938. if !server.IsClosed() {
  939. t.Fatalf("server should have closed")
  940. }
  941. }
  942. func TestLargeWindow(t *testing.T) {
  943. conf := DefaultConfig()
  944. conf.MaxStreamWindowSize *= 2
  945. client, server := testClientServerConfig(conf)
  946. defer client.Close()
  947. defer server.Close()
  948. stream, err := client.Open()
  949. if err != nil {
  950. t.Fatalf("err: %v", err)
  951. }
  952. defer stream.Close()
  953. stream2, err := server.Accept()
  954. if err != nil {
  955. t.Fatalf("err: %v", err)
  956. }
  957. defer stream2.Close()
  958. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  959. buf := make([]byte, conf.MaxStreamWindowSize)
  960. n, err := stream.Write(buf)
  961. if err != nil {
  962. t.Fatalf("err: %v", err)
  963. }
  964. if n != len(buf) {
  965. t.Fatalf("short write: %d", n)
  966. }
  967. }
  968. type UnlimitedReader struct{}
  969. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  970. runtime.Gosched()
  971. return len(p), nil
  972. }
  973. func TestSendData_VeryLarge(t *testing.T) {
  974. client, server := testClientServer()
  975. defer client.Close()
  976. defer server.Close()
  977. var n int64 = 1 * 1024 * 1024 * 1024
  978. var workers int = 16
  979. wg := &sync.WaitGroup{}
  980. wg.Add(workers * 2)
  981. for i := 0; i < workers; i++ {
  982. go func() {
  983. defer wg.Done()
  984. stream, err := server.AcceptStream()
  985. if err != nil {
  986. t.Fatalf("err: %v", err)
  987. }
  988. defer stream.Close()
  989. buf := make([]byte, 4)
  990. _, err = stream.Read(buf)
  991. if err != nil {
  992. t.Fatalf("err: %v", err)
  993. }
  994. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  995. t.Fatalf("bad header")
  996. }
  997. recv, err := io.Copy(ioutil.Discard, stream)
  998. if err != nil {
  999. t.Fatalf("err: %v", err)
  1000. }
  1001. if recv != n {
  1002. t.Fatalf("bad: %v", recv)
  1003. }
  1004. }()
  1005. }
  1006. for i := 0; i < workers; i++ {
  1007. go func() {
  1008. defer wg.Done()
  1009. stream, err := client.Open()
  1010. if err != nil {
  1011. t.Fatalf("err: %v", err)
  1012. }
  1013. defer stream.Close()
  1014. _, err = stream.Write([]byte{0, 1, 2, 3})
  1015. if err != nil {
  1016. t.Fatalf("err: %v", err)
  1017. }
  1018. unlimited := &UnlimitedReader{}
  1019. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  1020. if err != nil {
  1021. t.Fatalf("err: %v", err)
  1022. }
  1023. if sent != n {
  1024. t.Fatalf("bad: %v", sent)
  1025. }
  1026. }()
  1027. }
  1028. doneCh := make(chan struct{})
  1029. go func() {
  1030. wg.Wait()
  1031. close(doneCh)
  1032. }()
  1033. select {
  1034. case <-doneCh:
  1035. case <-time.After(20 * time.Second):
  1036. panic("timeout")
  1037. }
  1038. }
  1039. func TestBacklogExceeded_Accept(t *testing.T) {
  1040. client, server := testClientServer()
  1041. defer client.Close()
  1042. defer server.Close()
  1043. max := 5 * client.config.AcceptBacklog
  1044. go func() {
  1045. for i := 0; i < max; i++ {
  1046. stream, err := server.Accept()
  1047. if err != nil {
  1048. t.Fatalf("err: %v", err)
  1049. }
  1050. defer stream.Close()
  1051. }
  1052. }()
  1053. // Fill the backlog
  1054. for i := 0; i < max; i++ {
  1055. stream, err := client.Open()
  1056. if err != nil {
  1057. t.Fatalf("err: %v", err)
  1058. }
  1059. defer stream.Close()
  1060. if _, err := stream.Write([]byte("foo")); err != nil {
  1061. t.Fatalf("err: %v", err)
  1062. }
  1063. }
  1064. }
  1065. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  1066. client, server := testClientServerConfig(testConfNoKeepAlive())
  1067. defer client.Close()
  1068. defer server.Close()
  1069. var wg sync.WaitGroup
  1070. wg.Add(2)
  1071. // Choose a huge flood size that we know will result in a window update.
  1072. flood := int64(client.config.MaxStreamWindowSize) - 1
  1073. // The server will accept a new stream and then flood data to it.
  1074. go func() {
  1075. defer wg.Done()
  1076. stream, err := server.AcceptStream()
  1077. if err != nil {
  1078. t.Fatalf("err: %v", err)
  1079. }
  1080. defer stream.Close()
  1081. n, err := stream.Write(make([]byte, flood))
  1082. if err != nil {
  1083. t.Fatalf("err: %v", err)
  1084. }
  1085. if int64(n) != flood {
  1086. t.Fatalf("short write: %d", n)
  1087. }
  1088. }()
  1089. // The client will open a stream, block outbound writes, and then
  1090. // listen to the flood from the server, which should time out since
  1091. // it won't be able to send the window update.
  1092. go func() {
  1093. defer wg.Done()
  1094. stream, err := client.OpenStream()
  1095. if err != nil {
  1096. t.Fatalf("err: %v", err)
  1097. }
  1098. defer stream.Close()
  1099. conn := client.conn.(*pipeConn)
  1100. conn.writeBlocker.Lock()
  1101. defer conn.writeBlocker.Unlock()
  1102. _, err = stream.Read(make([]byte, flood))
  1103. if err != ErrConnectionWriteTimeout {
  1104. t.Fatalf("err: %v", err)
  1105. }
  1106. }()
  1107. wg.Wait()
  1108. }
  1109. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  1110. client, server := testClientServerConfig(testConfNoKeepAlive())
  1111. defer client.Close()
  1112. defer server.Close()
  1113. var wg sync.WaitGroup
  1114. wg.Add(1)
  1115. // Choose a huge flood size that we know will result in a window update.
  1116. flood := int64(client.config.MaxStreamWindowSize)
  1117. var wr *Stream
  1118. // The server will accept a new stream and then flood data to it.
  1119. go func() {
  1120. defer wg.Done()
  1121. var err error
  1122. wr, err = server.AcceptStream()
  1123. if err != nil {
  1124. t.Fatalf("err: %v", err)
  1125. }
  1126. defer wr.Close()
  1127. if wr.sendWindow != client.config.MaxStreamWindowSize {
  1128. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  1129. }
  1130. n, err := wr.Write(make([]byte, flood))
  1131. if err != nil {
  1132. t.Fatalf("err: %v", err)
  1133. }
  1134. if int64(n) != flood {
  1135. t.Fatalf("short write: %d", n)
  1136. }
  1137. if wr.sendWindow != 0 {
  1138. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  1139. }
  1140. }()
  1141. stream, err := client.OpenStream()
  1142. if err != nil {
  1143. t.Fatalf("err: %v", err)
  1144. }
  1145. defer stream.Close()
  1146. wg.Wait()
  1147. _, err = stream.Read(make([]byte, flood/2+1))
  1148. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1149. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1150. }
  1151. }
  1152. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1153. client, server := testClientServerConfig(testConfNoKeepAlive())
  1154. defer client.Close()
  1155. defer server.Close()
  1156. var wg sync.WaitGroup
  1157. wg.Add(2)
  1158. go func() {
  1159. defer wg.Done()
  1160. stream, err := server.AcceptStream()
  1161. if err != nil {
  1162. t.Fatalf("err: %v", err)
  1163. }
  1164. defer stream.Close()
  1165. }()
  1166. // The client will open the stream and then block outbound writes, we'll
  1167. // probe sendNoWait once it gets into that state.
  1168. go func() {
  1169. defer wg.Done()
  1170. stream, err := client.OpenStream()
  1171. if err != nil {
  1172. t.Fatalf("err: %v", err)
  1173. }
  1174. defer stream.Close()
  1175. conn := client.conn.(*pipeConn)
  1176. conn.writeBlocker.Lock()
  1177. defer conn.writeBlocker.Unlock()
  1178. hdr := header(make([]byte, headerSize))
  1179. hdr.encode(typePing, flagACK, 0, 0)
  1180. for {
  1181. err = client.sendNoWait(hdr)
  1182. if err == nil {
  1183. continue
  1184. } else if err == ErrConnectionWriteTimeout {
  1185. break
  1186. } else {
  1187. t.Fatalf("err: %v", err)
  1188. }
  1189. }
  1190. }()
  1191. wg.Wait()
  1192. }
  1193. func TestSession_PingOfDeath(t *testing.T) {
  1194. client, server := testClientServerConfig(testConfNoKeepAlive())
  1195. defer client.Close()
  1196. defer server.Close()
  1197. var wg sync.WaitGroup
  1198. wg.Add(2)
  1199. var doPingOfDeath sync.Mutex
  1200. doPingOfDeath.Lock()
  1201. // This is used later to block outbound writes.
  1202. conn := server.conn.(*pipeConn)
  1203. // The server will accept a stream, block outbound writes, and then
  1204. // flood its send channel so that no more headers can be queued.
  1205. go func() {
  1206. defer wg.Done()
  1207. stream, err := server.AcceptStream()
  1208. if err != nil {
  1209. t.Fatalf("err: %v", err)
  1210. }
  1211. defer stream.Close()
  1212. conn.writeBlocker.Lock()
  1213. for {
  1214. hdr := header(make([]byte, headerSize))
  1215. hdr.encode(typePing, 0, 0, 0)
  1216. err = server.sendNoWait(hdr)
  1217. if err == nil {
  1218. continue
  1219. } else if err == ErrConnectionWriteTimeout {
  1220. break
  1221. } else {
  1222. t.Fatalf("err: %v", err)
  1223. }
  1224. }
  1225. doPingOfDeath.Unlock()
  1226. }()
  1227. // The client will open a stream and then send the server a ping once it
  1228. // can no longer write. This makes sure the server doesn't deadlock reads
  1229. // while trying to reply to the ping with no ability to write.
  1230. go func() {
  1231. defer wg.Done()
  1232. stream, err := client.OpenStream()
  1233. if err != nil {
  1234. t.Fatalf("err: %v", err)
  1235. }
  1236. defer stream.Close()
  1237. // This ping will never unblock because the ping id will never
  1238. // show up in a response.
  1239. doPingOfDeath.Lock()
  1240. go func() { client.Ping() }()
  1241. // Wait for a while to make sure the previous ping times out,
  1242. // then turn writes back on and make sure a ping works again.
  1243. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1244. conn.writeBlocker.Unlock()
  1245. if _, err = client.Ping(); err != nil {
  1246. t.Fatalf("err: %v", err)
  1247. }
  1248. }()
  1249. wg.Wait()
  1250. }
  1251. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1252. client, server := testClientServerConfig(testConfNoKeepAlive())
  1253. defer client.Close()
  1254. defer server.Close()
  1255. var wg sync.WaitGroup
  1256. wg.Add(2)
  1257. go func() {
  1258. defer wg.Done()
  1259. stream, err := server.AcceptStream()
  1260. if err != nil {
  1261. t.Fatalf("err: %v", err)
  1262. }
  1263. defer stream.Close()
  1264. }()
  1265. // The client will open the stream and then block outbound writes, we'll
  1266. // tee up a write and make sure it eventually times out.
  1267. go func() {
  1268. defer wg.Done()
  1269. stream, err := client.OpenStream()
  1270. if err != nil {
  1271. t.Fatalf("err: %v", err)
  1272. }
  1273. defer stream.Close()
  1274. conn := client.conn.(*pipeConn)
  1275. conn.writeBlocker.Lock()
  1276. defer conn.writeBlocker.Unlock()
  1277. // Since the write goroutine is blocked then this will return a
  1278. // timeout since it can't get feedback about whether the write
  1279. // worked.
  1280. n, err := stream.Write([]byte("hello"))
  1281. if err != ErrConnectionWriteTimeout {
  1282. t.Fatalf("err: %v", err)
  1283. }
  1284. if n != 0 {
  1285. t.Fatalf("lied about writes: %d", n)
  1286. }
  1287. }()
  1288. wg.Wait()
  1289. }
  1290. func TestCancelAccept(t *testing.T) {
  1291. _, server := testClientServer()
  1292. defer server.Close()
  1293. ctx, cancel := context.WithCancel(context.Background())
  1294. var wg sync.WaitGroup
  1295. wg.Add(1)
  1296. go func() {
  1297. defer wg.Done()
  1298. stream, err := server.AcceptStreamWithContext(ctx)
  1299. if err != context.Canceled {
  1300. t.Fatalf("err: %v", err)
  1301. }
  1302. if stream != nil {
  1303. defer stream.Close()
  1304. }
  1305. }()
  1306. cancel()
  1307. wg.Wait()
  1308. }