shell.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Copyright (c) 2017 Gorillalabs. All rights reserved.
  2. package powershell
  3. import (
  4. "fmt"
  5. "io"
  6. "strings"
  7. "sync"
  8. "github.com/bhendo/go-powershell/backend"
  9. "github.com/bhendo/go-powershell/utils"
  10. "github.com/juju/errors"
  11. )
  12. const newline = "\r\n"
  13. type Shell interface {
  14. Execute(cmd string) (string, string, error)
  15. Exit()
  16. }
  17. type shell struct {
  18. handle backend.Waiter
  19. stdin io.Writer
  20. stdout io.Reader
  21. stderr io.Reader
  22. }
  23. func New(backend backend.Starter) (Shell, error) {
  24. handle, stdin, stdout, stderr, err := backend.StartProcess("powershell.exe", "-NoExit", "-Command", "-")
  25. if err != nil {
  26. return nil, err
  27. }
  28. return &shell{handle, stdin, stdout, stderr}, nil
  29. }
  30. func (s *shell) Execute(cmd string) (string, string, error) {
  31. if s.handle == nil {
  32. return "", "", errors.Annotate(errors.New(cmd), "Cannot execute commands on closed shells.")
  33. }
  34. outBoundary := createBoundary()
  35. errBoundary := createBoundary()
  36. // wrap the command in special markers so we know when to stop reading from the pipes
  37. full := fmt.Sprintf("%s; echo '%s'; [Console]::Error.WriteLine('%s')%s", cmd, outBoundary, errBoundary, newline)
  38. _, err := s.stdin.Write([]byte(full))
  39. if err != nil {
  40. return "", "", errors.Annotate(errors.Annotate(err, cmd), "Could not send PowerShell command")
  41. }
  42. // read stdout and stderr
  43. sout := ""
  44. serr := ""
  45. waiter := &sync.WaitGroup{}
  46. waiter.Add(2)
  47. go streamReader(s.stdout, outBoundary, &sout, waiter)
  48. go streamReader(s.stderr, errBoundary, &serr, waiter)
  49. waiter.Wait()
  50. if len(serr) > 0 {
  51. return sout, serr, errors.Annotate(errors.New(cmd), serr)
  52. }
  53. return sout, serr, nil
  54. }
  55. func (s *shell) Exit() {
  56. s.stdin.Write([]byte("exit" + newline))
  57. // if it's possible to close stdin, do so (some backends, like the local one,
  58. // do support it)
  59. closer, ok := s.stdin.(io.Closer)
  60. if ok {
  61. closer.Close()
  62. }
  63. s.handle.Wait()
  64. s.handle = nil
  65. s.stdin = nil
  66. s.stdout = nil
  67. s.stderr = nil
  68. }
  69. func streamReader(stream io.Reader, boundary string, buffer *string, signal *sync.WaitGroup) error {
  70. // read all output until we have found our boundary token
  71. output := ""
  72. bufsize := 64
  73. marker := boundary + newline
  74. for {
  75. buf := make([]byte, bufsize)
  76. read, err := stream.Read(buf)
  77. if err != nil {
  78. return err
  79. }
  80. output = output + string(buf[:read])
  81. if strings.HasSuffix(output, marker) {
  82. break
  83. }
  84. }
  85. *buffer = strings.TrimSuffix(output, marker)
  86. signal.Done()
  87. return nil
  88. }
  89. func createBoundary() string {
  90. return "$gorilla" + utils.CreateRandomString(12) + "$"
  91. }