remotecommand_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package remotecommand
  14. import (
  15. "bytes"
  16. "errors"
  17. "fmt"
  18. "io"
  19. "io/ioutil"
  20. "net/http"
  21. "net/http/httptest"
  22. "net/url"
  23. "strings"
  24. "testing"
  25. "time"
  26. "k8s.io/kubernetes/pkg/api"
  27. "k8s.io/kubernetes/pkg/api/testapi"
  28. "k8s.io/kubernetes/pkg/api/unversioned"
  29. "k8s.io/kubernetes/pkg/client/restclient"
  30. "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
  31. "k8s.io/kubernetes/pkg/types"
  32. "k8s.io/kubernetes/pkg/util/httpstream"
  33. "k8s.io/kubernetes/pkg/util/term"
  34. )
  35. type fakeExecutor struct {
  36. t *testing.T
  37. testName string
  38. errorData string
  39. stdoutData string
  40. stderrData string
  41. expectStdin bool
  42. stdinReceived bytes.Buffer
  43. tty bool
  44. messageCount int
  45. command []string
  46. exec bool
  47. }
  48. func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error {
  49. return ex.run(name, uid, container, cmd, in, out, err, tty)
  50. }
  51. func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error {
  52. return ex.run(name, uid, container, nil, in, out, err, tty)
  53. }
  54. func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
  55. ex.command = cmd
  56. ex.tty = tty
  57. if e, a := "pod", name; e != a {
  58. ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
  59. }
  60. if e, a := "uid", uid; e != string(a) {
  61. ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
  62. }
  63. if ex.exec {
  64. if e, a := "ls /", strings.Join(ex.command, " "); e != a {
  65. ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
  66. }
  67. } else {
  68. if len(ex.command) > 0 {
  69. ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
  70. }
  71. }
  72. if len(ex.errorData) > 0 {
  73. return errors.New(ex.errorData)
  74. }
  75. if len(ex.stdoutData) > 0 {
  76. for i := 0; i < ex.messageCount; i++ {
  77. fmt.Fprint(out, ex.stdoutData)
  78. }
  79. }
  80. if len(ex.stderrData) > 0 {
  81. for i := 0; i < ex.messageCount; i++ {
  82. fmt.Fprint(err, ex.stderrData)
  83. }
  84. }
  85. if ex.expectStdin {
  86. io.Copy(&ex.stdinReceived, in)
  87. }
  88. return nil
  89. }
  90. func fakeServer(t *testing.T, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
  91. return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  92. executor := &fakeExecutor{
  93. t: t,
  94. testName: testName,
  95. errorData: errorData,
  96. stdoutData: stdoutData,
  97. stderrData: stderrData,
  98. expectStdin: len(stdinData) > 0,
  99. tty: tty,
  100. messageCount: messageCount,
  101. exec: exec,
  102. }
  103. if exec {
  104. remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
  105. } else {
  106. remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
  107. }
  108. if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
  109. t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
  110. }
  111. })
  112. }
  113. func TestStream(t *testing.T) {
  114. testCases := []struct {
  115. TestName string
  116. Stdin string
  117. Stdout string
  118. Stderr string
  119. Error string
  120. Tty bool
  121. MessageCount int
  122. ClientProtocols []string
  123. ServerProtocols []string
  124. }{
  125. {
  126. TestName: "error",
  127. Error: "bail",
  128. Stdout: "a",
  129. ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
  130. ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
  131. },
  132. {
  133. TestName: "in/out/err",
  134. Stdin: "a",
  135. Stdout: "b",
  136. Stderr: "c",
  137. MessageCount: 100,
  138. ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
  139. ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
  140. },
  141. {
  142. TestName: "in/out/tty",
  143. Stdin: "a",
  144. Stdout: "b",
  145. Tty: true,
  146. MessageCount: 100,
  147. ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
  148. ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
  149. },
  150. {
  151. // 1.0 kubectl, 1.0 kubelet
  152. TestName: "unversioned client, unversioned server",
  153. Stdout: "b",
  154. Stderr: "c",
  155. MessageCount: 1,
  156. ClientProtocols: []string{},
  157. ServerProtocols: []string{},
  158. },
  159. {
  160. // 1.0 kubectl, 1.1+ kubelet
  161. TestName: "unversioned client, versioned server",
  162. Stdout: "b",
  163. Stderr: "c",
  164. MessageCount: 1,
  165. ClientProtocols: []string{},
  166. ServerProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
  167. },
  168. {
  169. // 1.1+ kubectl, 1.0 kubelet
  170. TestName: "versioned client, unversioned server",
  171. Stdout: "b",
  172. Stderr: "c",
  173. MessageCount: 1,
  174. ClientProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
  175. ServerProtocols: []string{},
  176. },
  177. }
  178. for _, testCase := range testCases {
  179. for _, exec := range []bool{true, false} {
  180. var name string
  181. if exec {
  182. name = testCase.TestName + " (exec)"
  183. } else {
  184. name = testCase.TestName + " (attach)"
  185. }
  186. var (
  187. streamIn io.Reader
  188. streamOut, streamErr io.Writer
  189. )
  190. localOut := &bytes.Buffer{}
  191. localErr := &bytes.Buffer{}
  192. server := httptest.NewServer(fakeServer(t, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
  193. url, _ := url.ParseRequestURI(server.URL)
  194. config := restclient.ContentConfig{
  195. GroupVersion: &unversioned.GroupVersion{Group: "x"},
  196. NegotiatedSerializer: testapi.Default.NegotiatedSerializer(),
  197. }
  198. c, err := restclient.NewRESTClient(url, "", config, -1, -1, nil, nil)
  199. if err != nil {
  200. t.Fatalf("failed to create a client: %v", err)
  201. }
  202. req := c.Post().Resource("testing")
  203. if exec {
  204. req.Param("command", "ls")
  205. req.Param("command", "/")
  206. }
  207. if len(testCase.Stdin) > 0 {
  208. req.Param(api.ExecStdinParam, "1")
  209. streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
  210. }
  211. if len(testCase.Stdout) > 0 {
  212. req.Param(api.ExecStdoutParam, "1")
  213. streamOut = localOut
  214. }
  215. if testCase.Tty {
  216. req.Param(api.ExecTTYParam, "1")
  217. } else if len(testCase.Stderr) > 0 {
  218. req.Param(api.ExecStderrParam, "1")
  219. streamErr = localErr
  220. }
  221. conf := &restclient.Config{
  222. Host: server.URL,
  223. }
  224. e, err := NewExecutor(conf, "POST", req.URL())
  225. if err != nil {
  226. t.Errorf("%s: unexpected error: %v", name, err)
  227. continue
  228. }
  229. err = e.Stream(StreamOptions{
  230. SupportedProtocols: testCase.ClientProtocols,
  231. Stdin: streamIn,
  232. Stdout: streamOut,
  233. Stderr: streamErr,
  234. Tty: testCase.Tty,
  235. })
  236. hasErr := err != nil
  237. if len(testCase.Error) > 0 {
  238. if !hasErr {
  239. t.Errorf("%s: expected an error", name)
  240. } else {
  241. if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
  242. t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
  243. }
  244. }
  245. server.Close()
  246. continue
  247. }
  248. if hasErr {
  249. t.Errorf("%s: unexpected error: %v", name, err)
  250. server.Close()
  251. continue
  252. }
  253. if len(testCase.Stdout) > 0 {
  254. if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
  255. t.Errorf("%s: expected stdout data %q, got %q", name, e, a)
  256. }
  257. }
  258. if testCase.Stderr != "" {
  259. if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
  260. t.Errorf("%s: expected stderr data %q, got %q", name, e, a)
  261. }
  262. }
  263. server.Close()
  264. }
  265. }
  266. }
  267. type fakeUpgrader struct {
  268. req *http.Request
  269. resp *http.Response
  270. conn httpstream.Connection
  271. err, connErr error
  272. checkResponse bool
  273. t *testing.T
  274. }
  275. func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) {
  276. u.req = req
  277. return u.resp, u.err
  278. }
  279. func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) {
  280. if u.checkResponse && u.resp != resp {
  281. u.t.Errorf("response objects passed did not match: %#v", resp)
  282. }
  283. return u.conn, u.connErr
  284. }
  285. type fakeConnection struct {
  286. httpstream.Connection
  287. }
  288. // Dial is the common functionality between any stream based upgrader, regardless of protocol.
  289. // This method ensures that someone can use a generic stream executor without being dependent
  290. // on the core Kube client config behavior.
  291. func TestDial(t *testing.T) {
  292. upgrader := &fakeUpgrader{
  293. t: t,
  294. checkResponse: true,
  295. conn: &fakeConnection{},
  296. resp: &http.Response{
  297. StatusCode: http.StatusSwitchingProtocols,
  298. Body: ioutil.NopCloser(&bytes.Buffer{}),
  299. },
  300. }
  301. var called bool
  302. testFn := func(rt http.RoundTripper) http.RoundTripper {
  303. if rt != upgrader {
  304. t.Fatalf("unexpected round tripper: %#v", rt)
  305. }
  306. called = true
  307. return rt
  308. }
  309. exec, err := NewStreamExecutor(upgrader, testFn, "POST", &url.URL{Host: "something.com", Scheme: "https"})
  310. if err != nil {
  311. t.Fatal(err)
  312. }
  313. conn, protocol, err := exec.Dial("protocol1")
  314. if err != nil {
  315. t.Fatal(err)
  316. }
  317. if conn != upgrader.conn {
  318. t.Errorf("unexpected connection: %#v", conn)
  319. }
  320. if !called {
  321. t.Errorf("wrapper not called")
  322. }
  323. _ = protocol
  324. }