123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- package plugins
- import (
- "context"
- "fmt"
- "git.nspix.com/golang/rest/v3"
- errpkg "git.nspix.com/golang/rest/v3/error"
- "git.nspix.com/golang/rest/v3/utils"
- validator "github.com/go-playground/validator/v10"
- "gorm.io/gorm"
- "gorm.io/gorm/schema"
- "reflect"
- "regexp"
- "strconv"
- "strings"
- )
- const (
- SkipValidations = "validations:skip_validations"
- )
- type (
- validateRule struct {
- Rule string
- Value string
- Valid bool
- }
- validateScope struct{}
- validScope struct {
- DB *gorm.DB
- Column string
- Model interface{}
- }
- validation struct {
- crud *rest.CRUD
- }
- )
- var (
- validate = validator.New()
- validateScopeKey = validateScope{}
- telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
- )
- func init() {
- _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
- val := fmt.Sprint(fl.Field().Interface())
- return telephoneRegex.MatchString(val)
- })
- _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
- var (
- scope *validScope
- ok bool
- count int64
- field *schema.Field
- primaryKeyValue reflect.Value
- )
- val := fl.Field().Interface()
- if scope, ok = ctx.Value(validateScopeKey).(*validScope); !ok {
- return true
- }
- if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
- field = scope.DB.Statement.Schema.PrimaryFields[0]
- primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
- for _, n := range field.BindNames {
- primaryKeyValue = primaryKeyValue.FieldByName(n)
- }
- }
- sess := scope.DB.Session(&gorm.Session{NewDB: true})
- if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
- sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
- } else {
- sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
- }
- if count > 0 {
- return false
- }
- return true
- })
- }
- func newRule(ss ...string) *validateRule {
- v := &validateRule{
- Valid: true,
- }
- if len(ss) == 1 {
- v.Rule = ss[0]
- } else if len(ss) >= 2 {
- v.Rule = ss[0]
- v.Value = ss[1]
- }
- return v
- }
- func generateRules(scm *rest.Schema, scenario string, rule rest.Rule) []*validateRule {
- rules := make([]*validateRule, 0, 5)
- if rule.Min != 0 {
- rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
- }
- if rule.Max != 0 {
- rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
- }
- //主键不做唯一判断
- if rule.Unique && !scm.Attribute.PrimaryKey {
- rules = append(rules, newRule("db_unique"))
- }
- if rule.Type != "" {
- rules = append(rules, newRule(rule.Type))
- }
- if rule.Required != nil && len(rule.Required) > 0 {
- for _, v := range rule.Required {
- if v == scenario {
- rules = append(rules, newRule("required"))
- break
- }
- }
- }
- return rules
- }
- func buildRules(rs []*validateRule) string {
- var sb strings.Builder
- for _, r := range rs {
- if !r.Valid {
- continue
- }
- if sb.Len() > 0 {
- sb.WriteString(",")
- }
- if r.Value == "" {
- sb.WriteString(r.Rule)
- } else {
- sb.WriteString(r.Rule + "=" + r.Value)
- }
- }
- return sb.String()
- }
- func getRule(name string, rules []*validateRule) *validateRule {
- for _, r := range rules {
- if r.Rule == name {
- return r
- }
- }
- return nil
- }
- func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
- var s string
- switch tag {
- case "db_unique":
- s = scm.Label + "值已经存在."
- break
- case "required":
- s = scm.Label + "值不能为空."
- case "max":
- if scm.Type == "string" {
- s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
- } else {
- s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
- }
- case "min":
- if scm.Type == "string" {
- s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
- } else {
- s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
- }
- }
- return s
- }
- func (vv *validation) validation(db *gorm.DB) {
- if result, ok := db.Get(SkipValidations); ok && result.(bool) {
- return
- }
- var (
- ok bool
- err error
- rules []*validateRule
- stmt *gorm.Statement
- model rest.Model
- scenario string
- skipValidate bool
- value reflect.Value
- schemas []*rest.Schema
- )
- stmt = db.Statement
- if stmt.Model == nil {
- return
- }
- if model, ok = stmt.Model.(rest.Model); !ok {
- return
- }
- scenario = rest.ScenarioUpdate
- for _, pk := range stmt.Schema.PrimaryFields {
- if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) {
- scenario = rest.ScenarioCreate
- break
- }
- }
- schemas = vv.crud.VisibleSchemas(context.Background(), stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
- for _, scm := range schemas {
- if rules = generateRules(scm, scenario, scm.Rule); len(rules) <= 0 {
- continue
- }
- value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name)
- if !value.IsValid() {
- continue
- }
- skipValidate = false
- if r := getRule("required", rules); r != nil {
- if value.Interface() != nil {
- vType := reflect.ValueOf(value.Interface())
- switch vType.Kind() {
- case reflect.Bool:
- skipValidate = true
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- skipValidate = true
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- skipValidate = true
- case reflect.Float32, reflect.Float64:
- skipValidate = true
- }
- }
- if skipValidate {
- r.Valid = false
- }
- } else {
- if utils.IsEmpty(value.Interface()) {
- continue
- }
- }
- ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{
- DB: db,
- Column: scm.Column,
- Model: stmt.Model,
- })
- if err = validate.VarCtx(ctx, value.Interface(), buildRules(rules)); err != nil {
- if errs, ok := err.(validator.ValidationErrors); ok {
- for _, e := range errs {
- _ = db.AddError(&errpkg.StructError{
- Tag: e.Tag(),
- Column: scm.Column,
- Message: formatError(scm.Rule, scm, e.Tag()),
- })
- }
- } else {
- _ = db.AddError(err)
- }
- break
- }
- }
- }
- func RegisterValidationCallback(db *gorm.DB, crud *rest.CRUD) (err error) {
- callback := db.Callback()
- vv := &validation{crud: crud}
- if callback.Create().Get("validations:validate") == nil {
- err = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation)
- }
- if callback.Update().Get("validations:validate") == nil {
- err = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation)
- }
- return
- }
|