stream_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package wsstream
  14. import (
  15. "bytes"
  16. "encoding/base64"
  17. "fmt"
  18. "io"
  19. "io/ioutil"
  20. "net/http"
  21. "reflect"
  22. "strings"
  23. "testing"
  24. "time"
  25. "golang.org/x/net/websocket"
  26. )
  27. func TestStream(t *testing.T) {
  28. input := "some random text"
  29. r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
  30. r.SetIdleTimeout(time.Second)
  31. data, err := readWebSocket(r, t, nil)
  32. if !reflect.DeepEqual(data, []byte(input)) {
  33. t.Errorf("unexpected server read: %v", data)
  34. }
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. }
  39. func TestStreamPing(t *testing.T) {
  40. input := "some random text"
  41. r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
  42. r.SetIdleTimeout(time.Second)
  43. err := expectWebSocketFrames(r, t, nil, [][]byte{
  44. {},
  45. []byte(input),
  46. })
  47. if err != nil {
  48. t.Fatal(err)
  49. }
  50. }
  51. func TestStreamBase64(t *testing.T) {
  52. input := "some random text"
  53. encoded := base64.StdEncoding.EncodeToString([]byte(input))
  54. r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
  55. data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
  56. if !reflect.DeepEqual(data, []byte(encoded)) {
  57. t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
  58. }
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. }
  63. func TestStreamVersionedBase64(t *testing.T) {
  64. input := "some random text"
  65. encoded := base64.StdEncoding.EncodeToString([]byte(input))
  66. r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
  67. "": {Binary: true},
  68. "binary.k8s.io": {Binary: true},
  69. "base64.binary.k8s.io": {Binary: false},
  70. "v1.binary.k8s.io": {Binary: true},
  71. "v1.base64.binary.k8s.io": {Binary: false},
  72. "v2.binary.k8s.io": {Binary: true},
  73. "v2.base64.binary.k8s.io": {Binary: false},
  74. })
  75. data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
  76. if !reflect.DeepEqual(data, []byte(encoded)) {
  77. t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
  78. }
  79. if err != nil {
  80. t.Fatal(err)
  81. }
  82. }
  83. func TestStreamVersionedCopy(t *testing.T) {
  84. for i, test := range versionTests() {
  85. func() {
  86. supportedProtocols := map[string]ReaderProtocolConfig{}
  87. for p, binary := range test.supported {
  88. supportedProtocols[p] = ReaderProtocolConfig{
  89. Binary: binary,
  90. }
  91. }
  92. input := "some random text"
  93. r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
  94. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  95. err := r.Copy(w, req)
  96. if err != nil {
  97. w.WriteHeader(503)
  98. }
  99. }))
  100. defer s.Close()
  101. config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
  102. if err != nil {
  103. t.Error(err)
  104. return
  105. }
  106. config.Protocol = test.requested
  107. client, err := websocket.DialConfig(config)
  108. if err != nil {
  109. if !test.error {
  110. t.Errorf("test %d: didn't expect error: %v", i, err)
  111. }
  112. return
  113. }
  114. defer client.Close()
  115. if test.error && err == nil {
  116. t.Errorf("test %d: expected an error", i)
  117. return
  118. }
  119. <-r.err
  120. if got, expected := r.selectedProtocol, test.expected; got != expected {
  121. t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
  122. }
  123. }()
  124. }
  125. }
  126. func TestStreamError(t *testing.T) {
  127. input := "some random text"
  128. errs := &errorReader{
  129. reads: [][]byte{
  130. []byte("some random"),
  131. []byte(" text"),
  132. },
  133. err: fmt.Errorf("bad read"),
  134. }
  135. r := NewReader(errs, false, NewDefaultReaderProtocols())
  136. data, err := readWebSocket(r, t, nil)
  137. if !reflect.DeepEqual(data, []byte(input)) {
  138. t.Errorf("unexpected server read: %v", data)
  139. }
  140. if err == nil || err.Error() != "bad read" {
  141. t.Fatal(err)
  142. }
  143. }
  144. func TestStreamSurvivesPanic(t *testing.T) {
  145. input := "some random text"
  146. errs := &errorReader{
  147. reads: [][]byte{
  148. []byte("some random"),
  149. []byte(" text"),
  150. },
  151. panicMessage: "bad read",
  152. }
  153. r := NewReader(errs, false, NewDefaultReaderProtocols())
  154. // do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
  155. r.handleCrash = func() { recover() }
  156. data, err := readWebSocket(r, t, nil)
  157. if !reflect.DeepEqual(data, []byte(input)) {
  158. t.Errorf("unexpected server read: %v", data)
  159. }
  160. if err != nil {
  161. t.Fatal(err)
  162. }
  163. }
  164. func TestStreamClosedDuringRead(t *testing.T) {
  165. for i := 0; i < 25; i++ {
  166. ch := make(chan struct{})
  167. input := "some random text"
  168. errs := &errorReader{
  169. reads: [][]byte{
  170. []byte("some random"),
  171. []byte(" text"),
  172. },
  173. err: fmt.Errorf("stuff"),
  174. pause: ch,
  175. }
  176. r := NewReader(errs, false, NewDefaultReaderProtocols())
  177. data, err := readWebSocket(r, t, func(c *websocket.Conn) {
  178. c.Close()
  179. close(ch)
  180. })
  181. // verify that the data returned by the server on an early close always has a specific error
  182. if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
  183. t.Fatal(err)
  184. }
  185. // verify that the data returned is a strict subset of the input
  186. if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
  187. t.Fatalf("unexpected server read: %q", string(data))
  188. }
  189. }
  190. }
  191. type errorReader struct {
  192. reads [][]byte
  193. err error
  194. panicMessage string
  195. pause chan struct{}
  196. }
  197. func (r *errorReader) Read(p []byte) (int, error) {
  198. if len(r.reads) == 0 {
  199. if r.pause != nil {
  200. <-r.pause
  201. }
  202. if len(r.panicMessage) != 0 {
  203. panic(r.panicMessage)
  204. }
  205. return 0, r.err
  206. }
  207. next := r.reads[0]
  208. r.reads = r.reads[1:]
  209. copy(p, next)
  210. return len(next), nil
  211. }
  212. func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
  213. errCh := make(chan error, 1)
  214. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  215. errCh <- r.Copy(w, req)
  216. }))
  217. defer s.Close()
  218. config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
  219. config.Protocol = protocols
  220. client, err := websocket.DialConfig(config)
  221. if err != nil {
  222. return nil, err
  223. }
  224. defer client.Close()
  225. if fn != nil {
  226. fn(client)
  227. }
  228. data, err := ioutil.ReadAll(client)
  229. if err != nil {
  230. return data, err
  231. }
  232. return data, <-errCh
  233. }
  234. func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
  235. errCh := make(chan error, 1)
  236. s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  237. errCh <- r.Copy(w, req)
  238. }))
  239. defer s.Close()
  240. config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
  241. config.Protocol = protocols
  242. ws, err := websocket.DialConfig(config)
  243. if err != nil {
  244. return err
  245. }
  246. defer ws.Close()
  247. if fn != nil {
  248. fn(ws)
  249. }
  250. for i := range frames {
  251. var data []byte
  252. if err := websocket.Message.Receive(ws, &data); err != nil {
  253. return err
  254. }
  255. if !reflect.DeepEqual(frames[i], data) {
  256. return fmt.Errorf("frame %d did not match expected: %v", data, err)
  257. }
  258. }
  259. var data []byte
  260. if err := websocket.Message.Receive(ws, &data); err != io.EOF {
  261. return fmt.Errorf("expected no more frames: %v (%v)", err, data)
  262. }
  263. return <-errCh
  264. }