package cli import ( "context" "encoding/json" "fmt" "github.com/peterh/liner" "io" "math" "net" "os" "path/filepath" "strings" "sync" "sync/atomic" "time" ) type Client struct { name string ctx context.Context address string sequence uint16 conn net.Conn liner *liner.State mutex sync.Mutex exitChan chan struct{} readyChan chan struct{} commandChan chan *Frame completerChan chan *Frame Timeout time.Duration exitFlag int32 } func (client *Client) getSequence() uint16 { client.mutex.Lock() defer client.mutex.Unlock() if client.sequence >= math.MaxUint16 { client.sequence = 0 } client.sequence++ n := client.sequence return n } func (client *Client) dialContext(ctx context.Context, address string) (conn net.Conn, err error) { var ( pos int network string dialer net.Dialer ) if pos = strings.Index(address, "://"); pos > -1 { network = address[:pos] address = address[pos+3:] } else { network = "tcp" } if conn, err = dialer.DialContext(ctx, network, address); err != nil { return } return } func (client *Client) renderBanner(info *Info) { client.name = info.Name fmt.Printf("Welcome to the %s(%s) monitor\n", info.Name, info.Version) fmt.Printf("Your connection id is %d\n", info.ID) fmt.Printf("Last login: %s from %s\n", info.ServerTime.Format(time.RFC822), info.RemoteAddr) fmt.Printf("Type 'help' for help. Type 'exit' for quit. Type 'cls' to clear input statement.\n") } func (client *Client) ioLoop(r io.Reader) { defer func() { _ = client.Close() }() for { frame, err := readFrame(r) if err != nil { return } switch frame.Type { case PacketTypeHandshake: info := &Info{} if err = json.Unmarshal(frame.Data, info); err == nil { client.renderBanner(info) } select { case client.readyChan <- struct{}{}: case <-client.exitChan: return } case PacketTypeCompleter: select { case client.completerChan <- frame: case <-client.exitChan: return } case PacketTypeCommand: select { case client.commandChan <- frame: case <-client.exitChan: return } } } } func (client *Client) waitResponse(seq uint16, timeout time.Duration) { timer := time.NewTimer(timeout) defer timer.Stop() for { select { case <-timer.C: fmt.Println("timeout waiting for response") return case <-client.exitChan: return case res, ok := <-client.commandChan: if !ok { break } if res.Seq == seq { if res.Error != "" { fmt.Print(res.Error) } else { fmt.Print(string(res.Data)) } if res.Flag == FlagComplete { fmt.Println("") return } } } } } func (client *Client) completer(str string) (ss []string) { var ( err error seq uint16 ) ss = make([]string, 0) seq = client.getSequence() if err = writeFrame(client.conn, newFrame(PacketTypeCompleter, FlagComplete, seq, []byte(str))); err != nil { return } select { case <-time.After(time.Second * 5): case frame, ok := <-client.completerChan: if ok { err = json.Unmarshal(frame.Data, &ss) } } return } func (client *Client) Execute(s string) (err error) { var ( seq uint16 ) if client.conn, err = client.dialContext(client.ctx, client.address); err != nil { return err } defer func() { _ = client.Close() }() go client.ioLoop(client.conn) seq = client.getSequence() if err = writeFrame(client.conn, newFrame(PacketTypeCommand, FlagComplete, seq, []byte(s))); err != nil { return err } client.waitResponse(seq, time.Second*30) return } func (client *Client) Shell() (err error) { var ( seq uint16 line string ) client.liner.SetCtrlCAborts(true) if client.conn, err = client.dialContext(client.ctx, client.address); err != nil { return err } defer func() { _ = client.Close() }() if err = writeFrame(client.conn, newFrame(PacketTypeHandshake, FlagComplete, client.getSequence(), nil)); err != nil { return } go client.ioLoop(client.conn) select { case <-client.readyChan: case <-client.ctx.Done(): return } client.liner.SetCompleter(client.completer) for { if line, err = client.liner.Prompt(client.name + "> "); err != nil { break } if atomic.LoadInt32(&client.exitFlag) == 1 { fmt.Println(Bye) break } line = strings.TrimSpace(line) if line == "" { continue } if strings.ToLower(line) == "exit" || strings.ToLower(line) == "quit" { fmt.Println(Bye) return } if strings.ToLower(line) == "clear" || strings.ToLower(line) == "cls" { fmt.Print("\033[2J") continue } seq = client.getSequence() if err = writeFrame(client.conn, newFrame(PacketTypeCommand, FlagComplete, seq, []byte(line))); err != nil { break } client.liner.AppendHistory(line) client.waitResponse(seq, client.Timeout) } return } func (client *Client) Close() (err error) { if !atomic.CompareAndSwapInt32(&client.exitFlag, 0, 1) { return } close(client.exitChan) if client.conn != nil { err = client.conn.Close() } if client.liner != nil { err = client.liner.Close() } return } func NewClient(ctx context.Context, addr string) *Client { if ctx == nil { ctx = context.Background() } return &Client{ ctx: ctx, address: addr, name: filepath.Base(os.Args[0]), Timeout: time.Second * 30, liner: liner.NewLiner(), readyChan: make(chan struct{}, 1), exitChan: make(chan struct{}), commandChan: make(chan *Frame, 5), completerChan: make(chan *Frame, 5), } }