client.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package cli
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/peterh/liner"
  7. "io"
  8. "math"
  9. "net"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. )
  17. type Client struct {
  18. name string
  19. ctx context.Context
  20. address string
  21. sequence uint16
  22. conn net.Conn
  23. liner *liner.State
  24. mutex sync.Mutex
  25. exitChan chan struct{}
  26. readyChan chan struct{}
  27. commandChan chan *Frame
  28. completerChan chan *Frame
  29. Timeout time.Duration
  30. exitFlag int32
  31. }
  32. func (client *Client) getSequence() uint16 {
  33. client.mutex.Lock()
  34. defer client.mutex.Unlock()
  35. if client.sequence >= math.MaxUint16 {
  36. client.sequence = 0
  37. }
  38. client.sequence++
  39. n := client.sequence
  40. return n
  41. }
  42. func (client *Client) dialContext(ctx context.Context, address string) (conn net.Conn, err error) {
  43. var (
  44. pos int
  45. network string
  46. dialer net.Dialer
  47. )
  48. if pos = strings.Index(address, "://"); pos > -1 {
  49. network = address[:pos]
  50. address = address[pos+3:]
  51. } else {
  52. network = "tcp"
  53. }
  54. if conn, err = dialer.DialContext(ctx, network, address); err != nil {
  55. return
  56. }
  57. return
  58. }
  59. func (client *Client) renderBanner(info *Info) {
  60. client.name = info.Name
  61. fmt.Printf("Welcome to the %s(%s) monitor\n", info.Name, info.Version)
  62. fmt.Printf("Your connection id is %d\n", info.ID)
  63. fmt.Printf("Last login: %s from %s\n", info.ServerTime.Format(time.RFC822), info.RemoteAddr)
  64. fmt.Printf("Type 'help' for help. Type 'exit' for quit. Type 'cls' to clear input statement.\n")
  65. }
  66. func (client *Client) ioLoop(r io.Reader) {
  67. defer func() {
  68. _ = client.Close()
  69. }()
  70. for {
  71. frame, err := readFrame(r)
  72. if err != nil {
  73. return
  74. }
  75. switch frame.Type {
  76. case PacketTypeHandshake:
  77. info := &Info{}
  78. if err = json.Unmarshal(frame.Data, info); err == nil {
  79. client.renderBanner(info)
  80. }
  81. select {
  82. case client.readyChan <- struct{}{}:
  83. case <-client.exitChan:
  84. return
  85. }
  86. case PacketTypeCompleter:
  87. select {
  88. case client.completerChan <- frame:
  89. case <-client.exitChan:
  90. return
  91. }
  92. case PacketTypeCommand:
  93. select {
  94. case client.commandChan <- frame:
  95. case <-client.exitChan:
  96. return
  97. }
  98. }
  99. }
  100. }
  101. func (client *Client) waitResponse(seq uint16, timeout time.Duration) {
  102. timer := time.NewTimer(timeout)
  103. defer timer.Stop()
  104. for {
  105. select {
  106. case <-timer.C:
  107. fmt.Println("timeout waiting for response")
  108. return
  109. case <-client.exitChan:
  110. return
  111. case res, ok := <-client.commandChan:
  112. if !ok {
  113. break
  114. }
  115. if res.Seq == seq {
  116. if res.Error != "" {
  117. fmt.Print(res.Error)
  118. } else {
  119. fmt.Print(string(res.Data))
  120. }
  121. if res.Flag == FlagComplete {
  122. fmt.Println("")
  123. return
  124. }
  125. }
  126. }
  127. }
  128. }
  129. func (client *Client) completer(str string) (ss []string) {
  130. var (
  131. err error
  132. seq uint16
  133. )
  134. ss = make([]string, 0)
  135. seq = client.getSequence()
  136. if err = writeFrame(client.conn, newFrame(PacketTypeCompleter, FlagComplete, seq, []byte(str))); err != nil {
  137. return
  138. }
  139. select {
  140. case <-time.After(time.Second * 5):
  141. case frame, ok := <-client.completerChan:
  142. if ok {
  143. err = json.Unmarshal(frame.Data, &ss)
  144. }
  145. }
  146. return
  147. }
  148. func (client *Client) Execute(s string) (err error) {
  149. var (
  150. seq uint16
  151. )
  152. if client.conn, err = client.dialContext(client.ctx, client.address); err != nil {
  153. return err
  154. }
  155. defer func() {
  156. _ = client.Close()
  157. }()
  158. go client.ioLoop(client.conn)
  159. seq = client.getSequence()
  160. if err = writeFrame(client.conn, newFrame(PacketTypeCommand, FlagComplete, seq, []byte(s))); err != nil {
  161. return err
  162. }
  163. client.waitResponse(seq, time.Second*30)
  164. return
  165. }
  166. func (client *Client) Shell() (err error) {
  167. var (
  168. seq uint16
  169. line string
  170. )
  171. client.liner.SetCtrlCAborts(true)
  172. if client.conn, err = client.dialContext(client.ctx, client.address); err != nil {
  173. return err
  174. }
  175. defer func() {
  176. _ = client.Close()
  177. }()
  178. if err = writeFrame(client.conn, newFrame(PacketTypeHandshake, FlagComplete, client.getSequence(), nil)); err != nil {
  179. return
  180. }
  181. go client.ioLoop(client.conn)
  182. select {
  183. case <-client.readyChan:
  184. case <-client.ctx.Done():
  185. return
  186. }
  187. client.liner.SetCompleter(client.completer)
  188. for {
  189. if line, err = client.liner.Prompt(client.name + "> "); err != nil {
  190. break
  191. }
  192. if atomic.LoadInt32(&client.exitFlag) == 1 {
  193. fmt.Println(Bye)
  194. break
  195. }
  196. line = strings.TrimSpace(line)
  197. if line == "" {
  198. continue
  199. }
  200. if strings.ToLower(line) == "exit" || strings.ToLower(line) == "quit" {
  201. fmt.Println(Bye)
  202. return
  203. }
  204. if strings.ToLower(line) == "clear" || strings.ToLower(line) == "cls" {
  205. fmt.Print("\033[2J")
  206. continue
  207. }
  208. seq = client.getSequence()
  209. if err = writeFrame(client.conn, newFrame(PacketTypeCommand, FlagComplete, seq, []byte(line))); err != nil {
  210. break
  211. }
  212. client.liner.AppendHistory(line)
  213. client.waitResponse(seq, client.Timeout)
  214. }
  215. return
  216. }
  217. func (client *Client) Close() (err error) {
  218. if !atomic.CompareAndSwapInt32(&client.exitFlag, 0, 1) {
  219. return
  220. }
  221. close(client.exitChan)
  222. if client.conn != nil {
  223. err = client.conn.Close()
  224. }
  225. if client.liner != nil {
  226. err = client.liner.Close()
  227. }
  228. return
  229. }
  230. func NewClient(ctx context.Context, addr string) *Client {
  231. if ctx == nil {
  232. ctx = context.Background()
  233. }
  234. return &Client{
  235. ctx: ctx,
  236. address: addr,
  237. name: filepath.Base(os.Args[0]),
  238. Timeout: time.Second * 30,
  239. liner: liner.NewLiner(),
  240. readyChan: make(chan struct{}, 1),
  241. exitChan: make(chan struct{}),
  242. commandChan: make(chan *Frame, 5),
  243. completerChan: make(chan *Frame, 5),
  244. }
  245. }