ssh_test.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ssh
  14. import (
  15. "fmt"
  16. "io"
  17. "net"
  18. "os"
  19. "reflect"
  20. "strings"
  21. "testing"
  22. "time"
  23. "k8s.io/kubernetes/pkg/util/wait"
  24. "github.com/golang/glog"
  25. "golang.org/x/crypto/ssh"
  26. )
  27. type testSSHServer struct {
  28. Host string
  29. Port string
  30. Type string
  31. Data []byte
  32. PrivateKey []byte
  33. PublicKey []byte
  34. }
  35. func runTestSSHServer(user, password string) (*testSSHServer, error) {
  36. result := &testSSHServer{}
  37. // Largely derived from https://godoc.org/golang.org/x/crypto/ssh#example-NewServerConn
  38. config := &ssh.ServerConfig{
  39. PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
  40. if c.User() == user && string(pass) == password {
  41. return nil, nil
  42. }
  43. return nil, fmt.Errorf("password rejected for %s", c.User())
  44. },
  45. PublicKeyCallback: func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
  46. result.Type = key.Type()
  47. result.Data = ssh.MarshalAuthorizedKey(key)
  48. return nil, nil
  49. },
  50. }
  51. privateKey, publicKey, err := GenerateKey(2048)
  52. if err != nil {
  53. return nil, err
  54. }
  55. privateBytes := EncodePrivateKey(privateKey)
  56. signer, err := ssh.ParsePrivateKey(privateBytes)
  57. if err != nil {
  58. return nil, err
  59. }
  60. config.AddHostKey(signer)
  61. result.PrivateKey = privateBytes
  62. publicBytes, err := EncodePublicKey(publicKey)
  63. if err != nil {
  64. return nil, err
  65. }
  66. result.PublicKey = publicBytes
  67. listener, err := net.Listen("tcp", "127.0.0.1:0")
  68. if err != nil {
  69. return nil, err
  70. }
  71. host, port, err := net.SplitHostPort(listener.Addr().String())
  72. if err != nil {
  73. return nil, err
  74. }
  75. result.Host = host
  76. result.Port = port
  77. go func() {
  78. // TODO: return this port.
  79. defer listener.Close()
  80. conn, err := listener.Accept()
  81. if err != nil {
  82. glog.Errorf("Failed to accept: %v", err)
  83. }
  84. _, chans, reqs, err := ssh.NewServerConn(conn, config)
  85. if err != nil {
  86. glog.Errorf("Failed handshake: %v", err)
  87. }
  88. go ssh.DiscardRequests(reqs)
  89. for newChannel := range chans {
  90. if newChannel.ChannelType() != "direct-tcpip" {
  91. newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
  92. continue
  93. }
  94. channel, requests, err := newChannel.Accept()
  95. if err != nil {
  96. glog.Errorf("Failed to accept channel: %v", err)
  97. }
  98. for req := range requests {
  99. glog.Infof("Got request: %v", req)
  100. }
  101. channel.Close()
  102. }
  103. }()
  104. return result, nil
  105. }
  106. func TestSSHTunnel(t *testing.T) {
  107. private, public, err := GenerateKey(2048)
  108. if err != nil {
  109. t.Errorf("unexpected error: %v", err)
  110. t.FailNow()
  111. }
  112. server, err := runTestSSHServer("foo", "bar")
  113. if err != nil {
  114. t.Errorf("unexpected error: %v", err)
  115. t.FailNow()
  116. }
  117. privateData := EncodePrivateKey(private)
  118. tunnel, err := NewSSHTunnelFromBytes("foo", privateData, server.Host)
  119. if err != nil {
  120. t.Errorf("unexpected error: %v", err)
  121. t.FailNow()
  122. }
  123. tunnel.SSHPort = server.Port
  124. if err := tunnel.Open(); err != nil {
  125. t.Errorf("unexpected error: %v", err)
  126. t.FailNow()
  127. }
  128. _, err = tunnel.Dial("tcp", "127.0.0.1:8080")
  129. if err != nil {
  130. t.Errorf("unexpected error: %v", err)
  131. }
  132. if server.Type != "ssh-rsa" {
  133. t.Errorf("expected %s, got %s", "ssh-rsa", server.Type)
  134. }
  135. publicData, err := EncodeSSHKey(public)
  136. if err != nil {
  137. t.Errorf("unexpected error: %v", err)
  138. }
  139. if !reflect.DeepEqual(server.Data, publicData) {
  140. t.Errorf("expected %s, got %s", string(server.Data), string(privateData))
  141. }
  142. if err := tunnel.Close(); err != nil {
  143. t.Errorf("unexpected error: %v", err)
  144. }
  145. }
  146. type fakeTunnel struct{}
  147. func (*fakeTunnel) Open() error {
  148. return nil
  149. }
  150. func (*fakeTunnel) Close() error {
  151. return nil
  152. }
  153. func (*fakeTunnel) Dial(network, address string) (net.Conn, error) {
  154. return nil, nil
  155. }
  156. type fakeTunnelCreator struct{}
  157. func (*fakeTunnelCreator) NewSSHTunnel(string, string, string) (tunnel, error) {
  158. return &fakeTunnel{}, nil
  159. }
  160. func TestSSHTunnelListUpdate(t *testing.T) {
  161. // Start with an empty tunnel list.
  162. l := &SSHTunnelList{
  163. adding: make(map[string]bool),
  164. tunnelCreator: &fakeTunnelCreator{},
  165. }
  166. // Start with 2 tunnels.
  167. addressStrings := []string{"1.2.3.4", "5.6.7.8"}
  168. l.Update(addressStrings)
  169. checkTunnelsCorrect(t, l, addressStrings)
  170. // Add another tunnel.
  171. addressStrings = append(addressStrings, "9.10.11.12")
  172. l.Update(addressStrings)
  173. checkTunnelsCorrect(t, l, addressStrings)
  174. // Go down to a single tunnel.
  175. addressStrings = []string{"1.2.3.4"}
  176. l.Update(addressStrings)
  177. checkTunnelsCorrect(t, l, addressStrings)
  178. // Replace w/ all new tunnels.
  179. addressStrings = []string{"21.22.23.24", "25.26.27.28"}
  180. l.Update(addressStrings)
  181. checkTunnelsCorrect(t, l, addressStrings)
  182. // Call update with the same tunnels.
  183. l.Update(addressStrings)
  184. checkTunnelsCorrect(t, l, addressStrings)
  185. }
  186. func checkTunnelsCorrect(t *testing.T, tunnelList *SSHTunnelList, addresses []string) {
  187. if err := wait.Poll(100*time.Millisecond, 2*time.Second, func() (bool, error) {
  188. return hasCorrectTunnels(tunnelList, addresses), nil
  189. }); err != nil {
  190. t.Errorf("Error waiting for tunnels to reach expected state: %v. Expected %v, had %v", err, addresses, tunnelList)
  191. }
  192. }
  193. func hasCorrectTunnels(tunnelList *SSHTunnelList, addresses []string) bool {
  194. tunnelList.tunnelsLock.Lock()
  195. defer tunnelList.tunnelsLock.Unlock()
  196. wantMap := make(map[string]bool)
  197. for _, addr := range addresses {
  198. wantMap[addr] = true
  199. }
  200. haveMap := make(map[string]bool)
  201. for _, entry := range tunnelList.entries {
  202. if wantMap[entry.Address] == false {
  203. return false
  204. }
  205. haveMap[entry.Address] = true
  206. }
  207. for _, addr := range addresses {
  208. if haveMap[addr] == false {
  209. return false
  210. }
  211. }
  212. return true
  213. }
  214. type mockSSHDialer struct {
  215. network string
  216. addr string
  217. config *ssh.ClientConfig
  218. }
  219. func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  220. d.network = network
  221. d.addr = addr
  222. d.config = config
  223. return nil, fmt.Errorf("mock error from Dial")
  224. }
  225. type mockSigner struct {
  226. }
  227. func (s *mockSigner) PublicKey() ssh.PublicKey {
  228. panic("mockSigner.PublicKey not implemented")
  229. }
  230. func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
  231. panic("mockSigner.Sign not implemented")
  232. }
  233. func TestSSHUser(t *testing.T) {
  234. signer := &mockSigner{}
  235. table := []struct {
  236. title string
  237. user string
  238. host string
  239. signer ssh.Signer
  240. command string
  241. expectUser string
  242. }{
  243. {
  244. title: "all values provided",
  245. user: "testuser",
  246. host: "testhost",
  247. signer: signer,
  248. command: "uptime",
  249. expectUser: "testuser",
  250. },
  251. {
  252. title: "empty user defaults to GetEnv(USER)",
  253. user: "",
  254. host: "testhost",
  255. signer: signer,
  256. command: "uptime",
  257. expectUser: os.Getenv("USER"),
  258. },
  259. }
  260. for _, item := range table {
  261. dialer := &mockSSHDialer{}
  262. _, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer, false)
  263. if err == nil {
  264. t.Errorf("expected error (as mock returns error); did not get one")
  265. }
  266. errString := err.Error()
  267. if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) {
  268. t.Errorf("unexpected error: %v", errString)
  269. }
  270. if dialer.network != "tcp" {
  271. t.Errorf("unexpected network: %v", dialer.network)
  272. }
  273. if dialer.config.User != item.expectUser {
  274. t.Errorf("unexpected user: %v", dialer.config.User)
  275. }
  276. if len(dialer.config.Auth) != 1 {
  277. t.Errorf("unexpected auth: %v", dialer.config.Auth)
  278. }
  279. // (No way to test Auth - nothing exported?)
  280. }
  281. }
  282. type slowDialer struct {
  283. delay time.Duration
  284. err error
  285. }
  286. func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  287. time.Sleep(s.delay)
  288. if s.err != nil {
  289. return nil, s.err
  290. }
  291. return &ssh.Client{}, nil
  292. }
  293. func TestTimeoutDialer(t *testing.T) {
  294. testCases := []struct {
  295. delay time.Duration
  296. timeout time.Duration
  297. err error
  298. expectedErrString string
  299. }{
  300. // delay > timeout should cause ssh.Dial to timeout.
  301. {1 * time.Second, 0, nil, "timed out dialing"},
  302. // delay < timeout should return the result of the call to the dialer.
  303. {0, 1 * time.Second, nil, ""},
  304. {0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"},
  305. }
  306. for _, tc := range testCases {
  307. dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout}
  308. _, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{})
  309. if len(tc.expectedErrString) == 0 && err != nil ||
  310. !strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
  311. t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
  312. }
  313. }
  314. }