callback.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package validator
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "git.nspix.com/golang/rest/internal/empty"
  7. "git.nspix.com/golang/rest/orm/schema"
  8. "git.nspix.com/golang/rest/scenario"
  9. "github.com/go-playground/locales/en"
  10. "github.com/go-playground/locales/zh"
  11. "github.com/go-playground/universal-translator"
  12. "github.com/go-playground/validator/v10"
  13. translation "github.com/go-playground/validator/v10/translations/zh"
  14. "gorm.io/gorm"
  15. "gorm.io/gorm/clause"
  16. shm "gorm.io/gorm/schema"
  17. "reflect"
  18. "regexp"
  19. "strconv"
  20. "strings"
  21. )
  22. type validateScope struct{}
  23. type scope struct {
  24. Db *gorm.DB
  25. Column string
  26. Model interface{}
  27. }
  28. var (
  29. validate = validator.New()
  30. SkipValidations = "validations:skip_validations"
  31. translator ut.Translator
  32. validateScopeKey = validateScope{}
  33. telephoneRegex = regexp.MustCompile("^\\d{5,20}$")
  34. )
  35. func init() {
  36. enLang := en.New()
  37. zhLang := zh.New()
  38. universal := ut.New(enLang, zhLang)
  39. translator, _ = universal.GetTranslator("en")
  40. _ = translation.RegisterDefaultTranslations(validate, translator)
  41. _ = validate.RegisterTranslation("db_unique", translator, func(ut ut.Translator) error {
  42. return ut.Add("db_unique", "{0}值已经存在.", true)
  43. }, func(ut ut.Translator, fe validator.FieldError) string {
  44. t, err := ut.T(fe.Tag(), fe.Field())
  45. if err != nil {
  46. return fe.(error).Error()
  47. }
  48. return t
  49. })
  50. _ = validate.RegisterTranslation("telephone", translator, func(ut ut.Translator) error {
  51. return ut.Add("telephone", "{0}号码不合法.", true)
  52. }, func(ut ut.Translator, fe validator.FieldError) string {
  53. t, err := ut.T(fe.Tag(), fe.Field())
  54. if err != nil {
  55. return fe.(error).Error()
  56. }
  57. return t
  58. })
  59. validate.RegisterTagNameFunc(func(field reflect.StructField) string {
  60. name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0]
  61. if name == "-" {
  62. return ""
  63. } else {
  64. return name
  65. }
  66. })
  67. _ = validate.RegisterValidationCtx("telephone", func(ctx context.Context, fl validator.FieldLevel) bool {
  68. val := fmt.Sprint(fl.Field().Interface())
  69. return telephoneRegex.MatchString(val)
  70. })
  71. _ = validate.RegisterValidationCtx("db_unique", func(ctx context.Context, fl validator.FieldLevel) bool {
  72. val := fl.Field().Interface()
  73. var (
  74. sp *scope
  75. ok bool
  76. count int64
  77. err error
  78. sess *gorm.DB
  79. field *shm.Field
  80. refValue reflect.Value
  81. )
  82. if sp, ok = ctx.Value(validateScopeKey).(*scope); !ok {
  83. return true
  84. }
  85. sess = sp.Db.Scopes(func(db *gorm.DB) *gorm.DB {
  86. s := db.Session(&gorm.Session{})
  87. s.Statement = &gorm.Statement{
  88. DB: db,
  89. ConnPool: db.Statement.ConnPool,
  90. Context: db.Statement.Context,
  91. Clauses: map[string]clause.Clause{},
  92. }
  93. return s
  94. })
  95. if err = sess.Statement.Parse(sp.Model); err == nil {
  96. if len(sess.Statement.Schema.PrimaryFields) > 0 {
  97. field = sess.Statement.Schema.PrimaryFields[0]
  98. refValue = reflect.Indirect(reflect.ValueOf(sp.Model))
  99. for _, n := range field.BindNames {
  100. refValue = refValue.FieldByName(n)
  101. }
  102. }
  103. }
  104. if refValue.IsValid() && !refValue.IsZero() && field != nil {
  105. sess.Model(sp.Model).Where(sp.Column+"=? AND "+field.Name+" != ?", val, refValue.Interface()).Count(&count)
  106. } else {
  107. sess.Model(sp.Model).Where(sp.Column+"=?", val).Count(&count)
  108. }
  109. if count > 0 {
  110. return false
  111. }
  112. return true
  113. })
  114. }
  115. type StructError struct {
  116. Tag string `json:"tag"`
  117. Column string `json:"column"`
  118. Message string `json:"message"`
  119. }
  120. func (err *StructError) Error() string {
  121. return err.Column + err.Message
  122. }
  123. func generateTag(scene string, rule schema.Rule) string {
  124. var s string
  125. if rule.Min != 0 {
  126. s += ",min=" + strconv.Itoa(rule.Min)
  127. }
  128. if rule.Max != 0 {
  129. s += ",max=" + strconv.Itoa(rule.Max)
  130. }
  131. if rule.Unique {
  132. s += ",db_unique"
  133. }
  134. if rule.Type != "" {
  135. s += "," + rule.Type
  136. }
  137. if rule.Required != nil && len(rule.Required) > 0 {
  138. for _, v := range rule.Required {
  139. if v == scene {
  140. s += ",required"
  141. }
  142. }
  143. }
  144. if s != "" {
  145. return s[1:]
  146. } else {
  147. return s
  148. }
  149. }
  150. func validation(db *gorm.DB) {
  151. var (
  152. err error
  153. tag string
  154. rule schema.Rule
  155. scene = scenario.Create
  156. )
  157. if _, ok := db.Get("gorm:update_column"); ok {
  158. return
  159. }
  160. if result, ok := db.Get(SkipValidations); ok && result.(bool) {
  161. return
  162. }
  163. if db.Error == nil && db.Statement.Schema != nil {
  164. stmt := db.Statement
  165. if stmt.Model != nil {
  166. if !schema.IsNewRecord(stmt.ReflectValue, stmt) {
  167. scene = scenario.Update
  168. }
  169. schemas := schema.VisibleField("organize", stmt.Table, scene)
  170. for _, field := range schemas {
  171. if field.Rules == "" {
  172. continue
  173. }
  174. if err = json.Unmarshal([]byte(field.Rules), &rule); err != nil {
  175. continue
  176. }
  177. if tag = generateTag(scene, rule); tag == "" {
  178. continue
  179. }
  180. fieldValue := stmt.ReflectValue.FieldByName(stmt.Schema.LookUpField(field.Column).Name)
  181. if !fieldValue.IsValid() {
  182. continue
  183. }
  184. //如果没有必填 并且值为空跳过验证
  185. if !strings.Contains(tag, "required") && empty.Is(fieldValue.Interface()) {
  186. continue
  187. }
  188. ctx := context.WithValue(context.Background(), validateScopeKey, &scope{
  189. Db: db,
  190. Column: field.Column,
  191. Model: stmt.Model,
  192. })
  193. if err = validate.VarCtx(ctx, fieldValue.Interface(), tag); err != nil {
  194. if errors, ok := err.(validator.ValidationErrors); ok {
  195. for _, e := range errors {
  196. _ = db.AddError(&StructError{
  197. Tag: e.Tag(),
  198. Column: field.Column,
  199. Message: e.Translate(translator),
  200. })
  201. }
  202. } else {
  203. _ = db.AddError(err)
  204. }
  205. break
  206. }
  207. }
  208. }
  209. }
  210. }
  211. // RegisterCallbacks register callback into GORM DB
  212. // BeforeSave and BeforeCreate is called on before_create
  213. // so this is called just after them
  214. func RegisterCallbacks(db *gorm.DB) {
  215. callback := db.Callback()
  216. if callback.Create().Get("validations:validate") == nil {
  217. _ = callback.Create().Before("gorm:before_create").Register("validations:validate", validation)
  218. }
  219. if callback.Update().Get("validations:validate") == nil {
  220. _ = callback.Update().Before("gorm:before_update").Register("validations:validate", validation)
  221. }
  222. }