callbacks.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package gorm
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "sort"
  8. "time"
  9. "gorm.io/gorm/schema"
  10. "gorm.io/gorm/utils"
  11. )
  12. func initializeCallbacks(db *DB) *callbacks {
  13. return &callbacks{
  14. processors: map[string]*processor{
  15. "create": {db: db},
  16. "query": {db: db},
  17. "update": {db: db},
  18. "delete": {db: db},
  19. "row": {db: db},
  20. "raw": {db: db},
  21. },
  22. }
  23. }
  24. // callbacks gorm callbacks manager
  25. type callbacks struct {
  26. processors map[string]*processor
  27. }
  28. type processor struct {
  29. db *DB
  30. fns []func(*DB)
  31. callbacks []*callback
  32. }
  33. type callback struct {
  34. name string
  35. before string
  36. after string
  37. remove bool
  38. replace bool
  39. match func(*DB) bool
  40. handler func(*DB)
  41. processor *processor
  42. }
  43. func (cs *callbacks) Create() *processor {
  44. return cs.processors["create"]
  45. }
  46. func (cs *callbacks) Query() *processor {
  47. return cs.processors["query"]
  48. }
  49. func (cs *callbacks) Update() *processor {
  50. return cs.processors["update"]
  51. }
  52. func (cs *callbacks) Delete() *processor {
  53. return cs.processors["delete"]
  54. }
  55. func (cs *callbacks) Row() *processor {
  56. return cs.processors["row"]
  57. }
  58. func (cs *callbacks) Raw() *processor {
  59. return cs.processors["raw"]
  60. }
  61. func (p *processor) Execute(db *DB) {
  62. curTime := time.Now()
  63. stmt := db.Statement
  64. if stmt.Model == nil {
  65. stmt.Model = stmt.Dest
  66. } else if stmt.Dest == nil {
  67. stmt.Dest = stmt.Model
  68. }
  69. if stmt.Model != nil {
  70. if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
  71. if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
  72. db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
  73. } else {
  74. db.AddError(err)
  75. }
  76. }
  77. }
  78. if stmt.Dest != nil {
  79. stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
  80. for stmt.ReflectValue.Kind() == reflect.Ptr {
  81. stmt.ReflectValue = stmt.ReflectValue.Elem()
  82. }
  83. if !stmt.ReflectValue.IsValid() {
  84. db.AddError(fmt.Errorf("invalid value"))
  85. }
  86. }
  87. for _, f := range p.fns {
  88. f(db)
  89. }
  90. db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
  91. return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
  92. }, db.Error)
  93. if !stmt.DB.DryRun {
  94. stmt.SQL.Reset()
  95. stmt.Vars = nil
  96. }
  97. }
  98. func (p *processor) Get(name string) func(*DB) {
  99. for i := len(p.callbacks) - 1; i >= 0; i-- {
  100. if v := p.callbacks[i]; v.name == name && !v.remove {
  101. return v.handler
  102. }
  103. }
  104. return nil
  105. }
  106. func (p *processor) Before(name string) *callback {
  107. return &callback{before: name, processor: p}
  108. }
  109. func (p *processor) After(name string) *callback {
  110. return &callback{after: name, processor: p}
  111. }
  112. func (p *processor) Match(fc func(*DB) bool) *callback {
  113. return &callback{match: fc, processor: p}
  114. }
  115. func (p *processor) Register(name string, fn func(*DB)) error {
  116. return (&callback{processor: p}).Register(name, fn)
  117. }
  118. func (p *processor) Remove(name string) error {
  119. return (&callback{processor: p}).Remove(name)
  120. }
  121. func (p *processor) Replace(name string, fn func(*DB)) error {
  122. return (&callback{processor: p}).Replace(name, fn)
  123. }
  124. func (p *processor) compile() (err error) {
  125. var callbacks []*callback
  126. for _, callback := range p.callbacks {
  127. if callback.match == nil || callback.match(p.db) {
  128. callbacks = append(callbacks, callback)
  129. }
  130. }
  131. p.callbacks = callbacks
  132. if p.fns, err = sortCallbacks(p.callbacks); err != nil {
  133. p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
  134. }
  135. return
  136. }
  137. func (c *callback) Before(name string) *callback {
  138. c.before = name
  139. return c
  140. }
  141. func (c *callback) After(name string) *callback {
  142. c.after = name
  143. return c
  144. }
  145. func (c *callback) Register(name string, fn func(*DB)) error {
  146. c.name = name
  147. c.handler = fn
  148. c.processor.callbacks = append(c.processor.callbacks, c)
  149. return c.processor.compile()
  150. }
  151. func (c *callback) Remove(name string) error {
  152. c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
  153. c.name = name
  154. c.remove = true
  155. c.processor.callbacks = append(c.processor.callbacks, c)
  156. return c.processor.compile()
  157. }
  158. func (c *callback) Replace(name string, fn func(*DB)) error {
  159. c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
  160. c.name = name
  161. c.handler = fn
  162. c.replace = true
  163. c.processor.callbacks = append(c.processor.callbacks, c)
  164. return c.processor.compile()
  165. }
  166. // getRIndex get right index from string slice
  167. func getRIndex(strs []string, str string) int {
  168. for i := len(strs) - 1; i >= 0; i-- {
  169. if strs[i] == str {
  170. return i
  171. }
  172. }
  173. return -1
  174. }
  175. func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
  176. var (
  177. names, sorted []string
  178. sortCallback func(*callback) error
  179. )
  180. sort.Slice(cs, func(i, j int) bool {
  181. return cs[j].before == "*" || cs[j].after == "*"
  182. })
  183. for _, c := range cs {
  184. // show warning message the callback name already exists
  185. if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
  186. c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
  187. }
  188. names = append(names, c.name)
  189. }
  190. sortCallback = func(c *callback) error {
  191. if c.before != "" { // if defined before callback
  192. if c.before == "*" && len(sorted) > 0 {
  193. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  194. sorted = append([]string{c.name}, sorted...)
  195. }
  196. } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
  197. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  198. // if before callback already sorted, append current callback just after it
  199. sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
  200. } else if curIdx > sortedIdx {
  201. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
  202. }
  203. } else if idx := getRIndex(names, c.before); idx != -1 {
  204. // if before callback exists
  205. cs[idx].after = c.name
  206. }
  207. }
  208. if c.after != "" { // if defined after callback
  209. if c.after == "*" && len(sorted) > 0 {
  210. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  211. sorted = append(sorted, c.name)
  212. }
  213. } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
  214. if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
  215. // if after callback sorted, append current callback to last
  216. sorted = append(sorted, c.name)
  217. } else if curIdx < sortedIdx {
  218. return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
  219. }
  220. } else if idx := getRIndex(names, c.after); idx != -1 {
  221. // if after callback exists but haven't sorted
  222. // set after callback's before callback to current callback
  223. after := cs[idx]
  224. if after.before == "" {
  225. after.before = c.name
  226. }
  227. if err := sortCallback(after); err != nil {
  228. return err
  229. }
  230. if err := sortCallback(c); err != nil {
  231. return err
  232. }
  233. }
  234. }
  235. // if current callback haven't been sorted, append it to last
  236. if getRIndex(sorted, c.name) == -1 {
  237. sorted = append(sorted, c.name)
  238. }
  239. return nil
  240. }
  241. for _, c := range cs {
  242. if err = sortCallback(c); err != nil {
  243. return
  244. }
  245. }
  246. for _, name := range sorted {
  247. if idx := getRIndex(names, name); !cs[idx].remove {
  248. fns = append(fns, cs[idx].handler)
  249. }
  250. }
  251. return
  252. }