ssh.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. // Copyright (c) 2017 Gorillalabs. All rights reserved.
  2. package backend
  3. import (
  4. "fmt"
  5. "io"
  6. "regexp"
  7. "strings"
  8. "github.com/juju/errors"
  9. )
  10. // sshSession exists so we don't create a hard dependency on crypto/ssh.
  11. type sshSession interface {
  12. Waiter
  13. StdinPipe() (io.WriteCloser, error)
  14. StdoutPipe() (io.Reader, error)
  15. StderrPipe() (io.Reader, error)
  16. Start(string) error
  17. }
  18. type SSH struct {
  19. Session sshSession
  20. }
  21. func (b *SSH) StartProcess(cmd string, args ...string) (Waiter, io.Writer, io.Reader, io.Reader, error) {
  22. stdin, err := b.Session.StdinPipe()
  23. if err != nil {
  24. return nil, nil, nil, nil, errors.Annotate(err, "Could not get hold of the SSH session's stdin stream")
  25. }
  26. stdout, err := b.Session.StdoutPipe()
  27. if err != nil {
  28. return nil, nil, nil, nil, errors.Annotate(err, "Could not get hold of the SSH session's stdout stream")
  29. }
  30. stderr, err := b.Session.StderrPipe()
  31. if err != nil {
  32. return nil, nil, nil, nil, errors.Annotate(err, "Could not get hold of the SSH session's stderr stream")
  33. }
  34. err = b.Session.Start(b.createCmd(cmd, args))
  35. if err != nil {
  36. return nil, nil, nil, nil, errors.Annotate(err, "Could not spawn process via SSH")
  37. }
  38. return b.Session, stdin, stdout, stderr, nil
  39. }
  40. func (b *SSH) createCmd(cmd string, args []string) string {
  41. parts := []string{cmd}
  42. simple := regexp.MustCompile(`^[a-z0-9_/.~+-]+$`)
  43. for _, arg := range args {
  44. if !simple.MatchString(arg) {
  45. arg = b.quote(arg)
  46. }
  47. parts = append(parts, arg)
  48. }
  49. return strings.Join(parts, " ")
  50. }
  51. func (b *SSH) quote(s string) string {
  52. return fmt.Sprintf(`"%s"`, s)
  53. }