validate.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package plugins
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "strconv"
  9. "strings"
  10. "git.nspix.com/golang/micro/helper/utils"
  11. "git.nspix.com/golang/rest/v2"
  12. errors2 "git.nspix.com/golang/rest/v2/errors"
  13. "github.com/go-playground/validator/v10"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/clause"
  16. "gorm.io/gorm/schema"
  17. )
  18. const (
  19. SkipValidations = "validations:skip_validations"
  20. )
  21. type (
  22. validateScope struct{}
  23. validScope struct {
  24. Db *gorm.DB
  25. Column string
  26. Model interface{}
  27. }
  28. validation struct {
  29. curd *rest.CRUD
  30. }
  31. )
  32. var (
  33. validate = validator.New()
  34. validateScopeKey = validateScope{}
  35. telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
  36. )
  37. func init() {
  38. _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
  39. val := fmt.Sprint(fl.Field().Interface())
  40. return telephoneRegex.MatchString(val)
  41. })
  42. _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
  43. val := fl.Field().Interface()
  44. var (
  45. sp *validScope
  46. ok bool
  47. count int64
  48. err error
  49. sess *gorm.DB
  50. field *schema.Field
  51. refValue reflect.Value
  52. )
  53. if sp, ok = ctx.Value(validateScopeKey).(*validScope); !ok {
  54. return true
  55. }
  56. sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB {
  57. s := db.Session(&gorm.Session{})
  58. s.Statement = &gorm.Statement{
  59. DB: db,
  60. ConnPool: db.Statement.ConnPool,
  61. Context: db.Statement.Context,
  62. Clauses: map[string]clause.Clause{},
  63. }
  64. return s
  65. })
  66. if err = sess.Statement.Parse(sp.Model); err == nil {
  67. if len(sess.Statement.Schema.PrimaryFields) > 0 {
  68. field = sess.Statement.Schema.PrimaryFields[0]
  69. refValue = reflect.Indirect(reflect.ValueOf(sp.Model))
  70. for _, n := range field.BindNames {
  71. refValue = refValue.FieldByName(n)
  72. }
  73. }
  74. }
  75. if refValue.IsValid() && !refValue.IsZero() && field != nil {
  76. sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count)
  77. } else {
  78. sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count)
  79. }
  80. if count > 0 {
  81. return false
  82. }
  83. return true
  84. })
  85. }
  86. func generateTag(scm *rest.Schema, scenario string, rule rest.Rule) string {
  87. var s string
  88. if rule.Min != 0 {
  89. s += ",min=" + strconv.Itoa(rule.Min)
  90. }
  91. if rule.Max != 0 {
  92. s += ",max=" + strconv.Itoa(rule.Max)
  93. }
  94. //主键不做判断
  95. if rule.Unique && scm.PrimaryKey == 0 {
  96. s += ",db_unique"
  97. }
  98. if rule.Type != "" {
  99. s += "," + rule.Type
  100. }
  101. if rule.Required != nil && len(rule.Required) > 0 {
  102. for _, v := range rule.Required {
  103. if v == scenario {
  104. s += ",required"
  105. }
  106. }
  107. }
  108. if s != "" {
  109. return s[1:]
  110. } else {
  111. return s
  112. }
  113. }
  114. func formatError(rule rest.Rule, scm *rest.Schema, tag string) string {
  115. var s string
  116. switch tag {
  117. case "db_unique":
  118. s = scm.Label + "值已经存在."
  119. break
  120. case "required":
  121. s = scm.Label + "值不能为空."
  122. case "max":
  123. if scm.Type == "string" {
  124. s = scm.Label + "长度不能大于" + strconv.Itoa(rule.Max)
  125. } else {
  126. s = scm.Label + "值不能大于" + strconv.Itoa(rule.Max)
  127. }
  128. case "min":
  129. if scm.Type == "string" {
  130. s = scm.Label + "长度不能小于" + strconv.Itoa(rule.Max)
  131. } else {
  132. s = scm.Label + "值不能小于" + strconv.Itoa(rule.Max)
  133. }
  134. }
  135. return s
  136. }
  137. func (vv *validation) validation(db *gorm.DB) {
  138. if result, ok := db.Get(SkipValidations); ok && result.(bool) {
  139. return
  140. }
  141. var (
  142. ok bool
  143. err error
  144. ruleString string
  145. stmt *gorm.Statement
  146. model rest.Model
  147. rule rest.Rule
  148. scenario string
  149. skipValidate bool
  150. value reflect.Value
  151. schemas []*rest.Schema
  152. )
  153. stmt = db.Statement
  154. if stmt.Model == nil {
  155. return
  156. }
  157. if model, ok = stmt.Model.(rest.Model); !ok {
  158. return
  159. }
  160. scenario = rest.ScenarioUpdate
  161. for _, pk := range stmt.Schema.PrimaryFields {
  162. if utils.IsEmpty(stmt.ReflectValue.FieldByName(pk.Name).Interface()) {
  163. scenario = rest.ScenarioCreate
  164. break
  165. }
  166. }
  167. schemas = vv.curd.Schemas(stmt.ReflectValue.FieldByName("Namespace").String(), model.ModuleName(), model.TableName(), scenario)
  168. for _, scm := range schemas {
  169. if scm.Rules == "" {
  170. continue
  171. }
  172. if err = json.Unmarshal([]byte(scm.Rules), &rule); err != nil {
  173. continue
  174. }
  175. if ruleString = generateTag(scm, scenario, rule); ruleString == "" {
  176. continue
  177. }
  178. value = stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(scm.Column).Name)
  179. if !value.IsValid() {
  180. continue
  181. }
  182. skipValidate = false
  183. if strings.Contains(ruleString, "required") {
  184. //如果数值为整形,小数,Bool跳过验证
  185. if value.Interface() != nil {
  186. vType := reflect.ValueOf(value.Interface())
  187. switch vType.Kind() {
  188. case reflect.Bool:
  189. skipValidate = true
  190. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  191. skipValidate = true
  192. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  193. skipValidate = true
  194. case reflect.Float32, reflect.Float64:
  195. skipValidate = true
  196. }
  197. }
  198. if skipValidate {
  199. continue
  200. }
  201. } else {
  202. if utils.IsEmpty(value.Interface()) {
  203. continue
  204. }
  205. }
  206. ctx := context.WithValue(context.Background(), validateScopeKey, &validScope{
  207. Db: db,
  208. Column: scm.Column,
  209. Model: stmt.Model,
  210. })
  211. if err = validate.VarCtx(ctx, value.Interface(), ruleString); err != nil {
  212. if errors, ok := err.(validator.ValidationErrors); ok {
  213. for _, e := range errors {
  214. _ = db.AddError(&errors2.StructError{
  215. Tag: e.Tag(),
  216. Column: scm.Column,
  217. Message: formatError(rule, scm, e.Tag()),
  218. })
  219. }
  220. } else {
  221. _ = db.AddError(err)
  222. }
  223. break
  224. }
  225. }
  226. }
  227. func RegisterValidationCallback(curd *rest.CRUD) {
  228. callback := curd.DB().Callback()
  229. vv := &validation{curd: curd}
  230. if callback.Create().Get("validations:validate") == nil {
  231. _ = callback.Create().Before("gorm:before_create").Register("validations:validate", vv.validation)
  232. }
  233. if callback.Update().Get("validations:validate") == nil {
  234. _ = callback.Update().Before("gorm:before_update").Register("validations:validate", vv.validation)
  235. }
  236. }