session_test.go 10 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "testing"
  8. "time"
  9. )
  10. type pipeConn struct {
  11. reader *io.PipeReader
  12. writer *io.PipeWriter
  13. }
  14. func (p *pipeConn) Read(b []byte) (int, error) {
  15. return p.reader.Read(b)
  16. }
  17. func (p *pipeConn) Write(b []byte) (int, error) {
  18. return p.writer.Write(b)
  19. }
  20. func (p *pipeConn) Close() error {
  21. p.reader.Close()
  22. return p.writer.Close()
  23. }
  24. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  25. read1, write1 := io.Pipe()
  26. read2, write2 := io.Pipe()
  27. return &pipeConn{read1, write2}, &pipeConn{read2, write1}
  28. }
  29. func testClientServer() (*Session, *Session) {
  30. conf := DefaultConfig()
  31. conf.AcceptBacklog = 64
  32. conf.KeepAliveInterval = 100 * time.Millisecond
  33. return testClientServerConfig(conf)
  34. }
  35. func testClientServerConfig(conf *Config) (*Session, *Session) {
  36. conn1, conn2 := testConn()
  37. client, _ := Client(conn1, conf)
  38. server, _ := Server(conn2, conf)
  39. return client, server
  40. }
  41. func TestPing(t *testing.T) {
  42. client, server := testClientServer()
  43. defer client.Close()
  44. defer server.Close()
  45. rtt, err := client.Ping()
  46. if err != nil {
  47. t.Fatalf("err: %v", err)
  48. }
  49. if rtt == 0 {
  50. t.Fatalf("bad: %v", rtt)
  51. }
  52. rtt, err = server.Ping()
  53. if err != nil {
  54. t.Fatalf("err: %v", err)
  55. }
  56. if rtt == 0 {
  57. t.Fatalf("bad: %v", rtt)
  58. }
  59. }
  60. func TestAccept(t *testing.T) {
  61. client, server := testClientServer()
  62. defer client.Close()
  63. defer server.Close()
  64. wg := &sync.WaitGroup{}
  65. wg.Add(4)
  66. go func() {
  67. defer wg.Done()
  68. stream, err := server.AcceptStream()
  69. if err != nil {
  70. t.Fatalf("err: %v", err)
  71. }
  72. if id := stream.StreamID(); id != 1 {
  73. t.Fatalf("bad: %v", id)
  74. }
  75. if err := stream.Close(); err != nil {
  76. t.Fatalf("err: %v", err)
  77. }
  78. }()
  79. go func() {
  80. defer wg.Done()
  81. stream, err := client.AcceptStream()
  82. if err != nil {
  83. t.Fatalf("err: %v", err)
  84. }
  85. if id := stream.StreamID(); id != 2 {
  86. t.Fatalf("bad: %v", id)
  87. }
  88. if err := stream.Close(); err != nil {
  89. t.Fatalf("err: %v", err)
  90. }
  91. }()
  92. go func() {
  93. defer wg.Done()
  94. stream, err := server.Open()
  95. if err != nil {
  96. t.Fatalf("err: %v", err)
  97. }
  98. if id := stream.StreamID(); id != 2 {
  99. t.Fatalf("bad: %v", id)
  100. }
  101. if err := stream.Close(); err != nil {
  102. t.Fatalf("err: %v", err)
  103. }
  104. }()
  105. go func() {
  106. defer wg.Done()
  107. stream, err := client.Open()
  108. if err != nil {
  109. t.Fatalf("err: %v", err)
  110. }
  111. if id := stream.StreamID(); id != 1 {
  112. t.Fatalf("bad: %v", id)
  113. }
  114. if err := stream.Close(); err != nil {
  115. t.Fatalf("err: %v", err)
  116. }
  117. }()
  118. doneCh := make(chan struct{})
  119. go func() {
  120. wg.Wait()
  121. close(doneCh)
  122. }()
  123. select {
  124. case <-doneCh:
  125. case <-time.After(time.Second):
  126. panic("timeout")
  127. }
  128. }
  129. func TestSendData_Small(t *testing.T) {
  130. client, server := testClientServer()
  131. defer client.Close()
  132. defer server.Close()
  133. wg := &sync.WaitGroup{}
  134. wg.Add(2)
  135. go func() {
  136. defer wg.Done()
  137. stream, err := server.AcceptStream()
  138. if err != nil {
  139. t.Fatalf("err: %v", err)
  140. }
  141. buf := make([]byte, 4)
  142. for i := 0; i < 1000; i++ {
  143. n, err := stream.Read(buf)
  144. if err != nil {
  145. t.Fatalf("err: %v", err)
  146. }
  147. if n != 4 {
  148. t.Fatalf("short read: %d", n)
  149. }
  150. if string(buf) != "test" {
  151. t.Fatalf("bad: %s", buf)
  152. }
  153. }
  154. if err := stream.Close(); err != nil {
  155. t.Fatalf("err: %v", err)
  156. }
  157. }()
  158. go func() {
  159. defer wg.Done()
  160. stream, err := client.Open()
  161. if err != nil {
  162. t.Fatalf("err: %v", err)
  163. }
  164. for i := 0; i < 1000; i++ {
  165. n, err := stream.Write([]byte("test"))
  166. if err != nil {
  167. t.Fatalf("err: %v", err)
  168. }
  169. if n != 4 {
  170. t.Fatalf("short write %d", n)
  171. }
  172. }
  173. if err := stream.Close(); err != nil {
  174. t.Fatalf("err: %v", err)
  175. }
  176. }()
  177. doneCh := make(chan struct{})
  178. go func() {
  179. wg.Wait()
  180. close(doneCh)
  181. }()
  182. select {
  183. case <-doneCh:
  184. case <-time.After(time.Second):
  185. panic("timeout")
  186. }
  187. }
  188. func TestSendData_Large(t *testing.T) {
  189. client, server := testClientServer()
  190. defer client.Close()
  191. defer server.Close()
  192. data := make([]byte, 512*1024)
  193. for idx := range data {
  194. data[idx] = byte(idx % 256)
  195. }
  196. wg := &sync.WaitGroup{}
  197. wg.Add(2)
  198. go func() {
  199. defer wg.Done()
  200. stream, err := server.AcceptStream()
  201. if err != nil {
  202. t.Fatalf("err: %v", err)
  203. }
  204. buf := make([]byte, 4*1024)
  205. for i := 0; i < 128; i++ {
  206. n, err := stream.Read(buf)
  207. if err != nil {
  208. t.Fatalf("err: %v", err)
  209. }
  210. if n != 4*1024 {
  211. t.Fatalf("short read: %d", n)
  212. }
  213. for idx := range buf {
  214. if buf[idx] != byte(idx%256) {
  215. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  216. }
  217. }
  218. }
  219. if err := stream.Close(); err != nil {
  220. t.Fatalf("err: %v", err)
  221. }
  222. }()
  223. go func() {
  224. defer wg.Done()
  225. stream, err := client.Open()
  226. if err != nil {
  227. t.Fatalf("err: %v", err)
  228. }
  229. n, err := stream.Write(data)
  230. if err != nil {
  231. t.Fatalf("err: %v", err)
  232. }
  233. if n != len(data) {
  234. t.Fatalf("short write %d", n)
  235. }
  236. if err := stream.Close(); err != nil {
  237. t.Fatalf("err: %v", err)
  238. }
  239. }()
  240. doneCh := make(chan struct{})
  241. go func() {
  242. wg.Wait()
  243. close(doneCh)
  244. }()
  245. select {
  246. case <-doneCh:
  247. case <-time.After(time.Second):
  248. panic("timeout")
  249. }
  250. }
  251. func TestGoAway(t *testing.T) {
  252. client, server := testClientServer()
  253. defer client.Close()
  254. defer server.Close()
  255. if err := server.GoAway(); err != nil {
  256. t.Fatalf("err: %v", err)
  257. }
  258. _, err := client.Open()
  259. if err != ErrRemoteGoAway {
  260. t.Fatalf("err: %v", err)
  261. }
  262. }
  263. func TestManyStreams(t *testing.T) {
  264. client, server := testClientServer()
  265. defer client.Close()
  266. defer server.Close()
  267. wg := &sync.WaitGroup{}
  268. acceptor := func(i int) {
  269. defer wg.Done()
  270. stream, err := server.AcceptStream()
  271. if err != nil {
  272. t.Fatalf("err: %v", err)
  273. }
  274. defer stream.Close()
  275. buf := make([]byte, 512)
  276. for {
  277. n, err := stream.Read(buf)
  278. if err == io.EOF {
  279. return
  280. }
  281. if err != nil {
  282. t.Fatalf("err: %v", err)
  283. }
  284. if n == 0 {
  285. t.Fatalf("err: %v", err)
  286. }
  287. }
  288. }
  289. sender := func(i int) {
  290. defer wg.Done()
  291. stream, err := client.Open()
  292. if err != nil {
  293. t.Fatalf("err: %v", err)
  294. }
  295. defer stream.Close()
  296. msg := fmt.Sprintf("%08d", i)
  297. for i := 0; i < 1000; i++ {
  298. n, err := stream.Write([]byte(msg))
  299. if err != nil {
  300. t.Fatalf("err: %v", err)
  301. }
  302. if n != len(msg) {
  303. t.Fatalf("short write %d", n)
  304. }
  305. }
  306. }
  307. for i := 0; i < 50; i++ {
  308. wg.Add(2)
  309. go acceptor(i)
  310. go sender(i)
  311. }
  312. wg.Wait()
  313. }
  314. func TestManyStreams_PingPong(t *testing.T) {
  315. client, server := testClientServer()
  316. defer client.Close()
  317. defer server.Close()
  318. wg := &sync.WaitGroup{}
  319. ping := []byte("ping")
  320. pong := []byte("pong")
  321. acceptor := func(i int) {
  322. defer wg.Done()
  323. stream, err := server.AcceptStream()
  324. if err != nil {
  325. t.Fatalf("err: %v", err)
  326. }
  327. defer stream.Close()
  328. buf := make([]byte, 4)
  329. for {
  330. n, err := stream.Read(buf)
  331. if err == io.EOF {
  332. return
  333. }
  334. if err != nil {
  335. t.Fatalf("err: %v", err)
  336. }
  337. if n != 4 {
  338. t.Fatalf("err: %v", err)
  339. }
  340. if !bytes.Equal(buf, ping) {
  341. t.Fatalf("bad: %s", buf)
  342. }
  343. n, err = stream.Write(pong)
  344. if err != nil {
  345. t.Fatalf("err: %v", err)
  346. }
  347. if n != 4 {
  348. t.Fatalf("err: %v", err)
  349. }
  350. }
  351. }
  352. sender := func(i int) {
  353. defer wg.Done()
  354. stream, err := client.Open()
  355. if err != nil {
  356. t.Fatalf("err: %v", err)
  357. }
  358. defer stream.Close()
  359. buf := make([]byte, 4)
  360. for i := 0; i < 1000; i++ {
  361. n, err := stream.Write(ping)
  362. if err != nil {
  363. t.Fatalf("err: %v", err)
  364. }
  365. if n != 4 {
  366. t.Fatalf("short write %d", n)
  367. }
  368. n, err = stream.Read(buf)
  369. if err != nil {
  370. t.Fatalf("err: %v", err)
  371. }
  372. if n != 4 {
  373. t.Fatalf("err: %v", err)
  374. }
  375. if !bytes.Equal(buf, pong) {
  376. t.Fatalf("bad: %s", buf)
  377. }
  378. }
  379. }
  380. for i := 0; i < 50; i++ {
  381. wg.Add(2)
  382. go acceptor(i)
  383. go sender(i)
  384. }
  385. wg.Wait()
  386. }
  387. func TestHalfClose(t *testing.T) {
  388. client, server := testClientServer()
  389. defer client.Close()
  390. defer server.Close()
  391. stream, err := client.Open()
  392. if err != nil {
  393. t.Fatalf("err: %v", err)
  394. }
  395. if _, err := stream.Write([]byte("a")); err != nil {
  396. t.Fatalf("err: %v", err)
  397. }
  398. stream2, err := server.Accept()
  399. if err != nil {
  400. t.Fatalf("err: %v", err)
  401. }
  402. stream2.Close() // Half close
  403. buf := make([]byte, 4)
  404. n, err := stream2.Read(buf)
  405. if err != nil {
  406. t.Fatalf("err: %v", err)
  407. }
  408. if n != 1 {
  409. t.Fatalf("bad: %v", n)
  410. }
  411. // Send more
  412. if _, err := stream.Write([]byte("bcd")); err != nil {
  413. t.Fatalf("err: %v", err)
  414. }
  415. stream.Close()
  416. // Read after close
  417. n, err = stream2.Read(buf)
  418. if err != nil {
  419. t.Fatalf("err: %v", err)
  420. }
  421. if n != 3 {
  422. t.Fatalf("bad: %v", n)
  423. }
  424. // EOF after close
  425. n, err = stream2.Read(buf)
  426. if err != io.EOF {
  427. t.Fatalf("err: %v", err)
  428. }
  429. if n != 0 {
  430. t.Fatalf("bad: %v", n)
  431. }
  432. }
  433. func TestReadDeadline(t *testing.T) {
  434. client, server := testClientServer()
  435. defer client.Close()
  436. defer server.Close()
  437. stream, err := client.Open()
  438. if err != nil {
  439. t.Fatalf("err: %v", err)
  440. }
  441. defer stream.Close()
  442. stream2, err := server.Accept()
  443. if err != nil {
  444. t.Fatalf("err: %v", err)
  445. }
  446. defer stream2.Close()
  447. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  448. t.Fatalf("err: %v", err)
  449. }
  450. buf := make([]byte, 4)
  451. if _, err := stream.Read(buf); err != ErrTimeout {
  452. t.Fatalf("err: %v", err)
  453. }
  454. }
  455. func TestWriteDeadline(t *testing.T) {
  456. client, server := testClientServer()
  457. defer client.Close()
  458. defer server.Close()
  459. stream, err := client.Open()
  460. if err != nil {
  461. t.Fatalf("err: %v", err)
  462. }
  463. defer stream.Close()
  464. stream2, err := server.Accept()
  465. if err != nil {
  466. t.Fatalf("err: %v", err)
  467. }
  468. defer stream2.Close()
  469. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  470. t.Fatalf("err: %v", err)
  471. }
  472. buf := make([]byte, 512)
  473. for i := 0; i < int(initialStreamWindow); i++ {
  474. _, err := stream.Write(buf)
  475. if err != nil && err == ErrTimeout {
  476. return
  477. } else if err != nil {
  478. t.Fatalf("err: %v", err)
  479. }
  480. }
  481. t.Fatalf("Expected timeout")
  482. }
  483. func TestBacklogExceeded(t *testing.T) {
  484. client, server := testClientServer()
  485. defer client.Close()
  486. defer server.Close()
  487. // Fill the backlog
  488. max := client.config.AcceptBacklog
  489. for i := 0; i < max; i++ {
  490. stream, err := client.Open()
  491. if err != nil {
  492. t.Fatalf("err: %v", err)
  493. }
  494. defer stream.Close()
  495. if _, err := stream.Write([]byte("foo")); err != nil {
  496. t.Fatalf("err: %v", err)
  497. }
  498. }
  499. // Exceed the backlog!
  500. stream, err := client.Open()
  501. if err != nil {
  502. t.Fatalf("err: %v", err)
  503. }
  504. defer stream.Close()
  505. if _, err := stream.Write([]byte("foo")); err != nil {
  506. t.Fatalf("err: %v", err)
  507. }
  508. buf := make([]byte, 4)
  509. stream.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
  510. if _, err := stream.Read(buf); err != ErrConnectionReset {
  511. t.Fatalf("err: %v", err)
  512. }
  513. }