package cli import ( "context" "encoding/json" "errors" "fmt" "git.nspix.com/golang/micro/helper/utils" "reflect" "strconv" "strings" "sync" ) var ( ErrInvalidStruct = errors.New("invalid struct") ErrInvalidArgument = errors.New("invalid argument") ) type encoder interface { Marshal() ([]byte, error) } type Context struct { ID int32 CmdStr string Args []string //所有参数 params []string locker sync.RWMutex Values map[string]interface{} response *Response ctx context.Context } func (ctx *Context) Context() context.Context { return ctx.ctx } func (ctx *Context) WithContext(c context.Context) { ctx.ctx = c } //HasArgument 是否有指定的参数 func (ctx *Context) HasArgument(i int) bool { return len(ctx.Args) > i } //ParamValue 获取参数的值 func (ctx *Context) ParamValue(name string) string { idx := -1 if ctx.params == nil { return "" } for i, s := range ctx.params { if strings.ToLower(s) == strings.ToLower(name) { idx = i break } } if idx > -1 { if ctx.HasArgument(idx) { return ctx.Argument(idx) } } return "" } //Argument 获取指定参数 func (ctx *Context) Argument(i int) string { if ctx.HasArgument(i) { return ctx.Args[i] } return "" } func (ctx *Context) reset(s string) { ctx.response = &Response{} ctx.Args = nil ctx.params = nil ctx.CmdStr = s } //Get 获取一个session变量 func (ctx *Context) Get(s string) interface{} { ctx.locker.RLock() defer ctx.locker.RUnlock() if ctx.Values == nil { return "" } return ctx.Values[s] } //Set 设置一个session变量 func (ctx *Context) Set(key string, value interface{}) { ctx.locker.Lock() defer ctx.locker.Unlock() if ctx.Values == nil { ctx.Values = make(map[string]interface{}) } ctx.Values[key] = value } //Bind 绑定一个变量 func (ctx *Context) Bind(i interface{}) (err error) { refVal := reflect.Indirect(reflect.ValueOf(i)) refType := refVal.Type() if refVal.Kind() != reflect.Struct { return ErrInvalidStruct } numOfField := refVal.Type().NumField() if numOfField != len(ctx.Args) { var usage string usage = "Usage: " + ctx.CmdStr + " " for i := 0; i < numOfField; i++ { usage += "{" + utils.LowerFirst(refType.Field(i).Name) + "|" + refVal.Field(i).Type().Kind().String() + "} " } ctx.Set("usage", usage) return ErrInvalidArgument } for i := 0; i < numOfField; i++ { switch refVal.Field(i).Kind() { case reflect.String: refVal.Field(i).SetString(ctx.Args[i]) case reflect.Int, reflect.Int32, reflect.Int64: n, _ := strconv.ParseInt(ctx.Args[i], 10, 64) refVal.Field(i).SetInt(n) case reflect.Float32, reflect.Float64: n, _ := strconv.ParseFloat(ctx.Args[i], 64) refVal.Field(i).SetFloat(n) default: err = fmt.Errorf("unsupported argument %d kind %s", i, refVal.Field(i).Kind()) return } } return } func (ctx *Context) Error(code int, msg string) (err error) { ctx.response.Code = code ctx.response.Error = msg return } func (ctx *Context) Success(i interface{}) (err error) { if v, ok := i.(encoder); ok { ctx.response.Data, err = v.Marshal() return } refVal := reflect.Indirect(reflect.ValueOf(i)) switch refVal.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: ctx.response.Data = []byte(strconv.FormatInt(refVal.Int(), 10)) case reflect.Float32, reflect.Float64: ctx.response.Data = []byte(strconv.FormatFloat(refVal.Float(), 'f', -1, 64)) case reflect.String: ctx.response.Data = []byte(refVal.String()) case reflect.Slice: if refVal.Type().Elem().Kind() == reflect.Uint8 { ctx.response.Data = refVal.Bytes() } else { ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t") } case reflect.Struct, reflect.Map: ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t") default: ctx.response.Data, err = json.MarshalIndent(refVal.Interface(), "", "\t") } return }