123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- /*
- Copyright 2015 The Kubernetes Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package wsstream
- import (
- "bytes"
- "encoding/base64"
- "fmt"
- "io"
- "io/ioutil"
- "net/http"
- "reflect"
- "strings"
- "testing"
- "time"
- "golang.org/x/net/websocket"
- )
- func TestStream(t *testing.T) {
- input := "some random text"
- r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
- r.SetIdleTimeout(time.Second)
- data, err := readWebSocket(r, t, nil)
- if !reflect.DeepEqual(data, []byte(input)) {
- t.Errorf("unexpected server read: %v", data)
- }
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestStreamPing(t *testing.T) {
- input := "some random text"
- r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
- r.SetIdleTimeout(time.Second)
- err := expectWebSocketFrames(r, t, nil, [][]byte{
- {},
- []byte(input),
- })
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestStreamBase64(t *testing.T) {
- input := "some random text"
- encoded := base64.StdEncoding.EncodeToString([]byte(input))
- r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
- data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
- if !reflect.DeepEqual(data, []byte(encoded)) {
- t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
- }
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestStreamVersionedBase64(t *testing.T) {
- input := "some random text"
- encoded := base64.StdEncoding.EncodeToString([]byte(input))
- r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
- "": {Binary: true},
- "binary.k8s.io": {Binary: true},
- "base64.binary.k8s.io": {Binary: false},
- "v1.binary.k8s.io": {Binary: true},
- "v1.base64.binary.k8s.io": {Binary: false},
- "v2.binary.k8s.io": {Binary: true},
- "v2.base64.binary.k8s.io": {Binary: false},
- })
- data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
- if !reflect.DeepEqual(data, []byte(encoded)) {
- t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
- }
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestStreamVersionedCopy(t *testing.T) {
- for i, test := range versionTests() {
- func() {
- supportedProtocols := map[string]ReaderProtocolConfig{}
- for p, binary := range test.supported {
- supportedProtocols[p] = ReaderProtocolConfig{
- Binary: binary,
- }
- }
- input := "some random text"
- r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- err := r.Copy(w, req)
- if err != nil {
- w.WriteHeader(503)
- }
- }))
- defer s.Close()
- config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
- if err != nil {
- t.Error(err)
- return
- }
- config.Protocol = test.requested
- client, err := websocket.DialConfig(config)
- if err != nil {
- if !test.error {
- t.Errorf("test %d: didn't expect error: %v", i, err)
- }
- return
- }
- defer client.Close()
- if test.error && err == nil {
- t.Errorf("test %d: expected an error", i)
- return
- }
- <-r.err
- if got, expected := r.selectedProtocol, test.expected; got != expected {
- t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
- }
- }()
- }
- }
- func TestStreamError(t *testing.T) {
- input := "some random text"
- errs := &errorReader{
- reads: [][]byte{
- []byte("some random"),
- []byte(" text"),
- },
- err: fmt.Errorf("bad read"),
- }
- r := NewReader(errs, false, NewDefaultReaderProtocols())
- data, err := readWebSocket(r, t, nil)
- if !reflect.DeepEqual(data, []byte(input)) {
- t.Errorf("unexpected server read: %v", data)
- }
- if err == nil || err.Error() != "bad read" {
- t.Fatal(err)
- }
- }
- func TestStreamSurvivesPanic(t *testing.T) {
- input := "some random text"
- errs := &errorReader{
- reads: [][]byte{
- []byte("some random"),
- []byte(" text"),
- },
- panicMessage: "bad read",
- }
- r := NewReader(errs, false, NewDefaultReaderProtocols())
- // do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
- r.handleCrash = func() { recover() }
- data, err := readWebSocket(r, t, nil)
- if !reflect.DeepEqual(data, []byte(input)) {
- t.Errorf("unexpected server read: %v", data)
- }
- if err != nil {
- t.Fatal(err)
- }
- }
- func TestStreamClosedDuringRead(t *testing.T) {
- for i := 0; i < 25; i++ {
- ch := make(chan struct{})
- input := "some random text"
- errs := &errorReader{
- reads: [][]byte{
- []byte("some random"),
- []byte(" text"),
- },
- err: fmt.Errorf("stuff"),
- pause: ch,
- }
- r := NewReader(errs, false, NewDefaultReaderProtocols())
- data, err := readWebSocket(r, t, func(c *websocket.Conn) {
- c.Close()
- close(ch)
- })
- // verify that the data returned by the server on an early close always has a specific error
- if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
- t.Fatal(err)
- }
- // verify that the data returned is a strict subset of the input
- if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
- t.Fatalf("unexpected server read: %q", string(data))
- }
- }
- }
- type errorReader struct {
- reads [][]byte
- err error
- panicMessage string
- pause chan struct{}
- }
- func (r *errorReader) Read(p []byte) (int, error) {
- if len(r.reads) == 0 {
- if r.pause != nil {
- <-r.pause
- }
- if len(r.panicMessage) != 0 {
- panic(r.panicMessage)
- }
- return 0, r.err
- }
- next := r.reads[0]
- r.reads = r.reads[1:]
- copy(p, next)
- return len(next), nil
- }
- func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
- errCh := make(chan error, 1)
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- errCh <- r.Copy(w, req)
- }))
- defer s.Close()
- config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
- config.Protocol = protocols
- client, err := websocket.DialConfig(config)
- if err != nil {
- return nil, err
- }
- defer client.Close()
- if fn != nil {
- fn(client)
- }
- data, err := ioutil.ReadAll(client)
- if err != nil {
- return data, err
- }
- return data, <-errCh
- }
- func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
- errCh := make(chan error, 1)
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- errCh <- r.Copy(w, req)
- }))
- defer s.Close()
- config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
- config.Protocol = protocols
- ws, err := websocket.DialConfig(config)
- if err != nil {
- return err
- }
- defer ws.Close()
- if fn != nil {
- fn(ws)
- }
- for i := range frames {
- var data []byte
- if err := websocket.Message.Receive(ws, &data); err != nil {
- return err
- }
- if !reflect.DeepEqual(frames[i], data) {
- return fmt.Errorf("frame %d did not match expected: %v", data, err)
- }
- }
- var data []byte
- if err := websocket.Message.Receive(ws, &data); err != io.EOF {
- return fmt.Errorf("expected no more frames: %v (%v)", err, data)
- }
- return <-errCh
- }
|