package plugins import ( "context" "fmt" "git.nspix.com/golang/rest/v3" errpkg "git.nspix.com/golang/rest/v3/error" "git.nspix.com/golang/rest/v3/utils" validator "github.com/go-playground/validator/v10" "gorm.io/gorm" "gorm.io/gorm/schema" "reflect" "regexp" "strconv" "strings" ) const ( SkipValidations = "validations:skip_validations" ) type ( validateRule struct { Rule string Value string Valid bool } validateScope struct{} validScope struct { DB *gorm.DB Column string Model interface{} } validation struct { crud *rest.CRUD } ) var ( validate = validator.New() validateScopeKey = validateScope{} telephoneRegex = regexp.MustCompile("^\\d{5,20}$") ) func init() { _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool { val := fmt.Sprint(fl.Field().Interface()) return telephoneRegex.MatchString(val) }) _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool { var ( scope *validScope ok bool count int64 field *schema.Field primaryKeyValue reflect.Value ) val := fl.Field().Interface() if scope, ok = ctx.Value(validateScopeKey).(*validScope); !ok { return true } if len(scope.DB.Statement.Schema.PrimaryFields) > 0 { field = scope.DB.Statement.Schema.PrimaryFields[0] primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model)) for _, n := range field.BindNames { primaryKeyValue = primaryKeyValue.FieldByName(n) } } sess := scope.DB.Session(&gorm.Session{NewDB: true}) if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil { sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count) } else { sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count) } if count > 0 { return false } return true }) } func newRule(ss ...string) *validateRule { v := &validateRule{ Valid: true, } if len(ss) == 1 { v.Rule = ss[0] } else if len(ss) >= 2 { v.Rule = ss[0] v.Value = ss[1] } return v } func generateRules(scm *rest.Schema, scenario string, rule rest.Rule) []*validateRule { rules := make([]*validateRule, 0, 5) if rule.Min != 0 { rules = append(rules, newRule("min", strconv.Itoa(rule.Min))) } if rule.Max != 0 { rules = append(rules, newRule("max", strconv.Itoa(rule.Max))) } //主键不做唯一判断 if rule.Unique && !scm.Attribute.PrimaryKey { rules = append(rules, newRule("db_unique")) } if rule.Type != "" { rules = append(rules, newRule(rule.Type)) } if rule.Required != nil && len(rule.Required) > 0 { for _, v := range rule.Required { if v == scenario { rules = append(rules, newRule("required")) break } } } return rules } func buildRules(rs []*validateRule) string { var sb strings.Builder for _, r := range rs { if !r.Valid { continue } if sb.Len() > 0 { sb.WriteString(",") } if r.Value == "" { sb.WriteString(r.Rule) } else { sb.WriteString(r.Rule + "=" + r.Value) } } return sb.String() } func getRule(name string, rules []*validateRule) *validateRule { for _, r := range rules { if r.Rule == name { return r } } return nil } func formatError(rule rest.Rule, scm *rest.Schema, tag string) string { var s string switch tag { case "db_unique": s = scm.Label + "值已经存在." break case "required": s = scm.Label + "值不能为空." case "max": if scm.Type == "string" { s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max) } else { s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max) } case "min": if scm.Type == "string" { s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max) } else { s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max) } } return s } func (vv *validation) validation(db *gorm.DB) { if result, ok := db.Get(SkipValidations); ok && result.(bool) { return } var ( ok bool err error rules []*validateRule stmt *gorm.Statement model rest.Model scenario string skipValidate bool value reflect.Value schemas []*rest.Schema ) stmt = db.Statement if stmt.Model == nil { return } if model, ok = stmt.Model.(rest.Model); !ok { return } scenario = rest.ScenarioUpdate for _, pk := range stmt.Schema.PrimaryFields { if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) { scenario = rest.ScenarioCreate break } } schemas = vv.crud.VisibleSchemas(context.Background(), stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario) for _, scm := range schemas { if rules = generateRules(scm, scenario, scm.Rule); len(rules) <= 0 { continue } value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name) if !value.IsValid() { continue } skipValidate = false if r := getRule("required", rules); r != nil { if value.Interface() != nil { vType := reflect.ValueOf(value.Interface()) switch vType.Kind() { case reflect.Bool: skipValidate = true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: skipValidate = true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: skipValidate = true case reflect.Float32, reflect.Float64: skipValidate = true } } if skipValidate { r.Valid = false } } else { if utils.IsEmpty(value.Interface()) { continue } } ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{ DB: db, Column: scm.Column, Model: stmt.Model, }) if err = validate.VarCtx(ctx, value.Interface(), buildRules(rules)); err != nil { if errs, ok := err.(validator.ValidationErrors); ok { for _, e := range errs { _ = db.AddError(&errpkg.StructError{ Tag: e.Tag(), Column: scm.Column, Message: formatError(scm.Rule, scm, e.Tag()), }) } } else { _ = db.AddError(err) } break } } } func RegisterValidationCallback(db *gorm.DB, crud *rest.CRUD) (err error) { callback := db.Callback() vv := &validation{crud: crud} if callback.Create().Get("validations:validate") == nil { err = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation) } if callback.Update().Get("validations:validate") == nil { err = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation) } return }