session_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "sync"
  6. "testing"
  7. "time"
  8. )
  9. type pipeConn struct {
  10. reader *io.PipeReader
  11. writer *io.PipeWriter
  12. }
  13. func (p *pipeConn) Read(b []byte) (int, error) {
  14. return p.reader.Read(b)
  15. }
  16. func (p *pipeConn) Write(b []byte) (int, error) {
  17. return p.writer.Write(b)
  18. }
  19. func (p *pipeConn) Close() error {
  20. p.reader.Close()
  21. return p.writer.Close()
  22. }
  23. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  24. read1, write1 := io.Pipe()
  25. read2, write2 := io.Pipe()
  26. return &pipeConn{read1, write2}, &pipeConn{read2, write1}
  27. }
  28. func testClientServer() (*Session, *Session) {
  29. conn1, conn2 := testConn()
  30. client, _ := Client(conn1, nil)
  31. server, _ := Server(conn2, nil)
  32. return client, server
  33. }
  34. func TestPing(t *testing.T) {
  35. client, server := testClientServer()
  36. defer client.Close()
  37. defer server.Close()
  38. rtt, err := client.Ping()
  39. if err != nil {
  40. t.Fatalf("err: %v", err)
  41. }
  42. if rtt == 0 {
  43. t.Fatalf("bad: %v", rtt)
  44. }
  45. rtt, err = server.Ping()
  46. if err != nil {
  47. t.Fatalf("err: %v", err)
  48. }
  49. if rtt == 0 {
  50. t.Fatalf("bad: %v", rtt)
  51. }
  52. }
  53. func TestAccept(t *testing.T) {
  54. client, server := testClientServer()
  55. defer client.Close()
  56. defer server.Close()
  57. wg := &sync.WaitGroup{}
  58. wg.Add(4)
  59. go func() {
  60. defer wg.Done()
  61. stream, err := server.AcceptStream()
  62. if err != nil {
  63. t.Fatalf("err: %v", err)
  64. }
  65. if id := stream.StreamID(); id != 1 {
  66. t.Fatalf("bad: %v", id)
  67. }
  68. if err := stream.Close(); err != nil {
  69. t.Fatalf("err: %v", err)
  70. }
  71. }()
  72. go func() {
  73. defer wg.Done()
  74. stream, err := client.AcceptStream()
  75. if err != nil {
  76. t.Fatalf("err: %v", err)
  77. }
  78. if id := stream.StreamID(); id != 2 {
  79. t.Fatalf("bad: %v", id)
  80. }
  81. if err := stream.Close(); err != nil {
  82. t.Fatalf("err: %v", err)
  83. }
  84. }()
  85. go func() {
  86. defer wg.Done()
  87. stream, err := server.Open()
  88. if err != nil {
  89. t.Fatalf("err: %v", err)
  90. }
  91. if id := stream.StreamID(); id != 2 {
  92. t.Fatalf("bad: %v", id)
  93. }
  94. if err := stream.Close(); err != nil {
  95. t.Fatalf("err: %v", err)
  96. }
  97. }()
  98. go func() {
  99. defer wg.Done()
  100. stream, err := client.Open()
  101. if err != nil {
  102. t.Fatalf("err: %v", err)
  103. }
  104. if id := stream.StreamID(); id != 1 {
  105. t.Fatalf("bad: %v", id)
  106. }
  107. if err := stream.Close(); err != nil {
  108. t.Fatalf("err: %v", err)
  109. }
  110. }()
  111. doneCh := make(chan struct{})
  112. go func() {
  113. wg.Wait()
  114. close(doneCh)
  115. }()
  116. select {
  117. case <-doneCh:
  118. case <-time.After(time.Second):
  119. panic("timeout")
  120. }
  121. }
  122. func TestSendData_Small(t *testing.T) {
  123. client, server := testClientServer()
  124. defer client.Close()
  125. defer server.Close()
  126. wg := &sync.WaitGroup{}
  127. wg.Add(2)
  128. go func() {
  129. defer wg.Done()
  130. stream, err := server.AcceptStream()
  131. if err != nil {
  132. t.Fatalf("err: %v", err)
  133. }
  134. buf := make([]byte, 4)
  135. for i := 0; i < 1000; i++ {
  136. n, err := stream.Read(buf)
  137. if err != nil {
  138. t.Fatalf("err: %v", err)
  139. }
  140. if n != 4 {
  141. t.Fatalf("short read: %d", n)
  142. }
  143. if string(buf) != "test" {
  144. t.Fatalf("bad: %s", buf)
  145. }
  146. }
  147. if err := stream.Close(); err != nil {
  148. t.Fatalf("err: %v", err)
  149. }
  150. }()
  151. go func() {
  152. defer wg.Done()
  153. stream, err := client.Open()
  154. if err != nil {
  155. t.Fatalf("err: %v", err)
  156. }
  157. for i := 0; i < 1000; i++ {
  158. n, err := stream.Write([]byte("test"))
  159. if err != nil {
  160. t.Fatalf("err: %v", err)
  161. }
  162. if n != 4 {
  163. t.Fatalf("short write %d", n)
  164. }
  165. }
  166. if err := stream.Close(); err != nil {
  167. t.Fatalf("err: %v", err)
  168. }
  169. }()
  170. doneCh := make(chan struct{})
  171. go func() {
  172. wg.Wait()
  173. close(doneCh)
  174. }()
  175. select {
  176. case <-doneCh:
  177. case <-time.After(time.Second):
  178. panic("timeout")
  179. }
  180. }
  181. func TestSendData_Large(t *testing.T) {
  182. client, server := testClientServer()
  183. defer client.Close()
  184. defer server.Close()
  185. data := make([]byte, 512*1024)
  186. for idx := range data {
  187. data[idx] = byte(idx % 256)
  188. }
  189. wg := &sync.WaitGroup{}
  190. wg.Add(2)
  191. go func() {
  192. defer wg.Done()
  193. stream, err := server.AcceptStream()
  194. if err != nil {
  195. t.Fatalf("err: %v", err)
  196. }
  197. buf := make([]byte, 4*1024)
  198. for i := 0; i < 128; i++ {
  199. n, err := stream.Read(buf)
  200. if err != nil {
  201. t.Fatalf("err: %v", err)
  202. }
  203. if n != 4*1024 {
  204. t.Fatalf("short read: %d", n)
  205. }
  206. for idx := range buf {
  207. if buf[idx] != byte(idx%256) {
  208. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  209. }
  210. }
  211. }
  212. if err := stream.Close(); err != nil {
  213. t.Fatalf("err: %v", err)
  214. }
  215. }()
  216. go func() {
  217. defer wg.Done()
  218. stream, err := client.Open()
  219. if err != nil {
  220. t.Fatalf("err: %v", err)
  221. }
  222. n, err := stream.Write(data)
  223. if err != nil {
  224. t.Fatalf("err: %v", err)
  225. }
  226. if n != len(data) {
  227. t.Fatalf("short write %d", n)
  228. }
  229. if err := stream.Close(); err != nil {
  230. t.Fatalf("err: %v", err)
  231. }
  232. }()
  233. doneCh := make(chan struct{})
  234. go func() {
  235. wg.Wait()
  236. close(doneCh)
  237. }()
  238. select {
  239. case <-doneCh:
  240. case <-time.After(time.Second):
  241. panic("timeout")
  242. }
  243. }
  244. func TestGoAway(t *testing.T) {
  245. client, server := testClientServer()
  246. defer client.Close()
  247. defer server.Close()
  248. if err := server.GoAway(); err != nil {
  249. t.Fatalf("err: %v", err)
  250. }
  251. _, err := client.Open()
  252. if err != ErrRemoteGoAway {
  253. t.Fatalf("err: %v", err)
  254. }
  255. }
  256. func TestManyStreams(t *testing.T) {
  257. client, server := testClientServer()
  258. defer client.Close()
  259. defer server.Close()
  260. wg := &sync.WaitGroup{}
  261. acceptor := func(i int) {
  262. defer wg.Done()
  263. stream, err := server.AcceptStream()
  264. if err != nil {
  265. t.Fatalf("err: %v", err)
  266. }
  267. defer stream.Close()
  268. buf := make([]byte, 512)
  269. for {
  270. n, err := stream.Read(buf)
  271. println("read")
  272. if err == io.EOF {
  273. return
  274. }
  275. if err != nil {
  276. t.Fatalf("err: %v", err)
  277. }
  278. if n == 0 {
  279. t.Fatalf("err: %v", err)
  280. }
  281. }
  282. }
  283. sender := func(i int) {
  284. defer wg.Done()
  285. stream, err := client.Open()
  286. if err != nil {
  287. t.Fatalf("err: %v", err)
  288. }
  289. defer stream.Close()
  290. msg := fmt.Sprintf("%08d", i)
  291. for i := 0; i < 1000; i++ {
  292. n, err := stream.Write([]byte(msg))
  293. println("write")
  294. if err != nil {
  295. t.Fatalf("err: %v", err)
  296. }
  297. if n != len(msg) {
  298. t.Fatalf("short write %d", n)
  299. }
  300. }
  301. }
  302. for i := 0; i < 50; i++ {
  303. wg.Add(2)
  304. go acceptor(i)
  305. go sender(i)
  306. }
  307. wg.Wait()
  308. }