123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- package plugins
- import (
- "context"
- "encoding/json"
- "fmt"
- "reflect"
- "regexp"
- "strconv"
- "strings"
- "git.nspix.com/golang/micro/helper/utils"
- "git.nspix.com/golang/rest/v2"
- errors2 "git.nspix.com/golang/rest/v2/errors"
- "github.com/go-playground/validator/v10"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/schema"
- )
- const (
- SkipValidations = "validations:skip_validations"
- )
- type (
- validateScope struct{}
- validScope struct {
- Db *gorm.DB
- Column string
- Model interface{}
- }
- )
- var (
- validate = validator.New()
- validateScopeKey = validateScope{}
- telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
- )
- func init() {
- _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
- val := fmt.Sprint(fl.Field().Interface())
- return telephoneRegex.MatchString(val)
- })
- _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
- val := fl.Field().Interface()
- var (
- sp *validScope
- ok bool
- count int64
- err error
- sess *gorm.DB
- field *schema.Field
- refValue reflect.Value
- )
- if sp, ok = ctx.Value(validateScopeKey).(*validScope); !ok {
- return true
- }
- sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB {
- s := db.Session(&gorm.Session{})
- s.Statement = &gorm.Statement{
- DB: db,
- ConnPool: db.Statement.ConnPool,
- Context: db.Statement.Context,
- Clauses: map[string]clause.Clause{},
- }
- return s
- })
- if err = sess.Statement.Parse(sp.Model); err == nil {
- if len(sess.Statement.Schema.PrimaryFields) > 0 {
- field = sess.Statement.Schema.PrimaryFields[0]
- refValue = reflect.Indirect(reflect.ValueOf(sp.Model))
- for _, n := range field.BindNames {
- refValue = refValue.FieldByName(n)
- }
- }
- }
- if refValue.IsValid() && !refValue.IsZero() && field != nil {
- sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count)
- } else {
- sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count)
- }
- if count > 0 {
- return false
- }
- return true
- })
- }
- func generateTag(scm *rest.Schema, scenario string, rule rest.Rule) string {
- var s string
- if rule.Min != 0 {
- s += ",min=" + strconv.Itoa(rule.Min)
- }
- if rule.Max != 0 {
- s += ",max=" + strconv.Itoa(rule.Max)
- }
- //主键不做判断
- if rule.Unique && scm.PrimaryKey == 0 {
- s += ",db_unique"
- }
- if rule.Type != "" {
- s += "," + rule.Type
- }
- if rule.Required != nil && len(rule.Required) > 0 {
- for _, v := range rule.Required {
- if v == scenario {
- s += ",required"
- }
- }
- }
- if s != "" {
- return s[1:]
- } else {
- return s
- }
- }
- func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
- var s string
- switch tag {
- case "db_unique":
- s = scm.Label + "值已经存在."
- break
- case "required":
- s = scm.Label + "值不能为空."
- case "max":
- if scm.Type == "string" {
- s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
- } else {
- s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
- }
- case "min":
- if scm.Type == "string" {
- s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
- } else {
- s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
- }
- }
- return s
- }
- func validation(db *gorm.DB) {
- if result, ok := db.Get(SkipValidations); ok && result.(bool) {
- return
- }
- var (
- ok bool
- err error
- ruleString string
- stmt *gorm.Statement
- model rest.Model
- rule rest.Rule
- scenario string
- skipValidate bool
- value reflect.Value
- schemas []*rest.Schema
- )
- stmt = db.Statement
- if stmt.Model == nil {
- return
- }
- if model, ok = stmt.Model.(rest.Model); !ok {
- return
- }
- scenario = rest.ScenarioUpdate
- for _, pk := range stmt.Schema.PrimaryFields {
- if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) {
- scenario = rest.ScenarioCreate
- break
- }
- }
- schemas = rest.VisibleSchemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
- for _, scm := range schemas {
- if scm.Rules == "" {
- continue
- }
- if err = json.Unmarshal([]byte(scm.Rules), &rule); err != nil {
- continue
- }
- if ruleString = generateTag(scm, scenario, rule); ruleString == "" {
- continue
- }
- value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name)
- if !value.IsValid() {
- continue
- }
- skipValidate = false
- if strings.Contains(ruleString, "required") {
- //如果数值为整形,小数,Bool跳过验证
- if value.Interface() != nil {
- vType := reflect.ValueOf(value.Interface())
- switch vType.Kind() {
- case reflect.Bool:
- skipValidate = true
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- skipValidate = true
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- skipValidate = true
- case reflect.Float32, reflect.Float64:
- skipValidate = true
- }
- }
- if skipValidate {
- continue
- }
- } else {
- if utils.IsEmpty(value.Interface()) {
- continue
- }
- }
- ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{
- Db: db,
- Column: scm.Column,
- Model: stmt.Model,
- })
- if err = validate.VarCtx(ctx, value.Interface(), ruleString); err != nil {
- if errors, ok := err.(validator.ValidationErrors); ok {
- for _, e := range errors {
- _ = db.AddError(&errors2.StructError{
- Tag: e.Tag(),
- Column: scm.Column,
- Message: formatError(rule, scm, e.Tag()),
- })
- }
- } else {
- _ = db.AddError(err)
- }
- break
- }
- }
- }
- func RegisterValidationCallback(db *gorm.DB) {
- callback := db.Callback()
- if callback.Create().Get("validations:validate") == nil {
- _ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation)
- }
- if callback.Update().Get("validations:validate") == nil {
- _ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation)
- }
- }
|