server.go 4.6 KB


  1. package cli
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "math"
  7. "net"
  8. "path"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "time"
  13. "git.nspix.com/golang/kos/util/env"
  14. "github.com/sourcegraph/conc"
  15. )
  16. var (
  17. ctxPool sync.Pool
  18. )
  19. type Server struct {
  20. ctx context.Context
  21. sequenceLocker sync.Mutex
  22. sequence int64
  23. ctxMap sync.Map
  24. waitGroup conc.WaitGroup
  25. middleware []Middleware
  26. router *Router
  27. l net.Listener
  28. }
  29. func (svr *Server) applyContext() *Context {
  30. if v := ctxPool.Get(); v != nil {
  31. if ctx, ok := v.(*Context); ok {
  32. return ctx
  33. }
  34. }
  35. return &Context{}
  36. }
  37. func (svr *Server) releaseContext(ctx *Context) {
  38. ctxPool.Put(ctx)
  39. }
  40. func (svr *Server) handle(ctx *Context, frame *Frame) (err error) {
  41. var (
  42. params map[string]string
  43. tokens []string
  44. args []string
  45. r *Router
  46. )
  47. cmd := string(frame.Data)
  48. tokens = strings.Fields(cmd)
  49. if frame.Timeout > 0 {
  50. childCtx, cancelFunc := context.WithTimeout(svr.ctx, time.Duration(frame.Timeout))
  51. ctx.setContext(childCtx)
  52. defer func() {
  53. cancelFunc()
  54. }()
  55. } else {
  56. ctx.setContext(svr.ctx)
  57. }
  58. if r, args, err = svr.router.Lookup(tokens); err != nil {
  59. if errors.Is(err, ErrNotFound) {
  60. err = ctx.Error(errNotFound, fmt.Sprintf("Command %s not found", cmd))
  61. } else {
  62. err = ctx.Error(errExecuteFailed, err.Error())
  63. }
  64. } else {
  65. if len(r.params) > len(args) {
  66. err = ctx.Error(errExecuteFailed, r.Usage())
  67. return
  68. }
  69. if len(r.params) > 0 {
  70. params = make(map[string]string)
  71. for i, s := range r.params {
  72. params[s] = args[i]
  73. }
  74. }
  75. ctx.setArgs(args)
  76. ctx.setParam(params)
  77. err = r.command.Handle(ctx)
  78. }
  79. return
  80. }
  81. func (svr *Server) nextSequence() int64 {
  82. svr.sequenceLocker.Lock()
  83. defer svr.sequenceLocker.Unlock()
  84. if svr.sequence >= math.MaxInt64 {
  85. svr.sequence = 1
  86. }
  87. svr.sequence++
  88. return svr.sequence
  89. }
  90. func (svr *Server) process(conn net.Conn) {
  91. var (
  92. err error
  93. ctx *Context
  94. frame *Frame
  95. )
  96. ctx = svr.applyContext()
  97. ctx.reset(svr.nextSequence(), conn)
  98. svr.ctxMap.Store(ctx.Id, ctx)
  99. defer func() {
  100. _ = conn.Close()
  101. svr.ctxMap.Delete(ctx.Id)
  102. svr.releaseContext(ctx)
  103. }()
  104. for {
  105. if frame, err = readFrame(conn); err != nil {
  106. break
  107. }
  108. //reset frame
  109. ctx.seq = frame.Seq
  110. switch frame.Type {
  111. case PacketTypeHandshake:
  112. if err = ctx.send(responsePayload{
  113. Type: PacketTypeHandshake,
  114. Data: &Info{
  115. ID: ctx.Id,
  116. Name: env.Get("VOX_NAME", ""),
  117. Version: env.Get("VOX_VERSION", ""),
  118. OS: runtime.GOOS,
  119. ServerTime: time.Now(),
  120. RemoteAddr: conn.RemoteAddr().String(),
  121. },
  122. }); err != nil {
  123. break
  124. }
  125. case PacketTypeCompleter:
  126. if err = ctx.send(responsePayload{
  127. Type: PacketTypeCompleter,
  128. Data: svr.router.Completer(strings.Fields(string(frame.Data))...),
  129. }); err != nil {
  130. break
  131. }
  132. case PacketTypeCommand:
  133. if err = svr.handle(ctx, frame); err != nil {
  134. break
  135. }
  136. default:
  137. break
  138. }
  139. }
  140. }
  141. func (svr *Server) serve() {
  142. for {
  143. conn, err := svr.l.Accept()
  144. if err != nil {
  145. break
  146. }
  147. svr.waitGroup.Go(func() {
  148. svr.process(conn)
  149. })
  150. }
  151. }
  152. func (svr *Server) wrapHandle(pathname, desc string, cb HandleFunc, middleware ...Middleware) Command {
  153. h := func(ctx *Context) (err error) {
  154. for i := len(svr.middleware) - 1; i >= 0; i-- {
  155. cb = svr.middleware[i](cb)
  156. }
  157. for i := len(middleware) - 1; i >= 0; i-- {
  158. cb = middleware[i](cb)
  159. }
  160. return cb(ctx)
  161. }
  162. if desc == "" {
  163. desc = strings.Join(strings.Split(strings.TrimPrefix(pathname, "/"), "/"), " ")
  164. }
  165. return Command{
  166. Path: pathname,
  167. Handle: h,
  168. Description: desc,
  169. }
  170. }
  171. func (svr *Server) Use(middleware ...Middleware) {
  172. svr.middleware = append(svr.middleware, middleware...)
  173. }
  174. func (svr *Server) Group(prefix string, commands []Command, middleware ...Middleware) {
  175. for _, cmd := range commands {
  176. svr.Handle(path.Join(prefix, cmd.Path), cmd.Description, cmd.Handle, middleware...)
  177. }
  178. }
  179. func (svr *Server) Handle(pathname string, desc string, cb HandleFunc, middleware ...Middleware) {
  180. svr.router.Handle(pathname, svr.wrapHandle(pathname, desc, cb, middleware...))
  181. }
  182. func (svr *Server) Serve(l net.Listener) (err error) {
  183. svr.l = l
  184. svr.Handle("/help", "Display help information", func(ctx *Context) (err error) {
  185. return ctx.Success(svr.router.String())
  186. })
  187. svr.serve()
  188. return
  189. }
  190. func (svr *Server) Shutdown() (err error) {
  191. err = svr.l.Close()
  192. svr.ctxMap.Range(func(key, value any) bool {
  193. if ctx, ok := value.(*Context); ok {
  194. err = ctx.Close()
  195. }
  196. return true
  197. })
  198. svr.waitGroup.Wait()
  199. return
  200. }
  201. func New(ctx context.Context) *Server {
  202. return &Server{
  203. ctx: ctx,
  204. router: newRouter(""),
  205. middleware: make([]Middleware, 0, 10),
  206. }
  207. }