|
- package validator
- import (
- "context"
- "encoding/json"
- "fmt"
- "git.nspix.com/golang/rest/internal/empty"
- "git.nspix.com/golang/rest/orm/schema"
- "git.nspix.com/golang/rest/scenario"
- "github.com/go-playground/locales/en"
- "github.com/go-playground/locales/zh"
- "github.com/go-playground/universal-translator"
- "github.com/go-playground/validator/v10"
- translation "github.com/go-playground/validator/v10/translations/zh"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- shm "gorm.io/gorm/schema"
- "reflect"
- "regexp"
- "strconv"
- "strings"
- )
- type validateScope struct{}
- type scope struct {
- Db *gorm.DB
- Column string
- Model interface{}
- }
- var (
- validate = validator.New()
- SkipValidations = "validations:skip_validations"
- translator ut.Translator
- validateScopeKey = validateScope{}
- telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
- )
- func init() {
- enLang := en.New()
- zhLang := zh.New()
- universal := ut.New(enLang, zhLang)
- translator, _ = universal.GetTranslator("en")
- _ = translation.RegisterDefaultTranslations(validate, translator)
- _ = validate.RegisterTranslation("db_unique", translator, func(ut ut.Translator) error {
- return ut.Add("db_unique", "{0}值已经存在.", true)
- }, func(ut ut.Translator, fe validator.FieldError) string {
- t, err := ut.T(fe.Tag(), fe.Field())
- if err != nil {
- return fe.(error).Error()
- }
- return t
- })
- _ = validate.RegisterTranslation("telephone", translator, func(ut ut.Translator) error {
- return ut.Add("telephone", "{0}号码不合法.", true)
- }, func(ut ut.Translator, fe validator.FieldError) string {
- t, err := ut.T(fe.Tag(), fe.Field())
- if err != nil {
- return fe.(error).Error()
- }
- return t
- })
- validate.RegisterTagNameFunc(func(field reflect.StructField) string {
- name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0]
- if name == "-" {
- return ""
- } else {
- return name
- }
- })
- _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
- val := fmt.Sprint(fl.Field().Interface())
- return telephoneRegex.MatchString(val)
- })
- _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
- val := fl.Field().Interface()
- var (
- sp *scope
- ok bool
- count int64
- err error
- sess *gorm.DB
- field *shm.Field
- refValue reflect.Value
- )
- if sp, ok = ctx.Value(validateScopeKey).(*scope); !ok {
- return true
- }
- sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB {
- s := db.Session(&gorm.Session{})
- s.Statement = &gorm.Statement{
- DB: db,
- ConnPool: db.Statement.ConnPool,
- Context: db.Statement.Context,
- Clauses: map[string]clause.Clause{},
- }
- return s
- })
- if err = sess.Statement.Parse(sp.Model); err == nil {
- if len(sess.Statement.Schema.PrimaryFields) > 0 {
- field = sess.Statement.Schema.PrimaryFields[0]
- refValue = reflect.Indirect(reflect.ValueOf(sp.Model))
- for _, n := range field.BindNames {
- refValue = refValue.FieldByName(n)
- }
- }
- }
- if refValue.IsValid() && !refValue.IsZero() && field != nil {
- sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count)
- } else {
- sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count)
- }
- if count > 0 {
- return false
- }
- return true
- })
- }
- type StructError struct {
- Tag string `json:"tag"`
- Column string `json:"column"`
- Message string `json:"message"`
- }
- func (err *StructError) Error() string {
- return err.Column + err.Message
- }
- func generateTag(scene string, rule schema.Rule) string {
- var s string
- if rule.Min != 0 {
- s += ",min=" + strconv.Itoa(rule.Min)
- }
- if rule.Max != 0 {
- s += ",max=" + strconv.Itoa(rule.Max)
- }
- if rule.Unique {
- s += ",db_unique"
- }
- if rule.Type != "" {
- s += "," + rule.Type
- }
- if rule.Required != nil && len(rule.Required) > 0 {
- for _, v := range rule.Required {
- if v == scene {
- s += ",required"
- }
- }
- }
- if s != "" {
- return s[1:]
- } else {
- return s
- }
- }
- func validation(db *gorm.DB) {
- var (
- err error
- tag string
- rule schema.Rule
- scene = scenario.Create
- )
- if _, ok := db.Get("gorm:update_column"); ok {
- return
- }
- if result, ok := db.Get(SkipValidations); ok && result.(bool) {
- return
- }
- if db.Error == nil && db.Statement.Schema != nil {
- stmt := db.Statement
- if stmt.Model != nil {
- if !schema.IsNewRecord(stmt.ReflectValue, stmt) {
- scene = scenario.Update
- }
- schemas := schema.VisibleField("organize", stmt.Table, scene, db)
- for _, field := range schemas {
- if field.Rules == "" {
- continue
- }
- if err = json.Unmarshal([]byte(field.Rules), &rule); err != nil {
- continue
- }
- if tag = generateTag(scene, rule); tag == "" {
- continue
- }
- fieldValue := stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(field.Column).Name)
- if !fieldValue.IsValid() {
- continue
- }
- //如果没有必填 并且值为空跳过验证
- if !strings.Contains(tag, "required") && empty.Is(fieldValue.Interface()) {
- continue
- }
- ctx := context.WithValue(context.Background(), validateScopeKey, &scope{
- Db: db,
- Column: field.Column,
- Model: stmt.Model,
- })
- if err = validate.VarCtx(ctx, fieldValue.Interface(), tag); err != nil {
- if errors, ok := err.(validator.ValidationErrors); ok {
- for _, e := range errors {
- _ = db.AddError(&StructError{
- Tag: e.Tag(),
- Column: field.Column,
- Message: e.Translate(translator),
- })
- }
- } else {
- _ = db.AddError(err)
- }
- break
- }
- }
- }
- }
- }
- // RegisterCallbacks register callback into GORM DB
- // BeforeSave and BeforeCreate is called on before_create
- // so this is called just after them
- func RegisterCallbacks(db *gorm.DB) {
- callback := db.Callback()
- if callback.Create().Get("validations:validate") == nil {
- _ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation)
- }
- if callback.Update().Get("validations:validate") == nil {
- _ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation)
- }
- }
|