package validator import ( "context" "encoding/json" "fmt" "git.nspix.com/golang/rest/internal/empty" "git.nspix.com/golang/rest/orm/schema" "git.nspix.com/golang/rest/scenario" "github.com/go-playground/locales/en" "github.com/go-playground/locales/zh" "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" translation "github.com/go-playground/validator/v10/translations/zh" "gorm.io/gorm" "gorm.io/gorm/clause" shm "gorm.io/gorm/schema" "reflect" "regexp" "strconv" "strings" ) type validateScope struct{} type scope struct { Db *gorm.DB Column string Model interface{} } var ( validate = validator.New() SkipValidations = "validations:skip_validations" translator ut.Translator validateScopeKey = validateScope{} telephoneRegex = regexp.MustCompile("^\\d{5,20}$") ) func init() { enLang := en.New() zhLang := zh.New() universal := ut.New(enLang, zhLang) translator, _ = universal.GetTranslator("en") _ = translation.RegisterDefaultTranslations(validate, translator) _ = validate.RegisterTranslation("db_unique", translator, func(ut ut.Translator) error { return ut.Add("db_unique", "{0}值已经存在.", true) }, func(ut ut.Translator, fe validator.FieldError) string { t, err := ut.T(fe.Tag(), fe.Field()) if err != nil { return fe.(error).Error() } return t }) _ = validate.RegisterTranslation("telephone", translator, func(ut ut.Translator) error { return ut.Add("telephone", "{0}号码不合法.", true) }, func(ut ut.Translator, fe validator.FieldError) string { t, err := ut.T(fe.Tag(), fe.Field()) if err != nil { return fe.(error).Error() } return t }) validate.RegisterTagNameFunc(func(field reflect.StructField) string { name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0] if name == "-" { return "" } else { return name } }) _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool { val := fmt.Sprint(fl.Field().Interface()) return telephoneRegex.MatchString(val) }) _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool { val := fl.Field().Interface() var ( sp *scope ok bool count int64 err error sess *gorm.DB field *shm.Field refValue reflect.Value ) if sp, ok = ctx.Value(validateScopeKey).(*scope); !ok { return true } sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB { s := db.Session(&gorm.Session{}) s.Statement = &gorm.Statement{ DB: db, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, } return s }) if err = sess.Statement.Parse(sp.Model); err == nil { if len(sess.Statement.Schema.PrimaryFields) > 0 { field = sess.Statement.Schema.PrimaryFields[0] refValue = reflect.Indirect(reflect.ValueOf(sp.Model)) for _, n := range field.BindNames { refValue = refValue.FieldByName(n) } } } if refValue.IsValid() && !refValue.IsZero() && field != nil { sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count) } else { sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count) } if count > 0 { return false } return true }) } type StructError struct { Tag string `json:"tag"` Column string `json:"column"` Message string `json:"message"` } func (err *StructError) Error() string { return err.Column + err.Message } func generateTag(scene string, rule schema.Rule) string { var s string if rule.Min != 0 { s += ",min=" + strconv.Itoa(rule.Min) } if rule.Max != 0 { s += ",max=" + strconv.Itoa(rule.Max) } if rule.Unique { s += ",db_unique" } if rule.Type != "" { s += "," + rule.Type } if rule.Required != nil && len(rule.Required) > 0 { for _, v := range rule.Required { if v == scene { s += ",required" } } } if s != "" { return s[1:] } else { return s } } func validation(db *gorm.DB) { var ( err error tag string rule schema.Rule scene = scenario.Create ) if _, ok := db.Get("gorm:update_column"); ok { return } if result, ok := db.Get(SkipValidations); ok && result.(bool) { return } if db.Error == nil && db.Statement.Schema != nil { stmt := db.Statement if stmt.Model != nil { if !schema.IsNewRecord(stmt.ReflectValue, stmt) { scene = scenario.Update } schemas := schema.VisibleField("organize", stmt.Table, scene) for _, field := range schemas { if field.Rules == "" { continue } if err = json.Unmarshal([]byte(field.Rules), &rule); err != nil { continue } if tag = generateTag(scene, rule); tag == "" { continue } fieldValue := stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(field.Column).Name) if !fieldValue.IsValid() { continue } //如果没有必填 并且值为空跳过验证 if !strings.Contains(tag, "required") && empty.Is(fieldValue.Interface()) { continue } ctx := context.WithValue(context.Background(), validateScopeKey, &scope{ Db: db, Column: field.Column, Model: stmt.Model, }) if err = validate.VarCtx(ctx, fieldValue.Interface(), tag); err != nil { if errors, ok := err.(validator.ValidationErrors); ok { for _, e := range errors { _ = db.AddError(&StructError{ Tag: e.Tag(), Column: field.Column, Message: e.Translate(translator), }) } } else { _ = db.AddError(err) } break } } } } } // RegisterCallbacks register callback into GORM DB // BeforeSave and BeforeCreate is called on before_create // so this is called just after them func RegisterCallbacks(db *gorm.DB) { callback := db.Callback() if callback.Create().Get("validations:validate") == nil { _ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation) } if callback.Update().Get("validations:validate") == nil { _ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation) } }