callbacks.go 7.9 KB


  1. package gorm
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "sort"
  8. "time"
  9. "gorm.io/gorm/schema"
  10. "gorm.io/gorm/utils"
  11. )
  12. func initializeCallbacks(db *DB) *callbacks {
  13. return &callbacks{
  14. processors: map[string]*processor{
  15. "create": {db: db},
  16. "query": {db: db},
  17. "update": {db: db},
  18. "delete": {db: db},
  19. "row": {db: db},
  20. "raw": {db: db},
  21. },
  22. }
  23. }
  24. // callbacks gorm callbacks manager
  25. type callbacks struct {
  26. processors map[string]*processor
  27. }
  28. type processor struct {
  29. db *DB
  30. Clauses []string
  31. fns []func(*DB)
  32. callbacks []*callback
  33. }
  34. type callback struct {
  35. name string
  36. before string
  37. after string
  38. remove bool
  39. replace bool
  40. match func(*DB) bool
  41. handler func(*DB)
  42. processor *processor
  43. }
  44. func (cs *callbacks) Create() *processor {
  45. return cs.processors["create"]
  46. }
  47. func (cs *callbacks) Query() *processor {
  48. return cs.processors["query"]
  49. }
  50. func (cs *callbacks) Update() *processor {
  51. return cs.processors["update"]
  52. }
  53. func (cs *callbacks) Delete() *processor {
  54. return cs.processors["delete"]
  55. }
  56. func (cs *callbacks) Row() *processor {
  57. return cs.processors["row"]
  58. }
  59. func (cs *callbacks) Raw() *processor {
  60. return cs.processors["raw"]
  61. }
  62. func (p *processor) Execute(db *DB) *DB {
  63. // call scopes
  64. for len(db.Statement.scopes) > 0 {
  65. scopes := db.Statement.scopes
  66. db.Statement.scopes = nil
  67. for _, scope := range scopes {
  68. db = scope(db)
  69. }
  70. }
  71. var (
  72. curTime = time.Now()
  73. stmt = db.Statement
  74. resetBuildClauses bool
  75. )
  76. if len(stmt.BuildClauses) == 0 {
  77. stmt.BuildClauses = p.Clauses
  78. resetBuildClauses = true
  79. }
  80. // assign model values
  81. if stmt.Model == nil {
  82. stmt.Model = stmt.Dest
  83. } else if stmt.Dest == nil {
  84. stmt.Dest = stmt.Model
  85. }
  86. // parse model values
  87. if stmt.Model != nil {
  88. if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
  89. if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
  90. db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
  91. } else {
  92. db.AddError(err)
  93. }
  94. }
  95. }
  96. // assign stmt.ReflectValue
  97. if stmt.Dest != nil {
  98. stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
  99. for stmt.ReflectValue.Kind() == reflect.Ptr {
  100. if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
  101. stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
  102. }
  103. stmt.ReflectValue = stmt.ReflectValue.Elem()
  104. }
  105. if !stmt.ReflectValue.IsValid() {
  106. db.AddError(ErrInvalidValue)
  107. }
  108. }
  109. for _, f := range p.fns {
  110. f(db)
  111. }
  112. db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
  113. return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
  114. }, db.Error)
  115. if !stmt.DB.DryRun {
  116. stmt.SQL.Reset()
  117. stmt.Vars = nil
  118. }
  119. if resetBuildClauses {
  120. stmt.BuildClauses = nil
  121. }
  122. return db
  123. }
  124. func (p *processor) Get(name string) func(*DB) {
  125. for i := len(p.callbacks) - 1; i >= 0; i-- {
  126. if v := p.callbacks[i]; v.name == name && !v.remove {
  127. return v.handler
  128. }
  129. }
  130. return nil
  131. }
  132. func (p *processor) Before(name string) *callback {
  133. return &callback{before: name, processor: p}
  134. }
  135. func (p *processor) After(name string) *callback {
  136. return &callback{after: name, processor: p}
  137. }
  138. func (p *processor) Match(fc func(*DB) bool) *callback {
  139. return &callback{match: fc, processor: p}
  140. }
  141. func (p *processor) Register(name string, fn func(*DB)) error {
  142. return (&callback{processor: p}).Register(name, fn)
  143. }
  144. func (p *processor) Remove(name string) error {
  145. return (&callback{processor: p}).Remove(name)
  146. }
  147. func (p *processor) Replace(name string, fn func(*DB)) error {
  148. return (&callback{processor: p}).Replace(name, fn)
  149. }
  150. func (p *processor) compile() (err error) {
  151. var callbacks []*callback
  152. for _, callback := range p.callbacks {
  153. if callback.match == nil || callback.match(p.db) {
  154. callbacks = append(callbacks, callback)
  155. }
  156. }
  157. p.callbacks = callbacks
  158. if p.fns, err = sortCallbacks(p.callbacks); err != nil {
  159. p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
  160. }
  161. return
  162. }
  163. func (c *callback) Before(name string) *callback {
  164. c.before = name
  165. return c
  166. }
  167. func (c *callback) After(name string) *callback {
  168. c.after = name
  169. return c
  170. }
  171. func (c *callback) Register(name string, fn func(*DB)) error {
  172. c.name = name
  173. c.handler = fn
  174. c.processor.callbacks = append(c.processor.callbacks, c)
  175. return c.processor.compile()
  176. }
  177. func (c *callback) Remove(name string) error {
  178. c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
  179. c.name = name
  180. c.remove = true
  181. c.processor.callbacks = append(c.processor.callbacks, c)
  182. return c.processor.compile()
  183. }
  184. func (c *callback) Replace(name string, fn func(*DB)) error {
  185. c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
  186. c.name = name
  187. c.handler = fn
  188. c.replace = true
  189. c.processor.callbacks = append(c.processor.callbacks, c)
  190. return c.processor.compile()
  191. }
  192. // getRIndex get right index from string slice
  193. func getRIndex(strs []string, str string) int {
  194. for i := len(strs) - 1; i >= 0; i-- {
  195. if strs[i] == str {
  196. return i
  197. }
  198. }
  199. return -1
  200. }
  201. func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
  202. var (
  203. names, sorted []string
  204. sortCallback func(*callback) error
  205. )
  206. sort.Slice(cs, func(i, j int) bool {
  207. return cs[j].before == "*" || cs[j].after == "*"
  208. })
  209. for _, c := range cs {
  210. // show warning message the callback name already exists
  211. if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
  212. c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
  213. }
  214. names = append(names, c.name)
  215. }
  216. sortCallback = func(c *callback) error {
  217. if c.before != "" { // if defined before callback
  218. if c.before == "*" && len(sorted) > 0 {
  219. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  220. sorted = append([]string{c.name}, sorted...)
  221. }
  222. } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
  223. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  224. // if before callback already sorted, append current callback just after it
  225. sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
  226. } else if curIdx > sortedIdx {
  227. return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
  228. }
  229. } else if idx := getRIndex(names, c.before); idx != -1 {
  230. // if before callback exists
  231. cs[idx].after = c.name
  232. }
  233. }
  234. if c.after != "" { // if defined after callback
  235. if c.after == "*" && len(sorted) > 0 {
  236. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  237. sorted = append(sorted, c.name)
  238. }
  239. } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
  240. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  241. // if after callback sorted, append current callback to last
  242. sorted = append(sorted, c.name)
  243. } else if curIdx < sortedIdx {
  244. return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
  245. }
  246. } else if idx := getRIndex(names, c.after); idx != -1 {
  247. // if after callback exists but haven't sorted
  248. // set after callback's before callback to current callback
  249. after := cs[idx]
  250. if after.before == "" {
  251. after.before = c.name
  252. }
  253. if err := sortCallback(after); err != nil {
  254. return err
  255. }
  256. if err := sortCallback(c); err != nil {
  257. return err
  258. }
  259. }
  260. }
  261. // if current callback haven't been sorted, append it to last
  262. if getRIndex(sorted, c.name) == -1 {
  263. sorted = append(sorted, c.name)
  264. }
  265. return nil
  266. }
  267. for _, c := range cs {
  268. if err = sortCallback(c); err != nil {
  269. return
  270. }
  271. }
  272. for _, name := range sorted {
  273. if idx := getRIndex(names, name); !cs[idx].remove {
  274. fns = append(fns, cs[idx].handler)
  275. }
  276. }
  277. return
  278. }