123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- package gorm
- import (
- "context"
- "errors"
- "fmt"
- "reflect"
- "sort"
- "time"
- "gorm.io/gorm/schema"
- "gorm.io/gorm/utils"
- )
- func initializeCallbacks(db *DB) *callbacks {
- return &callbacks{
- processors: map[string]*processor{
- "create": {db: db},
- "query": {db: db},
- "update": {db: db},
- "delete": {db: db},
- "row": {db: db},
- "raw": {db: db},
- },
- }
- }
- // callbacks gorm callbacks manager
- type callbacks struct {
- processors map[string]*processor
- }
- type processor struct {
- db *DB
- fns []func(*DB)
- callbacks []*callback
- }
- type callback struct {
- name string
- before string
- after string
- remove bool
- replace bool
- match func(*DB) bool
- handler func(*DB)
- processor *processor
- }
- func (cs *callbacks) Create() *processor {
- return cs.processors["create"]
- }
- func (cs *callbacks) Query() *processor {
- return cs.processors["query"]
- }
- func (cs *callbacks) Update() *processor {
- return cs.processors["update"]
- }
- func (cs *callbacks) Delete() *processor {
- return cs.processors["delete"]
- }
- func (cs *callbacks) Row() *processor {
- return cs.processors["row"]
- }
- func (cs *callbacks) Raw() *processor {
- return cs.processors["raw"]
- }
- func (p *processor) Execute(db *DB) {
- curTime := time.Now()
- stmt := db.Statement
- if stmt.Model == nil {
- stmt.Model = stmt.Dest
- } else if stmt.Dest == nil {
- stmt.Dest = stmt.Model
- }
- if stmt.Model != nil {
- if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
- if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
- db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
- } else {
- db.AddError(err)
- }
- }
- }
- if stmt.Dest != nil {
- stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
- for stmt.ReflectValue.Kind() == reflect.Ptr {
- stmt.ReflectValue = stmt.ReflectValue.Elem()
- }
- if !stmt.ReflectValue.IsValid() {
- db.AddError(fmt.Errorf("invalid value"))
- }
- }
- for _, f := range p.fns {
- f(db)
- }
- db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
- return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
- }, db.Error)
- if !stmt.DB.DryRun {
- stmt.SQL.Reset()
- stmt.Vars = nil
- }
- }
- func (p *processor) Get(name string) func(*DB) {
- for i := len(p.callbacks) - 1; i >= 0; i-- {
- if v := p.callbacks[i]; v.name == name && !v.remove {
- return v.handler
- }
- }
- return nil
- }
- func (p *processor) Before(name string) *callback {
- return &callback{before: name, processor: p}
- }
- func (p *processor) After(name string) *callback {
- return &callback{after: name, processor: p}
- }
- func (p *processor) Match(fc func(*DB) bool) *callback {
- return &callback{match: fc, processor: p}
- }
- func (p *processor) Register(name string, fn func(*DB)) error {
- return (&callback{processor: p}).Register(name, fn)
- }
- func (p *processor) Remove(name string) error {
- return (&callback{processor: p}).Remove(name)
- }
- func (p *processor) Replace(name string, fn func(*DB)) error {
- return (&callback{processor: p}).Replace(name, fn)
- }
- func (p *processor) compile() (err error) {
- var callbacks []*callback
- for _, callback := range p.callbacks {
- if callback.match == nil || callback.match(p.db) {
- callbacks = append(callbacks, callback)
- }
- }
- p.callbacks = callbacks
- if p.fns, err = sortCallbacks(p.callbacks); err != nil {
- p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
- }
- return
- }
- func (c *callback) Before(name string) *callback {
- c.before = name
- return c
- }
- func (c *callback) After(name string) *callback {
- c.after = name
- return c
- }
- func (c *callback) Register(name string, fn func(*DB)) error {
- c.name = name
- c.handler = fn
- c.processor.callbacks = append(c.processor.callbacks, c)
- return c.processor.compile()
- }
- func (c *callback) Remove(name string) error {
- c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
- c.name = name
- c.remove = true
- c.processor.callbacks = append(c.processor.callbacks, c)
- return c.processor.compile()
- }
- func (c *callback) Replace(name string, fn func(*DB)) error {
- c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
- c.name = name
- c.handler = fn
- c.replace = true
- c.processor.callbacks = append(c.processor.callbacks, c)
- return c.processor.compile()
- }
- // getRIndex get right index from string slice
- func getRIndex(strs []string, str string) int {
- for i := len(strs) - 1; i >= 0; i-- {
- if strs[i] == str {
- return i
- }
- }
- return -1
- }
- func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
- var (
- names, sorted []string
- sortCallback func(*callback) error
- )
- sort.Slice(cs, func(i, j int) bool {
- return cs[j].before == "*" || cs[j].after == "*"
- })
- for _, c := range cs {
- // show warning message the callback name already exists
- if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
- c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
- }
- names = append(names, c.name)
- }
- sortCallback = func(c *callback) error {
- if c.before != "" { // if defined before callback
- if c.before == "*" && len(sorted) > 0 {
- if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
- sorted = append([]string{c.name}, sorted...)
- }
- } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
- if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
- // if before callback already sorted, append current callback just after it
- sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
- } else if curIdx > sortedIdx {
- return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
- }
- } else if idx := getRIndex(names, c.before); idx != -1 {
- // if before callback exists
- cs[idx].after = c.name
- }
- }
- if c.after != "" { // if defined after callback
- if c.after == "*" && len(sorted) > 0 {
- if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
- sorted = append(sorted, c.name)
- }
- } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
- if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
- // if after callback sorted, append current callback to last
- sorted = append(sorted, c.name)
- } else if curIdx < sortedIdx {
- return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
- }
- } else if idx := getRIndex(names, c.after); idx != -1 {
- // if after callback exists but haven't sorted
- // set after callback's before callback to current callback
- after := cs[idx]
- if after.before == "" {
- after.before = c.name
- }
- if err := sortCallback(after); err != nil {
- return err
- }
- if err := sortCallback(c); err != nil {
- return err
- }
- }
- }
- // if current callback haven't been sorted, append it to last
- if getRIndex(sorted, c.name) == -1 {
- sorted = append(sorted, c.name)
- }
- return nil
- }
- for _, c := range cs {
- if err = sortCallback(c); err != nil {
- return
- }
- }
- for _, name := range sorted {
- if idx := getRIndex(names, name); !cs[idx].remove {
- fns = append(fns, cs[idx].handler)
- }
- }
- return
- }
|