session_test.go 30 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "net"
  9. "reflect"
  10. "runtime"
  11. "strings"
  12. "sync"
  13. "testing"
  14. "time"
  15. )
  16. type logCapture struct{ bytes.Buffer }
  17. func (l *logCapture) logs() []string {
  18. return strings.Split(strings.TrimSpace(l.String()), "\n")
  19. }
  20. func (l *logCapture) match(expect []string) bool {
  21. return reflect.DeepEqual(l.logs(), expect)
  22. }
  23. func captureLogs(s *Session) *logCapture {
  24. buf := new(logCapture)
  25. s.logger = log.New(buf, "", 0)
  26. return buf
  27. }
  28. type pipeConn struct {
  29. reader *io.PipeReader
  30. writer *io.PipeWriter
  31. writeBlocker sync.Mutex
  32. }
  33. func (p *pipeConn) Read(b []byte) (int, error) {
  34. return p.reader.Read(b)
  35. }
  36. func (p *pipeConn) Write(b []byte) (int, error) {
  37. p.writeBlocker.Lock()
  38. defer p.writeBlocker.Unlock()
  39. return p.writer.Write(b)
  40. }
  41. func (p *pipeConn) Close() error {
  42. p.reader.Close()
  43. return p.writer.Close()
  44. }
  45. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  46. read1, write1 := io.Pipe()
  47. read2, write2 := io.Pipe()
  48. conn1 := &pipeConn{reader: read1, writer: write2}
  49. conn2 := &pipeConn{reader: read2, writer: write1}
  50. return conn1, conn2
  51. }
  52. func testConf() *Config {
  53. conf := DefaultConfig()
  54. conf.AcceptBacklog = 64
  55. conf.KeepAliveInterval = 100 * time.Millisecond
  56. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  57. return conf
  58. }
  59. func testConfNoKeepAlive() *Config {
  60. conf := testConf()
  61. conf.EnableKeepAlive = false
  62. return conf
  63. }
  64. func testClientServer() (*Session, *Session) {
  65. return testClientServerConfig(testConf())
  66. }
  67. func testClientServerConfig(conf *Config) (*Session, *Session) {
  68. conn1, conn2 := testConn()
  69. client, _ := Client(conn1, conf)
  70. server, _ := Server(conn2, conf)
  71. return client, server
  72. }
  73. func TestPing(t *testing.T) {
  74. client, server := testClientServer()
  75. defer client.Close()
  76. defer server.Close()
  77. rtt, err := client.Ping()
  78. if err != nil {
  79. t.Fatalf("err: %v", err)
  80. }
  81. if rtt == 0 {
  82. t.Fatalf("bad: %v", rtt)
  83. }
  84. rtt, err = server.Ping()
  85. if err != nil {
  86. t.Fatalf("err: %v", err)
  87. }
  88. if rtt == 0 {
  89. t.Fatalf("bad: %v", rtt)
  90. }
  91. }
  92. func TestPing_Timeout(t *testing.T) {
  93. client, server := testClientServerConfig(testConfNoKeepAlive())
  94. defer client.Close()
  95. defer server.Close()
  96. // Prevent the client from responding
  97. clientConn := client.conn.(*pipeConn)
  98. clientConn.writeBlocker.Lock()
  99. errCh := make(chan error, 1)
  100. go func() {
  101. _, err := server.Ping() // Ping via the server session
  102. errCh <- err
  103. }()
  104. select {
  105. case err := <-errCh:
  106. if err != ErrTimeout {
  107. t.Fatalf("err: %v", err)
  108. }
  109. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  110. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  111. }
  112. // Verify that we recover, even if we gave up
  113. clientConn.writeBlocker.Unlock()
  114. go func() {
  115. _, err := server.Ping() // Ping via the server session
  116. errCh <- err
  117. }()
  118. select {
  119. case err := <-errCh:
  120. if err != nil {
  121. t.Fatalf("err: %v", err)
  122. }
  123. case <-time.After(client.config.ConnectionWriteTimeout):
  124. t.Fatalf("timeout")
  125. }
  126. }
  127. func TestCloseBeforeAck(t *testing.T) {
  128. cfg := testConf()
  129. cfg.AcceptBacklog = 8
  130. client, server := testClientServerConfig(cfg)
  131. defer client.Close()
  132. defer server.Close()
  133. for i := 0; i < 8; i++ {
  134. s, err := client.OpenStream()
  135. if err != nil {
  136. t.Fatal(err)
  137. }
  138. s.Close()
  139. }
  140. for i := 0; i < 8; i++ {
  141. s, err := server.AcceptStream()
  142. if err != nil {
  143. t.Fatal(err)
  144. }
  145. s.Close()
  146. }
  147. done := make(chan struct{})
  148. go func() {
  149. defer close(done)
  150. s, err := client.OpenStream()
  151. if err != nil {
  152. t.Fatal(err)
  153. }
  154. s.Close()
  155. }()
  156. select {
  157. case <-done:
  158. case <-time.After(time.Second * 5):
  159. t.Fatal("timed out trying to open stream")
  160. }
  161. }
  162. func TestAccept(t *testing.T) {
  163. client, server := testClientServer()
  164. defer client.Close()
  165. defer server.Close()
  166. if client.NumStreams() != 0 {
  167. t.Fatalf("bad")
  168. }
  169. if server.NumStreams() != 0 {
  170. t.Fatalf("bad")
  171. }
  172. wg := &sync.WaitGroup{}
  173. wg.Add(4)
  174. go func() {
  175. defer wg.Done()
  176. stream, err := server.AcceptStream()
  177. if err != nil {
  178. t.Fatalf("err: %v", err)
  179. }
  180. if id := stream.StreamID(); id != 1 {
  181. t.Fatalf("bad: %v", id)
  182. }
  183. if err := stream.Close(); err != nil {
  184. t.Fatalf("err: %v", err)
  185. }
  186. }()
  187. go func() {
  188. defer wg.Done()
  189. stream, err := client.AcceptStream()
  190. if err != nil {
  191. t.Fatalf("err: %v", err)
  192. }
  193. if id := stream.StreamID(); id != 2 {
  194. t.Fatalf("bad: %v", id)
  195. }
  196. if err := stream.Close(); err != nil {
  197. t.Fatalf("err: %v", err)
  198. }
  199. }()
  200. go func() {
  201. defer wg.Done()
  202. stream, err := server.OpenStream()
  203. if err != nil {
  204. t.Fatalf("err: %v", err)
  205. }
  206. if id := stream.StreamID(); id != 2 {
  207. t.Fatalf("bad: %v", id)
  208. }
  209. if err := stream.Close(); err != nil {
  210. t.Fatalf("err: %v", err)
  211. }
  212. }()
  213. go func() {
  214. defer wg.Done()
  215. stream, err := client.OpenStream()
  216. if err != nil {
  217. t.Fatalf("err: %v", err)
  218. }
  219. if id := stream.StreamID(); id != 1 {
  220. t.Fatalf("bad: %v", id)
  221. }
  222. if err := stream.Close(); err != nil {
  223. t.Fatalf("err: %v", err)
  224. }
  225. }()
  226. doneCh := make(chan struct{})
  227. go func() {
  228. wg.Wait()
  229. close(doneCh)
  230. }()
  231. select {
  232. case <-doneCh:
  233. case <-time.After(time.Second):
  234. panic("timeout")
  235. }
  236. }
  237. func TestOpenStreamTimeout(t *testing.T) {
  238. const timeout = 25 * time.Millisecond
  239. cfg := testConf()
  240. cfg.StreamOpenTimeout = timeout
  241. client, server := testClientServerConfig(cfg)
  242. defer client.Close()
  243. defer server.Close()
  244. clientLogs := captureLogs(client)
  245. // Open a single stream without a server to acknowledge it.
  246. s, err := client.OpenStream()
  247. if err != nil {
  248. t.Fatal(err)
  249. }
  250. // Sleep for longer than the stream open timeout.
  251. // Since no ACKs are received, the stream and session should be closed.
  252. time.Sleep(timeout * 5)
  253. if !clientLogs.match([]string{"[ERR] yamux: aborted stream open (destination=yamux:remote): i/o deadline reached"}) {
  254. t.Fatalf("server log incorect: %v", clientLogs.logs())
  255. }
  256. if s.state != streamClosed {
  257. t.Fatalf("stream should have been closed")
  258. }
  259. if !client.IsClosed() {
  260. t.Fatalf("session should have been closed")
  261. }
  262. }
  263. func TestClose_closeTimeout(t *testing.T) {
  264. conf := testConf()
  265. conf.StreamCloseTimeout = 10 * time.Millisecond
  266. client, server := testClientServerConfig(conf)
  267. defer client.Close()
  268. defer server.Close()
  269. if client.NumStreams() != 0 {
  270. t.Fatalf("bad")
  271. }
  272. if server.NumStreams() != 0 {
  273. t.Fatalf("bad")
  274. }
  275. wg := &sync.WaitGroup{}
  276. wg.Add(2)
  277. // Open a stream on the client but only close it on the server.
  278. // We want to see if the stream ever gets cleaned up on the client.
  279. var clientStream *Stream
  280. go func() {
  281. defer wg.Done()
  282. var err error
  283. clientStream, err = client.OpenStream()
  284. if err != nil {
  285. t.Fatalf("err: %v", err)
  286. }
  287. }()
  288. go func() {
  289. defer wg.Done()
  290. stream, err := server.AcceptStream()
  291. if err != nil {
  292. t.Fatalf("err: %v", err)
  293. }
  294. if err := stream.Close(); err != nil {
  295. t.Fatalf("err: %v", err)
  296. }
  297. }()
  298. doneCh := make(chan struct{})
  299. go func() {
  300. wg.Wait()
  301. close(doneCh)
  302. }()
  303. select {
  304. case <-doneCh:
  305. case <-time.After(time.Second):
  306. panic("timeout")
  307. }
  308. // We should have zero streams after our timeout period
  309. time.Sleep(100 * time.Millisecond)
  310. if v := server.NumStreams(); v > 0 {
  311. t.Fatalf("should have zero streams: %d", v)
  312. }
  313. if v := client.NumStreams(); v > 0 {
  314. t.Fatalf("should have zero streams: %d", v)
  315. }
  316. if _, err := clientStream.Write([]byte("hello")); err == nil {
  317. t.Fatal("should error on write")
  318. } else if err.Error() != "connection reset" {
  319. t.Fatalf("expected connection reset, got %q", err)
  320. }
  321. }
  322. func TestNonNilInterface(t *testing.T) {
  323. _, server := testClientServer()
  324. server.Close()
  325. conn, err := server.Accept()
  326. if err != nil && conn != nil {
  327. t.Error("bad: accept should return a connection of nil value")
  328. }
  329. conn, err = server.Open()
  330. if err != nil && conn != nil {
  331. t.Error("bad: open should return a connection of nil value")
  332. }
  333. }
  334. func TestSendData_Small(t *testing.T) {
  335. client, server := testClientServer()
  336. defer client.Close()
  337. defer server.Close()
  338. wg := &sync.WaitGroup{}
  339. wg.Add(2)
  340. go func() {
  341. defer wg.Done()
  342. stream, err := server.AcceptStream()
  343. if err != nil {
  344. t.Fatalf("err: %v", err)
  345. }
  346. if server.NumStreams() != 1 {
  347. t.Fatalf("bad")
  348. }
  349. buf := make([]byte, 4)
  350. for i := 0; i < 1000; i++ {
  351. n, err := stream.Read(buf)
  352. if err != nil {
  353. t.Fatalf("err: %v", err)
  354. }
  355. if n != 4 {
  356. t.Fatalf("short read: %d", n)
  357. }
  358. if string(buf) != "test" {
  359. t.Fatalf("bad: %s", buf)
  360. }
  361. }
  362. if err := stream.Close(); err != nil {
  363. t.Fatalf("err: %v", err)
  364. }
  365. }()
  366. go func() {
  367. defer wg.Done()
  368. stream, err := client.Open()
  369. if err != nil {
  370. t.Fatalf("err: %v", err)
  371. }
  372. if client.NumStreams() != 1 {
  373. t.Fatalf("bad")
  374. }
  375. for i := 0; i < 1000; i++ {
  376. n, err := stream.Write([]byte("test"))
  377. if err != nil {
  378. t.Fatalf("err: %v", err)
  379. }
  380. if n != 4 {
  381. t.Fatalf("short write %d", n)
  382. }
  383. }
  384. if err := stream.Close(); err != nil {
  385. t.Fatalf("err: %v", err)
  386. }
  387. }()
  388. doneCh := make(chan struct{})
  389. go func() {
  390. wg.Wait()
  391. close(doneCh)
  392. }()
  393. select {
  394. case <-doneCh:
  395. if client.NumStreams() != 0 {
  396. t.Fatalf("bad")
  397. }
  398. if server.NumStreams() != 0 {
  399. t.Fatalf("bad")
  400. }
  401. return
  402. case <-time.After(time.Second):
  403. panic("timeout")
  404. }
  405. }
  406. func TestSendData_Large(t *testing.T) {
  407. client, server := testClientServer()
  408. defer client.Close()
  409. defer server.Close()
  410. const (
  411. sendSize = 250 * 1024 * 1024
  412. recvSize = 4 * 1024
  413. )
  414. data := make([]byte, sendSize)
  415. for idx := range data {
  416. data[idx] = byte(idx % 256)
  417. }
  418. wg := &sync.WaitGroup{}
  419. wg.Add(2)
  420. go func() {
  421. defer wg.Done()
  422. stream, err := server.AcceptStream()
  423. if err != nil {
  424. t.Fatalf("err: %v", err)
  425. }
  426. var sz int
  427. buf := make([]byte, recvSize)
  428. for i := 0; i < sendSize/recvSize; i++ {
  429. n, err := stream.Read(buf)
  430. if err != nil {
  431. t.Fatalf("err: %v", err)
  432. }
  433. if n != recvSize {
  434. t.Fatalf("short read: %d", n)
  435. }
  436. sz += n
  437. for idx := range buf {
  438. if buf[idx] != byte(idx%256) {
  439. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  440. }
  441. }
  442. }
  443. if err := stream.Close(); err != nil {
  444. t.Fatalf("err: %v", err)
  445. }
  446. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  447. }()
  448. go func() {
  449. defer wg.Done()
  450. stream, err := client.Open()
  451. if err != nil {
  452. t.Fatalf("err: %v", err)
  453. }
  454. n, err := stream.Write(data)
  455. if err != nil {
  456. t.Fatalf("err: %v", err)
  457. }
  458. if n != len(data) {
  459. t.Fatalf("short write %d", n)
  460. }
  461. if err := stream.Close(); err != nil {
  462. t.Fatalf("err: %v", err)
  463. }
  464. }()
  465. doneCh := make(chan struct{})
  466. go func() {
  467. wg.Wait()
  468. close(doneCh)
  469. }()
  470. select {
  471. case <-doneCh:
  472. return
  473. case <-time.After(5 * time.Second):
  474. panic("timeout")
  475. }
  476. }
  477. func TestGoAway(t *testing.T) {
  478. client, server := testClientServer()
  479. defer client.Close()
  480. defer server.Close()
  481. if err := server.GoAway(); err != nil {
  482. t.Fatalf("err: %v", err)
  483. }
  484. _, err := client.Open()
  485. if err != ErrRemoteGoAway {
  486. t.Fatalf("err: %v", err)
  487. }
  488. }
  489. func TestManyStreams(t *testing.T) {
  490. client, server := testClientServer()
  491. defer client.Close()
  492. defer server.Close()
  493. wg := &sync.WaitGroup{}
  494. acceptor := func(i int) {
  495. defer wg.Done()
  496. stream, err := server.AcceptStream()
  497. if err != nil {
  498. t.Fatalf("err: %v", err)
  499. }
  500. defer stream.Close()
  501. buf := make([]byte, 512)
  502. for {
  503. n, err := stream.Read(buf)
  504. if err == io.EOF {
  505. return
  506. }
  507. if err != nil {
  508. t.Fatalf("err: %v", err)
  509. }
  510. if n == 0 {
  511. t.Fatalf("err: %v", err)
  512. }
  513. }
  514. }
  515. sender := func(i int) {
  516. defer wg.Done()
  517. stream, err := client.Open()
  518. if err != nil {
  519. t.Fatalf("err: %v", err)
  520. }
  521. defer stream.Close()
  522. msg := fmt.Sprintf("%08d", i)
  523. for i := 0; i < 1000; i++ {
  524. n, err := stream.Write([]byte(msg))
  525. if err != nil {
  526. t.Fatalf("err: %v", err)
  527. }
  528. if n != len(msg) {
  529. t.Fatalf("short write %d", n)
  530. }
  531. }
  532. }
  533. for i := 0; i < 50; i++ {
  534. wg.Add(2)
  535. go acceptor(i)
  536. go sender(i)
  537. }
  538. wg.Wait()
  539. }
  540. func TestManyStreams_PingPong(t *testing.T) {
  541. client, server := testClientServer()
  542. defer client.Close()
  543. defer server.Close()
  544. wg := &sync.WaitGroup{}
  545. ping := []byte("ping")
  546. pong := []byte("pong")
  547. acceptor := func(i int) {
  548. defer wg.Done()
  549. stream, err := server.AcceptStream()
  550. if err != nil {
  551. t.Fatalf("err: %v", err)
  552. }
  553. defer stream.Close()
  554. buf := make([]byte, 4)
  555. for {
  556. // Read the 'ping'
  557. n, err := stream.Read(buf)
  558. if err == io.EOF {
  559. return
  560. }
  561. if err != nil {
  562. t.Fatalf("err: %v", err)
  563. }
  564. if n != 4 {
  565. t.Fatalf("err: %v", err)
  566. }
  567. if !bytes.Equal(buf, ping) {
  568. t.Fatalf("bad: %s", buf)
  569. }
  570. // Shrink the internal buffer!
  571. stream.Shrink()
  572. // Write out the 'pong'
  573. n, err = stream.Write(pong)
  574. if err != nil {
  575. t.Fatalf("err: %v", err)
  576. }
  577. if n != 4 {
  578. t.Fatalf("err: %v", err)
  579. }
  580. }
  581. }
  582. sender := func(i int) {
  583. defer wg.Done()
  584. stream, err := client.OpenStream()
  585. if err != nil {
  586. t.Fatalf("err: %v", err)
  587. }
  588. defer stream.Close()
  589. buf := make([]byte, 4)
  590. for i := 0; i < 1000; i++ {
  591. // Send the 'ping'
  592. n, err := stream.Write(ping)
  593. if err != nil {
  594. t.Fatalf("err: %v", err)
  595. }
  596. if n != 4 {
  597. t.Fatalf("short write %d", n)
  598. }
  599. // Read the 'pong'
  600. n, err = stream.Read(buf)
  601. if err != nil {
  602. t.Fatalf("err: %v", err)
  603. }
  604. if n != 4 {
  605. t.Fatalf("err: %v", err)
  606. }
  607. if !bytes.Equal(buf, pong) {
  608. t.Fatalf("bad: %s", buf)
  609. }
  610. // Shrink the buffer
  611. stream.Shrink()
  612. }
  613. }
  614. for i := 0; i < 50; i++ {
  615. wg.Add(2)
  616. go acceptor(i)
  617. go sender(i)
  618. }
  619. wg.Wait()
  620. }
  621. func TestHalfClose(t *testing.T) {
  622. client, server := testClientServer()
  623. defer client.Close()
  624. defer server.Close()
  625. stream, err := client.Open()
  626. if err != nil {
  627. t.Fatalf("err: %v", err)
  628. }
  629. if _, err = stream.Write([]byte("a")); err != nil {
  630. t.Fatalf("err: %v", err)
  631. }
  632. stream2, err := server.Accept()
  633. if err != nil {
  634. t.Fatalf("err: %v", err)
  635. }
  636. stream2.Close() // Half close
  637. buf := make([]byte, 4)
  638. n, err := stream2.Read(buf)
  639. if err != nil {
  640. t.Fatalf("err: %v", err)
  641. }
  642. if n != 1 {
  643. t.Fatalf("bad: %v", n)
  644. }
  645. // Send more
  646. if _, err = stream.Write([]byte("bcd")); err != nil {
  647. t.Fatalf("err: %v", err)
  648. }
  649. stream.Close()
  650. // Read after close
  651. n, err = stream2.Read(buf)
  652. if err != nil {
  653. t.Fatalf("err: %v", err)
  654. }
  655. if n != 3 {
  656. t.Fatalf("bad: %v", n)
  657. }
  658. // EOF after close
  659. n, err = stream2.Read(buf)
  660. if err != io.EOF {
  661. t.Fatalf("err: %v", err)
  662. }
  663. if n != 0 {
  664. t.Fatalf("bad: %v", n)
  665. }
  666. }
  667. func TestHalfCloseSessionShutdown(t *testing.T) {
  668. client, server := testClientServer()
  669. defer client.Close()
  670. defer server.Close()
  671. // dataSize must be large enough to ensure the server will send a window
  672. // update
  673. dataSize := int64(server.config.MaxStreamWindowSize)
  674. data := make([]byte, dataSize)
  675. for idx := range data {
  676. data[idx] = byte(idx % 256)
  677. }
  678. stream, err := client.Open()
  679. if err != nil {
  680. t.Fatalf("err: %v", err)
  681. }
  682. if _, err = stream.Write(data); err != nil {
  683. t.Fatalf("err: %v", err)
  684. }
  685. stream2, err := server.Accept()
  686. if err != nil {
  687. t.Fatalf("err: %v", err)
  688. }
  689. if err := stream.Close(); err != nil {
  690. t.Fatalf("err: %v", err)
  691. }
  692. // Shut down the session of the sending side. This should not cause reads
  693. // to fail on the receiving side.
  694. if err := client.Close(); err != nil {
  695. t.Fatalf("err: %v", err)
  696. }
  697. buf := make([]byte, dataSize)
  698. n, err := stream2.Read(buf)
  699. if err != nil {
  700. t.Fatalf("err: %v", err)
  701. }
  702. if int64(n) != dataSize {
  703. t.Fatalf("bad: %v", n)
  704. }
  705. // EOF after close
  706. n, err = stream2.Read(buf)
  707. if err != io.EOF {
  708. t.Fatalf("err: %v", err)
  709. }
  710. if n != 0 {
  711. t.Fatalf("bad: %v", n)
  712. }
  713. }
  714. func TestReadDeadline(t *testing.T) {
  715. client, server := testClientServer()
  716. defer client.Close()
  717. defer server.Close()
  718. stream, err := client.Open()
  719. if err != nil {
  720. t.Fatalf("err: %v", err)
  721. }
  722. defer stream.Close()
  723. stream2, err := server.Accept()
  724. if err != nil {
  725. t.Fatalf("err: %v", err)
  726. }
  727. defer stream2.Close()
  728. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  729. t.Fatalf("err: %v", err)
  730. }
  731. buf := make([]byte, 4)
  732. _, err = stream.Read(buf)
  733. if err != ErrTimeout {
  734. t.Fatalf("err: %v", err)
  735. }
  736. // See https://github.com/hashicorp/yamux/issues/90
  737. // The standard library's http server package will read from connections in
  738. // the background to detect if they are alive.
  739. //
  740. // It sets a read deadline on connections and detect if the returned error
  741. // is a network timeout error which implements net.Error.
  742. //
  743. // The HTTP server will cancel all server requests if it isn't timeout error
  744. // from the connection.
  745. //
  746. // We assert that we return an error meeting the interface to avoid
  747. // accidently breaking yamux session compatability with the standard
  748. // library's http server implementation.
  749. if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
  750. t.Fatalf("reading timeout error is expected to implement net.Error and return true when calling Timeout()")
  751. }
  752. }
  753. func TestReadDeadline_BlockedRead(t *testing.T) {
  754. client, server := testClientServer()
  755. defer client.Close()
  756. defer server.Close()
  757. stream, err := client.Open()
  758. if err != nil {
  759. t.Fatalf("err: %v", err)
  760. }
  761. defer stream.Close()
  762. stream2, err := server.Accept()
  763. if err != nil {
  764. t.Fatalf("err: %v", err)
  765. }
  766. defer stream2.Close()
  767. // Start a read that will block
  768. errCh := make(chan error, 1)
  769. go func() {
  770. buf := make([]byte, 4)
  771. _, err := stream.Read(buf)
  772. errCh <- err
  773. close(errCh)
  774. }()
  775. // Wait to ensure the read has started.
  776. time.Sleep(5 * time.Millisecond)
  777. // Update the read deadline
  778. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  779. t.Fatalf("err: %v", err)
  780. }
  781. select {
  782. case <-time.After(100 * time.Millisecond):
  783. t.Fatal("expected read timeout")
  784. case err := <-errCh:
  785. if err != ErrTimeout {
  786. t.Fatalf("expected ErrTimeout; got %v", err)
  787. }
  788. }
  789. }
  790. func TestWriteDeadline(t *testing.T) {
  791. client, server := testClientServer()
  792. defer client.Close()
  793. defer server.Close()
  794. stream, err := client.Open()
  795. if err != nil {
  796. t.Fatalf("err: %v", err)
  797. }
  798. defer stream.Close()
  799. stream2, err := server.Accept()
  800. if err != nil {
  801. t.Fatalf("err: %v", err)
  802. }
  803. defer stream2.Close()
  804. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  805. t.Fatalf("err: %v", err)
  806. }
  807. buf := make([]byte, 512)
  808. for i := 0; i < int(initialStreamWindow); i++ {
  809. _, err := stream.Write(buf)
  810. if err != nil && err == ErrTimeout {
  811. return
  812. } else if err != nil {
  813. t.Fatalf("err: %v", err)
  814. }
  815. }
  816. t.Fatalf("Expected timeout")
  817. }
  818. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  819. client, server := testClientServer()
  820. defer client.Close()
  821. defer server.Close()
  822. stream, err := client.Open()
  823. if err != nil {
  824. t.Fatalf("err: %v", err)
  825. }
  826. defer stream.Close()
  827. stream2, err := server.Accept()
  828. if err != nil {
  829. t.Fatalf("err: %v", err)
  830. }
  831. defer stream2.Close()
  832. // Start a goroutine making writes that will block
  833. errCh := make(chan error, 1)
  834. go func() {
  835. buf := make([]byte, 512)
  836. for i := 0; i < int(initialStreamWindow); i++ {
  837. _, err := stream.Write(buf)
  838. if err == nil {
  839. continue
  840. }
  841. errCh <- err
  842. close(errCh)
  843. return
  844. }
  845. close(errCh)
  846. }()
  847. // Wait to ensure the write has started.
  848. time.Sleep(5 * time.Millisecond)
  849. // Update the write deadline
  850. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  851. t.Fatalf("err: %v", err)
  852. }
  853. select {
  854. case <-time.After(1 * time.Second):
  855. t.Fatal("expected write timeout")
  856. case err := <-errCh:
  857. if err != ErrTimeout {
  858. t.Fatalf("expected ErrTimeout; got %v", err)
  859. }
  860. }
  861. }
  862. func TestBacklogExceeded(t *testing.T) {
  863. client, server := testClientServer()
  864. defer client.Close()
  865. defer server.Close()
  866. // Fill the backlog
  867. max := client.config.AcceptBacklog
  868. for i := 0; i < max; i++ {
  869. stream, err := client.Open()
  870. if err != nil {
  871. t.Fatalf("err: %v", err)
  872. }
  873. defer stream.Close()
  874. if _, err := stream.Write([]byte("foo")); err != nil {
  875. t.Fatalf("err: %v", err)
  876. }
  877. }
  878. // Attempt to open a new stream
  879. errCh := make(chan error, 1)
  880. go func() {
  881. _, err := client.Open()
  882. errCh <- err
  883. }()
  884. // Shutdown the server
  885. go func() {
  886. time.Sleep(10 * time.Millisecond)
  887. server.Close()
  888. }()
  889. select {
  890. case err := <-errCh:
  891. if err == nil {
  892. t.Fatalf("open should fail")
  893. }
  894. case <-time.After(time.Second):
  895. t.Fatalf("timeout")
  896. }
  897. }
  898. func TestKeepAlive(t *testing.T) {
  899. client, server := testClientServer()
  900. defer client.Close()
  901. defer server.Close()
  902. time.Sleep(200 * time.Millisecond)
  903. // Ping value should increase
  904. client.pingLock.Lock()
  905. defer client.pingLock.Unlock()
  906. if client.pingID == 0 {
  907. t.Fatalf("should ping")
  908. }
  909. server.pingLock.Lock()
  910. defer server.pingLock.Unlock()
  911. if server.pingID == 0 {
  912. t.Fatalf("should ping")
  913. }
  914. }
  915. func TestKeepAlive_Timeout(t *testing.T) {
  916. conn1, conn2 := testConn()
  917. clientConf := testConf()
  918. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  919. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  920. client, _ := Client(conn1, clientConf)
  921. defer client.Close()
  922. server, _ := Server(conn2, testConf())
  923. defer server.Close()
  924. _ = captureLogs(client) // Client logs aren't part of the test
  925. serverLogs := captureLogs(server)
  926. errCh := make(chan error, 1)
  927. go func() {
  928. _, err := server.Accept() // Wait until server closes
  929. errCh <- err
  930. }()
  931. // Prevent the client from responding
  932. clientConn := client.conn.(*pipeConn)
  933. clientConn.writeBlocker.Lock()
  934. select {
  935. case err := <-errCh:
  936. if err != ErrKeepAliveTimeout {
  937. t.Fatalf("unexpected error: %v", err)
  938. }
  939. case <-time.After(1 * time.Second):
  940. t.Fatalf("timeout waiting for timeout")
  941. }
  942. clientConn.writeBlocker.Unlock()
  943. if !server.IsClosed() {
  944. t.Fatalf("server should have closed")
  945. }
  946. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  947. t.Fatalf("server log incorect: %v", serverLogs.logs())
  948. }
  949. }
  950. func TestLargeWindow(t *testing.T) {
  951. conf := DefaultConfig()
  952. conf.MaxStreamWindowSize *= 2
  953. client, server := testClientServerConfig(conf)
  954. defer client.Close()
  955. defer server.Close()
  956. stream, err := client.Open()
  957. if err != nil {
  958. t.Fatalf("err: %v", err)
  959. }
  960. defer stream.Close()
  961. stream2, err := server.Accept()
  962. if err != nil {
  963. t.Fatalf("err: %v", err)
  964. }
  965. defer stream2.Close()
  966. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  967. buf := make([]byte, conf.MaxStreamWindowSize)
  968. n, err := stream.Write(buf)
  969. if err != nil {
  970. t.Fatalf("err: %v", err)
  971. }
  972. if n != len(buf) {
  973. t.Fatalf("short write: %d", n)
  974. }
  975. }
  976. type UnlimitedReader struct{}
  977. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  978. runtime.Gosched()
  979. return len(p), nil
  980. }
  981. func TestSendData_VeryLarge(t *testing.T) {
  982. client, server := testClientServer()
  983. defer client.Close()
  984. defer server.Close()
  985. var n int64 = 1 * 1024 * 1024 * 1024
  986. var workers int = 16
  987. wg := &sync.WaitGroup{}
  988. wg.Add(workers * 2)
  989. for i := 0; i < workers; i++ {
  990. go func() {
  991. defer wg.Done()
  992. stream, err := server.AcceptStream()
  993. if err != nil {
  994. t.Fatalf("err: %v", err)
  995. }
  996. defer stream.Close()
  997. buf := make([]byte, 4)
  998. _, err = stream.Read(buf)
  999. if err != nil {
  1000. t.Fatalf("err: %v", err)
  1001. }
  1002. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  1003. t.Fatalf("bad header")
  1004. }
  1005. recv, err := io.Copy(ioutil.Discard, stream)
  1006. if err != nil {
  1007. t.Fatalf("err: %v", err)
  1008. }
  1009. if recv != n {
  1010. t.Fatalf("bad: %v", recv)
  1011. }
  1012. }()
  1013. }
  1014. for i := 0; i < workers; i++ {
  1015. go func() {
  1016. defer wg.Done()
  1017. stream, err := client.Open()
  1018. if err != nil {
  1019. t.Fatalf("err: %v", err)
  1020. }
  1021. defer stream.Close()
  1022. _, err = stream.Write([]byte{0, 1, 2, 3})
  1023. if err != nil {
  1024. t.Fatalf("err: %v", err)
  1025. }
  1026. unlimited := &UnlimitedReader{}
  1027. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  1028. if err != nil {
  1029. t.Fatalf("err: %v", err)
  1030. }
  1031. if sent != n {
  1032. t.Fatalf("bad: %v", sent)
  1033. }
  1034. }()
  1035. }
  1036. doneCh := make(chan struct{})
  1037. go func() {
  1038. wg.Wait()
  1039. close(doneCh)
  1040. }()
  1041. select {
  1042. case <-doneCh:
  1043. case <-time.After(20 * time.Second):
  1044. panic("timeout")
  1045. }
  1046. }
  1047. func TestBacklogExceeded_Accept(t *testing.T) {
  1048. client, server := testClientServer()
  1049. defer client.Close()
  1050. defer server.Close()
  1051. max := 5 * client.config.AcceptBacklog
  1052. go func() {
  1053. for i := 0; i < max; i++ {
  1054. stream, err := server.Accept()
  1055. if err != nil {
  1056. t.Fatalf("err: %v", err)
  1057. }
  1058. defer stream.Close()
  1059. }
  1060. }()
  1061. // Fill the backlog
  1062. for i := 0; i < max; i++ {
  1063. stream, err := client.Open()
  1064. if err != nil {
  1065. t.Fatalf("err: %v", err)
  1066. }
  1067. defer stream.Close()
  1068. if _, err := stream.Write([]byte("foo")); err != nil {
  1069. t.Fatalf("err: %v", err)
  1070. }
  1071. }
  1072. }
  1073. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  1074. client, server := testClientServerConfig(testConfNoKeepAlive())
  1075. defer client.Close()
  1076. defer server.Close()
  1077. var wg sync.WaitGroup
  1078. wg.Add(2)
  1079. // Choose a huge flood size that we know will result in a window update.
  1080. flood := int64(client.config.MaxStreamWindowSize) - 1
  1081. // The server will accept a new stream and then flood data to it.
  1082. go func() {
  1083. defer wg.Done()
  1084. stream, err := server.AcceptStream()
  1085. if err != nil {
  1086. t.Fatalf("err: %v", err)
  1087. }
  1088. defer stream.Close()
  1089. n, err := stream.Write(make([]byte, flood))
  1090. if err != nil {
  1091. t.Fatalf("err: %v", err)
  1092. }
  1093. if int64(n) != flood {
  1094. t.Fatalf("short write: %d", n)
  1095. }
  1096. }()
  1097. // The client will open a stream, block outbound writes, and then
  1098. // listen to the flood from the server, which should time out since
  1099. // it won't be able to send the window update.
  1100. go func() {
  1101. defer wg.Done()
  1102. stream, err := client.OpenStream()
  1103. if err != nil {
  1104. t.Fatalf("err: %v", err)
  1105. }
  1106. defer stream.Close()
  1107. conn := client.conn.(*pipeConn)
  1108. conn.writeBlocker.Lock()
  1109. defer conn.writeBlocker.Unlock()
  1110. _, err = stream.Read(make([]byte, flood))
  1111. if err != ErrConnectionWriteTimeout {
  1112. t.Fatalf("err: %v", err)
  1113. }
  1114. }()
  1115. wg.Wait()
  1116. }
  1117. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  1118. client, server := testClientServerConfig(testConfNoKeepAlive())
  1119. defer client.Close()
  1120. defer server.Close()
  1121. var wg sync.WaitGroup
  1122. wg.Add(1)
  1123. // Choose a huge flood size that we know will result in a window update.
  1124. flood := int64(client.config.MaxStreamWindowSize)
  1125. var wr *Stream
  1126. // The server will accept a new stream and then flood data to it.
  1127. go func() {
  1128. defer wg.Done()
  1129. var err error
  1130. wr, err = server.AcceptStream()
  1131. if err != nil {
  1132. t.Fatalf("err: %v", err)
  1133. }
  1134. defer wr.Close()
  1135. if wr.sendWindow != client.config.MaxStreamWindowSize {
  1136. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  1137. }
  1138. n, err := wr.Write(make([]byte, flood))
  1139. if err != nil {
  1140. t.Fatalf("err: %v", err)
  1141. }
  1142. if int64(n) != flood {
  1143. t.Fatalf("short write: %d", n)
  1144. }
  1145. if wr.sendWindow != 0 {
  1146. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  1147. }
  1148. }()
  1149. stream, err := client.OpenStream()
  1150. if err != nil {
  1151. t.Fatalf("err: %v", err)
  1152. }
  1153. defer stream.Close()
  1154. wg.Wait()
  1155. _, err = stream.Read(make([]byte, flood/2+1))
  1156. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1157. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1158. }
  1159. }
  1160. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1161. client, server := testClientServerConfig(testConfNoKeepAlive())
  1162. defer client.Close()
  1163. defer server.Close()
  1164. var wg sync.WaitGroup
  1165. wg.Add(2)
  1166. go func() {
  1167. defer wg.Done()
  1168. stream, err := server.AcceptStream()
  1169. if err != nil {
  1170. t.Fatalf("err: %v", err)
  1171. }
  1172. defer stream.Close()
  1173. }()
  1174. // The client will open the stream and then block outbound writes, we'll
  1175. // probe sendNoWait once it gets into that state.
  1176. go func() {
  1177. defer wg.Done()
  1178. stream, err := client.OpenStream()
  1179. if err != nil {
  1180. t.Fatalf("err: %v", err)
  1181. }
  1182. defer stream.Close()
  1183. conn := client.conn.(*pipeConn)
  1184. conn.writeBlocker.Lock()
  1185. defer conn.writeBlocker.Unlock()
  1186. hdr := header(make([]byte, headerSize))
  1187. hdr.encode(typePing, flagACK, 0, 0)
  1188. for {
  1189. err = client.sendNoWait(hdr)
  1190. if err == nil {
  1191. continue
  1192. } else if err == ErrConnectionWriteTimeout {
  1193. break
  1194. } else {
  1195. t.Fatalf("err: %v", err)
  1196. }
  1197. }
  1198. }()
  1199. wg.Wait()
  1200. }
  1201. func TestSession_PingOfDeath(t *testing.T) {
  1202. client, server := testClientServerConfig(testConfNoKeepAlive())
  1203. defer client.Close()
  1204. defer server.Close()
  1205. var wg sync.WaitGroup
  1206. wg.Add(2)
  1207. var doPingOfDeath sync.Mutex
  1208. doPingOfDeath.Lock()
  1209. // This is used later to block outbound writes.
  1210. conn := server.conn.(*pipeConn)
  1211. // The server will accept a stream, block outbound writes, and then
  1212. // flood its send channel so that no more headers can be queued.
  1213. go func() {
  1214. defer wg.Done()
  1215. stream, err := server.AcceptStream()
  1216. if err != nil {
  1217. t.Fatalf("err: %v", err)
  1218. }
  1219. defer stream.Close()
  1220. conn.writeBlocker.Lock()
  1221. for {
  1222. hdr := header(make([]byte, headerSize))
  1223. hdr.encode(typePing, 0, 0, 0)
  1224. err = server.sendNoWait(hdr)
  1225. if err == nil {
  1226. continue
  1227. } else if err == ErrConnectionWriteTimeout {
  1228. break
  1229. } else {
  1230. t.Fatalf("err: %v", err)
  1231. }
  1232. }
  1233. doPingOfDeath.Unlock()
  1234. }()
  1235. // The client will open a stream and then send the server a ping once it
  1236. // can no longer write. This makes sure the server doesn't deadlock reads
  1237. // while trying to reply to the ping with no ability to write.
  1238. go func() {
  1239. defer wg.Done()
  1240. stream, err := client.OpenStream()
  1241. if err != nil {
  1242. t.Fatalf("err: %v", err)
  1243. }
  1244. defer stream.Close()
  1245. // This ping will never unblock because the ping id will never
  1246. // show up in a response.
  1247. doPingOfDeath.Lock()
  1248. go func() { client.Ping() }()
  1249. // Wait for a while to make sure the previous ping times out,
  1250. // then turn writes back on and make sure a ping works again.
  1251. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1252. conn.writeBlocker.Unlock()
  1253. if _, err = client.Ping(); err != nil {
  1254. t.Fatalf("err: %v", err)
  1255. }
  1256. }()
  1257. wg.Wait()
  1258. }
  1259. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1260. client, server := testClientServerConfig(testConfNoKeepAlive())
  1261. defer client.Close()
  1262. defer server.Close()
  1263. var wg sync.WaitGroup
  1264. wg.Add(2)
  1265. go func() {
  1266. defer wg.Done()
  1267. stream, err := server.AcceptStream()
  1268. if err != nil {
  1269. t.Fatalf("err: %v", err)
  1270. }
  1271. defer stream.Close()
  1272. }()
  1273. // The client will open the stream and then block outbound writes, we'll
  1274. // tee up a write and make sure it eventually times out.
  1275. go func() {
  1276. defer wg.Done()
  1277. stream, err := client.OpenStream()
  1278. if err != nil {
  1279. t.Fatalf("err: %v", err)
  1280. }
  1281. defer stream.Close()
  1282. conn := client.conn.(*pipeConn)
  1283. conn.writeBlocker.Lock()
  1284. defer conn.writeBlocker.Unlock()
  1285. // Since the write goroutine is blocked then this will return a
  1286. // timeout since it can't get feedback about whether the write
  1287. // worked.
  1288. n, err := stream.Write([]byte("hello"))
  1289. if err != ErrConnectionWriteTimeout {
  1290. t.Fatalf("err: %v", err)
  1291. }
  1292. if n != 0 {
  1293. t.Fatalf("lied about writes: %d", n)
  1294. }
  1295. }()
  1296. wg.Wait()
  1297. }