validate.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package plugins
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "strconv"
  9. "strings"
  10. "git.nspix.com/golang/micro/helper/utils"
  11. "git.nspix.com/golang/rest/v2"
  12. errors2 "git.nspix.com/golang/rest/v2/errors"
  13. "github.com/go-playground/validator/v10"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/clause"
  16. "gorm.io/gorm/schema"
  17. )
  18. const (
  19. SkipValidations = "validations:skip_validations"
  20. )
  21. type (
  22. validateScope struct{}
  23. validScope struct {
  24. Db *gorm.DB
  25. Column string
  26. Model interface{}
  27. }
  28. )
  29. var (
  30. validate = validator.New()
  31. validateScopeKey = validateScope{}
  32. telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
  33. )
  34. func init() {
  35. _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
  36. val := fmt.Sprint(fl.Field().Interface())
  37. return telephoneRegex.MatchString(val)
  38. })
  39. _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
  40. val := fl.Field().Interface()
  41. var (
  42. sp *validScope
  43. ok bool
  44. count int64
  45. err error
  46. sess *gorm.DB
  47. field *schema.Field
  48. refValue reflect.Value
  49. )
  50. if sp, ok = ctx.Value(validateScopeKey).(*validScope); !ok {
  51. return true
  52. }
  53. sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB {
  54. s := db.Session(&gorm.Session{})
  55. s.Statement = &gorm.Statement{
  56. DB: db,
  57. ConnPool: db.Statement.ConnPool,
  58. Context: db.Statement.Context,
  59. Clauses: map[string]clause.Clause{},
  60. }
  61. return s
  62. })
  63. if err = sess.Statement.Parse(sp.Model); err == nil {
  64. if len(sess.Statement.Schema.PrimaryFields) > 0 {
  65. field = sess.Statement.Schema.PrimaryFields[0]
  66. refValue = reflect.Indirect(reflect.ValueOf(sp.Model))
  67. for _, n := range field.BindNames {
  68. refValue = refValue.FieldByName(n)
  69. }
  70. }
  71. }
  72. if refValue.IsValid() && !refValue.IsZero() && field != nil {
  73. sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count)
  74. } else {
  75. sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count)
  76. }
  77. if count > 0 {
  78. return false
  79. }
  80. return true
  81. })
  82. }
  83. func generateTag(scm *rest.Schema, scenario string, rule rest.Rule) string {
  84. var s string
  85. if rule.Min != 0 {
  86. s += ",min=" + strconv.Itoa(rule.Min)
  87. }
  88. if rule.Max != 0 {
  89. s += ",max=" + strconv.Itoa(rule.Max)
  90. }
  91. //主键不做判断
  92. if rule.Unique && scm.PrimaryKey == 0 {
  93. s += ",db_unique"
  94. }
  95. if rule.Type != "" {
  96. s += "," + rule.Type
  97. }
  98. if rule.Required != nil && len(rule.Required) > 0 {
  99. for _, v := range rule.Required {
  100. if v == scenario {
  101. s += ",required"
  102. }
  103. }
  104. }
  105. if s != "" {
  106. return s[1:]
  107. } else {
  108. return s
  109. }
  110. }
  111. func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
  112. var s string
  113. switch tag {
  114. case "db_unique":
  115. s = scm.Label + "值已经存在."
  116. break
  117. case "required":
  118. s = scm.Label + "值不能为空."
  119. case "max":
  120. if scm.Type == "string" {
  121. s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
  122. } else {
  123. s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
  124. }
  125. case "min":
  126. if scm.Type == "string" {
  127. s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
  128. } else {
  129. s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
  130. }
  131. }
  132. return s
  133. }
  134. func validation(db *gorm.DB) {
  135. if result, ok := db.Get(SkipValidations); ok && result.(bool) {
  136. return
  137. }
  138. var (
  139. ok bool
  140. err error
  141. ruleString string
  142. stmt *gorm.Statement
  143. model rest.Model
  144. rule rest.Rule
  145. scenario string
  146. skipValidate bool
  147. value reflect.Value
  148. schemas []*rest.Schema
  149. )
  150. stmt = db.Statement
  151. if stmt.Model == nil {
  152. return
  153. }
  154. if model, ok = stmt.Model.(rest.Model); !ok {
  155. return
  156. }
  157. scenario = rest.ScenarioUpdate
  158. for _, pk := range stmt.Schema.PrimaryFields {
  159. if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) {
  160. scenario = rest.ScenarioCreate
  161. break
  162. }
  163. }
  164. schemas = rest.VisibleSchemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
  165. for _, scm := range schemas {
  166. if scm.Rules == "" {
  167. continue
  168. }
  169. if err = json.Unmarshal([]byte(scm.Rules), &rule); err != nil {
  170. continue
  171. }
  172. if ruleString = generateTag(scm, scenario, rule); ruleString == "" {
  173. continue
  174. }
  175. value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name)
  176. if !value.IsValid() {
  177. continue
  178. }
  179. skipValidate = false
  180. if strings.Contains(ruleString, "required") {
  181. //如果数值为整形,小数,Bool跳过验证
  182. if value.Interface() != nil {
  183. vType := reflect.ValueOf(value.Interface())
  184. switch vType.Kind() {
  185. case reflect.Bool:
  186. skipValidate = true
  187. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  188. skipValidate = true
  189. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  190. skipValidate = true
  191. case reflect.Float32, reflect.Float64:
  192. skipValidate = true
  193. }
  194. }
  195. if skipValidate {
  196. continue
  197. }
  198. } else {
  199. if utils.IsEmpty(value.Interface()) {
  200. continue
  201. }
  202. }
  203. ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{
  204. Db: db,
  205. Column: scm.Column,
  206. Model: stmt.Model,
  207. })
  208. if err = validate.VarCtx(ctx, value.Interface(), ruleString); err != nil {
  209. if errors, ok := err.(validator.ValidationErrors); ok {
  210. for _, e := range errors {
  211. _ = db.AddError(&errors2.StructError{
  212. Tag: e.Tag(),
  213. Column: scm.Column,
  214. Message: formatError(rule, scm, e.Tag()),
  215. })
  216. }
  217. } else {
  218. _ = db.AddError(err)
  219. }
  220. break
  221. }
  222. }
  223. }
  224. func RegisterValidationCallback(db *gorm.DB) {
  225. callback := db.Callback()
  226. if callback.Create().Get("validations:validate") == nil {
  227. _ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation)
  228. }
  229. if callback.Update().Get("validations:validate") == nil {
  230. _ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation)
  231. }
  232. }