session_test.go 28 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "net"
  9. "reflect"
  10. "runtime"
  11. "strings"
  12. "sync"
  13. "testing"
  14. "time"
  15. )
  16. type logCapture struct{ bytes.Buffer }
  17. func (l *logCapture) logs() []string {
  18. return strings.Split(strings.TrimSpace(l.String()), "\n")
  19. }
  20. func (l *logCapture) match(expect []string) bool {
  21. return reflect.DeepEqual(l.logs(), expect)
  22. }
  23. func captureLogs(s *Session) *logCapture {
  24. buf := new(logCapture)
  25. s.logger = log.New(buf, "", 0)
  26. return buf
  27. }
  28. type pipeConn struct {
  29. reader *io.PipeReader
  30. writer *io.PipeWriter
  31. writeBlocker sync.Mutex
  32. }
  33. func (p *pipeConn) Read(b []byte) (int, error) {
  34. return p.reader.Read(b)
  35. }
  36. func (p *pipeConn) Write(b []byte) (int, error) {
  37. p.writeBlocker.Lock()
  38. defer p.writeBlocker.Unlock()
  39. return p.writer.Write(b)
  40. }
  41. func (p *pipeConn) Close() error {
  42. p.reader.Close()
  43. return p.writer.Close()
  44. }
  45. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  46. read1, write1 := io.Pipe()
  47. read2, write2 := io.Pipe()
  48. conn1 := &pipeConn{reader: read1, writer: write2}
  49. conn2 := &pipeConn{reader: read2, writer: write1}
  50. return conn1, conn2
  51. }
  52. func testConf() *Config {
  53. conf := DefaultConfig()
  54. conf.AcceptBacklog = 64
  55. conf.KeepAliveInterval = 100 * time.Millisecond
  56. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  57. return conf
  58. }
  59. func testConfNoKeepAlive() *Config {
  60. conf := testConf()
  61. conf.EnableKeepAlive = false
  62. return conf
  63. }
  64. func testClientServer() (*Session, *Session) {
  65. return testClientServerConfig(testConf())
  66. }
  67. func testClientServerConfig(conf *Config) (*Session, *Session) {
  68. conn1, conn2 := testConn()
  69. client, _ := Client(conn1, conf)
  70. server, _ := Server(conn2, conf)
  71. return client, server
  72. }
  73. func TestPing(t *testing.T) {
  74. client, server := testClientServer()
  75. defer client.Close()
  76. defer server.Close()
  77. rtt, err := client.Ping()
  78. if err != nil {
  79. t.Fatalf("err: %v", err)
  80. }
  81. if rtt == 0 {
  82. t.Fatalf("bad: %v", rtt)
  83. }
  84. rtt, err = server.Ping()
  85. if err != nil {
  86. t.Fatalf("err: %v", err)
  87. }
  88. if rtt == 0 {
  89. t.Fatalf("bad: %v", rtt)
  90. }
  91. }
  92. func TestPing_Timeout(t *testing.T) {
  93. client, server := testClientServerConfig(testConfNoKeepAlive())
  94. defer client.Close()
  95. defer server.Close()
  96. // Prevent the client from responding
  97. clientConn := client.conn.(*pipeConn)
  98. clientConn.writeBlocker.Lock()
  99. errCh := make(chan error, 1)
  100. go func() {
  101. _, err := server.Ping() // Ping via the server session
  102. errCh <- err
  103. }()
  104. select {
  105. case err := <-errCh:
  106. if err != ErrTimeout {
  107. t.Fatalf("err: %v", err)
  108. }
  109. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  110. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  111. }
  112. // Verify that we recover, even if we gave up
  113. clientConn.writeBlocker.Unlock()
  114. go func() {
  115. _, err := server.Ping() // Ping via the server session
  116. errCh <- err
  117. }()
  118. select {
  119. case err := <-errCh:
  120. if err != nil {
  121. t.Fatalf("err: %v", err)
  122. }
  123. case <-time.After(client.config.ConnectionWriteTimeout):
  124. t.Fatalf("timeout")
  125. }
  126. }
  127. func TestCloseBeforeAck(t *testing.T) {
  128. cfg := testConf()
  129. cfg.AcceptBacklog = 8
  130. client, server := testClientServerConfig(cfg)
  131. defer client.Close()
  132. defer server.Close()
  133. for i := 0; i < 8; i++ {
  134. s, err := client.OpenStream()
  135. if err != nil {
  136. t.Fatal(err)
  137. }
  138. s.Close()
  139. }
  140. for i := 0; i < 8; i++ {
  141. s, err := server.AcceptStream()
  142. if err != nil {
  143. t.Fatal(err)
  144. }
  145. s.Close()
  146. }
  147. done := make(chan struct{})
  148. go func() {
  149. defer close(done)
  150. s, err := client.OpenStream()
  151. if err != nil {
  152. t.Fatal(err)
  153. }
  154. s.Close()
  155. }()
  156. select {
  157. case <-done:
  158. case <-time.After(time.Second * 5):
  159. t.Fatal("timed out trying to open stream")
  160. }
  161. }
  162. func TestAccept(t *testing.T) {
  163. client, server := testClientServer()
  164. defer client.Close()
  165. defer server.Close()
  166. if client.NumStreams() != 0 {
  167. t.Fatalf("bad")
  168. }
  169. if server.NumStreams() != 0 {
  170. t.Fatalf("bad")
  171. }
  172. wg := &sync.WaitGroup{}
  173. wg.Add(4)
  174. go func() {
  175. defer wg.Done()
  176. stream, err := server.AcceptStream()
  177. if err != nil {
  178. t.Fatalf("err: %v", err)
  179. }
  180. if id := stream.StreamID(); id != 1 {
  181. t.Fatalf("bad: %v", id)
  182. }
  183. if err := stream.Close(); err != nil {
  184. t.Fatalf("err: %v", err)
  185. }
  186. }()
  187. go func() {
  188. defer wg.Done()
  189. stream, err := client.AcceptStream()
  190. if err != nil {
  191. t.Fatalf("err: %v", err)
  192. }
  193. if id := stream.StreamID(); id != 2 {
  194. t.Fatalf("bad: %v", id)
  195. }
  196. if err := stream.Close(); err != nil {
  197. t.Fatalf("err: %v", err)
  198. }
  199. }()
  200. go func() {
  201. defer wg.Done()
  202. stream, err := server.OpenStream()
  203. if err != nil {
  204. t.Fatalf("err: %v", err)
  205. }
  206. if id := stream.StreamID(); id != 2 {
  207. t.Fatalf("bad: %v", id)
  208. }
  209. if err := stream.Close(); err != nil {
  210. t.Fatalf("err: %v", err)
  211. }
  212. }()
  213. go func() {
  214. defer wg.Done()
  215. stream, err := client.OpenStream()
  216. if err != nil {
  217. t.Fatalf("err: %v", err)
  218. }
  219. if id := stream.StreamID(); id != 1 {
  220. t.Fatalf("bad: %v", id)
  221. }
  222. if err := stream.Close(); err != nil {
  223. t.Fatalf("err: %v", err)
  224. }
  225. }()
  226. doneCh := make(chan struct{})
  227. go func() {
  228. wg.Wait()
  229. close(doneCh)
  230. }()
  231. select {
  232. case <-doneCh:
  233. case <-time.After(time.Second):
  234. panic("timeout")
  235. }
  236. }
  237. func TestClose_closeTimeout(t *testing.T) {
  238. conf := testConf()
  239. conf.StreamCloseTimeout = 10 * time.Millisecond
  240. client, server := testClientServerConfig(conf)
  241. defer client.Close()
  242. defer server.Close()
  243. if client.NumStreams() != 0 {
  244. t.Fatalf("bad")
  245. }
  246. if server.NumStreams() != 0 {
  247. t.Fatalf("bad")
  248. }
  249. wg := &sync.WaitGroup{}
  250. wg.Add(2)
  251. // Open a stream on the client but only close it on the server.
  252. // We want to see if the stream ever gets cleaned up on the client.
  253. var clientStream *Stream
  254. go func() {
  255. defer wg.Done()
  256. var err error
  257. clientStream, err = client.OpenStream()
  258. if err != nil {
  259. t.Fatalf("err: %v", err)
  260. }
  261. }()
  262. go func() {
  263. defer wg.Done()
  264. stream, err := server.AcceptStream()
  265. if err != nil {
  266. t.Fatalf("err: %v", err)
  267. }
  268. if err := stream.Close(); err != nil {
  269. t.Fatalf("err: %v", err)
  270. }
  271. }()
  272. doneCh := make(chan struct{})
  273. go func() {
  274. wg.Wait()
  275. close(doneCh)
  276. }()
  277. select {
  278. case <-doneCh:
  279. case <-time.After(time.Second):
  280. panic("timeout")
  281. }
  282. // We should have zero streams after our timeout period
  283. time.Sleep(100 * time.Millisecond)
  284. if v := server.NumStreams(); v > 0 {
  285. t.Fatalf("should have zero streams: %d", v)
  286. }
  287. if v := client.NumStreams(); v > 0 {
  288. t.Fatalf("should have zero streams: %d", v)
  289. }
  290. if _, err := clientStream.Write([]byte("hello")); err == nil {
  291. t.Fatal("should error on write")
  292. } else if err.Error() != "connection reset" {
  293. t.Fatalf("expected connection reset, got %q", err)
  294. }
  295. }
  296. func TestNonNilInterface(t *testing.T) {
  297. _, server := testClientServer()
  298. server.Close()
  299. conn, err := server.Accept()
  300. if err != nil && conn != nil {
  301. t.Error("bad: accept should return a connection of nil value")
  302. }
  303. conn, err = server.Open()
  304. if err != nil && conn != nil {
  305. t.Error("bad: open should return a connection of nil value")
  306. }
  307. }
  308. func TestSendData_Small(t *testing.T) {
  309. client, server := testClientServer()
  310. defer client.Close()
  311. defer server.Close()
  312. wg := &sync.WaitGroup{}
  313. wg.Add(2)
  314. go func() {
  315. defer wg.Done()
  316. stream, err := server.AcceptStream()
  317. if err != nil {
  318. t.Fatalf("err: %v", err)
  319. }
  320. if server.NumStreams() != 1 {
  321. t.Fatalf("bad")
  322. }
  323. buf := make([]byte, 4)
  324. for i := 0; i < 1000; i++ {
  325. n, err := stream.Read(buf)
  326. if err != nil {
  327. t.Fatalf("err: %v", err)
  328. }
  329. if n != 4 {
  330. t.Fatalf("short read: %d", n)
  331. }
  332. if string(buf) != "test" {
  333. t.Fatalf("bad: %s", buf)
  334. }
  335. }
  336. if err := stream.Close(); err != nil {
  337. t.Fatalf("err: %v", err)
  338. }
  339. }()
  340. go func() {
  341. defer wg.Done()
  342. stream, err := client.Open()
  343. if err != nil {
  344. t.Fatalf("err: %v", err)
  345. }
  346. if client.NumStreams() != 1 {
  347. t.Fatalf("bad")
  348. }
  349. for i := 0; i < 1000; i++ {
  350. n, err := stream.Write([]byte("test"))
  351. if err != nil {
  352. t.Fatalf("err: %v", err)
  353. }
  354. if n != 4 {
  355. t.Fatalf("short write %d", n)
  356. }
  357. }
  358. if err := stream.Close(); err != nil {
  359. t.Fatalf("err: %v", err)
  360. }
  361. }()
  362. doneCh := make(chan struct{})
  363. go func() {
  364. wg.Wait()
  365. close(doneCh)
  366. }()
  367. select {
  368. case <-doneCh:
  369. case <-time.After(time.Second):
  370. panic("timeout")
  371. }
  372. if client.NumStreams() != 0 {
  373. t.Fatalf("bad")
  374. }
  375. if server.NumStreams() != 0 {
  376. t.Fatalf("bad")
  377. }
  378. }
  379. func TestSendData_Large(t *testing.T) {
  380. client, server := testClientServer()
  381. defer client.Close()
  382. defer server.Close()
  383. const (
  384. sendSize = 250 * 1024 * 1024
  385. recvSize = 4 * 1024
  386. )
  387. data := make([]byte, sendSize)
  388. for idx := range data {
  389. data[idx] = byte(idx % 256)
  390. }
  391. wg := &sync.WaitGroup{}
  392. wg.Add(2)
  393. go func() {
  394. defer wg.Done()
  395. stream, err := server.AcceptStream()
  396. if err != nil {
  397. t.Fatalf("err: %v", err)
  398. }
  399. var sz int
  400. buf := make([]byte, recvSize)
  401. for i := 0; i < sendSize/recvSize; i++ {
  402. n, err := stream.Read(buf)
  403. if err != nil {
  404. t.Fatalf("err: %v", err)
  405. }
  406. if n != recvSize {
  407. t.Fatalf("short read: %d", n)
  408. }
  409. sz += n
  410. for idx := range buf {
  411. if buf[idx] != byte(idx%256) {
  412. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  413. }
  414. }
  415. }
  416. if err := stream.Close(); err != nil {
  417. t.Fatalf("err: %v", err)
  418. }
  419. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  420. }()
  421. go func() {
  422. defer wg.Done()
  423. stream, err := client.Open()
  424. if err != nil {
  425. t.Fatalf("err: %v", err)
  426. }
  427. n, err := stream.Write(data)
  428. if err != nil {
  429. t.Fatalf("err: %v", err)
  430. }
  431. if n != len(data) {
  432. t.Fatalf("short write %d", n)
  433. }
  434. if err := stream.Close(); err != nil {
  435. t.Fatalf("err: %v", err)
  436. }
  437. }()
  438. doneCh := make(chan struct{})
  439. go func() {
  440. wg.Wait()
  441. close(doneCh)
  442. }()
  443. select {
  444. case <-doneCh:
  445. case <-time.After(5 * time.Second):
  446. panic("timeout")
  447. }
  448. }
  449. func TestGoAway(t *testing.T) {
  450. client, server := testClientServer()
  451. defer client.Close()
  452. defer server.Close()
  453. if err := server.GoAway(); err != nil {
  454. t.Fatalf("err: %v", err)
  455. }
  456. _, err := client.Open()
  457. if err != ErrRemoteGoAway {
  458. t.Fatalf("err: %v", err)
  459. }
  460. }
  461. func TestManyStreams(t *testing.T) {
  462. client, server := testClientServer()
  463. defer client.Close()
  464. defer server.Close()
  465. wg := &sync.WaitGroup{}
  466. acceptor := func(i int) {
  467. defer wg.Done()
  468. stream, err := server.AcceptStream()
  469. if err != nil {
  470. t.Fatalf("err: %v", err)
  471. }
  472. defer stream.Close()
  473. buf := make([]byte, 512)
  474. for {
  475. n, err := stream.Read(buf)
  476. if err == io.EOF {
  477. return
  478. }
  479. if err != nil {
  480. t.Fatalf("err: %v", err)
  481. }
  482. if n == 0 {
  483. t.Fatalf("err: %v", err)
  484. }
  485. }
  486. }
  487. sender := func(i int) {
  488. defer wg.Done()
  489. stream, err := client.Open()
  490. if err != nil {
  491. t.Fatalf("err: %v", err)
  492. }
  493. defer stream.Close()
  494. msg := fmt.Sprintf("%08d", i)
  495. for i := 0; i < 1000; i++ {
  496. n, err := stream.Write([]byte(msg))
  497. if err != nil {
  498. t.Fatalf("err: %v", err)
  499. }
  500. if n != len(msg) {
  501. t.Fatalf("short write %d", n)
  502. }
  503. }
  504. }
  505. for i := 0; i < 50; i++ {
  506. wg.Add(2)
  507. go acceptor(i)
  508. go sender(i)
  509. }
  510. wg.Wait()
  511. }
  512. func TestManyStreams_PingPong(t *testing.T) {
  513. client, server := testClientServer()
  514. defer client.Close()
  515. defer server.Close()
  516. wg := &sync.WaitGroup{}
  517. ping := []byte("ping")
  518. pong := []byte("pong")
  519. acceptor := func(i int) {
  520. defer wg.Done()
  521. stream, err := server.AcceptStream()
  522. if err != nil {
  523. t.Fatalf("err: %v", err)
  524. }
  525. defer stream.Close()
  526. buf := make([]byte, 4)
  527. for {
  528. // Read the 'ping'
  529. n, err := stream.Read(buf)
  530. if err == io.EOF {
  531. return
  532. }
  533. if err != nil {
  534. t.Fatalf("err: %v", err)
  535. }
  536. if n != 4 {
  537. t.Fatalf("err: %v", err)
  538. }
  539. if !bytes.Equal(buf, ping) {
  540. t.Fatalf("bad: %s", buf)
  541. }
  542. // Shrink the internal buffer!
  543. stream.Shrink()
  544. // Write out the 'pong'
  545. n, err = stream.Write(pong)
  546. if err != nil {
  547. t.Fatalf("err: %v", err)
  548. }
  549. if n != 4 {
  550. t.Fatalf("err: %v", err)
  551. }
  552. }
  553. }
  554. sender := func(i int) {
  555. defer wg.Done()
  556. stream, err := client.OpenStream()
  557. if err != nil {
  558. t.Fatalf("err: %v", err)
  559. }
  560. defer stream.Close()
  561. buf := make([]byte, 4)
  562. for i := 0; i < 1000; i++ {
  563. // Send the 'ping'
  564. n, err := stream.Write(ping)
  565. if err != nil {
  566. t.Fatalf("err: %v", err)
  567. }
  568. if n != 4 {
  569. t.Fatalf("short write %d", n)
  570. }
  571. // Read the 'pong'
  572. n, err = stream.Read(buf)
  573. if err != nil {
  574. t.Fatalf("err: %v", err)
  575. }
  576. if n != 4 {
  577. t.Fatalf("err: %v", err)
  578. }
  579. if !bytes.Equal(buf, pong) {
  580. t.Fatalf("bad: %s", buf)
  581. }
  582. // Shrink the buffer
  583. stream.Shrink()
  584. }
  585. }
  586. for i := 0; i < 50; i++ {
  587. wg.Add(2)
  588. go acceptor(i)
  589. go sender(i)
  590. }
  591. wg.Wait()
  592. }
  593. func TestHalfClose(t *testing.T) {
  594. client, server := testClientServer()
  595. defer client.Close()
  596. defer server.Close()
  597. stream, err := client.Open()
  598. if err != nil {
  599. t.Fatalf("err: %v", err)
  600. }
  601. if _, err = stream.Write([]byte("a")); err != nil {
  602. t.Fatalf("err: %v", err)
  603. }
  604. stream2, err := server.Accept()
  605. if err != nil {
  606. t.Fatalf("err: %v", err)
  607. }
  608. stream2.Close() // Half close
  609. buf := make([]byte, 4)
  610. n, err := stream2.Read(buf)
  611. if err != nil {
  612. t.Fatalf("err: %v", err)
  613. }
  614. if n != 1 {
  615. t.Fatalf("bad: %v", n)
  616. }
  617. // Send more
  618. if _, err = stream.Write([]byte("bcd")); err != nil {
  619. t.Fatalf("err: %v", err)
  620. }
  621. stream.Close()
  622. // Read after close
  623. n, err = stream2.Read(buf)
  624. if err != nil {
  625. t.Fatalf("err: %v", err)
  626. }
  627. if n != 3 {
  628. t.Fatalf("bad: %v", n)
  629. }
  630. // EOF after close
  631. n, err = stream2.Read(buf)
  632. if err != io.EOF {
  633. t.Fatalf("err: %v", err)
  634. }
  635. if n != 0 {
  636. t.Fatalf("bad: %v", n)
  637. }
  638. }
  639. func TestReadDeadline(t *testing.T) {
  640. client, server := testClientServer()
  641. defer client.Close()
  642. defer server.Close()
  643. stream, err := client.Open()
  644. if err != nil {
  645. t.Fatalf("err: %v", err)
  646. }
  647. defer stream.Close()
  648. stream2, err := server.Accept()
  649. if err != nil {
  650. t.Fatalf("err: %v", err)
  651. }
  652. defer stream2.Close()
  653. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  654. t.Fatalf("err: %v", err)
  655. }
  656. buf := make([]byte, 4)
  657. _, err = stream.Read(buf)
  658. if err != ErrTimeout {
  659. t.Fatalf("err: %v", err)
  660. }
  661. // See https://github.com/hashicorp/yamux/issues/90
  662. // Standard http server package will read from connections in background to detect if it's alive.
  663. // It sets read deadline on connections and detect if the returned error is timeout error which implements net.Error.
  664. // HTTP server will cancel all server requests if it isn't timeout error from connections.
  665. if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
  666. t.Fatalf("error of reading timeout is expected to implement net.Error and return true when calling Timeout(), but not")
  667. }
  668. }
  669. func TestReadDeadline_BlockedRead(t *testing.T) {
  670. client, server := testClientServer()
  671. defer client.Close()
  672. defer server.Close()
  673. stream, err := client.Open()
  674. if err != nil {
  675. t.Fatalf("err: %v", err)
  676. }
  677. defer stream.Close()
  678. stream2, err := server.Accept()
  679. if err != nil {
  680. t.Fatalf("err: %v", err)
  681. }
  682. defer stream2.Close()
  683. // Start a read that will block
  684. errCh := make(chan error, 1)
  685. go func() {
  686. buf := make([]byte, 4)
  687. _, err := stream.Read(buf)
  688. errCh <- err
  689. close(errCh)
  690. }()
  691. // Wait to ensure the read has started.
  692. time.Sleep(5 * time.Millisecond)
  693. // Update the read deadline
  694. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  695. t.Fatalf("err: %v", err)
  696. }
  697. select {
  698. case <-time.After(100 * time.Millisecond):
  699. t.Fatal("expected read timeout")
  700. case err := <-errCh:
  701. if err != ErrTimeout {
  702. t.Fatalf("expected ErrTimeout; got %v", err)
  703. }
  704. }
  705. }
  706. func TestWriteDeadline(t *testing.T) {
  707. client, server := testClientServer()
  708. defer client.Close()
  709. defer server.Close()
  710. stream, err := client.Open()
  711. if err != nil {
  712. t.Fatalf("err: %v", err)
  713. }
  714. defer stream.Close()
  715. stream2, err := server.Accept()
  716. if err != nil {
  717. t.Fatalf("err: %v", err)
  718. }
  719. defer stream2.Close()
  720. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  721. t.Fatalf("err: %v", err)
  722. }
  723. buf := make([]byte, 512)
  724. for i := 0; i < int(initialStreamWindow); i++ {
  725. _, err := stream.Write(buf)
  726. if err != nil && err == ErrTimeout {
  727. return
  728. } else if err != nil {
  729. t.Fatalf("err: %v", err)
  730. }
  731. }
  732. t.Fatalf("Expected timeout")
  733. }
  734. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  735. client, server := testClientServer()
  736. defer client.Close()
  737. defer server.Close()
  738. stream, err := client.Open()
  739. if err != nil {
  740. t.Fatalf("err: %v", err)
  741. }
  742. defer stream.Close()
  743. stream2, err := server.Accept()
  744. if err != nil {
  745. t.Fatalf("err: %v", err)
  746. }
  747. defer stream2.Close()
  748. // Start a goroutine making writes that will block
  749. errCh := make(chan error, 1)
  750. go func() {
  751. buf := make([]byte, 512)
  752. for i := 0; i < int(initialStreamWindow); i++ {
  753. _, err := stream.Write(buf)
  754. if err == nil {
  755. continue
  756. }
  757. errCh <- err
  758. close(errCh)
  759. return
  760. }
  761. close(errCh)
  762. }()
  763. // Wait to ensure the write has started.
  764. time.Sleep(5 * time.Millisecond)
  765. // Update the write deadline
  766. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  767. t.Fatalf("err: %v", err)
  768. }
  769. select {
  770. case <-time.After(1 * time.Second):
  771. t.Fatal("expected write timeout")
  772. case err := <-errCh:
  773. if err != ErrTimeout {
  774. t.Fatalf("expected ErrTimeout; got %v", err)
  775. }
  776. }
  777. }
  778. func TestBacklogExceeded(t *testing.T) {
  779. client, server := testClientServer()
  780. defer client.Close()
  781. defer server.Close()
  782. // Fill the backlog
  783. max := client.config.AcceptBacklog
  784. for i := 0; i < max; i++ {
  785. stream, err := client.Open()
  786. if err != nil {
  787. t.Fatalf("err: %v", err)
  788. }
  789. defer stream.Close()
  790. if _, err := stream.Write([]byte("foo")); err != nil {
  791. t.Fatalf("err: %v", err)
  792. }
  793. }
  794. // Attempt to open a new stream
  795. errCh := make(chan error, 1)
  796. go func() {
  797. _, err := client.Open()
  798. errCh <- err
  799. }()
  800. // Shutdown the server
  801. go func() {
  802. time.Sleep(10 * time.Millisecond)
  803. server.Close()
  804. }()
  805. select {
  806. case err := <-errCh:
  807. if err == nil {
  808. t.Fatalf("open should fail")
  809. }
  810. case <-time.After(time.Second):
  811. t.Fatalf("timeout")
  812. }
  813. }
  814. func TestKeepAlive(t *testing.T) {
  815. client, server := testClientServer()
  816. defer client.Close()
  817. defer server.Close()
  818. time.Sleep(200 * time.Millisecond)
  819. // Ping value should increase
  820. client.pingLock.Lock()
  821. defer client.pingLock.Unlock()
  822. if client.pingID == 0 {
  823. t.Fatalf("should ping")
  824. }
  825. server.pingLock.Lock()
  826. defer server.pingLock.Unlock()
  827. if server.pingID == 0 {
  828. t.Fatalf("should ping")
  829. }
  830. }
  831. func TestKeepAlive_Timeout(t *testing.T) {
  832. conn1, conn2 := testConn()
  833. clientConf := testConf()
  834. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  835. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  836. client, _ := Client(conn1, clientConf)
  837. defer client.Close()
  838. server, _ := Server(conn2, testConf())
  839. defer server.Close()
  840. _ = captureLogs(client) // Client logs aren't part of the test
  841. serverLogs := captureLogs(server)
  842. errCh := make(chan error, 1)
  843. go func() {
  844. _, err := server.Accept() // Wait until server closes
  845. errCh <- err
  846. }()
  847. // Prevent the client from responding
  848. clientConn := client.conn.(*pipeConn)
  849. clientConn.writeBlocker.Lock()
  850. select {
  851. case err := <-errCh:
  852. if err != ErrKeepAliveTimeout {
  853. t.Fatalf("unexpected error: %v", err)
  854. }
  855. case <-time.After(1 * time.Second):
  856. t.Fatalf("timeout waiting for timeout")
  857. }
  858. if !server.IsClosed() {
  859. t.Fatalf("server should have closed")
  860. }
  861. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  862. t.Fatalf("server log incorect: %v", serverLogs.logs())
  863. }
  864. }
  865. func TestLargeWindow(t *testing.T) {
  866. conf := DefaultConfig()
  867. conf.MaxStreamWindowSize *= 2
  868. client, server := testClientServerConfig(conf)
  869. defer client.Close()
  870. defer server.Close()
  871. stream, err := client.Open()
  872. if err != nil {
  873. t.Fatalf("err: %v", err)
  874. }
  875. defer stream.Close()
  876. stream2, err := server.Accept()
  877. if err != nil {
  878. t.Fatalf("err: %v", err)
  879. }
  880. defer stream2.Close()
  881. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  882. buf := make([]byte, conf.MaxStreamWindowSize)
  883. n, err := stream.Write(buf)
  884. if err != nil {
  885. t.Fatalf("err: %v", err)
  886. }
  887. if n != len(buf) {
  888. t.Fatalf("short write: %d", n)
  889. }
  890. }
  891. type UnlimitedReader struct{}
  892. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  893. runtime.Gosched()
  894. return len(p), nil
  895. }
  896. func TestSendData_VeryLarge(t *testing.T) {
  897. client, server := testClientServer()
  898. defer client.Close()
  899. defer server.Close()
  900. var n int64 = 1 * 1024 * 1024 * 1024
  901. var workers int = 16
  902. wg := &sync.WaitGroup{}
  903. wg.Add(workers * 2)
  904. for i := 0; i < workers; i++ {
  905. go func() {
  906. defer wg.Done()
  907. stream, err := server.AcceptStream()
  908. if err != nil {
  909. t.Fatalf("err: %v", err)
  910. }
  911. defer stream.Close()
  912. buf := make([]byte, 4)
  913. _, err = stream.Read(buf)
  914. if err != nil {
  915. t.Fatalf("err: %v", err)
  916. }
  917. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  918. t.Fatalf("bad header")
  919. }
  920. recv, err := io.Copy(ioutil.Discard, stream)
  921. if err != nil {
  922. t.Fatalf("err: %v", err)
  923. }
  924. if recv != n {
  925. t.Fatalf("bad: %v", recv)
  926. }
  927. }()
  928. }
  929. for i := 0; i < workers; i++ {
  930. go func() {
  931. defer wg.Done()
  932. stream, err := client.Open()
  933. if err != nil {
  934. t.Fatalf("err: %v", err)
  935. }
  936. defer stream.Close()
  937. _, err = stream.Write([]byte{0, 1, 2, 3})
  938. if err != nil {
  939. t.Fatalf("err: %v", err)
  940. }
  941. unlimited := &UnlimitedReader{}
  942. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  943. if err != nil {
  944. t.Fatalf("err: %v", err)
  945. }
  946. if sent != n {
  947. t.Fatalf("bad: %v", sent)
  948. }
  949. }()
  950. }
  951. doneCh := make(chan struct{})
  952. go func() {
  953. wg.Wait()
  954. close(doneCh)
  955. }()
  956. select {
  957. case <-doneCh:
  958. case <-time.After(20 * time.Second):
  959. panic("timeout")
  960. }
  961. }
  962. func TestBacklogExceeded_Accept(t *testing.T) {
  963. client, server := testClientServer()
  964. defer client.Close()
  965. defer server.Close()
  966. max := 5 * client.config.AcceptBacklog
  967. go func() {
  968. for i := 0; i < max; i++ {
  969. stream, err := server.Accept()
  970. if err != nil {
  971. t.Fatalf("err: %v", err)
  972. }
  973. defer stream.Close()
  974. }
  975. }()
  976. // Fill the backlog
  977. for i := 0; i < max; i++ {
  978. stream, err := client.Open()
  979. if err != nil {
  980. t.Fatalf("err: %v", err)
  981. }
  982. defer stream.Close()
  983. if _, err := stream.Write([]byte("foo")); err != nil {
  984. t.Fatalf("err: %v", err)
  985. }
  986. }
  987. }
  988. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  989. client, server := testClientServerConfig(testConfNoKeepAlive())
  990. defer client.Close()
  991. defer server.Close()
  992. var wg sync.WaitGroup
  993. wg.Add(2)
  994. // Choose a huge flood size that we know will result in a window update.
  995. flood := int64(client.config.MaxStreamWindowSize) - 1
  996. // The server will accept a new stream and then flood data to it.
  997. go func() {
  998. defer wg.Done()
  999. stream, err := server.AcceptStream()
  1000. if err != nil {
  1001. t.Fatalf("err: %v", err)
  1002. }
  1003. defer stream.Close()
  1004. n, err := stream.Write(make([]byte, flood))
  1005. if err != nil {
  1006. t.Fatalf("err: %v", err)
  1007. }
  1008. if int64(n) != flood {
  1009. t.Fatalf("short write: %d", n)
  1010. }
  1011. }()
  1012. // The client will open a stream, block outbound writes, and then
  1013. // listen to the flood from the server, which should time out since
  1014. // it won't be able to send the window update.
  1015. go func() {
  1016. defer wg.Done()
  1017. stream, err := client.OpenStream()
  1018. if err != nil {
  1019. t.Fatalf("err: %v", err)
  1020. }
  1021. defer stream.Close()
  1022. conn := client.conn.(*pipeConn)
  1023. conn.writeBlocker.Lock()
  1024. _, err = stream.Read(make([]byte, flood))
  1025. if err != ErrConnectionWriteTimeout {
  1026. t.Fatalf("err: %v", err)
  1027. }
  1028. }()
  1029. wg.Wait()
  1030. }
  1031. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  1032. client, server := testClientServerConfig(testConfNoKeepAlive())
  1033. defer client.Close()
  1034. defer server.Close()
  1035. var wg sync.WaitGroup
  1036. wg.Add(1)
  1037. // Choose a huge flood size that we know will result in a window update.
  1038. flood := int64(client.config.MaxStreamWindowSize)
  1039. var wr *Stream
  1040. // The server will accept a new stream and then flood data to it.
  1041. go func() {
  1042. defer wg.Done()
  1043. var err error
  1044. wr, err = server.AcceptStream()
  1045. if err != nil {
  1046. t.Fatalf("err: %v", err)
  1047. }
  1048. defer wr.Close()
  1049. if wr.sendWindow != client.config.MaxStreamWindowSize {
  1050. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  1051. }
  1052. n, err := wr.Write(make([]byte, flood))
  1053. if err != nil {
  1054. t.Fatalf("err: %v", err)
  1055. }
  1056. if int64(n) != flood {
  1057. t.Fatalf("short write: %d", n)
  1058. }
  1059. if wr.sendWindow != 0 {
  1060. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  1061. }
  1062. }()
  1063. stream, err := client.OpenStream()
  1064. if err != nil {
  1065. t.Fatalf("err: %v", err)
  1066. }
  1067. defer stream.Close()
  1068. wg.Wait()
  1069. _, err = stream.Read(make([]byte, flood/2+1))
  1070. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1071. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1072. }
  1073. }
  1074. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1075. client, server := testClientServerConfig(testConfNoKeepAlive())
  1076. defer client.Close()
  1077. defer server.Close()
  1078. var wg sync.WaitGroup
  1079. wg.Add(2)
  1080. go func() {
  1081. defer wg.Done()
  1082. stream, err := server.AcceptStream()
  1083. if err != nil {
  1084. t.Fatalf("err: %v", err)
  1085. }
  1086. defer stream.Close()
  1087. }()
  1088. // The client will open the stream and then block outbound writes, we'll
  1089. // probe sendNoWait once it gets into that state.
  1090. go func() {
  1091. defer wg.Done()
  1092. stream, err := client.OpenStream()
  1093. if err != nil {
  1094. t.Fatalf("err: %v", err)
  1095. }
  1096. defer stream.Close()
  1097. conn := client.conn.(*pipeConn)
  1098. conn.writeBlocker.Lock()
  1099. hdr := header(make([]byte, headerSize))
  1100. hdr.encode(typePing, flagACK, 0, 0)
  1101. for {
  1102. err = client.sendNoWait(hdr)
  1103. if err == nil {
  1104. continue
  1105. } else if err == ErrConnectionWriteTimeout {
  1106. break
  1107. } else {
  1108. t.Fatalf("err: %v", err)
  1109. }
  1110. }
  1111. }()
  1112. wg.Wait()
  1113. }
  1114. func TestSession_PingOfDeath(t *testing.T) {
  1115. client, server := testClientServerConfig(testConfNoKeepAlive())
  1116. defer client.Close()
  1117. defer server.Close()
  1118. var wg sync.WaitGroup
  1119. wg.Add(2)
  1120. var doPingOfDeath sync.Mutex
  1121. doPingOfDeath.Lock()
  1122. // This is used later to block outbound writes.
  1123. conn := server.conn.(*pipeConn)
  1124. // The server will accept a stream, block outbound writes, and then
  1125. // flood its send channel so that no more headers can be queued.
  1126. go func() {
  1127. defer wg.Done()
  1128. stream, err := server.AcceptStream()
  1129. if err != nil {
  1130. t.Fatalf("err: %v", err)
  1131. }
  1132. defer stream.Close()
  1133. conn.writeBlocker.Lock()
  1134. for {
  1135. hdr := header(make([]byte, headerSize))
  1136. hdr.encode(typePing, 0, 0, 0)
  1137. err = server.sendNoWait(hdr)
  1138. if err == nil {
  1139. continue
  1140. } else if err == ErrConnectionWriteTimeout {
  1141. break
  1142. } else {
  1143. t.Fatalf("err: %v", err)
  1144. }
  1145. }
  1146. doPingOfDeath.Unlock()
  1147. }()
  1148. // The client will open a stream and then send the server a ping once it
  1149. // can no longer write. This makes sure the server doesn't deadlock reads
  1150. // while trying to reply to the ping with no ability to write.
  1151. go func() {
  1152. defer wg.Done()
  1153. stream, err := client.OpenStream()
  1154. if err != nil {
  1155. t.Fatalf("err: %v", err)
  1156. }
  1157. defer stream.Close()
  1158. // This ping will never unblock because the ping id will never
  1159. // show up in a response.
  1160. doPingOfDeath.Lock()
  1161. go func() { client.Ping() }()
  1162. // Wait for a while to make sure the previous ping times out,
  1163. // then turn writes back on and make sure a ping works again.
  1164. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1165. conn.writeBlocker.Unlock()
  1166. if _, err = client.Ping(); err != nil {
  1167. t.Fatalf("err: %v", err)
  1168. }
  1169. }()
  1170. wg.Wait()
  1171. }
  1172. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1173. client, server := testClientServerConfig(testConfNoKeepAlive())
  1174. defer client.Close()
  1175. defer server.Close()
  1176. var wg sync.WaitGroup
  1177. wg.Add(2)
  1178. go func() {
  1179. defer wg.Done()
  1180. stream, err := server.AcceptStream()
  1181. if err != nil {
  1182. t.Fatalf("err: %v", err)
  1183. }
  1184. defer stream.Close()
  1185. }()
  1186. // The client will open the stream and then block outbound writes, we'll
  1187. // tee up a write and make sure it eventually times out.
  1188. go func() {
  1189. defer wg.Done()
  1190. stream, err := client.OpenStream()
  1191. if err != nil {
  1192. t.Fatalf("err: %v", err)
  1193. }
  1194. defer stream.Close()
  1195. conn := client.conn.(*pipeConn)
  1196. conn.writeBlocker.Lock()
  1197. // Since the write goroutine is blocked then this will return a
  1198. // timeout since it can't get feedback about whether the write
  1199. // worked.
  1200. n, err := stream.Write([]byte("hello"))
  1201. if err != ErrConnectionWriteTimeout {
  1202. t.Fatalf("err: %v", err)
  1203. }
  1204. if n != 0 {
  1205. t.Fatalf("lied about writes: %d", n)
  1206. }
  1207. }()
  1208. wg.Wait()
  1209. }