Sfoglia il codice sorgente

处理cli模式兼容到handle方法

lxg 3 anni fa
parent
commit
4b89800a49

+ 19 - 1
cmd/main.go

@@ -4,6 +4,14 @@ import (
 	"git.nspix.com/golang/micro"
 )
 
+type (
+	Request struct {
+		Name string
+		Age  int
+		Co   float64
+	}
+)
+
 func main() {
 	svr := micro.New(
 		micro.WithName("git.nspix.com/test", "0.0.01"),
@@ -12,5 +20,15 @@ func main() {
 		micro.WithCli(),
 		micro.WithPort(6567),
 	)
+
+	svr.Handle("getUserList", func(ctx micro.Context) (err error) {
+		var req Request
+		if err = ctx.Bind(&req); err != nil {
+			return err
+		}
+		return ctx.Success(req)
+	}, func(o *micro.HandleOptions) {
+		o.DisableCli = false
+	})
 	svr.Run()
-}
+}

+ 1 - 1
cmd/mock/cli.go

@@ -7,7 +7,7 @@ import (
 )
 
 func main() {
-	if conn, err := net.Dial("tcp", "10.9.1.93:6567"); err == nil {
+	if conn, err := net.Dial("tcp", "192.168.6.76:6567"); err == nil {
 		cli.OpenInteractive(context.Background(), conn)
 	}
 }

+ 77 - 5
gateway/cli/context.go

@@ -1,12 +1,33 @@
 package cli
 
-import "sync"
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"git.nspix.com/golang/micro/utils/helper"
+	"reflect"
+	"strconv"
+	"sync"
+)
+
+var (
+	ErrInvalidStruct   = errors.New("invalid struct")
+	ErrInvalidArgument = errors.New("invalid argument")
+)
 
 type Context struct {
-	ID     int32
-	locker sync.RWMutex
-	CmdStr string
-	Values map[string]interface{}
+	ID       int32
+	CmdStr   string
+	Args     []string //所有参数
+	locker   sync.RWMutex
+	Values   map[string]interface{}
+	response *Response
+}
+
+func (ctx *Context) reset(s string) {
+	ctx.response = &Response{}
+	ctx.Args = nil
+	ctx.CmdStr = s
 }
 
 func (ctx *Context) Get(s string) interface{} {
@@ -28,13 +49,64 @@ func (ctx *Context) Set(key string, value interface{}) {
 }
 
 func (ctx *Context) Bind(i interface{}) (err error) {
+	refVal := reflect.Indirect(reflect.ValueOf(i))
+	refType := refVal.Type()
+	if refVal.Kind() != reflect.Struct {
+		return ErrInvalidStruct
+	}
+	numOfField := refVal.Type().NumField()
+	if numOfField != len(ctx.Args) {
+		var usage string
+		usage = "Usage: " + ctx.CmdStr + " "
+		for i := 0; i < numOfField; i++ {
+			usage += "{" + helper.LowerFirst(refType.Field(i).Name) + "|" + refVal.Field(i).Type().Kind().String() + "} "
+		}
+		ctx.Set("usage", usage)
+		return ErrInvalidArgument
+	}
+	for i := 0; i < numOfField; i++ {
+		switch refVal.Field(i).Kind() {
+		case reflect.String:
+			refVal.Field(i).SetString(ctx.Args[i])
+		case reflect.Int, reflect.Int32, reflect.Int64:
+			n, _ := strconv.ParseInt(ctx.Args[i], 10, 64)
+			refVal.Field(i).SetInt(n)
+		case reflect.Float32, reflect.Float64:
+			n, _ := strconv.ParseFloat(ctx.Args[i], 64)
+			refVal.Field(i).SetFloat(n)
+		default:
+			err = fmt.Errorf("unsupported argument %d kind %s", i, refVal.Field(i).Kind())
+			return
+		}
+	}
 	return
 }
 
 func (ctx *Context) Error(code int, msg string) (err error) {
+	ctx.response.Code = code
+	ctx.response.Error = msg
 	return
 }
 
 func (ctx *Context) Success(i interface{}) (err error) {
+	refVal := reflect.Indirect(reflect.ValueOf(i))
+	switch refVal.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		ctx.response.Data = []byte(strconv.FormatInt(refVal.Int(), 10))
+	case reflect.Float32, reflect.Float64:
+		ctx.response.Data = []byte(strconv.FormatFloat(refVal.Float(), 'f', -1, 64))
+	case reflect.String:
+		ctx.response.Data = []byte(refVal.String())
+	case reflect.Slice:
+		if refVal.Type().Elem().Kind() == reflect.Uint8 {
+			ctx.response.Data = refVal.Bytes()
+		} else {
+			ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t")
+		}
+	case reflect.Struct, reflect.Map:
+		ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t")
+	default:
+		ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t")
+	}
 	return
 }

+ 129 - 0
gateway/cli/context_test.go

@@ -0,0 +1,129 @@
+package cli
+
+import (
+	"reflect"
+	"sync"
+	"testing"
+)
+
+func TestContext_Bind(t *testing.T) {
+	type fields struct {
+		ID       int32
+		CmdStr   string
+		Args     []string
+		locker   sync.RWMutex
+		Values   map[string]interface{}
+		response *Response
+	}
+	type args struct {
+		i interface{}
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			ctx := &Context{
+				ID:       tt.fields.ID,
+				CmdStr:   tt.fields.CmdStr,
+				Args:     tt.fields.Args,
+				locker:   tt.fields.locker,
+				Values:   tt.fields.Values,
+				response: tt.fields.response,
+			}
+			if err := ctx.Bind(tt.args.i); (err != nil) != tt.wantErr {
+				t.Errorf("Bind() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+func TestContext_Error(t *testing.T) {
+	type fields struct {
+		ID       int32
+		CmdStr   string
+		Args     []string
+		locker   sync.RWMutex
+		Values   map[string]interface{}
+		response *Response
+	}
+	type args struct {
+		code int
+		msg  string
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			ctx := &Context{
+				ID:       tt.fields.ID,
+				CmdStr:   tt.fields.CmdStr,
+				Args:     tt.fields.Args,
+				locker:   tt.fields.locker,
+				Values:   tt.fields.Values,
+				response: tt.fields.response,
+			}
+			if err := ctx.Error(tt.args.code, tt.args.msg); (err != nil) != tt.wantErr {
+				t.Errorf("Error() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+
+func TestContext_Success(t *testing.T) {
+	type fields struct {
+		ID       int32
+		CmdStr   string
+		Args     []string
+		locker   sync.RWMutex
+		Values   map[string]interface{}
+		response *Response
+	}
+	type args struct {
+		i interface{}
+	}
+	tests := []struct {
+		name    string
+		fields  fields
+		args    args
+		wantErr bool
+	}{
+		// TODO: Add test cases.
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			ctx := &Context{
+				ID:       tt.fields.ID,
+				CmdStr:   tt.fields.CmdStr,
+				Args:     tt.fields.Args,
+				locker:   tt.fields.locker,
+				Values:   tt.fields.Values,
+				response: tt.fields.response,
+			}
+			if err := ctx.Success(tt.args.i); (err != nil) != tt.wantErr {
+				t.Errorf("Success() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}
+
+func TestRef(t *testing.T)  {
+	type x struct {
+		A string
+	}
+	refVal := reflect.Indirect(reflect.ValueOf(&x{A: "xxx"}))
+	t.Log(refVal.Field(0).Kind())
+	t.Log(refVal.Type().Field(0).Name)
+}

+ 20 - 68
gateway/cli/executor.go

@@ -6,7 +6,6 @@ import (
 	"git.nspix.com/golang/micro/utils/console"
 	"reflect"
 	"sort"
-	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -28,12 +27,12 @@ type Executor struct {
 	parent         *Executor
 	usage          string
 	description    string
-	handle         interface{}
+	handleFunc     HandleFunc
 	children       map[string]*Executor
 	createdTime    time.Time
 	seq            int64
 	locker         sync.RWMutex
-	NotFoundHandle func(ctx *Context, cmd string) ([]byte, error)
+	NotFoundHandle func(ctx *Context, cmd string) (*Response, error)
 }
 
 type (
@@ -137,9 +136,13 @@ func (exec *Executor) String() string {
 		sort.Sort(vs)
 		for _, v := range vs {
 			if prefix == "" {
-				values = append(values, []interface{}{v.name, v.getDescription()})
+				if v.handleFunc != nil {
+					values = append(values, []interface{}{v.name, v.getDescription()})
+				}
 			} else {
-				values = append(values, []interface{}{prefix + " " + v.name, v.getDescription()})
+				if v.handleFunc != nil {
+					values = append(values, []interface{}{prefix + " " + v.name, v.getDescription()})
+				}
 			}
 			if prefix == "" {
 				loop(v.name, v)
@@ -162,12 +165,12 @@ func (exec *Executor) Append(child ...*Executor) *Executor {
 	return exec
 }
 
-func (exec *Executor) WithHandle(v interface{}) *Executor  {
-	exec.handle = v
+func (exec *Executor) WithHandle(cb HandleFunc) *Executor {
+	exec.handleFunc = cb
 	return exec
 }
 
-func (exec *Executor) Do(ctx *Context, args ...string) (b []byte, err error) {
+func (exec *Executor) Do(ctx *Context, args ...string) (res *Response, err error) {
 	var (
 		root interface{}
 	)
@@ -182,72 +185,21 @@ func (exec *Executor) Do(ctx *Context, args ...string) (b []byte, err error) {
 			return exec.children[args[0]].Do(ctx)
 		}
 	}
-	if exec.handle == nil {
+	if exec.handleFunc == nil {
 		if exec.NotFoundHandle != nil {
-			b, err = exec.NotFoundHandle(ctx, ctx.CmdStr)
+			res, err = exec.NotFoundHandle(ctx, ctx.CmdStr)
 		} else {
 			err = fmt.Errorf("%s not found", ctx.CmdStr)
 		}
 		return
 	}
-	refVal := reflect.ValueOf(exec.handle)
-	refType := refVal.Type()
-	if refType.Kind() != reflect.Func {
-		err = ErrInvalidHandle
-		return
-	}
-	//checking args
-	if refType.NumIn() > 0 && len(args)+1 < refType.NumIn() {
-		var usage string
-		if exec.usage == "" {
-			usage = "Usage: " + ctx.CmdStr + " ["
-			for i := 1; i < refType.NumIn(); i++ {
-				usage += " " + refType.In(i).String()
-			}
-			usage += " ]"
-		} else {
-			usage = "Usage: " + exec.usage
-		}
-		err = errors.New(usage)
-		return
-	}
-	arguments := make([]reflect.Value, refType.NumIn())
-	for i := 0; i < refType.NumIn(); i++ {
-		if i == 0 {
-			arguments[i] = reflect.ValueOf(ctx)
-			continue
-		}
-		switch refType.In(i).Kind() {
-		case reflect.String:
-			arguments[i] = reflect.ValueOf(args[i-1])
-		case reflect.Int:
-			n, _ := strconv.ParseInt(args[i-1], 10, 64)
-			arguments[i] = reflect.ValueOf(int(n))
-		case reflect.Int32:
-			n, _ := strconv.ParseInt(args[i-1], 10, 32)
-			arguments[i] = reflect.ValueOf(int32(n))
-		case reflect.Int64:
-			n, _ := strconv.ParseInt(args[i-1], 10, 64)
-			arguments[i] = reflect.ValueOf(n)
-		case reflect.Float32:
-			n, _ := strconv.ParseFloat(args[i-1], 32)
-			arguments[i] = reflect.ValueOf(float32(n))
-		case reflect.Float64:
-			n, _ := strconv.ParseFloat(args[i-1], 64)
-			arguments[i] = reflect.ValueOf(n)
-		case stringSliceKind:
-			arguments[i] = reflect.ValueOf(args[i-1:])
-		default:
-			err = fmt.Errorf("unsupported argument %d kind %s", i-1, refType.In(i).Kind().String())
-			return
-		}
-	}
-	values := refVal.Call(arguments)
-	for _, v := range values {
-		if v.Type().Implements(errInterface) {
-			err = v.Interface().(error)
-		} else if v.Kind() == reflect.Slice {
-			b = v.Bytes()
+	ctx.Args = args
+	if err = exec.handleFunc(ctx); err == nil {
+		res = ctx.response
+	} else {
+		if err == ErrInvalidArgument {
+			res = &Response{Code: 1000, Error: fmt.Sprint(ctx.Get("usage"))}
+			err = nil
 		}
 	}
 	return

+ 7 - 0
gateway/cli/response.go

@@ -0,0 +1,7 @@
+package cli
+
+type Response struct {
+	Code  int    `json:"code"`
+	Error string `json:"error"`
+	Data  []byte `json:"data"`
+}

+ 44 - 6
gateway/cli/server.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"git.nspix.com/golang/micro/utils/bytepool"
+	"git.nspix.com/golang/micro/utils/helper"
 	"net"
 	"os"
 	"runtime"
@@ -19,6 +20,8 @@ const (
 	EOL = "\n"
 )
 
+type HandleFunc func(ctx *Context) (err error)
+
 type Server struct {
 	seq        int32
 	locker     sync.RWMutex
@@ -37,6 +40,7 @@ func (svr *Server) writePack(conn net.Conn, packet *Frame) (err error) {
 func (svr *Server) process(id int32, conn net.Conn) (err error) {
 	var (
 		buf []byte
+		res *Response
 	)
 	buffer := bytepool.Get(MaxReadBufferLength)
 	defer func() {
@@ -82,9 +86,13 @@ func (svr *Server) process(id int32, conn net.Conn) (err error) {
 				svr.contextMap[id] = ctx
 			}
 			svr.locker.Unlock()
-			ctx.CmdStr = strings.TrimSpace(string(reqPacket.Data))
-			if buf, err = svr.executor.Do(ctx, tokens...); err == nil {
-				err = svr.writePack(conn, &Frame{Type: PacketTypeData, Data: buf})
+			ctx.reset(strings.TrimSpace(string(reqPacket.Data)))
+			if res, err = svr.executor.Do(ctx, tokens...); err == nil {
+				if res.Code == 0 {
+					err = svr.writePack(conn, &Frame{Type: PacketTypeData, Data: res.Data})
+				} else {
+					err = svr.writePack(conn, &Frame{Type: PacketTypeData, Error: res.Error})
+				}
 			} else {
 				err = svr.writePack(conn, &Frame{Type: PacketTypeData, Error: err.Error()})
 			}
@@ -95,13 +103,43 @@ func (svr *Server) process(id int32, conn net.Conn) (err error) {
 	return
 }
 
+func (svr *Server) Handle(path string, cb HandleFunc) {
+	tokens := helper.BreakUp(path)
+	svr.locker.Lock()
+	defer svr.locker.Unlock()
+	var (
+		err    error
+		length int
+		p      *Executor
+		q      *Executor
+	)
+	length = len(tokens)
+	p = svr.executor
+	for i, token := range tokens {
+		token = strings.TrimSpace(strings.ToLower(token))
+		if q, err = p.Children(token); err == nil {
+			if i == length-1 {
+				panic(path + " already exists")
+			}
+			p = q
+		} else {
+			q = NewExecutor(token, "", strings.Title(strings.Join(tokens, " ")))
+			if i == length-1 {
+				q.handleFunc = cb
+			}
+			p.Append(q)
+			p = q
+		}
+	}
+}
+
 func (svr *Server) Serve(listener net.Listener) (err error) {
 	var (
 		conn net.Conn
 	)
-	svr.executor.Append(NewExecutor("help", "help", "Display this help").WithHandle(func(ctx *Context) []byte {
-		return []byte(svr.executor.String())
-	}))
+	svr.Handle("help", func(ctx *Context) (err error) {
+		return ctx.Success(svr.executor.String())
+	})
 	for {
 		if conn, err = listener.Accept(); err != nil {
 			break

+ 7 - 4
micro.go

@@ -2,6 +2,7 @@ package micro
 
 import (
 	"context"
+	"git.nspix.com/golang/micro/gateway/cli"
 
 	"git.nspix.com/golang/micro/gateway/http"
 	"git.nspix.com/golang/micro/gateway/rpc"
@@ -18,10 +19,11 @@ type (
 	}
 
 	HandleOptions struct {
-		DisableRpc  bool   //禁用RPC功能
-		DisableHttp bool   //禁用HTTP功能
-		HttpPath    string //重定向HTTP路由
-		HttpMethod  string //HTTP路径
+		DisableRpc     bool   //禁用RPC功能
+		DisableHttp    bool   //禁用HTTP功能
+		DisableCli     bool   //禁用CLI功能
+		HttpPath       string //重定向HTTP路由
+		HttpMethod     string //HTTP路径
 	}
 
 	HandleOption func(o *HandleOptions)
@@ -32,6 +34,7 @@ type (
 		Node() *registry.ServiceNode                                                //获取节点信息
 		HttpServe() *http.Server                                                    //获取HTTP实例
 		RPCServe() *rpc.Server                                                      //获取RPC实例
+		CliServe() *cli.Server                                                      //获取cli服务端
 		PeekService(name string) ([]*registry.ServiceNode, error)                   //选择一个服务
 		Handle(method string, cb HandleFunc, opts ...HandleOption)                  //注册一个处理器
 		NewRequest(name, method string, body interface{}) (req *Request, err error) //创建一个rpc请求

+ 9 - 1
service.go

@@ -83,8 +83,10 @@ func (svr *Service) eventLoop() {
 	}
 }
 
+//Handle 处理函数
 func (svr *Service) Handle(method string, cb HandleFunc, opts ...HandleOption) {
-	opt := &HandleOptions{HttpMethod: "POST"}
+	//disable cli default
+	opt := &HandleOptions{HttpMethod: "POST", DisableCli: true}
 	for _, f := range opts {
 		f(opt)
 	}
@@ -106,6 +108,12 @@ func (svr *Service) Handle(method string, cb HandleFunc, opts ...HandleOption) {
 			return cb(ctx)
 		})
 	}
+	//启用CLI模式
+	if svr.opts.EnableCli && !opt.DisableCli {
+		svr.cliSvr.Handle(method, func(ctx *cli.Context) (err error) {
+			return cb(ctx)
+		})
+	}
 	return
 }
 

+ 3 - 1
utils/console/console.go

@@ -7,6 +7,8 @@ import (
 	"unicode/utf8"
 )
 
+type Table [][]interface{}
+
 type Options struct {
 	Bordered bool
 	Header   bool
@@ -38,7 +40,7 @@ func calcCharsetWidth(s string) int {
 	}
 }
 
-func Pretty(values [][]interface{}, opts *Options) []byte {
+func Pretty(values Table, opts *Options) []byte {
 	columns := make([][]string, 0)
 	var widths []int
 	var width int

+ 41 - 0
utils/helper/helper.go

@@ -0,0 +1,41 @@
+package helper
+
+import "strings"
+
+func LowerFirst(s string) string {
+	isFirst := true
+	return strings.Map(func(r rune) rune {
+		if isFirst && r >= 'A' && r <= 'Z' {
+			return r + 32
+		}
+		isFirst = false
+		return r
+	}, s)
+}
+
+func BreakUp(s string) []string {
+	length := len(s)
+	b := make([]byte, length)
+	ss := make([]string, 0)
+	var p int
+	for i := 0; i < length; i++ {
+		if s[i] >= 'A' && s[i] <= 'Z' {
+			if p > 0 {
+				ss = append(ss, string(b[:p]))
+			}
+			p = 0
+			b[p] = s[i] + 32
+		} else {
+			b[p] = s[i]
+		}
+		p++
+	}
+	if p > 0 {
+		ss = append(ss, string(b[:p]))
+	}
+	return ss
+}
+
+func Camel2id(s string) string {
+	return strings.Join(BreakUp(s), "-")
+}

+ 43 - 0
utils/helper/helper_test.go

@@ -0,0 +1,43 @@
+package helper
+
+import "testing"
+
+func TestCamel2id(t *testing.T) {
+	type args struct {
+		s string
+	}
+	tests := []struct {
+		name string
+		args args
+		want string
+	}{
+		{"v1", args{s: "getServerName"}, "get server name"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := Camel2id(tt.args.s); got != tt.want {
+				t.Errorf("Camel2id() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestUcFirst(t *testing.T) {
+	type args struct {
+		s string
+	}
+	tests := []struct {
+		name string
+		args args
+		want string
+	}{
+		{"1",args{s:"GetName"},"getName"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := LowerFirst(tt.args.s); got != tt.want {
+				t.Errorf("UcFirst() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}