123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- /*
- Copyright 2015 The Kubernetes Authors.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package wsstream
- import (
- "encoding/base64"
- "io"
- "io/ioutil"
- "net/http"
- "net/http/httptest"
- "reflect"
- "sync"
- "testing"
- "golang.org/x/net/websocket"
- )
- func newServer(handler http.Handler) (*httptest.Server, string) {
- server := httptest.NewServer(handler)
- serverAddr := server.Listener.Addr().String()
- return server, serverAddr
- }
- func TestRawConn(t *testing.T) {
- channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
- conn := NewConn(NewDefaultChannelProtocols(channels))
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- conn.Open(w, req)
- }))
- defer s.Close()
- client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
- if err != nil {
- t.Fatal(err)
- }
- defer client.Close()
- <-conn.ready
- wg := sync.WaitGroup{}
- // verify we can read a client write
- wg.Add(1)
- go func() {
- defer wg.Done()
- data, err := ioutil.ReadAll(conn.channels[0])
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(data, []byte("client")) {
- t.Errorf("unexpected server read: %v", data)
- }
- }()
- if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
- t.Fatalf("%d: %v", n, err)
- }
- // verify we can read a server write
- wg.Add(1)
- go func() {
- defer wg.Done()
- if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
- t.Fatalf("%d: %v", n, err)
- }
- }()
- data := make([]byte, 1024)
- if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
- t.Fatalf("%d: %v", n, err)
- }
- if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
- t.Errorf("unexpected client read: %v", data[:7])
- }
- // verify that an ignore channel is empty in both directions.
- if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
- t.Errorf("writes should be ignored")
- }
- data = make([]byte, 1024)
- if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
- t.Errorf("reads should be ignored")
- }
- // verify that a write to a Read channel doesn't block
- if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
- t.Errorf("writes should be ignored")
- }
- // verify that a read from a Write channel doesn't block
- data = make([]byte, 1024)
- if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
- t.Errorf("reads should be ignored")
- }
- // verify that a client write to a Write channel doesn't block (is dropped)
- if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
- t.Fatalf("%d: %v", n, err)
- }
- client.Close()
- wg.Wait()
- }
- func TestBase64Conn(t *testing.T) {
- conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- conn.Open(w, req)
- }))
- defer s.Close()
- config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
- if err != nil {
- t.Fatal(err)
- }
- config.Protocol = []string{"base64.channel.k8s.io"}
- client, err := websocket.DialConfig(config)
- if err != nil {
- t.Fatal(err)
- }
- defer client.Close()
- <-conn.ready
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- defer wg.Done()
- data, err := ioutil.ReadAll(conn.channels[0])
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(data, []byte("client")) {
- t.Errorf("unexpected server read: %s", string(data))
- }
- }()
- clientData := base64.StdEncoding.EncodeToString([]byte("client"))
- if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
- t.Fatalf("%d: %v", n, err)
- }
- wg.Add(1)
- go func() {
- defer wg.Done()
- if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
- t.Fatalf("%d: %v", n, err)
- }
- }()
- data := make([]byte, 1024)
- if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
- t.Fatalf("%d: %v", n, err)
- }
- expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
- if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
- t.Errorf("unexpected client read: %v", data[:9])
- }
- client.Close()
- wg.Wait()
- }
- type versionTest struct {
- supported map[string]bool // protocol -> binary
- requested []string
- error bool
- expected string
- }
- func versionTests() []versionTest {
- const (
- binary = true
- base64 = false
- )
- return []versionTest{
- {
- supported: nil,
- requested: []string{"raw"},
- error: true,
- },
- {
- supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
- requested: nil,
- expected: "",
- },
- {
- supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
- requested: []string{"v1.raw"},
- error: true,
- },
- {
- supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
- requested: []string{"v1.raw", "v1.base64"},
- error: true,
- }, {
- supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
- requested: []string{"v1.raw", "raw"},
- expected: "raw",
- },
- {
- supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
- requested: []string{"v1.raw"},
- expected: "v1.raw",
- },
- {
- supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
- requested: []string{"v2.base64"},
- expected: "v2.base64",
- },
- }
- }
- func TestVersionedConn(t *testing.T) {
- for i, test := range versionTests() {
- func() {
- supportedProtocols := map[string]ChannelProtocolConfig{}
- for p, binary := range test.supported {
- supportedProtocols[p] = ChannelProtocolConfig{
- Binary: binary,
- Channels: []ChannelType{ReadWriteChannel},
- }
- }
- conn := NewConn(supportedProtocols)
- // note that it's not enough to wait for conn.ready to avoid a race here. Hence,
- // we use a channel.
- selectedProtocol := make(chan string, 0)
- s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
- p, _, _ := conn.Open(w, req)
- selectedProtocol <- p
- }))
- defer s.Close()
- config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
- if err != nil {
- t.Fatal(err)
- }
- config.Protocol = test.requested
- client, err := websocket.DialConfig(config)
- if err != nil {
- if !test.error {
- t.Fatalf("test %d: didn't expect error: %v", i, err)
- } else {
- return
- }
- }
- defer client.Close()
- if test.error && err == nil {
- t.Fatalf("test %d: expected an error", i)
- }
- <-conn.ready
- if got, expected := <-selectedProtocol, test.expected; got != expected {
- t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
- }
- }()
- }
- }
|