session_test.go 26 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "reflect"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "testing"
  13. "time"
  14. )
  15. type logCapture struct{ bytes.Buffer }
  16. func (l *logCapture) logs() []string {
  17. return strings.Split(strings.TrimSpace(l.String()), "\n")
  18. }
  19. func (l *logCapture) match(expect []string) bool {
  20. return reflect.DeepEqual(l.logs(), expect)
  21. }
  22. func captureLogs(s *Session) *logCapture {
  23. buf := new(logCapture)
  24. s.logger = log.New(buf, "", 0)
  25. return buf
  26. }
  27. type pipeConn struct {
  28. reader *io.PipeReader
  29. writer *io.PipeWriter
  30. writeBlocker sync.Mutex
  31. }
  32. func (p *pipeConn) Read(b []byte) (int, error) {
  33. return p.reader.Read(b)
  34. }
  35. func (p *pipeConn) Write(b []byte) (int, error) {
  36. p.writeBlocker.Lock()
  37. defer p.writeBlocker.Unlock()
  38. return p.writer.Write(b)
  39. }
  40. func (p *pipeConn) Close() error {
  41. p.reader.Close()
  42. return p.writer.Close()
  43. }
  44. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  45. read1, write1 := io.Pipe()
  46. read2, write2 := io.Pipe()
  47. conn1 := &pipeConn{reader: read1, writer: write2}
  48. conn2 := &pipeConn{reader: read2, writer: write1}
  49. return conn1, conn2
  50. }
  51. func testConf() *Config {
  52. conf := DefaultConfig()
  53. conf.AcceptBacklog = 64
  54. conf.KeepAliveInterval = 100 * time.Millisecond
  55. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  56. return conf
  57. }
  58. func testConfNoKeepAlive() *Config {
  59. conf := testConf()
  60. conf.EnableKeepAlive = false
  61. return conf
  62. }
  63. func testClientServer() (*Session, *Session) {
  64. return testClientServerConfig(testConf())
  65. }
  66. func testClientServerConfig(conf *Config) (*Session, *Session) {
  67. conn1, conn2 := testConn()
  68. client, _ := Client(conn1, conf)
  69. server, _ := Server(conn2, conf)
  70. return client, server
  71. }
  72. func TestPing(t *testing.T) {
  73. client, server := testClientServer()
  74. defer client.Close()
  75. defer server.Close()
  76. rtt, err := client.Ping()
  77. if err != nil {
  78. t.Fatalf("err: %v", err)
  79. }
  80. if rtt == 0 {
  81. t.Fatalf("bad: %v", rtt)
  82. }
  83. rtt, err = server.Ping()
  84. if err != nil {
  85. t.Fatalf("err: %v", err)
  86. }
  87. if rtt == 0 {
  88. t.Fatalf("bad: %v", rtt)
  89. }
  90. }
  91. func TestPing_Timeout(t *testing.T) {
  92. client, server := testClientServerConfig(testConfNoKeepAlive())
  93. defer client.Close()
  94. defer server.Close()
  95. // Prevent the client from responding
  96. clientConn := client.conn.(*pipeConn)
  97. clientConn.writeBlocker.Lock()
  98. errCh := make(chan error, 1)
  99. go func() {
  100. _, err := server.Ping() // Ping via the server session
  101. errCh <- err
  102. }()
  103. select {
  104. case err := <-errCh:
  105. if err != ErrTimeout {
  106. t.Fatalf("err: %v", err)
  107. }
  108. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  109. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  110. }
  111. // Verify that we recover, even if we gave up
  112. clientConn.writeBlocker.Unlock()
  113. go func() {
  114. _, err := server.Ping() // Ping via the server session
  115. errCh <- err
  116. }()
  117. select {
  118. case err := <-errCh:
  119. if err != nil {
  120. t.Fatalf("err: %v", err)
  121. }
  122. case <-time.After(client.config.ConnectionWriteTimeout):
  123. t.Fatalf("timeout")
  124. }
  125. }
  126. func TestCloseBeforeAck(t *testing.T) {
  127. cfg := testConf()
  128. cfg.AcceptBacklog = 8
  129. client, server := testClientServerConfig(cfg)
  130. defer client.Close()
  131. defer server.Close()
  132. for i := 0; i < 8; i++ {
  133. s, err := client.OpenStream()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. s.Close()
  138. }
  139. for i := 0; i < 8; i++ {
  140. s, err := server.AcceptStream()
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. s.Close()
  145. }
  146. done := make(chan struct{})
  147. go func() {
  148. defer close(done)
  149. s, err := client.OpenStream()
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. s.Close()
  154. }()
  155. select {
  156. case <-done:
  157. case <-time.After(time.Second * 5):
  158. t.Fatal("timed out trying to open stream")
  159. }
  160. }
  161. func TestAccept(t *testing.T) {
  162. client, server := testClientServer()
  163. defer client.Close()
  164. defer server.Close()
  165. if client.NumStreams() != 0 {
  166. t.Fatalf("bad")
  167. }
  168. if server.NumStreams() != 0 {
  169. t.Fatalf("bad")
  170. }
  171. wg := &sync.WaitGroup{}
  172. wg.Add(4)
  173. go func() {
  174. defer wg.Done()
  175. stream, err := server.AcceptStream()
  176. if err != nil {
  177. t.Fatalf("err: %v", err)
  178. }
  179. if id := stream.StreamID(); id != 1 {
  180. t.Fatalf("bad: %v", id)
  181. }
  182. if err := stream.Close(); err != nil {
  183. t.Fatalf("err: %v", err)
  184. }
  185. }()
  186. go func() {
  187. defer wg.Done()
  188. stream, err := client.AcceptStream()
  189. if err != nil {
  190. t.Fatalf("err: %v", err)
  191. }
  192. if id := stream.StreamID(); id != 2 {
  193. t.Fatalf("bad: %v", id)
  194. }
  195. if err := stream.Close(); err != nil {
  196. t.Fatalf("err: %v", err)
  197. }
  198. }()
  199. go func() {
  200. defer wg.Done()
  201. stream, err := server.OpenStream()
  202. if err != nil {
  203. t.Fatalf("err: %v", err)
  204. }
  205. if id := stream.StreamID(); id != 2 {
  206. t.Fatalf("bad: %v", id)
  207. }
  208. if err := stream.Close(); err != nil {
  209. t.Fatalf("err: %v", err)
  210. }
  211. }()
  212. go func() {
  213. defer wg.Done()
  214. stream, err := client.OpenStream()
  215. if err != nil {
  216. t.Fatalf("err: %v", err)
  217. }
  218. if id := stream.StreamID(); id != 1 {
  219. t.Fatalf("bad: %v", id)
  220. }
  221. if err := stream.Close(); err != nil {
  222. t.Fatalf("err: %v", err)
  223. }
  224. }()
  225. doneCh := make(chan struct{})
  226. go func() {
  227. wg.Wait()
  228. close(doneCh)
  229. }()
  230. select {
  231. case <-doneCh:
  232. case <-time.After(time.Second):
  233. panic("timeout")
  234. }
  235. }
  236. func TestNonNilInterface(t *testing.T) {
  237. _, server := testClientServer()
  238. server.Close()
  239. conn, err := server.Accept()
  240. if err != nil && conn != nil {
  241. t.Error("bad: accept should return a connection of nil value")
  242. }
  243. conn, err = server.Open()
  244. if err != nil && conn != nil {
  245. t.Error("bad: open should return a connection of nil value")
  246. }
  247. }
  248. func TestSendData_Small(t *testing.T) {
  249. client, server := testClientServer()
  250. defer client.Close()
  251. defer server.Close()
  252. wg := &sync.WaitGroup{}
  253. wg.Add(2)
  254. go func() {
  255. defer wg.Done()
  256. stream, err := server.AcceptStream()
  257. if err != nil {
  258. t.Fatalf("err: %v", err)
  259. }
  260. if server.NumStreams() != 1 {
  261. t.Fatalf("bad")
  262. }
  263. buf := make([]byte, 4)
  264. for i := 0; i < 1000; i++ {
  265. n, err := stream.Read(buf)
  266. if err != nil {
  267. t.Fatalf("err: %v", err)
  268. }
  269. if n != 4 {
  270. t.Fatalf("short read: %d", n)
  271. }
  272. if string(buf) != "test" {
  273. t.Fatalf("bad: %s", buf)
  274. }
  275. }
  276. if err := stream.Close(); err != nil {
  277. t.Fatalf("err: %v", err)
  278. }
  279. }()
  280. go func() {
  281. defer wg.Done()
  282. stream, err := client.Open()
  283. if err != nil {
  284. t.Fatalf("err: %v", err)
  285. }
  286. if client.NumStreams() != 1 {
  287. t.Fatalf("bad")
  288. }
  289. for i := 0; i < 1000; i++ {
  290. n, err := stream.Write([]byte("test"))
  291. if err != nil {
  292. t.Fatalf("err: %v", err)
  293. }
  294. if n != 4 {
  295. t.Fatalf("short write %d", n)
  296. }
  297. }
  298. if err := stream.Close(); err != nil {
  299. t.Fatalf("err: %v", err)
  300. }
  301. }()
  302. doneCh := make(chan struct{})
  303. go func() {
  304. wg.Wait()
  305. close(doneCh)
  306. }()
  307. select {
  308. case <-doneCh:
  309. case <-time.After(time.Second):
  310. panic("timeout")
  311. }
  312. if client.NumStreams() != 0 {
  313. t.Fatalf("bad")
  314. }
  315. if server.NumStreams() != 0 {
  316. t.Fatalf("bad")
  317. }
  318. }
  319. func TestSendData_Large(t *testing.T) {
  320. client, server := testClientServer()
  321. defer client.Close()
  322. defer server.Close()
  323. const (
  324. sendSize = 250 * 1024 * 1024
  325. recvSize = 4 * 1024
  326. )
  327. data := make([]byte, sendSize)
  328. for idx := range data {
  329. data[idx] = byte(idx % 256)
  330. }
  331. wg := &sync.WaitGroup{}
  332. wg.Add(2)
  333. go func() {
  334. defer wg.Done()
  335. stream, err := server.AcceptStream()
  336. if err != nil {
  337. t.Fatalf("err: %v", err)
  338. }
  339. var sz int
  340. buf := make([]byte, recvSize)
  341. for i := 0; i < sendSize/recvSize; i++ {
  342. n, err := stream.Read(buf)
  343. if err != nil {
  344. t.Fatalf("err: %v", err)
  345. }
  346. if n != recvSize {
  347. t.Fatalf("short read: %d", n)
  348. }
  349. sz += n
  350. for idx := range buf {
  351. if buf[idx] != byte(idx%256) {
  352. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  353. }
  354. }
  355. }
  356. if err := stream.Close(); err != nil {
  357. t.Fatalf("err: %v", err)
  358. }
  359. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  360. }()
  361. go func() {
  362. defer wg.Done()
  363. stream, err := client.Open()
  364. if err != nil {
  365. t.Fatalf("err: %v", err)
  366. }
  367. n, err := stream.Write(data)
  368. if err != nil {
  369. t.Fatalf("err: %v", err)
  370. }
  371. if n != len(data) {
  372. t.Fatalf("short write %d", n)
  373. }
  374. if err := stream.Close(); err != nil {
  375. t.Fatalf("err: %v", err)
  376. }
  377. }()
  378. doneCh := make(chan struct{})
  379. go func() {
  380. wg.Wait()
  381. close(doneCh)
  382. }()
  383. select {
  384. case <-doneCh:
  385. case <-time.After(5 * time.Second):
  386. panic("timeout")
  387. }
  388. }
  389. func TestGoAway(t *testing.T) {
  390. client, server := testClientServer()
  391. defer client.Close()
  392. defer server.Close()
  393. if err := server.GoAway(); err != nil {
  394. t.Fatalf("err: %v", err)
  395. }
  396. _, err := client.Open()
  397. if err != ErrRemoteGoAway {
  398. t.Fatalf("err: %v", err)
  399. }
  400. }
  401. func TestManyStreams(t *testing.T) {
  402. client, server := testClientServer()
  403. defer client.Close()
  404. defer server.Close()
  405. wg := &sync.WaitGroup{}
  406. acceptor := func(i int) {
  407. defer wg.Done()
  408. stream, err := server.AcceptStream()
  409. if err != nil {
  410. t.Fatalf("err: %v", err)
  411. }
  412. defer stream.Close()
  413. buf := make([]byte, 512)
  414. for {
  415. n, err := stream.Read(buf)
  416. if err == io.EOF {
  417. return
  418. }
  419. if err != nil {
  420. t.Fatalf("err: %v", err)
  421. }
  422. if n == 0 {
  423. t.Fatalf("err: %v", err)
  424. }
  425. }
  426. }
  427. sender := func(i int) {
  428. defer wg.Done()
  429. stream, err := client.Open()
  430. if err != nil {
  431. t.Fatalf("err: %v", err)
  432. }
  433. defer stream.Close()
  434. msg := fmt.Sprintf("%08d", i)
  435. for i := 0; i < 1000; i++ {
  436. n, err := stream.Write([]byte(msg))
  437. if err != nil {
  438. t.Fatalf("err: %v", err)
  439. }
  440. if n != len(msg) {
  441. t.Fatalf("short write %d", n)
  442. }
  443. }
  444. }
  445. for i := 0; i < 50; i++ {
  446. wg.Add(2)
  447. go acceptor(i)
  448. go sender(i)
  449. }
  450. wg.Wait()
  451. }
  452. func TestManyStreams_PingPong(t *testing.T) {
  453. client, server := testClientServer()
  454. defer client.Close()
  455. defer server.Close()
  456. wg := &sync.WaitGroup{}
  457. ping := []byte("ping")
  458. pong := []byte("pong")
  459. acceptor := func(i int) {
  460. defer wg.Done()
  461. stream, err := server.AcceptStream()
  462. if err != nil {
  463. t.Fatalf("err: %v", err)
  464. }
  465. defer stream.Close()
  466. buf := make([]byte, 4)
  467. for {
  468. // Read the 'ping'
  469. n, err := stream.Read(buf)
  470. if err == io.EOF {
  471. return
  472. }
  473. if err != nil {
  474. t.Fatalf("err: %v", err)
  475. }
  476. if n != 4 {
  477. t.Fatalf("err: %v", err)
  478. }
  479. if !bytes.Equal(buf, ping) {
  480. t.Fatalf("bad: %s", buf)
  481. }
  482. // Shrink the internal buffer!
  483. stream.Shrink()
  484. // Write out the 'pong'
  485. n, err = stream.Write(pong)
  486. if err != nil {
  487. t.Fatalf("err: %v", err)
  488. }
  489. if n != 4 {
  490. t.Fatalf("err: %v", err)
  491. }
  492. }
  493. }
  494. sender := func(i int) {
  495. defer wg.Done()
  496. stream, err := client.OpenStream()
  497. if err != nil {
  498. t.Fatalf("err: %v", err)
  499. }
  500. defer stream.Close()
  501. buf := make([]byte, 4)
  502. for i := 0; i < 1000; i++ {
  503. // Send the 'ping'
  504. n, err := stream.Write(ping)
  505. if err != nil {
  506. t.Fatalf("err: %v", err)
  507. }
  508. if n != 4 {
  509. t.Fatalf("short write %d", n)
  510. }
  511. // Read the 'pong'
  512. n, err = stream.Read(buf)
  513. if err != nil {
  514. t.Fatalf("err: %v", err)
  515. }
  516. if n != 4 {
  517. t.Fatalf("err: %v", err)
  518. }
  519. if !bytes.Equal(buf, pong) {
  520. t.Fatalf("bad: %s", buf)
  521. }
  522. // Shrink the buffer
  523. stream.Shrink()
  524. }
  525. }
  526. for i := 0; i < 50; i++ {
  527. wg.Add(2)
  528. go acceptor(i)
  529. go sender(i)
  530. }
  531. wg.Wait()
  532. }
  533. func TestHalfClose(t *testing.T) {
  534. client, server := testClientServer()
  535. defer client.Close()
  536. defer server.Close()
  537. stream, err := client.Open()
  538. if err != nil {
  539. t.Fatalf("err: %v", err)
  540. }
  541. if _, err = stream.Write([]byte("a")); err != nil {
  542. t.Fatalf("err: %v", err)
  543. }
  544. stream2, err := server.Accept()
  545. if err != nil {
  546. t.Fatalf("err: %v", err)
  547. }
  548. stream2.Close() // Half close
  549. buf := make([]byte, 4)
  550. n, err := stream2.Read(buf)
  551. if err != nil {
  552. t.Fatalf("err: %v", err)
  553. }
  554. if n != 1 {
  555. t.Fatalf("bad: %v", n)
  556. }
  557. // Send more
  558. if _, err = stream.Write([]byte("bcd")); err != nil {
  559. t.Fatalf("err: %v", err)
  560. }
  561. stream.Close()
  562. // Read after close
  563. n, err = stream2.Read(buf)
  564. if err != nil {
  565. t.Fatalf("err: %v", err)
  566. }
  567. if n != 3 {
  568. t.Fatalf("bad: %v", n)
  569. }
  570. // EOF after close
  571. n, err = stream2.Read(buf)
  572. if err != io.EOF {
  573. t.Fatalf("err: %v", err)
  574. }
  575. if n != 0 {
  576. t.Fatalf("bad: %v", n)
  577. }
  578. }
  579. func TestReadDeadline(t *testing.T) {
  580. client, server := testClientServer()
  581. defer client.Close()
  582. defer server.Close()
  583. stream, err := client.Open()
  584. if err != nil {
  585. t.Fatalf("err: %v", err)
  586. }
  587. defer stream.Close()
  588. stream2, err := server.Accept()
  589. if err != nil {
  590. t.Fatalf("err: %v", err)
  591. }
  592. defer stream2.Close()
  593. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  594. t.Fatalf("err: %v", err)
  595. }
  596. buf := make([]byte, 4)
  597. if _, err := stream.Read(buf); err != ErrTimeout {
  598. t.Fatalf("err: %v", err)
  599. }
  600. }
  601. func TestReadDeadline_BlockedRead(t *testing.T) {
  602. client, server := testClientServer()
  603. defer client.Close()
  604. defer server.Close()
  605. stream, err := client.Open()
  606. if err != nil {
  607. t.Fatalf("err: %v", err)
  608. }
  609. defer stream.Close()
  610. stream2, err := server.Accept()
  611. if err != nil {
  612. t.Fatalf("err: %v", err)
  613. }
  614. defer stream2.Close()
  615. // Start a read that will block
  616. errCh := make(chan error, 1)
  617. go func() {
  618. buf := make([]byte, 4)
  619. _, err := stream.Read(buf)
  620. errCh <- err
  621. close(errCh)
  622. }()
  623. // Wait to ensure the read has started.
  624. time.Sleep(5 * time.Millisecond)
  625. // Update the read deadline
  626. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  627. t.Fatalf("err: %v", err)
  628. }
  629. select {
  630. case <-time.After(100 * time.Millisecond):
  631. t.Fatal("expected read timeout")
  632. case err := <-errCh:
  633. if err != ErrTimeout {
  634. t.Fatalf("expected ErrTimeout; got %v", err)
  635. }
  636. }
  637. }
  638. func TestWriteDeadline(t *testing.T) {
  639. client, server := testClientServer()
  640. defer client.Close()
  641. defer server.Close()
  642. stream, err := client.Open()
  643. if err != nil {
  644. t.Fatalf("err: %v", err)
  645. }
  646. defer stream.Close()
  647. stream2, err := server.Accept()
  648. if err != nil {
  649. t.Fatalf("err: %v", err)
  650. }
  651. defer stream2.Close()
  652. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  653. t.Fatalf("err: %v", err)
  654. }
  655. buf := make([]byte, 512)
  656. for i := 0; i < int(initialStreamWindow); i++ {
  657. _, err := stream.Write(buf)
  658. if err != nil && err == ErrTimeout {
  659. return
  660. } else if err != nil {
  661. t.Fatalf("err: %v", err)
  662. }
  663. }
  664. t.Fatalf("Expected timeout")
  665. }
  666. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  667. client, server := testClientServer()
  668. defer client.Close()
  669. defer server.Close()
  670. stream, err := client.Open()
  671. if err != nil {
  672. t.Fatalf("err: %v", err)
  673. }
  674. defer stream.Close()
  675. stream2, err := server.Accept()
  676. if err != nil {
  677. t.Fatalf("err: %v", err)
  678. }
  679. defer stream2.Close()
  680. // Start a goroutine making writes that will block
  681. errCh := make(chan error, 1)
  682. go func() {
  683. buf := make([]byte, 512)
  684. for i := 0; i < int(initialStreamWindow); i++ {
  685. _, err := stream.Write(buf)
  686. if err == nil {
  687. continue
  688. }
  689. errCh <- err
  690. close(errCh)
  691. return
  692. }
  693. close(errCh)
  694. }()
  695. // Wait to ensure the write has started.
  696. time.Sleep(5 * time.Millisecond)
  697. // Update the write deadline
  698. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  699. t.Fatalf("err: %v", err)
  700. }
  701. select {
  702. case <-time.After(1 * time.Second):
  703. t.Fatal("expected write timeout")
  704. case err := <-errCh:
  705. if err != ErrTimeout {
  706. t.Fatalf("expected ErrTimeout; got %v", err)
  707. }
  708. }
  709. }
  710. func TestBacklogExceeded(t *testing.T) {
  711. client, server := testClientServer()
  712. defer client.Close()
  713. defer server.Close()
  714. // Fill the backlog
  715. max := client.config.AcceptBacklog
  716. for i := 0; i < max; i++ {
  717. stream, err := client.Open()
  718. if err != nil {
  719. t.Fatalf("err: %v", err)
  720. }
  721. defer stream.Close()
  722. if _, err := stream.Write([]byte("foo")); err != nil {
  723. t.Fatalf("err: %v", err)
  724. }
  725. }
  726. // Attempt to open a new stream
  727. errCh := make(chan error, 1)
  728. go func() {
  729. _, err := client.Open()
  730. errCh <- err
  731. }()
  732. // Shutdown the server
  733. go func() {
  734. time.Sleep(10 * time.Millisecond)
  735. server.Close()
  736. }()
  737. select {
  738. case err := <-errCh:
  739. if err == nil {
  740. t.Fatalf("open should fail")
  741. }
  742. case <-time.After(time.Second):
  743. t.Fatalf("timeout")
  744. }
  745. }
  746. func TestKeepAlive(t *testing.T) {
  747. client, server := testClientServer()
  748. defer client.Close()
  749. defer server.Close()
  750. time.Sleep(200 * time.Millisecond)
  751. // Ping value should increase
  752. client.pingLock.Lock()
  753. defer client.pingLock.Unlock()
  754. if client.pingID == 0 {
  755. t.Fatalf("should ping")
  756. }
  757. server.pingLock.Lock()
  758. defer server.pingLock.Unlock()
  759. if server.pingID == 0 {
  760. t.Fatalf("should ping")
  761. }
  762. }
  763. func TestKeepAlive_Timeout(t *testing.T) {
  764. conn1, conn2 := testConn()
  765. clientConf := testConf()
  766. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  767. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  768. client, _ := Client(conn1, clientConf)
  769. defer client.Close()
  770. server, _ := Server(conn2, testConf())
  771. defer server.Close()
  772. _ = captureLogs(client) // Client logs aren't part of the test
  773. serverLogs := captureLogs(server)
  774. errCh := make(chan error, 1)
  775. go func() {
  776. _, err := server.Accept() // Wait until server closes
  777. errCh <- err
  778. }()
  779. // Prevent the client from responding
  780. clientConn := client.conn.(*pipeConn)
  781. clientConn.writeBlocker.Lock()
  782. select {
  783. case err := <-errCh:
  784. if err != ErrKeepAliveTimeout {
  785. t.Fatalf("unexpected error: %v", err)
  786. }
  787. case <-time.After(1 * time.Second):
  788. t.Fatalf("timeout waiting for timeout")
  789. }
  790. if !server.IsClosed() {
  791. t.Fatalf("server should have closed")
  792. }
  793. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  794. t.Fatalf("server log incorect: %v", serverLogs.logs())
  795. }
  796. }
  797. func TestLargeWindow(t *testing.T) {
  798. conf := DefaultConfig()
  799. conf.MaxStreamWindowSize *= 2
  800. client, server := testClientServerConfig(conf)
  801. defer client.Close()
  802. defer server.Close()
  803. stream, err := client.Open()
  804. if err != nil {
  805. t.Fatalf("err: %v", err)
  806. }
  807. defer stream.Close()
  808. stream2, err := server.Accept()
  809. if err != nil {
  810. t.Fatalf("err: %v", err)
  811. }
  812. defer stream2.Close()
  813. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  814. buf := make([]byte, conf.MaxStreamWindowSize)
  815. n, err := stream.Write(buf)
  816. if err != nil {
  817. t.Fatalf("err: %v", err)
  818. }
  819. if n != len(buf) {
  820. t.Fatalf("short write: %d", n)
  821. }
  822. }
  823. type UnlimitedReader struct{}
  824. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  825. runtime.Gosched()
  826. return len(p), nil
  827. }
  828. func TestSendData_VeryLarge(t *testing.T) {
  829. client, server := testClientServer()
  830. defer client.Close()
  831. defer server.Close()
  832. var n int64 = 1 * 1024 * 1024 * 1024
  833. var workers int = 16
  834. wg := &sync.WaitGroup{}
  835. wg.Add(workers * 2)
  836. for i := 0; i < workers; i++ {
  837. go func() {
  838. defer wg.Done()
  839. stream, err := server.AcceptStream()
  840. if err != nil {
  841. t.Fatalf("err: %v", err)
  842. }
  843. defer stream.Close()
  844. buf := make([]byte, 4)
  845. _, err = stream.Read(buf)
  846. if err != nil {
  847. t.Fatalf("err: %v", err)
  848. }
  849. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  850. t.Fatalf("bad header")
  851. }
  852. recv, err := io.Copy(ioutil.Discard, stream)
  853. if err != nil {
  854. t.Fatalf("err: %v", err)
  855. }
  856. if recv != n {
  857. t.Fatalf("bad: %v", recv)
  858. }
  859. }()
  860. }
  861. for i := 0; i < workers; i++ {
  862. go func() {
  863. defer wg.Done()
  864. stream, err := client.Open()
  865. if err != nil {
  866. t.Fatalf("err: %v", err)
  867. }
  868. defer stream.Close()
  869. _, err = stream.Write([]byte{0, 1, 2, 3})
  870. if err != nil {
  871. t.Fatalf("err: %v", err)
  872. }
  873. unlimited := &UnlimitedReader{}
  874. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  875. if err != nil {
  876. t.Fatalf("err: %v", err)
  877. }
  878. if sent != n {
  879. t.Fatalf("bad: %v", sent)
  880. }
  881. }()
  882. }
  883. doneCh := make(chan struct{})
  884. go func() {
  885. wg.Wait()
  886. close(doneCh)
  887. }()
  888. select {
  889. case <-doneCh:
  890. case <-time.After(20 * time.Second):
  891. panic("timeout")
  892. }
  893. }
  894. func TestBacklogExceeded_Accept(t *testing.T) {
  895. client, server := testClientServer()
  896. defer client.Close()
  897. defer server.Close()
  898. max := 5 * client.config.AcceptBacklog
  899. go func() {
  900. for i := 0; i < max; i++ {
  901. stream, err := server.Accept()
  902. if err != nil {
  903. t.Fatalf("err: %v", err)
  904. }
  905. defer stream.Close()
  906. }
  907. }()
  908. // Fill the backlog
  909. for i := 0; i < max; i++ {
  910. stream, err := client.Open()
  911. if err != nil {
  912. t.Fatalf("err: %v", err)
  913. }
  914. defer stream.Close()
  915. if _, err := stream.Write([]byte("foo")); err != nil {
  916. t.Fatalf("err: %v", err)
  917. }
  918. }
  919. }
  920. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  921. client, server := testClientServerConfig(testConfNoKeepAlive())
  922. defer client.Close()
  923. defer server.Close()
  924. var wg sync.WaitGroup
  925. wg.Add(2)
  926. // Choose a huge flood size that we know will result in a window update.
  927. flood := int64(client.config.MaxStreamWindowSize) - 1
  928. // The server will accept a new stream and then flood data to it.
  929. go func() {
  930. defer wg.Done()
  931. stream, err := server.AcceptStream()
  932. if err != nil {
  933. t.Fatalf("err: %v", err)
  934. }
  935. defer stream.Close()
  936. n, err := stream.Write(make([]byte, flood))
  937. if err != nil {
  938. t.Fatalf("err: %v", err)
  939. }
  940. if int64(n) != flood {
  941. t.Fatalf("short write: %d", n)
  942. }
  943. }()
  944. // The client will open a stream, block outbound writes, and then
  945. // listen to the flood from the server, which should time out since
  946. // it won't be able to send the window update.
  947. go func() {
  948. defer wg.Done()
  949. stream, err := client.OpenStream()
  950. if err != nil {
  951. t.Fatalf("err: %v", err)
  952. }
  953. defer stream.Close()
  954. conn := client.conn.(*pipeConn)
  955. conn.writeBlocker.Lock()
  956. _, err = stream.Read(make([]byte, flood))
  957. if err != ErrConnectionWriteTimeout {
  958. t.Fatalf("err: %v", err)
  959. }
  960. }()
  961. wg.Wait()
  962. }
  963. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  964. client, server := testClientServerConfig(testConfNoKeepAlive())
  965. defer client.Close()
  966. defer server.Close()
  967. var wg sync.WaitGroup
  968. wg.Add(1)
  969. // Choose a huge flood size that we know will result in a window update.
  970. flood := int64(client.config.MaxStreamWindowSize)
  971. var wr *Stream
  972. // The server will accept a new stream and then flood data to it.
  973. go func() {
  974. defer wg.Done()
  975. var err error
  976. wr, err = server.AcceptStream()
  977. if err != nil {
  978. t.Fatalf("err: %v", err)
  979. }
  980. defer wr.Close()
  981. if wr.sendWindow != client.config.MaxStreamWindowSize {
  982. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  983. }
  984. n, err := wr.Write(make([]byte, flood))
  985. if err != nil {
  986. t.Fatalf("err: %v", err)
  987. }
  988. if int64(n) != flood {
  989. t.Fatalf("short write: %d", n)
  990. }
  991. if wr.sendWindow != 0 {
  992. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  993. }
  994. }()
  995. stream, err := client.OpenStream()
  996. if err != nil {
  997. t.Fatalf("err: %v", err)
  998. }
  999. defer stream.Close()
  1000. wg.Wait()
  1001. _, err = stream.Read(make([]byte, flood/2+1))
  1002. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1003. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1004. }
  1005. }
  1006. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1007. client, server := testClientServerConfig(testConfNoKeepAlive())
  1008. defer client.Close()
  1009. defer server.Close()
  1010. var wg sync.WaitGroup
  1011. wg.Add(2)
  1012. go func() {
  1013. defer wg.Done()
  1014. stream, err := server.AcceptStream()
  1015. if err != nil {
  1016. t.Fatalf("err: %v", err)
  1017. }
  1018. defer stream.Close()
  1019. }()
  1020. // The client will open the stream and then block outbound writes, we'll
  1021. // probe sendNoWait once it gets into that state.
  1022. go func() {
  1023. defer wg.Done()
  1024. stream, err := client.OpenStream()
  1025. if err != nil {
  1026. t.Fatalf("err: %v", err)
  1027. }
  1028. defer stream.Close()
  1029. conn := client.conn.(*pipeConn)
  1030. conn.writeBlocker.Lock()
  1031. hdr := header(make([]byte, headerSize))
  1032. hdr.encode(typePing, flagACK, 0, 0)
  1033. for {
  1034. err = client.sendNoWait(hdr)
  1035. if err == nil {
  1036. continue
  1037. } else if err == ErrConnectionWriteTimeout {
  1038. break
  1039. } else {
  1040. t.Fatalf("err: %v", err)
  1041. }
  1042. }
  1043. }()
  1044. wg.Wait()
  1045. }
  1046. func TestSession_PingOfDeath(t *testing.T) {
  1047. client, server := testClientServerConfig(testConfNoKeepAlive())
  1048. defer client.Close()
  1049. defer server.Close()
  1050. var wg sync.WaitGroup
  1051. wg.Add(2)
  1052. var doPingOfDeath sync.Mutex
  1053. doPingOfDeath.Lock()
  1054. // This is used later to block outbound writes.
  1055. conn := server.conn.(*pipeConn)
  1056. // The server will accept a stream, block outbound writes, and then
  1057. // flood its send channel so that no more headers can be queued.
  1058. go func() {
  1059. defer wg.Done()
  1060. stream, err := server.AcceptStream()
  1061. if err != nil {
  1062. t.Fatalf("err: %v", err)
  1063. }
  1064. defer stream.Close()
  1065. conn.writeBlocker.Lock()
  1066. for {
  1067. hdr := header(make([]byte, headerSize))
  1068. hdr.encode(typePing, 0, 0, 0)
  1069. err = server.sendNoWait(hdr)
  1070. if err == nil {
  1071. continue
  1072. } else if err == ErrConnectionWriteTimeout {
  1073. break
  1074. } else {
  1075. t.Fatalf("err: %v", err)
  1076. }
  1077. }
  1078. doPingOfDeath.Unlock()
  1079. }()
  1080. // The client will open a stream and then send the server a ping once it
  1081. // can no longer write. This makes sure the server doesn't deadlock reads
  1082. // while trying to reply to the ping with no ability to write.
  1083. go func() {
  1084. defer wg.Done()
  1085. stream, err := client.OpenStream()
  1086. if err != nil {
  1087. t.Fatalf("err: %v", err)
  1088. }
  1089. defer stream.Close()
  1090. // This ping will never unblock because the ping id will never
  1091. // show up in a response.
  1092. doPingOfDeath.Lock()
  1093. go func() { client.Ping() }()
  1094. // Wait for a while to make sure the previous ping times out,
  1095. // then turn writes back on and make sure a ping works again.
  1096. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1097. conn.writeBlocker.Unlock()
  1098. if _, err = client.Ping(); err != nil {
  1099. t.Fatalf("err: %v", err)
  1100. }
  1101. }()
  1102. wg.Wait()
  1103. }
  1104. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1105. client, server := testClientServerConfig(testConfNoKeepAlive())
  1106. defer client.Close()
  1107. defer server.Close()
  1108. var wg sync.WaitGroup
  1109. wg.Add(2)
  1110. go func() {
  1111. defer wg.Done()
  1112. stream, err := server.AcceptStream()
  1113. if err != nil {
  1114. t.Fatalf("err: %v", err)
  1115. }
  1116. defer stream.Close()
  1117. }()
  1118. // The client will open the stream and then block outbound writes, we'll
  1119. // tee up a write and make sure it eventually times out.
  1120. go func() {
  1121. defer wg.Done()
  1122. stream, err := client.OpenStream()
  1123. if err != nil {
  1124. t.Fatalf("err: %v", err)
  1125. }
  1126. defer stream.Close()
  1127. conn := client.conn.(*pipeConn)
  1128. conn.writeBlocker.Lock()
  1129. // Since the write goroutine is blocked then this will return a
  1130. // timeout since it can't get feedback about whether the write
  1131. // worked.
  1132. n, err := stream.Write([]byte("hello"))
  1133. if err != ErrConnectionWriteTimeout {
  1134. t.Fatalf("err: %v", err)
  1135. }
  1136. if n != 0 {
  1137. t.Fatalf("lied about writes: %d", n)
  1138. }
  1139. }()
  1140. wg.Wait()
  1141. }