conn_test.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package wsstream
  14. import (
  15. "encoding/base64"
  16. "io"
  17. "io/ioutil"
  18. "net/http"
  19. "net/http/httptest"
  20. "reflect"
  21. "sync"
  22. "testing"
  23. "golang.org/x/net/websocket"
  24. )
  25. func newServer(handler http.Handler) (*httptest.Server, string) {
  26. server := httptest.NewServer(handler)
  27. serverAddr := server.Listener.Addr().String()
  28. return server, serverAddr
  29. }
  30. func TestRawConn(t *testing.T) {
  31. channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
  32. conn := NewConn(NewDefaultChannelProtocols(channels))
  33. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  34. conn.Open(w, req)
  35. }))
  36. defer s.Close()
  37. client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. defer client.Close()
  42. <-conn.ready
  43. wg := sync.WaitGroup{}
  44. // verify we can read a client write
  45. wg.Add(1)
  46. go func() {
  47. defer wg.Done()
  48. data, err := ioutil.ReadAll(conn.channels[0])
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. if !reflect.DeepEqual(data, []byte("client")) {
  53. t.Errorf("unexpected server read: %v", data)
  54. }
  55. }()
  56. if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
  57. t.Fatalf("%d: %v", n, err)
  58. }
  59. // verify we can read a server write
  60. wg.Add(1)
  61. go func() {
  62. defer wg.Done()
  63. if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
  64. t.Fatalf("%d: %v", n, err)
  65. }
  66. }()
  67. data := make([]byte, 1024)
  68. if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
  69. t.Fatalf("%d: %v", n, err)
  70. }
  71. if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
  72. t.Errorf("unexpected client read: %v", data[:7])
  73. }
  74. // verify that an ignore channel is empty in both directions.
  75. if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
  76. t.Errorf("writes should be ignored")
  77. }
  78. data = make([]byte, 1024)
  79. if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
  80. t.Errorf("reads should be ignored")
  81. }
  82. // verify that a write to a Read channel doesn't block
  83. if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
  84. t.Errorf("writes should be ignored")
  85. }
  86. // verify that a read from a Write channel doesn't block
  87. data = make([]byte, 1024)
  88. if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
  89. t.Errorf("reads should be ignored")
  90. }
  91. // verify that a client write to a Write channel doesn't block (is dropped)
  92. if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
  93. t.Fatalf("%d: %v", n, err)
  94. }
  95. client.Close()
  96. wg.Wait()
  97. }
  98. func TestBase64Conn(t *testing.T) {
  99. conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
  100. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  101. conn.Open(w, req)
  102. }))
  103. defer s.Close()
  104. config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
  105. if err != nil {
  106. t.Fatal(err)
  107. }
  108. config.Protocol = []string{"base64.channel.k8s.io"}
  109. client, err := websocket.DialConfig(config)
  110. if err != nil {
  111. t.Fatal(err)
  112. }
  113. defer client.Close()
  114. <-conn.ready
  115. wg := sync.WaitGroup{}
  116. wg.Add(1)
  117. go func() {
  118. defer wg.Done()
  119. data, err := ioutil.ReadAll(conn.channels[0])
  120. if err != nil {
  121. t.Fatal(err)
  122. }
  123. if !reflect.DeepEqual(data, []byte("client")) {
  124. t.Errorf("unexpected server read: %s", string(data))
  125. }
  126. }()
  127. clientData := base64.StdEncoding.EncodeToString([]byte("client"))
  128. if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
  129. t.Fatalf("%d: %v", n, err)
  130. }
  131. wg.Add(1)
  132. go func() {
  133. defer wg.Done()
  134. if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
  135. t.Fatalf("%d: %v", n, err)
  136. }
  137. }()
  138. data := make([]byte, 1024)
  139. if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
  140. t.Fatalf("%d: %v", n, err)
  141. }
  142. expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
  143. if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
  144. t.Errorf("unexpected client read: %v", data[:9])
  145. }
  146. client.Close()
  147. wg.Wait()
  148. }
  149. type versionTest struct {
  150. supported map[string]bool // protocol -> binary
  151. requested []string
  152. error bool
  153. expected string
  154. }
  155. func versionTests() []versionTest {
  156. const (
  157. binary = true
  158. base64 = false
  159. )
  160. return []versionTest{
  161. {
  162. supported: nil,
  163. requested: []string{"raw"},
  164. error: true,
  165. },
  166. {
  167. supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
  168. requested: nil,
  169. expected: "",
  170. },
  171. {
  172. supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
  173. requested: []string{"v1.raw"},
  174. error: true,
  175. },
  176. {
  177. supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
  178. requested: []string{"v1.raw", "v1.base64"},
  179. error: true,
  180. }, {
  181. supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
  182. requested: []string{"v1.raw", "raw"},
  183. expected: "raw",
  184. },
  185. {
  186. supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
  187. requested: []string{"v1.raw"},
  188. expected: "v1.raw",
  189. },
  190. {
  191. supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
  192. requested: []string{"v2.base64"},
  193. expected: "v2.base64",
  194. },
  195. }
  196. }
  197. func TestVersionedConn(t *testing.T) {
  198. for i, test := range versionTests() {
  199. func() {
  200. supportedProtocols := map[string]ChannelProtocolConfig{}
  201. for p, binary := range test.supported {
  202. supportedProtocols[p] = ChannelProtocolConfig{
  203. Binary: binary,
  204. Channels: []ChannelType{ReadWriteChannel},
  205. }
  206. }
  207. conn := NewConn(supportedProtocols)
  208. // note that it's not enough to wait for conn.ready to avoid a race here. Hence,
  209. // we use a channel.
  210. selectedProtocol := make(chan string, 0)
  211. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  212. p, _, _ := conn.Open(w, req)
  213. selectedProtocol <- p
  214. }))
  215. defer s.Close()
  216. config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. config.Protocol = test.requested
  221. client, err := websocket.DialConfig(config)
  222. if err != nil {
  223. if !test.error {
  224. t.Fatalf("test %d: didn't expect error: %v", i, err)
  225. } else {
  226. return
  227. }
  228. }
  229. defer client.Close()
  230. if test.error && err == nil {
  231. t.Fatalf("test %d: expected an error", i)
  232. }
  233. <-conn.ready
  234. if got, expected := <-selectedProtocol, test.expected; got != expected {
  235. t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
  236. }
  237. }()
  238. }
  239. }