ssh.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. /*
  2. Copyright 2015 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ssh
  14. import (
  15. "bytes"
  16. "crypto/rand"
  17. "crypto/rsa"
  18. "crypto/tls"
  19. "crypto/x509"
  20. "encoding/pem"
  21. "errors"
  22. "fmt"
  23. "io"
  24. "io/ioutil"
  25. mathrand "math/rand"
  26. "net"
  27. "net/http"
  28. "net/url"
  29. "os"
  30. "sync"
  31. "time"
  32. "github.com/golang/glog"
  33. "github.com/prometheus/client_golang/prometheus"
  34. "golang.org/x/crypto/ssh"
  35. utilnet "k8s.io/kubernetes/pkg/util/net"
  36. "k8s.io/kubernetes/pkg/util/runtime"
  37. "k8s.io/kubernetes/pkg/util/wait"
  38. )
  39. var (
  40. tunnelOpenCounter = prometheus.NewCounter(
  41. prometheus.CounterOpts{
  42. Name: "ssh_tunnel_open_count",
  43. Help: "Counter of ssh tunnel total open attempts",
  44. },
  45. )
  46. tunnelOpenFailCounter = prometheus.NewCounter(
  47. prometheus.CounterOpts{
  48. Name: "ssh_tunnel_open_fail_count",
  49. Help: "Counter of ssh tunnel failed open attempts",
  50. },
  51. )
  52. )
  53. func init() {
  54. prometheus.MustRegister(tunnelOpenCounter)
  55. prometheus.MustRegister(tunnelOpenFailCounter)
  56. }
  57. // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here:
  58. // https://godoc.org/golang.org/x/crypto/ssh#ServerConn
  59. type SSHTunnel struct {
  60. Config *ssh.ClientConfig
  61. Host string
  62. SSHPort string
  63. running bool
  64. sock net.Listener
  65. client *ssh.Client
  66. }
  67. func (s *SSHTunnel) copyBytes(out io.Writer, in io.Reader) {
  68. if _, err := io.Copy(out, in); err != nil {
  69. glog.Errorf("Error in SSH tunnel: %v", err)
  70. }
  71. }
  72. func NewSSHTunnel(user, keyfile, host string) (*SSHTunnel, error) {
  73. signer, err := MakePrivateKeySignerFromFile(keyfile)
  74. if err != nil {
  75. return nil, err
  76. }
  77. return makeSSHTunnel(user, signer, host)
  78. }
  79. func NewSSHTunnelFromBytes(user string, privateKey []byte, host string) (*SSHTunnel, error) {
  80. signer, err := MakePrivateKeySignerFromBytes(privateKey)
  81. if err != nil {
  82. return nil, err
  83. }
  84. return makeSSHTunnel(user, signer, host)
  85. }
  86. func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, error) {
  87. config := ssh.ClientConfig{
  88. User: user,
  89. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  90. }
  91. return &SSHTunnel{
  92. Config: &config,
  93. Host: host,
  94. SSHPort: "22",
  95. }, nil
  96. }
  97. func (s *SSHTunnel) Open() error {
  98. var err error
  99. s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
  100. tunnelOpenCounter.Inc()
  101. if err != nil {
  102. tunnelOpenFailCounter.Inc()
  103. }
  104. return err
  105. }
  106. func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) {
  107. if s.client == nil {
  108. return nil, errors.New("tunnel is not opened.")
  109. }
  110. return s.client.Dial(network, address)
  111. }
  112. func (s *SSHTunnel) tunnel(conn net.Conn, remoteHost, remotePort string) error {
  113. if s.client == nil {
  114. return errors.New("tunnel is not opened.")
  115. }
  116. tunnel, err := s.client.Dial("tcp", net.JoinHostPort(remoteHost, remotePort))
  117. if err != nil {
  118. return err
  119. }
  120. go s.copyBytes(tunnel, conn)
  121. go s.copyBytes(conn, tunnel)
  122. return nil
  123. }
  124. func (s *SSHTunnel) Close() error {
  125. if s.client == nil {
  126. return errors.New("Cannot close tunnel. Tunnel was not opened.")
  127. }
  128. if err := s.client.Close(); err != nil {
  129. return err
  130. }
  131. return nil
  132. }
  133. // Interface to allow mocking of ssh.Dial, for testing SSH
  134. type sshDialer interface {
  135. Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
  136. }
  137. // Real implementation of sshDialer
  138. type realSSHDialer struct{}
  139. var _ sshDialer = &realSSHDialer{}
  140. func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  141. return ssh.Dial(network, addr, config)
  142. }
  143. // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
  144. // ssh library can hang indefinitely inside the Dial() call (see issue #23835).
  145. // Wrapping all Dial() calls with a conservative timeout provides safety against
  146. // getting stuck on that.
  147. type timeoutDialer struct {
  148. dialer sshDialer
  149. timeout time.Duration
  150. }
  151. // 150 seconds is longer than the underlying default TCP backoff delay (127
  152. // seconds). This timeout is only intended to catch otherwise uncaught hangs.
  153. const sshDialTimeout = 150 * time.Second
  154. var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
  155. func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
  156. var client *ssh.Client
  157. errCh := make(chan error, 1)
  158. go func() {
  159. defer runtime.HandleCrash()
  160. var err error
  161. client, err = d.dialer.Dial(network, addr, config)
  162. errCh <- err
  163. }()
  164. select {
  165. case err := <-errCh:
  166. return client, err
  167. case <-time.After(d.timeout):
  168. return nil, fmt.Errorf("timed out dialing %s:%s", network, addr)
  169. }
  170. }
  171. // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
  172. // host as specific user, along with any SSH-level error.
  173. // If user=="", it will default (like SSH) to os.Getenv("USER")
  174. func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
  175. return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
  176. }
  177. // Internal implementation of runSSHCommand, for testing
  178. func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer, retry bool) (string, string, int, error) {
  179. if user == "" {
  180. user = os.Getenv("USER")
  181. }
  182. // Setup the config, dial the server, and open a session.
  183. config := &ssh.ClientConfig{
  184. User: user,
  185. Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
  186. }
  187. client, err := dialer.Dial("tcp", host, config)
  188. if err != nil && retry {
  189. err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) {
  190. fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err)
  191. if client, err = dialer.Dial("tcp", host, config); err != nil {
  192. return false, nil
  193. }
  194. return true, nil
  195. })
  196. }
  197. if err != nil {
  198. return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
  199. }
  200. session, err := client.NewSession()
  201. if err != nil {
  202. return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
  203. }
  204. defer session.Close()
  205. // Run the command.
  206. code := 0
  207. var bout, berr bytes.Buffer
  208. session.Stdout, session.Stderr = &bout, &berr
  209. if err = session.Run(cmd); err != nil {
  210. // Check whether the command failed to run or didn't complete.
  211. if exiterr, ok := err.(*ssh.ExitError); ok {
  212. // If we got an ExitError and the exit code is nonzero, we'll
  213. // consider the SSH itself successful (just that the command run
  214. // errored on the host).
  215. if code = exiterr.ExitStatus(); code != 0 {
  216. err = nil
  217. }
  218. } else {
  219. // Some other kind of error happened (e.g. an IOError); consider the
  220. // SSH unsuccessful.
  221. err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
  222. }
  223. }
  224. return bout.String(), berr.String(), code, err
  225. }
  226. func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
  227. // Create an actual signer.
  228. buffer, err := ioutil.ReadFile(key)
  229. if err != nil {
  230. return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err)
  231. }
  232. return MakePrivateKeySignerFromBytes(buffer)
  233. }
  234. func MakePrivateKeySignerFromBytes(buffer []byte) (ssh.Signer, error) {
  235. signer, err := ssh.ParsePrivateKey(buffer)
  236. if err != nil {
  237. return nil, fmt.Errorf("error parsing SSH key: '%v'", err)
  238. }
  239. return signer, nil
  240. }
  241. func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
  242. buffer, err := ioutil.ReadFile(keyFile)
  243. if err != nil {
  244. return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err)
  245. }
  246. keyBlock, _ := pem.Decode(buffer)
  247. key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes)
  248. if err != nil {
  249. return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err)
  250. }
  251. rsaKey, ok := key.(*rsa.PublicKey)
  252. if !ok {
  253. return nil, fmt.Errorf("SSH key could not be parsed as rsa public key")
  254. }
  255. return rsaKey, nil
  256. }
  257. type tunnel interface {
  258. Open() error
  259. Close() error
  260. Dial(network, address string) (net.Conn, error)
  261. }
  262. type sshTunnelEntry struct {
  263. Address string
  264. Tunnel tunnel
  265. }
  266. type sshTunnelCreator interface {
  267. NewSSHTunnel(user, keyFile, healthCheckURL string) (tunnel, error)
  268. }
  269. type realTunnelCreator struct{}
  270. func (*realTunnelCreator) NewSSHTunnel(user, keyFile, healthCheckURL string) (tunnel, error) {
  271. return NewSSHTunnel(user, keyFile, healthCheckURL)
  272. }
  273. type SSHTunnelList struct {
  274. entries []sshTunnelEntry
  275. adding map[string]bool
  276. tunnelCreator sshTunnelCreator
  277. tunnelsLock sync.Mutex
  278. user string
  279. keyfile string
  280. healthCheckURL *url.URL
  281. }
  282. func NewSSHTunnelList(user, keyfile string, healthCheckURL *url.URL, stopChan chan struct{}) *SSHTunnelList {
  283. l := &SSHTunnelList{
  284. adding: make(map[string]bool),
  285. tunnelCreator: &realTunnelCreator{},
  286. user: user,
  287. keyfile: keyfile,
  288. healthCheckURL: healthCheckURL,
  289. }
  290. healthCheckPoll := 1 * time.Minute
  291. go wait.Until(func() {
  292. l.tunnelsLock.Lock()
  293. defer l.tunnelsLock.Unlock()
  294. // Healthcheck each tunnel every minute
  295. numTunnels := len(l.entries)
  296. for i, entry := range l.entries {
  297. // Stagger healthchecks evenly across duration of healthCheckPoll.
  298. delay := healthCheckPoll * time.Duration(i) / time.Duration(numTunnels)
  299. l.delayedHealthCheck(entry, delay)
  300. }
  301. }, healthCheckPoll, stopChan)
  302. return l
  303. }
  304. func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration) {
  305. go func() {
  306. defer runtime.HandleCrash()
  307. time.Sleep(delay)
  308. if err := l.healthCheck(e); err != nil {
  309. glog.Errorf("Healthcheck failed for tunnel to %q: %v", e.Address, err)
  310. glog.Infof("Attempting once to re-establish tunnel to %q", e.Address)
  311. l.removeAndReAdd(e)
  312. }
  313. }()
  314. }
  315. func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
  316. // GET the healthcheck path using the provided tunnel's dial function.
  317. transport := utilnet.SetTransportDefaults(&http.Transport{
  318. Dial: e.Tunnel.Dial,
  319. // TODO(cjcullen): Plumb real TLS options through.
  320. TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
  321. })
  322. client := &http.Client{Transport: transport}
  323. resp, err := client.Get(l.healthCheckURL.String())
  324. if err != nil {
  325. return err
  326. }
  327. resp.Body.Close()
  328. return nil
  329. }
  330. func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
  331. // Find the entry to replace.
  332. l.tunnelsLock.Lock()
  333. defer l.tunnelsLock.Unlock()
  334. for i, entry := range l.entries {
  335. if entry.Tunnel == e.Tunnel {
  336. l.entries = append(l.entries[:i], l.entries[i+1:]...)
  337. l.adding[e.Address] = true
  338. go l.createAndAddTunnel(e.Address)
  339. return
  340. }
  341. }
  342. }
  343. func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
  344. start := time.Now()
  345. id := mathrand.Int63() // So you can match begins/ends in the log.
  346. glog.Infof("[%x: %v] Dialing...", id, addr)
  347. defer func() {
  348. glog.Infof("[%x: %v] Dialed in %v.", id, addr, time.Now().Sub(start))
  349. }()
  350. tunnel, err := l.pickRandomTunnel()
  351. if err != nil {
  352. return nil, err
  353. }
  354. return tunnel.Dial(net, addr)
  355. }
  356. func (l *SSHTunnelList) pickRandomTunnel() (tunnel, error) {
  357. l.tunnelsLock.Lock()
  358. defer l.tunnelsLock.Unlock()
  359. if len(l.entries) == 0 {
  360. return nil, fmt.Errorf("No SSH tunnels currently open. Were the targets able to accept an ssh-key for user %q?", l.user)
  361. }
  362. n := mathrand.Intn(len(l.entries))
  363. return l.entries[n].Tunnel, nil
  364. }
  365. // Update reconciles the list's entries with the specified addresses. Existing
  366. // tunnels that are not in addresses are removed from entries and closed in a
  367. // background goroutine. New tunnels specified in addresses are opened in a
  368. // background goroutine and then added to entries.
  369. func (l *SSHTunnelList) Update(addrs []string) {
  370. haveAddrsMap := make(map[string]bool)
  371. wantAddrsMap := make(map[string]bool)
  372. func() {
  373. l.tunnelsLock.Lock()
  374. defer l.tunnelsLock.Unlock()
  375. // Build a map of what we currently have.
  376. for i := range l.entries {
  377. haveAddrsMap[l.entries[i].Address] = true
  378. }
  379. // Determine any necessary additions.
  380. for i := range addrs {
  381. // Add tunnel if it is not in l.entries or l.adding
  382. if _, ok := haveAddrsMap[addrs[i]]; !ok {
  383. if _, ok := l.adding[addrs[i]]; !ok {
  384. l.adding[addrs[i]] = true
  385. addr := addrs[i]
  386. go func() {
  387. defer runtime.HandleCrash()
  388. // Actually adding tunnel to list will block until lock
  389. // is released after deletions.
  390. l.createAndAddTunnel(addr)
  391. }()
  392. }
  393. }
  394. wantAddrsMap[addrs[i]] = true
  395. }
  396. // Determine any necessary deletions.
  397. var newEntries []sshTunnelEntry
  398. for i := range l.entries {
  399. if _, ok := wantAddrsMap[l.entries[i].Address]; !ok {
  400. tunnelEntry := l.entries[i]
  401. glog.Infof("Removing tunnel to deleted node at %q", tunnelEntry.Address)
  402. go func() {
  403. defer runtime.HandleCrash()
  404. if err := tunnelEntry.Tunnel.Close(); err != nil {
  405. glog.Errorf("Failed to close tunnel to %q: %v", tunnelEntry.Address, err)
  406. }
  407. }()
  408. } else {
  409. newEntries = append(newEntries, l.entries[i])
  410. }
  411. }
  412. l.entries = newEntries
  413. }()
  414. }
  415. func (l *SSHTunnelList) createAndAddTunnel(addr string) {
  416. glog.Infof("Trying to add tunnel to %q", addr)
  417. tunnel, err := l.tunnelCreator.NewSSHTunnel(l.user, l.keyfile, addr)
  418. if err != nil {
  419. glog.Errorf("Failed to create tunnel for %q: %v", addr, err)
  420. return
  421. }
  422. if err := tunnel.Open(); err != nil {
  423. glog.Errorf("Failed to open tunnel to %q: %v", addr, err)
  424. l.tunnelsLock.Lock()
  425. delete(l.adding, addr)
  426. l.tunnelsLock.Unlock()
  427. return
  428. }
  429. l.tunnelsLock.Lock()
  430. l.entries = append(l.entries, sshTunnelEntry{addr, tunnel})
  431. delete(l.adding, addr)
  432. l.tunnelsLock.Unlock()
  433. glog.Infof("Successfully added tunnel for %q", addr)
  434. }
  435. func EncodePrivateKey(private *rsa.PrivateKey) []byte {
  436. return pem.EncodeToMemory(&pem.Block{
  437. Bytes: x509.MarshalPKCS1PrivateKey(private),
  438. Type: "RSA PRIVATE KEY",
  439. })
  440. }
  441. func EncodePublicKey(public *rsa.PublicKey) ([]byte, error) {
  442. publicBytes, err := x509.MarshalPKIXPublicKey(public)
  443. if err != nil {
  444. return nil, err
  445. }
  446. return pem.EncodeToMemory(&pem.Block{
  447. Bytes: publicBytes,
  448. Type: "PUBLIC KEY",
  449. }), nil
  450. }
  451. func EncodeSSHKey(public *rsa.PublicKey) ([]byte, error) {
  452. publicKey, err := ssh.NewPublicKey(public)
  453. if err != nil {
  454. return nil, err
  455. }
  456. return ssh.MarshalAuthorizedKey(publicKey), nil
  457. }
  458. func GenerateKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
  459. private, err := rsa.GenerateKey(rand.Reader, bits)
  460. if err != nil {
  461. return nil, nil, err
  462. }
  463. return private, &private.PublicKey, nil
  464. }