package rest import ( "context" "encoding/csv" "encoding/json" "fmt" "path" "reflect" "strconv" "strings" "time" "git.nspix.com/golang/micro/gateway/http" "git.nspix.com/golang/rest/v2/errors" "git.nspix.com/golang/rest/v2/internal/inflector" lru "github.com/hashicorp/golang-lru" "gorm.io/gorm" "gorm.io/gorm/clause" ) const ( HttpAccessDenied = 8004 HttpInvalidPayload = 8002 HttpRequestCallbackFailed = 8003 HttpValidateFailed = 8008 HttpDatabaseQueryFailed = 8010 HttpDatabaseFindFailed = 8011 HttpDatabaseCreateFailed = 8012 HttpDatabaseUpdateFailed = 8013 HttpDatabaseDeleteFailed = 8014 HttpDatabaseExportFailed = 8015 ) type DiffAttr struct { Column string `json:"column"` Label string `json:"label"` OldValue interface{} `json:"old_value"` NewValue interface{} `json:"new_value"` } type Entity struct { opts *Options model Model primaryKey string reflectValue reflect.Value reflectType reflect.Type statement *gorm.Statement isImplementKvMapping bool mappingLabelField string mappingValueField string singularName string pluralizeName string scenarios []string hooks []Hook lruCache *lru.Cache createdAt time.Time } func (e *Entity) ID() string { return e.model.TableName() + "@" + e.model.ModuleName() } func (e *Entity) hasScenario(s string) bool { for _, scenario := range e.scenarios { if s == scenario { return true } } return false } // callMethod 调用回调函数 func (e *Entity) callMethod(model interface{}, name string, args ...interface{}) (err error) { refVal := reflect.ValueOf(model) if refVal.Kind() != reflect.Ptr { return } method := refVal.MethodByName(name) // if !method.CanAddr() { return } var ( ok bool in []reflect.Value out []reflect.Value ) if method.Type().NumIn() == len(args) { in = make([]reflect.Value, len(args)) for i, arg := range args { in[i] = reflect.ValueOf(arg) } out = method.Call(in) for _, v := range out { if err, ok = v.Interface().(error); ok { return } } } return } // getPrimaryKeyValue get reflect model primary value func (e *Entity) getPrimaryKeyValue(model interface{}) interface{} { if e.statement == nil { return nil } if len(e.statement.Schema.PrimaryFields) > 0 { primaryField := e.statement.Schema.PrimaryFields[0] refVal := reflect.Indirect(reflect.ValueOf(model)) val := refVal.FieldByName(primaryField.Name) return val.Interface() } return 0 } //getFieldValue get field value from reflect value func (e *Entity) getFieldValue(model reflect.Value, column string) interface{} { var ( name string ) refVal := reflect.Indirect(model) for _, field := range e.statement.Schema.Fields { if field.DBName == column { name = field.Name break } else if field.Name == column { name = column break } } if name == "" { return nil } fieldVal := refVal.FieldByName(name) return fieldVal.Interface() } // setFieldValue set reflect field value func (e *Entity) setFieldValue(model reflect.Value, column string, value interface{}) { var ( name string ) refVal := reflect.Indirect(model) for _, field := range e.statement.Schema.Fields { if field.DBName == column { name = field.Name break } else if field.Name == column { name = column break } } if name == "" { return } fieldVal := refVal.FieldByName(name) if fieldVal.CanSet() { fieldVal.Set(reflect.ValueOf(value)) } } // getScenarioMethod 获取某个场景下HTTP请求方法 func (e *Entity) getScenarioMethod(scenario string) string { var method string switch scenario { case ScenarioList: method = "GET" case ScenarioView: method = "GET" case ScenarioCreate: method = "POST" case ScenarioUpdate: method = "PUT" case ScenarioDelete: method = "DELETE" case ScenarioExport: method = "GET" } return method } // getScenarioUrl 获取某个场景下HTTP请求的URL func (e *Entity) getScenarioUrl(scenario string) string { var uri string switch scenario { case ScenarioList: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.pluralizeName case ScenarioView: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "/:id" case ScenarioCreate: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName case ScenarioUpdate: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "/:id" case ScenarioDelete: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "/:id" case ScenarioExport: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-export" case ScenarioMapping: uri = e.opts.Prefix + "/" + e.model.ModuleName() + "/" + e.singularName + "-mapping" } return path.Clean(uri) } // getScenarioHandle 获取某个场景下HTTP请求的处理回调 func (e *Entity) getScenarioHandle(scenario string) http.HandleFunc { var handleFunc http.HandleFunc switch scenario { case ScenarioList: handleFunc = e.actionIndex case ScenarioView: handleFunc = e.actionView case ScenarioCreate: handleFunc = e.actionCreate case ScenarioUpdate: handleFunc = e.actionUpdate case ScenarioDelete: handleFunc = e.actionDelete case ScenarioExport: handleFunc = e.actionExport case ScenarioMapping: handleFunc = e.actionMapping } return handleFunc } // prepareConditions 解析查询条件 func (e *Entity) prepareConditions(ctx *http.Context, query *Query, schemas []*Schema) { var ( err error formValue string model interface{} activeModel FilterColumnInterface ) model = reflect.New(e.reflectType).Interface() activeModel, _ = model.(FilterColumnInterface) //处理默认的搜索 for _, schema := range schemas { if activeModel != nil { if err = activeModel.OnSearchColumn(ctx, query, schema); err != nil { continue } } if schema.Native == 0 { continue } formValue = ctx.FormValue(schema.Column) switch schema.Format { case "string", "text", "textarea": if schema.getProperties().Match == MatchExactly { query.AndFilterWhere(NewQueryCondition(schema.Column, formValue)) } else { query.AndFilterWhere(NewQueryConditionWithOperator("LIKE", schema.Column, formValue)) } case "date", "time", "datetime": var sep string seps := []byte{',', '/'} for _, s := range seps { if strings.IndexByte(formValue, s) > -1 { sep = string(s) } } if ss := strings.Split(formValue, sep); len(ss) == 2 { query.AndFilterWhere( NewQueryConditionWithOperator(">=", schema.Column, strings.TrimSpace(ss[0])), NewQueryConditionWithOperator("<=", schema.Column, strings.TrimSpace(ss[1])), ) } else { query.AndFilterWhere(NewQueryCondition(schema.Column, formValue)) } case "duration", "number", "integer", "decimal": query.AndFilterWhere(NewQueryCondition(schema.Column, formValue)) default: if schema.Type == "string" { if schema.getProperties().Match == MatchExactly { query.AndFilterWhere(NewQueryCondition(schema.Column, formValue)) } else { query.AndFilterWhere(NewQueryConditionWithOperator("LIKE", schema.Column, formValue)) } } else { query.AndFilterWhere(NewQueryCondition(schema.Column, formValue)) } } } //处理排序 sortPar := ctx.FormValue("sort") if sortPar != "" { sorts := strings.Split(sortPar, ",") for _, s := range sorts { if s[0] == '-' { query.OrderBy(s[1:], "DESC") } else { if s[0] == '+' { query.OrderBy(s[1:], "ASC") } else { query.OrderBy(s, "ASC") } } } } } // isKvMapping 是否实现键值对结构 func (e *Entity) isKvMapping() bool { return e.isImplementKvMapping } // getMappingValue 获取映射值 func (e *Entity) getMappingValue(namespace string) []mappingValue { if !e.isKvMapping() { return nil } if v, ok := e.lruCache.Get(namespace + ":mappingValue"); ok { return v.([]mappingValue) } values := make([]mappingValue, 0) if err := e.opts.DB.Select(e.mappingLabelField+" AS label", e.mappingValueField+" AS value").Where("namespace=?", namespace).Table(e.model.TableName()).Scan(&values).Error; err == nil { e.lruCache.Add(namespace+":mappingValue", values) } return values } // invalidMappingValue 删除映射缓存数据 func (e *Entity) invalidMappingValue(namespace string) { e.lruCache.Remove(namespace + ":mappingValue") } func (e *Entity) invalidCache(namespace string) { e.invalidMappingValue(namespace) return } //actionIndex func (e *Entity) actionIndex(ctx *http.Context) (err error) { var ( page int pageIndex int pageSize int namespace string query *Query ) if !e.hasScenario(ScenarioList) { return ctx.Error(HttpAccessDenied, "access denied") } namespace = ctx.ParamValue("@namespace") page, _ = strconv.Atoi(ctx.FormValue("page")) pageSize, _ = strconv.Atoi(ctx.FormValue("pagesize")) if pageSize <= 0 { pageSize = 15 } pageIndex = page if pageIndex > 0 { pageIndex-- } sliceValue := reflect.MakeSlice(reflect.SliceOf(e.reflectType), 0, 0) models := reflect.New(sliceValue.Type()) models.Elem().Set(sliceValue) query = NewQuery(e.opts.DB) searchSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch) indexSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList) e.prepareConditions(ctx, query, searchSchemas) if e.opts.EnableNamespace { query.AndFilterWhere(NewQueryCondition("namespace", namespace)) } query.Offset(pageIndex * pageSize).Limit(pageSize) if err = query.All(models.Interface()); err != nil { return ctx.Error(HttpDatabaseQueryFailed, err.Error()) } requestCtx := ctx.Request().Context() if requestCtx == nil { requestCtx = context.Background() } requestCtx = context.WithValue(requestCtx, "namespace", namespace) return ctx.Success(map[string]interface{}{ "page": page, "pageSize": pageSize, "totalCount": query.Limit(0).Offset(0).Count(e.model), "data": e.opts.Formatter.formatModels(requestCtx, models.Interface(), indexSchemas, e.statement), }) } func (e *Entity) actionView(ctx *http.Context) (err error) { var ( model interface{} namespace string ) if !e.hasScenario(ScenarioView) { return ctx.Error(HttpAccessDenied, "access denied") } namespace = ctx.ParamValue("@namespace") scenario := ctx.FormValue("scenario") idStr := ctx.ParamValue("id") model = reflect.New(e.reflectType).Interface() conditions := map[string]interface{}{ e.primaryKey: idStr, } if e.opts.EnableNamespace { conditions["namespace"] = namespace } if err = e.opts.DB.Where(conditions).First(model).Error; err != nil { return ctx.Error(HttpDatabaseFindFailed, err.Error()) } if ctx.FormValue("format") != "" { //获取指定场景下面的字段进行渲染显示 var schemas []*Schema if scenario == "" { schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioView) } else { schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), scenario) } requestCtx := ctx.Request().Context() if requestCtx == nil { requestCtx = context.Background() } requestCtx = context.WithValue(requestCtx, "namespace", namespace) return ctx.Success(e.opts.Formatter.formatModel(requestCtx, model, schemas, e.statement)) } return ctx.Success(model) } func (e *Entity) actionExport(ctx *http.Context) (err error) { var ( query *Query namespace string ) if !e.hasScenario(ScenarioExport) { return ctx.Error(HttpAccessDenied, "access denied") } namespace = ctx.ParamValue("@namespace") sliceValue := reflect.MakeSlice(reflect.SliceOf(e.reflectType), 0, 0) models := reflect.New(sliceValue.Type()) models.Elem().Set(sliceValue) query = NewQuery(e.opts.DB) searchSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioSearch) exportSchemas := visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioList) e.prepareConditions(ctx, query, searchSchemas) if e.opts.EnableNamespace { query.AndFilterWhere(NewQueryCondition("namespace", namespace)) } if err = query.All(models.Interface()); err != nil { return ctx.Error(HttpDatabaseExportFailed, err.Error()) } ctx.Response().Header().Set("Content-Type", "text/csv") ctx.Response().Header().Set("Access-Control-Expose-Headers", "Content-Disposition") ctx.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=%s.csv", e.singularName)) requestCtx := ctx.Request().Context() if requestCtx == nil { requestCtx = context.Background() } requestCtx = context.WithValue(requestCtx, "namespace", namespace) value := e.opts.Formatter.formatModels(requestCtx, models.Interface(), exportSchemas, e.statement) writer := csv.NewWriter(ctx.Response()) ss := make([]string, len(exportSchemas)) for i, field := range exportSchemas { ss[i] = field.Label } _ = writer.Write(ss) if values, ok := value.([]interface{}); ok { for _, val := range values { row, ok2 := val.(map[string]interface{}) if !ok2 { continue } for i, field := range exportSchemas { if v, ok := row[field.Column]; ok { ss[i] = fmt.Sprint(v) } else { ss[i] = "" } } _ = writer.Write(ss) } } writer.Flush() return } func (e *Entity) actionCreate(ctx *http.Context) (err error) { var ( errTx error namespace string model interface{} schemas []*Schema refModel reflect.Value diffAttrs = make([]*DiffAttr, 0) ) if !e.hasScenario(ScenarioCreate) { return ctx.Error(HttpAccessDenied, "access denied") } namespace = ctx.ParamValue("@namespace") refModel = reflect.New(e.reflectType) model = refModel.Interface() if err = ctx.Bind(model); err != nil { return ctx.Error(HttpInvalidPayload, err.Error()) } schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioCreate) //设置某个字段的值 e.setFieldValue(refModel, "namespace", namespace) //global set field value e.setFieldValue(refModel, "CreatedBy", ctx.ParamValue("@uid")) e.setFieldValue(refModel, "CreatedDept", ctx.ParamValue("@department")) e.setFieldValue(refModel, "UpdatedBy", ctx.ParamValue("@uid")) e.setFieldValue(refModel, "UpdatedDept", ctx.ParamValue("@department")) if err = e.opts.DB.Transaction(func(tx *gorm.DB) error { //执行创建前回调函数 if errTx = e.callMethod(model, "OnBeforeCreateRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //执行保存前回调函数 if errTx = e.callMethod(model, "OnBeforeSaveRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //创建数据 if errTx = tx.Create(model).Error; errTx != nil { return errTx } //对比差异数据 for _, scm := range schemas { diffAttrs = append(diffAttrs, &DiffAttr{ Column: scm.Column, Label: scm.Label, OldValue: nil, NewValue: e.getFieldValue(refModel, scm.Column), }) } //执行创建后回调函数 if errTx = e.callMethod(model, "OnAfterCreateRequest", []interface{}{ctx, tx, model, diffAttrs}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //执行保存后回调函数 if errTx = e.callMethod(model, "OnAfterSaveRequest", []interface{}{ctx, tx, model, diffAttrs}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } return errTx }); err == nil { e.invalidCache(namespace) pkVal := e.getPrimaryKeyValue(model) if len(e.hooks) > 0 { for _, hook := range e.hooks { hook.HookAfterCreate(ctx, e.opts.DB, pkVal, model, diffAttrs) } } return ctx.Success(map[string]interface{}{ "id": pkVal, "table": e.model.TableName(), "state": "created", }) } //form validation if validateError, ok := err.(*errors.StructError); ok { ctx.Response().Header().Set("Content-Type", "application/json") return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{ "errno": HttpValidateFailed, "result": validateError, }) } return ctx.Error(HttpDatabaseCreateFailed, err.Error()) } func (e *Entity) actionUpdate(ctx *http.Context) (err error) { var ( errTx error namespace string model interface{} schemas []*Schema refModel reflect.Value oldValues = make(map[string]interface{}) diffs = make(map[string]interface{}) diffAttrs = make([]*DiffAttr, 0) ) if !e.hasScenario(ScenarioUpdate) { return ctx.Error(HttpAccessDenied, "access denied") } namespace = ctx.ParamValue("@namespace") idStr := ctx.ParamValue("id") refModel = reflect.New(e.reflectType) model = refModel.Interface() //默认设置更新用户 e.setFieldValue(refModel, "UpdatedBy", ctx.ParamValue("@uid")) e.setFieldValue(refModel, "UpdatedDept", ctx.ParamValue("@department")) conditions := map[string]interface{}{ e.primaryKey: idStr, } if e.opts.EnableNamespace { conditions["namespace"] = namespace } if err = e.opts.DB.Where(conditions).First(model).Error; err != nil { return ctx.Error(HttpDatabaseFindFailed, err.Error()) } schemas = visibleSchemas(namespace, e.model.ModuleName(), e.model.TableName(), ScenarioUpdate) for _, scm := range schemas { oldValues[scm.Column] = e.getFieldValue(refModel, scm.Column) } if err = ctx.Bind(model); err != nil { return ctx.Error(HttpInvalidPayload, err.Error()) } if err = e.opts.DB.Transaction(func(tx *gorm.DB) error { //更新前回调函数 if errTx = e.callMethod(model, "OnBeforeUpdateRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //执行保存前回调函数 if errTx = e.callMethod(model, "OnBeforeSaveRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //对比差异数据 for _, scm := range schemas { v := e.getFieldValue(refModel, scm.Column) if oldValues[scm.Column] != v { diffs[scm.Column] = v diffAttrs = append(diffAttrs, &DiffAttr{ Column: scm.Column, Label: scm.Label, OldValue: oldValues[scm.Column], NewValue: v, }) } } //进行局部数据更新 if len(diffs) > 0 { if errTx = tx.Model(model).Updates(diffs).Error; errTx != nil { return errTx } } //更新后回调函数 if errTx = e.callMethod(model, "OnAfterUpdateRequest", []interface{}{ctx, tx, model, diffAttrs}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //执行保存后回调函数 if errTx = e.callMethod(model, "OnAfterSaveRequest", []interface{}{ctx, tx, model, diffAttrs}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } return errTx }); err == nil { e.invalidCache(namespace) pkVal := e.getPrimaryKeyValue(model) if len(e.hooks) > 0 { for _, hook := range e.hooks { hook.HookAfterUpdate(ctx, e.opts.DB, pkVal, model, diffAttrs) } } return ctx.Success(map[string]interface{}{ "id": pkVal, "table": e.model.TableName(), "state": "updated", }) } //form validation if validateError, ok := err.(*errors.StructError); ok { ctx.Response().Header().Set("Content-Type", "application/json") return json.NewEncoder(ctx.Response()).Encode(map[string]interface{}{ "errno": HttpValidateFailed, "result": validateError, }) } return ctx.Error(HttpDatabaseUpdateFailed, err.Error()) } func (e *Entity) actionDelete(ctx *http.Context) (err error) { var ( errTx error model interface{} namespace string ) if !e.hasScenario(ScenarioDelete) { return ctx.Error(HttpAccessDenied, "access denied") } idStr := ctx.ParamValue("id") namespace = ctx.ParamValue("@namespace") model = reflect.New(e.reflectType).Interface() conditions := map[string]interface{}{ e.primaryKey: idStr, } if e.opts.EnableNamespace { conditions["namespace"] = namespace } if err = e.opts.DB.Where(conditions).First(model).Error; err != nil { return ctx.Error(HttpDatabaseFindFailed, err.Error()) } if err = e.opts.DB.Transaction(func(tx *gorm.DB) error { //删除前回调函数 if errTx = e.callMethod(model, "OnBeforeDeleteRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } //删除数据 if errTx = tx.Delete(model).Error; errTx != nil { return errTx } //删除后回调函数 if errTx = e.callMethod(model, "OnAfterDeleteRequest", []interface{}{ctx, tx, model}); errTx != nil { return ctx.Error(HttpRequestCallbackFailed, err.Error()) } return errTx }); err == nil { e.invalidCache(namespace) return ctx.Success(map[string]interface{}{ "id": e.getPrimaryKeyValue(model), "table": e.model.TableName(), "state": "deleted", }) } else { return ctx.Error(HttpDatabaseDeleteFailed, err.Error()) } } func (e *Entity) actionMapping(ctx *http.Context) (err error) { namespace := ctx.ParamValue("@namespace") return ctx.Success(e.getMappingValue(namespace)) } func newEntity(model Model, opts *Options) *Entity { entity := &Entity{ model: model, opts: opts, createdAt: time.Now(), reflectValue: reflect.Indirect(reflect.ValueOf(model)), } entity.lruCache, _ = lru.New(50) entity.reflectType = entity.reflectValue.Type() entity.singularName = inflector.Singularize(model.TableName()) entity.pluralizeName = inflector.Pluralize(model.TableName()) val := reflect.New(entity.reflectType).Interface() if kvMapping, ok := val.(KvMapping); ok { entity.isImplementKvMapping = true entity.mappingLabelField = kvMapping.LabelField() entity.mappingValueField = kvMapping.ValueField() } if opts.DB != nil { entity.statement = &gorm.Statement{ DB: opts.DB, ConnPool: opts.DB.ConnPool, Clauses: map[string]clause.Clause{}, } if err := entity.statement.Parse(model); err != nil { panic(err) } if entity.statement.Schema != nil { if entity.statement.Schema.PrimaryFieldDBNames != nil && len(entity.statement.Schema.PrimaryFieldDBNames) > 0 { entity.primaryKey = entity.statement.Schema.PrimaryFieldDBNames[0] } //把字段名称转成成数据库字段 for _, field := range entity.statement.Schema.Fields { if field.Name == entity.mappingValueField { entity.mappingValueField = field.DBName } if field.Name == entity.mappingLabelField { entity.mappingLabelField = field.DBName } } } } return entity }