validate.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package plugins
  2. import (
  3. "context"
  4. "fmt"
  5. "git.nspix.com/golang/rest/v3"
  6. errpkg "git.nspix.com/golang/rest/v3/error"
  7. "git.nspix.com/golang/rest/v3/utils"
  8. validator "github.com/go-playground/validator/v10"
  9. "gorm.io/gorm"
  10. "gorm.io/gorm/schema"
  11. "reflect"
  12. "regexp"
  13. "strconv"
  14. "strings"
  15. )
  16. const (
  17. SkipValidations = "validations:skip_validations"
  18. )
  19. type (
  20. validateRule struct {
  21. Rule string
  22. Value string
  23. Valid bool
  24. }
  25. validateScope struct{}
  26. validScope struct {
  27. DB *gorm.DB
  28. Column string
  29. Model interface{}
  30. }
  31. validation struct {
  32. crud *rest.CRUD
  33. }
  34. )
  35. var (
  36. validate = validator.New()
  37. validateScopeKey = validateScope{}
  38. telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
  39. )
  40. func init() {
  41. _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
  42. val := fmt.Sprint(fl.Field().Interface())
  43. return telephoneRegex.MatchString(val)
  44. })
  45. _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
  46. var (
  47. scope *validScope
  48. ok bool
  49. count int64
  50. field *schema.Field
  51. primaryKeyValue reflect.Value
  52. )
  53. val := fl.Field().Interface()
  54. if scope, ok = ctx.Value(validateScopeKey).(*validScope); !ok {
  55. return true
  56. }
  57. if len(scope.DB.Statement.Schema.PrimaryFields) > 0 {
  58. field = scope.DB.Statement.Schema.PrimaryFields[0]
  59. primaryKeyValue = reflect.Indirect(reflect.ValueOf(scope.Model))
  60. for _, n := range field.BindNames {
  61. primaryKeyValue = primaryKeyValue.FieldByName(n)
  62. }
  63. }
  64. sess := scope.DB.Session(&gorm.Session{NewDB: true})
  65. if primaryKeyValue.IsValid() && !primaryKeyValue.IsZero() && field != nil {
  66. sess.Model(scope.Model).Where(scope.Column+"=? AND "+field.Name+" != ?", val, primaryKeyValue.Interface()).Count(&count)
  67. } else {
  68. sess.Model(scope.Model).Where(scope.Column+"=?", val).Count(&count)
  69. }
  70. if count > 0 {
  71. return false
  72. }
  73. return true
  74. })
  75. }
  76. func newRule(ss ...string) *validateRule {
  77. v := &validateRule{
  78. Valid: true,
  79. }
  80. if len(ss) == 1 {
  81. v.Rule = ss[0]
  82. } else if len(ss) >= 2 {
  83. v.Rule = ss[0]
  84. v.Value = ss[1]
  85. }
  86. return v
  87. }
  88. func generateRules(scm *rest.Schema, scenario string, rule rest.Rule) []*validateRule {
  89. rules := make([]*validateRule, 0, 5)
  90. if rule.Min != 0 {
  91. rules = append(rules, newRule("min", strconv.Itoa(rule.Min)))
  92. }
  93. if rule.Max != 0 {
  94. rules = append(rules, newRule("max", strconv.Itoa(rule.Max)))
  95. }
  96. //主键不做唯一判断
  97. if rule.Unique && !scm.Attribute.PrimaryKey {
  98. rules = append(rules, newRule("db_unique"))
  99. }
  100. if rule.Type != "" {
  101. rules = append(rules, newRule(rule.Type))
  102. }
  103. if rule.Required != nil && len(rule.Required) > 0 {
  104. for _, v := range rule.Required {
  105. if v == scenario {
  106. rules = append(rules, newRule("required"))
  107. break
  108. }
  109. }
  110. }
  111. return rules
  112. }
  113. func buildRules(rs []*validateRule) string {
  114. var sb strings.Builder
  115. for _, r := range rs {
  116. if !r.Valid {
  117. continue
  118. }
  119. if sb.Len() > 0 {
  120. sb.WriteString(",")
  121. }
  122. if r.Value == "" {
  123. sb.WriteString(r.Rule)
  124. } else {
  125. sb.WriteString(r.Rule + "=" + r.Value)
  126. }
  127. }
  128. return sb.String()
  129. }
  130. func getRule(name string, rules []*validateRule) *validateRule {
  131. for _, r := range rules {
  132. if r.Rule == name {
  133. return r
  134. }
  135. }
  136. return nil
  137. }
  138. func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
  139. var s string
  140. switch tag {
  141. case "db_unique":
  142. s = scm.Label + "值已经存在."
  143. break
  144. case "required":
  145. s = scm.Label + "值不能为空."
  146. case "max":
  147. if scm.Type == "string" {
  148. s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
  149. } else {
  150. s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
  151. }
  152. case "min":
  153. if scm.Type == "string" {
  154. s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
  155. } else {
  156. s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
  157. }
  158. }
  159. return s
  160. }
  161. func (vv *validation) validation(db *gorm.DB) {
  162. if result, ok := db.Get(SkipValidations); ok && result.(bool) {
  163. return
  164. }
  165. var (
  166. ok bool
  167. err error
  168. rules []*validateRule
  169. stmt *gorm.Statement
  170. model rest.Model
  171. scenario string
  172. skipValidate bool
  173. value reflect.Value
  174. schemas []*rest.Schema
  175. )
  176. stmt = db.Statement
  177. if stmt.Model == nil {
  178. return
  179. }
  180. if model, ok = stmt.Model.(rest.Model); !ok {
  181. return
  182. }
  183. scenario = rest.ScenarioUpdate
  184. for _, pk := range stmt.Schema.PrimaryFields {
  185. if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) {
  186. scenario = rest.ScenarioCreate
  187. break
  188. }
  189. }
  190. schemas = vv.crud.VisibleSchemas(context.Background(), stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
  191. for _, scm := range schemas {
  192. if rules = generateRules(scm, scenario, scm.Rule); len(rules) <= 0 {
  193. continue
  194. }
  195. value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name)
  196. if !value.IsValid() {
  197. continue
  198. }
  199. skipValidate = false
  200. if r := getRule("required", rules); r != nil {
  201. if value.Interface() != nil {
  202. vType := reflect.ValueOf(value.Interface())
  203. switch vType.Kind() {
  204. case reflect.Bool:
  205. skipValidate = true
  206. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  207. skipValidate = true
  208. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  209. skipValidate = true
  210. case reflect.Float32, reflect.Float64:
  211. skipValidate = true
  212. }
  213. }
  214. if skipValidate {
  215. r.Valid = false
  216. }
  217. } else {
  218. if utils.IsEmpty(value.Interface()) {
  219. continue
  220. }
  221. }
  222. ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{
  223. DB: db,
  224. Column: scm.Column,
  225. Model: stmt.Model,
  226. })
  227. if err = validate.VarCtx(ctx, value.Interface(), buildRules(rules)); err != nil {
  228. if errs, ok := err.(validator.ValidationErrors); ok {
  229. for _, e := range errs {
  230. _ = db.AddError(&errpkg.StructError{
  231. Tag: e.Tag(),
  232. Column: scm.Column,
  233. Message: formatError(scm.Rule, scm, e.Tag()),
  234. })
  235. }
  236. } else {
  237. _ = db.AddError(err)
  238. }
  239. break
  240. }
  241. }
  242. }
  243. func RegisterValidationCallback(db *gorm.DB, crud *rest.CRUD) (err error) {
  244. callback := db.Callback()
  245. vv := &validation{crud: crud}
  246. if callback.Create().Get("validations:validate") == nil {
  247. err = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation)
  248. }
  249. if callback.Update().Get("validations:validate") == nil {
  250. err = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation)
  251. }
  252. return
  253. }