session_test.go 28 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "reflect"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "testing"
  13. "time"
  14. )
  15. type logCapture struct{ bytes.Buffer }
  16. func (l *logCapture) logs() []string {
  17. return strings.Split(strings.TrimSpace(l.String()), "\n")
  18. }
  19. func (l *logCapture) match(expect []string) bool {
  20. return reflect.DeepEqual(l.logs(), expect)
  21. }
  22. func captureLogs(s *Session) *logCapture {
  23. buf := new(logCapture)
  24. s.logger = log.New(buf, "", 0)
  25. return buf
  26. }
  27. type pipeConn struct {
  28. reader *io.PipeReader
  29. writer *io.PipeWriter
  30. writeBlocker sync.Mutex
  31. }
  32. func (p *pipeConn) Read(b []byte) (int, error) {
  33. return p.reader.Read(b)
  34. }
  35. func (p *pipeConn) Write(b []byte) (int, error) {
  36. p.writeBlocker.Lock()
  37. defer p.writeBlocker.Unlock()
  38. return p.writer.Write(b)
  39. }
  40. func (p *pipeConn) Close() error {
  41. p.reader.Close()
  42. return p.writer.Close()
  43. }
  44. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  45. read1, write1 := io.Pipe()
  46. read2, write2 := io.Pipe()
  47. conn1 := &pipeConn{reader: read1, writer: write2}
  48. conn2 := &pipeConn{reader: read2, writer: write1}
  49. return conn1, conn2
  50. }
  51. func testConf() *Config {
  52. conf := DefaultConfig()
  53. conf.AcceptBacklog = 64
  54. conf.KeepAliveInterval = 100 * time.Millisecond
  55. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  56. return conf
  57. }
  58. func testConfNoKeepAlive() *Config {
  59. conf := testConf()
  60. conf.EnableKeepAlive = false
  61. return conf
  62. }
  63. func testClientServer() (*Session, *Session) {
  64. return testClientServerConfig(testConf())
  65. }
  66. func testClientServerConfig(conf *Config) (*Session, *Session) {
  67. conn1, conn2 := testConn()
  68. client, _ := Client(conn1, conf)
  69. server, _ := Server(conn2, conf)
  70. return client, server
  71. }
  72. func TestPing(t *testing.T) {
  73. client, server := testClientServer()
  74. defer client.Close()
  75. defer server.Close()
  76. rtt, err := client.Ping()
  77. if err != nil {
  78. t.Fatalf("err: %v", err)
  79. }
  80. if rtt == 0 {
  81. t.Fatalf("bad: %v", rtt)
  82. }
  83. rtt, err = server.Ping()
  84. if err != nil {
  85. t.Fatalf("err: %v", err)
  86. }
  87. if rtt == 0 {
  88. t.Fatalf("bad: %v", rtt)
  89. }
  90. }
  91. func TestPing_Timeout(t *testing.T) {
  92. client, server := testClientServerConfig(testConfNoKeepAlive())
  93. defer client.Close()
  94. defer server.Close()
  95. // Prevent the client from responding
  96. clientConn := client.conn.(*pipeConn)
  97. clientConn.writeBlocker.Lock()
  98. errCh := make(chan error, 1)
  99. go func() {
  100. _, err := server.Ping() // Ping via the server session
  101. errCh <- err
  102. }()
  103. select {
  104. case err := <-errCh:
  105. if err != ErrTimeout {
  106. t.Fatalf("err: %v", err)
  107. }
  108. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  109. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  110. }
  111. // Verify that we recover, even if we gave up
  112. clientConn.writeBlocker.Unlock()
  113. go func() {
  114. _, err := server.Ping() // Ping via the server session
  115. errCh <- err
  116. }()
  117. select {
  118. case err := <-errCh:
  119. if err != nil {
  120. t.Fatalf("err: %v", err)
  121. }
  122. case <-time.After(client.config.ConnectionWriteTimeout):
  123. t.Fatalf("timeout")
  124. }
  125. }
  126. func TestCloseBeforeAck(t *testing.T) {
  127. cfg := testConf()
  128. cfg.AcceptBacklog = 8
  129. client, server := testClientServerConfig(cfg)
  130. defer client.Close()
  131. defer server.Close()
  132. for i := 0; i < 8; i++ {
  133. s, err := client.OpenStream()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. s.Close()
  138. }
  139. for i := 0; i < 8; i++ {
  140. s, err := server.AcceptStream()
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. s.Close()
  145. }
  146. done := make(chan struct{})
  147. go func() {
  148. defer close(done)
  149. s, err := client.OpenStream()
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. s.Close()
  154. }()
  155. select {
  156. case <-done:
  157. case <-time.After(time.Second * 5):
  158. t.Fatal("timed out trying to open stream")
  159. }
  160. }
  161. func TestAccept(t *testing.T) {
  162. client, server := testClientServer()
  163. defer client.Close()
  164. defer server.Close()
  165. if client.NumStreams() != 0 {
  166. t.Fatalf("bad")
  167. }
  168. if server.NumStreams() != 0 {
  169. t.Fatalf("bad")
  170. }
  171. wg := &sync.WaitGroup{}
  172. wg.Add(4)
  173. go func() {
  174. defer wg.Done()
  175. stream, err := server.AcceptStream()
  176. if err != nil {
  177. t.Fatalf("err: %v", err)
  178. }
  179. if id := stream.StreamID(); id != 1 {
  180. t.Fatalf("bad: %v", id)
  181. }
  182. if err := stream.Close(); err != nil {
  183. t.Fatalf("err: %v", err)
  184. }
  185. }()
  186. go func() {
  187. defer wg.Done()
  188. stream, err := client.AcceptStream()
  189. if err != nil {
  190. t.Fatalf("err: %v", err)
  191. }
  192. if id := stream.StreamID(); id != 2 {
  193. t.Fatalf("bad: %v", id)
  194. }
  195. if err := stream.Close(); err != nil {
  196. t.Fatalf("err: %v", err)
  197. }
  198. }()
  199. go func() {
  200. defer wg.Done()
  201. stream, err := server.OpenStream()
  202. if err != nil {
  203. t.Fatalf("err: %v", err)
  204. }
  205. if id := stream.StreamID(); id != 2 {
  206. t.Fatalf("bad: %v", id)
  207. }
  208. if err := stream.Close(); err != nil {
  209. t.Fatalf("err: %v", err)
  210. }
  211. }()
  212. go func() {
  213. defer wg.Done()
  214. stream, err := client.OpenStream()
  215. if err != nil {
  216. t.Fatalf("err: %v", err)
  217. }
  218. if id := stream.StreamID(); id != 1 {
  219. t.Fatalf("bad: %v", id)
  220. }
  221. if err := stream.Close(); err != nil {
  222. t.Fatalf("err: %v", err)
  223. }
  224. }()
  225. doneCh := make(chan struct{})
  226. go func() {
  227. wg.Wait()
  228. close(doneCh)
  229. }()
  230. select {
  231. case <-doneCh:
  232. case <-time.After(time.Second):
  233. panic("timeout")
  234. }
  235. }
  236. func TestClose_closeTimeout(t *testing.T) {
  237. conf := testConf()
  238. conf.StreamCloseTimeout = 10 * time.Millisecond
  239. client, server := testClientServerConfig(conf)
  240. defer client.Close()
  241. defer server.Close()
  242. if client.NumStreams() != 0 {
  243. t.Fatalf("bad")
  244. }
  245. if server.NumStreams() != 0 {
  246. t.Fatalf("bad")
  247. }
  248. wg := &sync.WaitGroup{}
  249. wg.Add(2)
  250. // Open a stream on the client but only close it on the server.
  251. // We want to see if the stream ever gets cleaned up on the client.
  252. var clientStream *Stream
  253. go func() {
  254. defer wg.Done()
  255. var err error
  256. clientStream, err = client.OpenStream()
  257. if err != nil {
  258. t.Fatalf("err: %v", err)
  259. }
  260. }()
  261. go func() {
  262. defer wg.Done()
  263. stream, err := server.AcceptStream()
  264. if err != nil {
  265. t.Fatalf("err: %v", err)
  266. }
  267. if err := stream.Close(); err != nil {
  268. t.Fatalf("err: %v", err)
  269. }
  270. }()
  271. doneCh := make(chan struct{})
  272. go func() {
  273. wg.Wait()
  274. close(doneCh)
  275. }()
  276. select {
  277. case <-doneCh:
  278. case <-time.After(time.Second):
  279. panic("timeout")
  280. }
  281. // We should have zero streams after our timeout period
  282. time.Sleep(100 * time.Millisecond)
  283. if v := server.NumStreams(); v > 0 {
  284. t.Fatalf("should have zero streams: %d", v)
  285. }
  286. if v := client.NumStreams(); v > 0 {
  287. t.Fatalf("should have zero streams: %d", v)
  288. }
  289. if _, err := clientStream.Write([]byte("hello")); err == nil {
  290. t.Fatal("should error on write")
  291. } else if err.Error() != "connection reset" {
  292. t.Fatalf("expected connection reset, got %q", err)
  293. }
  294. }
  295. func TestNonNilInterface(t *testing.T) {
  296. _, server := testClientServer()
  297. server.Close()
  298. conn, err := server.Accept()
  299. if err != nil && conn != nil {
  300. t.Error("bad: accept should return a connection of nil value")
  301. }
  302. conn, err = server.Open()
  303. if err != nil && conn != nil {
  304. t.Error("bad: open should return a connection of nil value")
  305. }
  306. }
  307. func TestSendData_Small(t *testing.T) {
  308. client, server := testClientServer()
  309. defer client.Close()
  310. defer server.Close()
  311. wg := &sync.WaitGroup{}
  312. wg.Add(2)
  313. go func() {
  314. defer wg.Done()
  315. stream, err := server.AcceptStream()
  316. if err != nil {
  317. t.Fatalf("err: %v", err)
  318. }
  319. if server.NumStreams() != 1 {
  320. t.Fatalf("bad")
  321. }
  322. buf := make([]byte, 4)
  323. for i := 0; i < 1000; i++ {
  324. n, err := stream.Read(buf)
  325. if err != nil {
  326. t.Fatalf("err: %v", err)
  327. }
  328. if n != 4 {
  329. t.Fatalf("short read: %d", n)
  330. }
  331. if string(buf) != "test" {
  332. t.Fatalf("bad: %s", buf)
  333. }
  334. }
  335. if err := stream.Close(); err != nil {
  336. t.Fatalf("err: %v", err)
  337. }
  338. }()
  339. go func() {
  340. defer wg.Done()
  341. stream, err := client.Open()
  342. if err != nil {
  343. t.Fatalf("err: %v", err)
  344. }
  345. if client.NumStreams() != 1 {
  346. t.Fatalf("bad")
  347. }
  348. for i := 0; i < 1000; i++ {
  349. n, err := stream.Write([]byte("test"))
  350. if err != nil {
  351. t.Fatalf("err: %v", err)
  352. }
  353. if n != 4 {
  354. t.Fatalf("short write %d", n)
  355. }
  356. }
  357. if err := stream.Close(); err != nil {
  358. t.Fatalf("err: %v", err)
  359. }
  360. }()
  361. doneCh := make(chan struct{})
  362. go func() {
  363. wg.Wait()
  364. close(doneCh)
  365. }()
  366. select {
  367. case <-doneCh:
  368. case <-time.After(time.Second):
  369. panic("timeout")
  370. }
  371. if client.NumStreams() != 0 {
  372. t.Fatalf("bad")
  373. }
  374. if server.NumStreams() != 0 {
  375. t.Fatalf("bad")
  376. }
  377. }
  378. func TestSendData_Large(t *testing.T) {
  379. client, server := testClientServer()
  380. defer client.Close()
  381. defer server.Close()
  382. const (
  383. sendSize = 250 * 1024 * 1024
  384. recvSize = 4 * 1024
  385. )
  386. data := make([]byte, sendSize)
  387. for idx := range data {
  388. data[idx] = byte(idx % 256)
  389. }
  390. wg := &sync.WaitGroup{}
  391. wg.Add(2)
  392. go func() {
  393. defer wg.Done()
  394. stream, err := server.AcceptStream()
  395. if err != nil {
  396. t.Fatalf("err: %v", err)
  397. }
  398. var sz int
  399. buf := make([]byte, recvSize)
  400. for i := 0; i < sendSize/recvSize; i++ {
  401. n, err := stream.Read(buf)
  402. if err != nil {
  403. t.Fatalf("err: %v", err)
  404. }
  405. if n != recvSize {
  406. t.Fatalf("short read: %d", n)
  407. }
  408. sz += n
  409. for idx := range buf {
  410. if buf[idx] != byte(idx%256) {
  411. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  412. }
  413. }
  414. }
  415. if err := stream.Close(); err != nil {
  416. t.Fatalf("err: %v", err)
  417. }
  418. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  419. }()
  420. go func() {
  421. defer wg.Done()
  422. stream, err := client.Open()
  423. if err != nil {
  424. t.Fatalf("err: %v", err)
  425. }
  426. n, err := stream.Write(data)
  427. if err != nil {
  428. t.Fatalf("err: %v", err)
  429. }
  430. if n != len(data) {
  431. t.Fatalf("short write %d", n)
  432. }
  433. if err := stream.Close(); err != nil {
  434. t.Fatalf("err: %v", err)
  435. }
  436. }()
  437. doneCh := make(chan struct{})
  438. go func() {
  439. wg.Wait()
  440. close(doneCh)
  441. }()
  442. select {
  443. case <-doneCh:
  444. case <-time.After(5 * time.Second):
  445. panic("timeout")
  446. }
  447. }
  448. func TestGoAway(t *testing.T) {
  449. client, server := testClientServer()
  450. defer client.Close()
  451. defer server.Close()
  452. if err := server.GoAway(); err != nil {
  453. t.Fatalf("err: %v", err)
  454. }
  455. _, err := client.Open()
  456. if err != ErrRemoteGoAway {
  457. t.Fatalf("err: %v", err)
  458. }
  459. }
  460. func TestManyStreams(t *testing.T) {
  461. client, server := testClientServer()
  462. defer client.Close()
  463. defer server.Close()
  464. wg := &sync.WaitGroup{}
  465. acceptor := func(i int) {
  466. defer wg.Done()
  467. stream, err := server.AcceptStream()
  468. if err != nil {
  469. t.Fatalf("err: %v", err)
  470. }
  471. defer stream.Close()
  472. buf := make([]byte, 512)
  473. for {
  474. n, err := stream.Read(buf)
  475. if err == io.EOF {
  476. return
  477. }
  478. if err != nil {
  479. t.Fatalf("err: %v", err)
  480. }
  481. if n == 0 {
  482. t.Fatalf("err: %v", err)
  483. }
  484. }
  485. }
  486. sender := func(i int) {
  487. defer wg.Done()
  488. stream, err := client.Open()
  489. if err != nil {
  490. t.Fatalf("err: %v", err)
  491. }
  492. defer stream.Close()
  493. msg := fmt.Sprintf("%08d", i)
  494. for i := 0; i < 1000; i++ {
  495. n, err := stream.Write([]byte(msg))
  496. if err != nil {
  497. t.Fatalf("err: %v", err)
  498. }
  499. if n != len(msg) {
  500. t.Fatalf("short write %d", n)
  501. }
  502. }
  503. }
  504. for i := 0; i < 50; i++ {
  505. wg.Add(2)
  506. go acceptor(i)
  507. go sender(i)
  508. }
  509. wg.Wait()
  510. }
  511. func TestManyStreams_PingPong(t *testing.T) {
  512. client, server := testClientServer()
  513. defer client.Close()
  514. defer server.Close()
  515. wg := &sync.WaitGroup{}
  516. ping := []byte("ping")
  517. pong := []byte("pong")
  518. acceptor := func(i int) {
  519. defer wg.Done()
  520. stream, err := server.AcceptStream()
  521. if err != nil {
  522. t.Fatalf("err: %v", err)
  523. }
  524. defer stream.Close()
  525. buf := make([]byte, 4)
  526. for {
  527. // Read the 'ping'
  528. n, err := stream.Read(buf)
  529. if err == io.EOF {
  530. return
  531. }
  532. if err != nil {
  533. t.Fatalf("err: %v", err)
  534. }
  535. if n != 4 {
  536. t.Fatalf("err: %v", err)
  537. }
  538. if !bytes.Equal(buf, ping) {
  539. t.Fatalf("bad: %s", buf)
  540. }
  541. // Shrink the internal buffer!
  542. stream.Shrink()
  543. // Write out the 'pong'
  544. n, err = stream.Write(pong)
  545. if err != nil {
  546. t.Fatalf("err: %v", err)
  547. }
  548. if n != 4 {
  549. t.Fatalf("err: %v", err)
  550. }
  551. }
  552. }
  553. sender := func(i int) {
  554. defer wg.Done()
  555. stream, err := client.OpenStream()
  556. if err != nil {
  557. t.Fatalf("err: %v", err)
  558. }
  559. defer stream.Close()
  560. buf := make([]byte, 4)
  561. for i := 0; i < 1000; i++ {
  562. // Send the 'ping'
  563. n, err := stream.Write(ping)
  564. if err != nil {
  565. t.Fatalf("err: %v", err)
  566. }
  567. if n != 4 {
  568. t.Fatalf("short write %d", n)
  569. }
  570. // Read the 'pong'
  571. n, err = stream.Read(buf)
  572. if err != nil {
  573. t.Fatalf("err: %v", err)
  574. }
  575. if n != 4 {
  576. t.Fatalf("err: %v", err)
  577. }
  578. if !bytes.Equal(buf, pong) {
  579. t.Fatalf("bad: %s", buf)
  580. }
  581. // Shrink the buffer
  582. stream.Shrink()
  583. }
  584. }
  585. for i := 0; i < 50; i++ {
  586. wg.Add(2)
  587. go acceptor(i)
  588. go sender(i)
  589. }
  590. wg.Wait()
  591. }
  592. func TestHalfClose(t *testing.T) {
  593. client, server := testClientServer()
  594. defer client.Close()
  595. defer server.Close()
  596. stream, err := client.Open()
  597. if err != nil {
  598. t.Fatalf("err: %v", err)
  599. }
  600. if _, err = stream.Write([]byte("a")); err != nil {
  601. t.Fatalf("err: %v", err)
  602. }
  603. stream2, err := server.Accept()
  604. if err != nil {
  605. t.Fatalf("err: %v", err)
  606. }
  607. stream2.Close() // Half close
  608. buf := make([]byte, 4)
  609. n, err := stream2.Read(buf)
  610. if err != nil {
  611. t.Fatalf("err: %v", err)
  612. }
  613. if n != 1 {
  614. t.Fatalf("bad: %v", n)
  615. }
  616. // Send more
  617. if _, err = stream.Write([]byte("bcd")); err != nil {
  618. t.Fatalf("err: %v", err)
  619. }
  620. stream.Close()
  621. // Read after close
  622. n, err = stream2.Read(buf)
  623. if err != nil {
  624. t.Fatalf("err: %v", err)
  625. }
  626. if n != 3 {
  627. t.Fatalf("bad: %v", n)
  628. }
  629. // EOF after close
  630. n, err = stream2.Read(buf)
  631. if err != io.EOF {
  632. t.Fatalf("err: %v", err)
  633. }
  634. if n != 0 {
  635. t.Fatalf("bad: %v", n)
  636. }
  637. }
  638. func TestReadDeadline(t *testing.T) {
  639. client, server := testClientServer()
  640. defer client.Close()
  641. defer server.Close()
  642. stream, err := client.Open()
  643. if err != nil {
  644. t.Fatalf("err: %v", err)
  645. }
  646. defer stream.Close()
  647. stream2, err := server.Accept()
  648. if err != nil {
  649. t.Fatalf("err: %v", err)
  650. }
  651. defer stream2.Close()
  652. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  653. t.Fatalf("err: %v", err)
  654. }
  655. buf := make([]byte, 4)
  656. if _, err := stream.Read(buf); err != ErrTimeout {
  657. t.Fatalf("err: %v", err)
  658. }
  659. }
  660. func TestReadDeadline_BlockedRead(t *testing.T) {
  661. client, server := testClientServer()
  662. defer client.Close()
  663. defer server.Close()
  664. stream, err := client.Open()
  665. if err != nil {
  666. t.Fatalf("err: %v", err)
  667. }
  668. defer stream.Close()
  669. stream2, err := server.Accept()
  670. if err != nil {
  671. t.Fatalf("err: %v", err)
  672. }
  673. defer stream2.Close()
  674. // Start a read that will block
  675. errCh := make(chan error, 1)
  676. go func() {
  677. buf := make([]byte, 4)
  678. _, err := stream.Read(buf)
  679. errCh <- err
  680. close(errCh)
  681. }()
  682. // Wait to ensure the read has started.
  683. time.Sleep(5 * time.Millisecond)
  684. // Update the read deadline
  685. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  686. t.Fatalf("err: %v", err)
  687. }
  688. select {
  689. case <-time.After(100 * time.Millisecond):
  690. t.Fatal("expected read timeout")
  691. case err := <-errCh:
  692. if err != ErrTimeout {
  693. t.Fatalf("expected ErrTimeout; got %v", err)
  694. }
  695. }
  696. }
  697. func TestWriteDeadline(t *testing.T) {
  698. client, server := testClientServer()
  699. defer client.Close()
  700. defer server.Close()
  701. stream, err := client.Open()
  702. if err != nil {
  703. t.Fatalf("err: %v", err)
  704. }
  705. defer stream.Close()
  706. stream2, err := server.Accept()
  707. if err != nil {
  708. t.Fatalf("err: %v", err)
  709. }
  710. defer stream2.Close()
  711. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  712. t.Fatalf("err: %v", err)
  713. }
  714. buf := make([]byte, 512)
  715. for i := 0; i < int(initialStreamWindow); i++ {
  716. _, err := stream.Write(buf)
  717. if err != nil && err == ErrTimeout {
  718. return
  719. } else if err != nil {
  720. t.Fatalf("err: %v", err)
  721. }
  722. }
  723. t.Fatalf("Expected timeout")
  724. }
  725. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  726. client, server := testClientServer()
  727. defer client.Close()
  728. defer server.Close()
  729. stream, err := client.Open()
  730. if err != nil {
  731. t.Fatalf("err: %v", err)
  732. }
  733. defer stream.Close()
  734. stream2, err := server.Accept()
  735. if err != nil {
  736. t.Fatalf("err: %v", err)
  737. }
  738. defer stream2.Close()
  739. // Start a goroutine making writes that will block
  740. errCh := make(chan error, 1)
  741. go func() {
  742. buf := make([]byte, 512)
  743. for i := 0; i < int(initialStreamWindow); i++ {
  744. _, err := stream.Write(buf)
  745. if err == nil {
  746. continue
  747. }
  748. errCh <- err
  749. close(errCh)
  750. return
  751. }
  752. close(errCh)
  753. }()
  754. // Wait to ensure the write has started.
  755. time.Sleep(5 * time.Millisecond)
  756. // Update the write deadline
  757. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  758. t.Fatalf("err: %v", err)
  759. }
  760. select {
  761. case <-time.After(1 * time.Second):
  762. t.Fatal("expected write timeout")
  763. case err := <-errCh:
  764. if err != ErrTimeout {
  765. t.Fatalf("expected ErrTimeout; got %v", err)
  766. }
  767. }
  768. }
  769. func TestBacklogExceeded(t *testing.T) {
  770. client, server := testClientServer()
  771. defer client.Close()
  772. defer server.Close()
  773. // Fill the backlog
  774. max := client.config.AcceptBacklog
  775. for i := 0; i < max; i++ {
  776. stream, err := client.Open()
  777. if err != nil {
  778. t.Fatalf("err: %v", err)
  779. }
  780. defer stream.Close()
  781. if _, err := stream.Write([]byte("foo")); err != nil {
  782. t.Fatalf("err: %v", err)
  783. }
  784. }
  785. // Attempt to open a new stream
  786. errCh := make(chan error, 1)
  787. go func() {
  788. _, err := client.Open()
  789. errCh <- err
  790. }()
  791. // Shutdown the server
  792. go func() {
  793. time.Sleep(10 * time.Millisecond)
  794. server.Close()
  795. }()
  796. select {
  797. case err := <-errCh:
  798. if err == nil {
  799. t.Fatalf("open should fail")
  800. }
  801. case <-time.After(time.Second):
  802. t.Fatalf("timeout")
  803. }
  804. }
  805. func TestKeepAlive(t *testing.T) {
  806. client, server := testClientServer()
  807. defer client.Close()
  808. defer server.Close()
  809. time.Sleep(200 * time.Millisecond)
  810. // Ping value should increase
  811. client.pingLock.Lock()
  812. defer client.pingLock.Unlock()
  813. if client.pingID == 0 {
  814. t.Fatalf("should ping")
  815. }
  816. server.pingLock.Lock()
  817. defer server.pingLock.Unlock()
  818. if server.pingID == 0 {
  819. t.Fatalf("should ping")
  820. }
  821. }
  822. func TestKeepAlive_Timeout(t *testing.T) {
  823. conn1, conn2 := testConn()
  824. clientConf := testConf()
  825. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  826. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  827. client, _ := Client(conn1, clientConf)
  828. defer client.Close()
  829. server, _ := Server(conn2, testConf())
  830. defer server.Close()
  831. _ = captureLogs(client) // Client logs aren't part of the test
  832. serverLogs := captureLogs(server)
  833. errCh := make(chan error, 1)
  834. go func() {
  835. _, err := server.Accept() // Wait until server closes
  836. errCh <- err
  837. }()
  838. // Prevent the client from responding
  839. clientConn := client.conn.(*pipeConn)
  840. clientConn.writeBlocker.Lock()
  841. select {
  842. case err := <-errCh:
  843. if err != ErrKeepAliveTimeout {
  844. t.Fatalf("unexpected error: %v", err)
  845. }
  846. case <-time.After(1 * time.Second):
  847. t.Fatalf("timeout waiting for timeout")
  848. }
  849. if !server.IsClosed() {
  850. t.Fatalf("server should have closed")
  851. }
  852. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  853. t.Fatalf("server log incorect: %v", serverLogs.logs())
  854. }
  855. }
  856. func TestLargeWindow(t *testing.T) {
  857. conf := DefaultConfig()
  858. conf.MaxStreamWindowSize *= 2
  859. client, server := testClientServerConfig(conf)
  860. defer client.Close()
  861. defer server.Close()
  862. stream, err := client.Open()
  863. if err != nil {
  864. t.Fatalf("err: %v", err)
  865. }
  866. defer stream.Close()
  867. stream2, err := server.Accept()
  868. if err != nil {
  869. t.Fatalf("err: %v", err)
  870. }
  871. defer stream2.Close()
  872. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  873. buf := make([]byte, conf.MaxStreamWindowSize)
  874. n, err := stream.Write(buf)
  875. if err != nil {
  876. t.Fatalf("err: %v", err)
  877. }
  878. if n != len(buf) {
  879. t.Fatalf("short write: %d", n)
  880. }
  881. }
  882. type UnlimitedReader struct{}
  883. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  884. runtime.Gosched()
  885. return len(p), nil
  886. }
  887. func TestSendData_VeryLarge(t *testing.T) {
  888. client, server := testClientServer()
  889. defer client.Close()
  890. defer server.Close()
  891. var n int64 = 1 * 1024 * 1024 * 1024
  892. var workers int = 16
  893. wg := &sync.WaitGroup{}
  894. wg.Add(workers * 2)
  895. for i := 0; i < workers; i++ {
  896. go func() {
  897. defer wg.Done()
  898. stream, err := server.AcceptStream()
  899. if err != nil {
  900. t.Fatalf("err: %v", err)
  901. }
  902. defer stream.Close()
  903. buf := make([]byte, 4)
  904. _, err = stream.Read(buf)
  905. if err != nil {
  906. t.Fatalf("err: %v", err)
  907. }
  908. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  909. t.Fatalf("bad header")
  910. }
  911. recv, err := io.Copy(ioutil.Discard, stream)
  912. if err != nil {
  913. t.Fatalf("err: %v", err)
  914. }
  915. if recv != n {
  916. t.Fatalf("bad: %v", recv)
  917. }
  918. }()
  919. }
  920. for i := 0; i < workers; i++ {
  921. go func() {
  922. defer wg.Done()
  923. stream, err := client.Open()
  924. if err != nil {
  925. t.Fatalf("err: %v", err)
  926. }
  927. defer stream.Close()
  928. _, err = stream.Write([]byte{0, 1, 2, 3})
  929. if err != nil {
  930. t.Fatalf("err: %v", err)
  931. }
  932. unlimited := &UnlimitedReader{}
  933. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  934. if err != nil {
  935. t.Fatalf("err: %v", err)
  936. }
  937. if sent != n {
  938. t.Fatalf("bad: %v", sent)
  939. }
  940. }()
  941. }
  942. doneCh := make(chan struct{})
  943. go func() {
  944. wg.Wait()
  945. close(doneCh)
  946. }()
  947. select {
  948. case <-doneCh:
  949. case <-time.After(20 * time.Second):
  950. panic("timeout")
  951. }
  952. }
  953. func TestBacklogExceeded_Accept(t *testing.T) {
  954. client, server := testClientServer()
  955. defer client.Close()
  956. defer server.Close()
  957. max := 5 * client.config.AcceptBacklog
  958. go func() {
  959. for i := 0; i < max; i++ {
  960. stream, err := server.Accept()
  961. if err != nil {
  962. t.Fatalf("err: %v", err)
  963. }
  964. defer stream.Close()
  965. }
  966. }()
  967. // Fill the backlog
  968. for i := 0; i < max; i++ {
  969. stream, err := client.Open()
  970. if err != nil {
  971. t.Fatalf("err: %v", err)
  972. }
  973. defer stream.Close()
  974. if _, err := stream.Write([]byte("foo")); err != nil {
  975. t.Fatalf("err: %v", err)
  976. }
  977. }
  978. }
  979. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  980. client, server := testClientServerConfig(testConfNoKeepAlive())
  981. defer client.Close()
  982. defer server.Close()
  983. var wg sync.WaitGroup
  984. wg.Add(2)
  985. // Choose a huge flood size that we know will result in a window update.
  986. flood := int64(client.config.MaxStreamWindowSize) - 1
  987. // The server will accept a new stream and then flood data to it.
  988. go func() {
  989. defer wg.Done()
  990. stream, err := server.AcceptStream()
  991. if err != nil {
  992. t.Fatalf("err: %v", err)
  993. }
  994. defer stream.Close()
  995. n, err := stream.Write(make([]byte, flood))
  996. if err != nil {
  997. t.Fatalf("err: %v", err)
  998. }
  999. if int64(n) != flood {
  1000. t.Fatalf("short write: %d", n)
  1001. }
  1002. }()
  1003. // The client will open a stream, block outbound writes, and then
  1004. // listen to the flood from the server, which should time out since
  1005. // it won't be able to send the window update.
  1006. go func() {
  1007. defer wg.Done()
  1008. stream, err := client.OpenStream()
  1009. if err != nil {
  1010. t.Fatalf("err: %v", err)
  1011. }
  1012. defer stream.Close()
  1013. conn := client.conn.(*pipeConn)
  1014. conn.writeBlocker.Lock()
  1015. _, err = stream.Read(make([]byte, flood))
  1016. if err != ErrConnectionWriteTimeout {
  1017. t.Fatalf("err: %v", err)
  1018. }
  1019. }()
  1020. wg.Wait()
  1021. }
  1022. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  1023. client, server := testClientServerConfig(testConfNoKeepAlive())
  1024. defer client.Close()
  1025. defer server.Close()
  1026. var wg sync.WaitGroup
  1027. wg.Add(1)
  1028. // Choose a huge flood size that we know will result in a window update.
  1029. flood := int64(client.config.MaxStreamWindowSize)
  1030. var wr *Stream
  1031. // The server will accept a new stream and then flood data to it.
  1032. go func() {
  1033. defer wg.Done()
  1034. var err error
  1035. wr, err = server.AcceptStream()
  1036. if err != nil {
  1037. t.Fatalf("err: %v", err)
  1038. }
  1039. defer wr.Close()
  1040. if wr.sendWindow != client.config.MaxStreamWindowSize {
  1041. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  1042. }
  1043. n, err := wr.Write(make([]byte, flood))
  1044. if err != nil {
  1045. t.Fatalf("err: %v", err)
  1046. }
  1047. if int64(n) != flood {
  1048. t.Fatalf("short write: %d", n)
  1049. }
  1050. if wr.sendWindow != 0 {
  1051. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  1052. }
  1053. }()
  1054. stream, err := client.OpenStream()
  1055. if err != nil {
  1056. t.Fatalf("err: %v", err)
  1057. }
  1058. defer stream.Close()
  1059. wg.Wait()
  1060. _, err = stream.Read(make([]byte, flood/2+1))
  1061. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1062. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1063. }
  1064. }
  1065. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1066. client, server := testClientServerConfig(testConfNoKeepAlive())
  1067. defer client.Close()
  1068. defer server.Close()
  1069. var wg sync.WaitGroup
  1070. wg.Add(2)
  1071. go func() {
  1072. defer wg.Done()
  1073. stream, err := server.AcceptStream()
  1074. if err != nil {
  1075. t.Fatalf("err: %v", err)
  1076. }
  1077. defer stream.Close()
  1078. }()
  1079. // The client will open the stream and then block outbound writes, we'll
  1080. // probe sendNoWait once it gets into that state.
  1081. go func() {
  1082. defer wg.Done()
  1083. stream, err := client.OpenStream()
  1084. if err != nil {
  1085. t.Fatalf("err: %v", err)
  1086. }
  1087. defer stream.Close()
  1088. conn := client.conn.(*pipeConn)
  1089. conn.writeBlocker.Lock()
  1090. hdr := header(make([]byte, headerSize))
  1091. hdr.encode(typePing, flagACK, 0, 0)
  1092. for {
  1093. err = client.sendNoWait(hdr)
  1094. if err == nil {
  1095. continue
  1096. } else if err == ErrConnectionWriteTimeout {
  1097. break
  1098. } else {
  1099. t.Fatalf("err: %v", err)
  1100. }
  1101. }
  1102. }()
  1103. wg.Wait()
  1104. }
  1105. func TestSession_PingOfDeath(t *testing.T) {
  1106. client, server := testClientServerConfig(testConfNoKeepAlive())
  1107. defer client.Close()
  1108. defer server.Close()
  1109. var wg sync.WaitGroup
  1110. wg.Add(2)
  1111. var doPingOfDeath sync.Mutex
  1112. doPingOfDeath.Lock()
  1113. // This is used later to block outbound writes.
  1114. conn := server.conn.(*pipeConn)
  1115. // The server will accept a stream, block outbound writes, and then
  1116. // flood its send channel so that no more headers can be queued.
  1117. go func() {
  1118. defer wg.Done()
  1119. stream, err := server.AcceptStream()
  1120. if err != nil {
  1121. t.Fatalf("err: %v", err)
  1122. }
  1123. defer stream.Close()
  1124. conn.writeBlocker.Lock()
  1125. for {
  1126. hdr := header(make([]byte, headerSize))
  1127. hdr.encode(typePing, 0, 0, 0)
  1128. err = server.sendNoWait(hdr)
  1129. if err == nil {
  1130. continue
  1131. } else if err == ErrConnectionWriteTimeout {
  1132. break
  1133. } else {
  1134. t.Fatalf("err: %v", err)
  1135. }
  1136. }
  1137. doPingOfDeath.Unlock()
  1138. }()
  1139. // The client will open a stream and then send the server a ping once it
  1140. // can no longer write. This makes sure the server doesn't deadlock reads
  1141. // while trying to reply to the ping with no ability to write.
  1142. go func() {
  1143. defer wg.Done()
  1144. stream, err := client.OpenStream()
  1145. if err != nil {
  1146. t.Fatalf("err: %v", err)
  1147. }
  1148. defer stream.Close()
  1149. // This ping will never unblock because the ping id will never
  1150. // show up in a response.
  1151. doPingOfDeath.Lock()
  1152. go func() { client.Ping() }()
  1153. // Wait for a while to make sure the previous ping times out,
  1154. // then turn writes back on and make sure a ping works again.
  1155. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1156. conn.writeBlocker.Unlock()
  1157. if _, err = client.Ping(); err != nil {
  1158. t.Fatalf("err: %v", err)
  1159. }
  1160. }()
  1161. wg.Wait()
  1162. }
  1163. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1164. client, server := testClientServerConfig(testConfNoKeepAlive())
  1165. defer client.Close()
  1166. defer server.Close()
  1167. var wg sync.WaitGroup
  1168. wg.Add(2)
  1169. go func() {
  1170. defer wg.Done()
  1171. stream, err := server.AcceptStream()
  1172. if err != nil {
  1173. t.Fatalf("err: %v", err)
  1174. }
  1175. defer stream.Close()
  1176. }()
  1177. // The client will open the stream and then block outbound writes, we'll
  1178. // tee up a write and make sure it eventually times out.
  1179. go func() {
  1180. defer wg.Done()
  1181. stream, err := client.OpenStream()
  1182. if err != nil {
  1183. t.Fatalf("err: %v", err)
  1184. }
  1185. defer stream.Close()
  1186. conn := client.conn.(*pipeConn)
  1187. conn.writeBlocker.Lock()
  1188. // Since the write goroutine is blocked then this will return a
  1189. // timeout since it can't get feedback about whether the write
  1190. // worked.
  1191. n, err := stream.Write([]byte("hello"))
  1192. if err != ErrConnectionWriteTimeout {
  1193. t.Fatalf("err: %v", err)
  1194. }
  1195. if n != 0 {
  1196. t.Fatalf("lied about writes: %d", n)
  1197. }
  1198. }()
  1199. wg.Wait()
  1200. }