package plugins import ( "context" "encoding/json" "fmt" "reflect" "regexp" "strconv" "strings" "git.nspix.com/golang/micro/helper/utils" "git.nspix.com/golang/rest/v2" errors2 "git.nspix.com/golang/rest/v2/errors" "github.com/go-playground/validator/v10" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) const ( SkipValidations = "validations:skip_validations" ) type ( validateScope struct{} validScope struct { Db *gorm.DB Column string Model interface{} } validation struct { curd *rest.CRUD } ) var ( validate = validator.New() validateScopeKey = validateScope{} telephoneRegex = regexp.MustCompile("^\\d{5,20}$") ) func init() { _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool { val := fmt.Sprint(fl.Field().Interface()) return telephoneRegex.MatchString(val) }) _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool { val := fl.Field().Interface() var ( sp *validScope ok bool count int64 err error sess *gorm.DB field *schema.Field refValue reflect.Value ) if sp, ok = ctx.Value(validateScopeKey).(*validScope); !ok { return true } sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB { s := db.Session(&gorm.Session{}) s.Statement = &gorm.Statement{ DB: db, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, } return s }) if err = sess.Statement.Parse(sp.Model); err == nil { if len(sess.Statement.Schema.PrimaryFields) > 0 { field = sess.Statement.Schema.PrimaryFields[0] refValue = reflect.Indirect(reflect.ValueOf(sp.Model)) for _, n := range field.BindNames { refValue = refValue.FieldByName(n) } } } if refValue.IsValid() && !refValue.IsZero() && field != nil { sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count) } else { sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count) } if count > 0 { return false } return true }) } func generateTag(scm *rest.Schema, scenario string, rule rest.Rule) string { var s string if rule.Min != 0 { s += ",min=" + strconv.Itoa(rule.Min) } if rule.Max != 0 { s += ",max=" + strconv.Itoa(rule.Max) } //主键不做判断 if rule.Unique && scm.PrimaryKey == 0 { s += ",db_unique" } if rule.Type != "" { s += "," + rule.Type } if rule.Required != nil && len(rule.Required) > 0 { for _, v := range rule.Required { if v == scenario { s += ",required" } } } if s != "" { return s[1:] } else { return s } } func formatError(rule rest.Rule, scm *rest.Schema, tag string) string { var s string switch tag { case "db_unique": s = scm.Label + "值已经存在." break case "required": s = scm.Label + "值不能为空." case "max": if scm.Type == "string" { s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max) } else { s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max) } case "min": if scm.Type == "string" { s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max) } else { s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max) } } return s } func (vv *validation) validation(db *gorm.DB) { if result, ok := db.Get(SkipValidations); ok && result.(bool) { return } var ( ok bool err error ruleString string stmt *gorm.Statement model rest.Model rule rest.Rule scenario string skipValidate bool value reflect.Value schemas []*rest.Schema ) stmt = db.Statement if stmt.Model == nil { return } if model, ok = stmt.Model.(rest.Model); !ok { return } scenario = rest.ScenarioUpdate for _, pk := range stmt.Schema.PrimaryFields { if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) { scenario = rest.ScenarioCreate break } } schemas = vv.curd.Schemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario) for _, scm := range schemas { if scm.Rules == "" { continue } if err = json.Unmarshal([]byte(scm.Rules), &rule); err != nil { continue } if ruleString = generateTag(scm, scenario, rule); ruleString == "" { continue } value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name) if !value.IsValid() { continue } skipValidate = false if strings.Contains(ruleString, "required") { //如果数值为整形,小数,Bool跳过验证 if value.Interface() != nil { vType := reflect.ValueOf(value.Interface()) switch vType.Kind() { case reflect.Bool: skipValidate = true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: skipValidate = true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: skipValidate = true case reflect.Float32, reflect.Float64: skipValidate = true } } if skipValidate { continue } } else { if utils.IsEmpty(value.Interface()) { continue } } ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{ Db: db, Column: scm.Column, Model: stmt.Model, }) if err = validate.VarCtx(ctx, value.Interface(), ruleString); err != nil { if errors, ok := err.(validator.ValidationErrors); ok { for _, e := range errors { _ = db.AddError(&errors2.StructError{ Tag: e.Tag(), Column: scm.Column, Message: formatError(rule, scm, e.Tag()), }) } } else { _ = db.AddError(err) } break } } } func RegisterValidationCallback(curd *rest.CRUD) { callback := curd.DB().Callback() vv := &validation{curd: curd} if callback.Create().Get("validations:validate") == nil { _ = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation) } if callback.Update().Get("validations:validate") == nil { _ = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation) } }