123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- package callbacks
- import (
- "reflect"
- "strings"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/schema"
- )
- func SaveBeforeAssociations(db *gorm.DB) {
- if db.Error == nil && db.Statement.Schema != nil {
- selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
- // Save Belongs To associations
- for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- setupReferences := func(obj reflect.Value, elem reflect.Value) {
- for _, ref := range rel.References {
- if !ref.OwnPrimaryKey {
- pv, _ := ref.PrimaryKey.ValueOf(elem)
- db.AddError(ref.ForeignKey.Set(obj, pv))
- if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
- dest[ref.ForeignKey.DBName] = pv
- if _, ok := dest[rel.Name]; ok {
- dest[rel.Name] = elem.Interface()
- }
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- var (
- objs []reflect.Value
- fieldType = rel.Field.FieldType
- isPtr = fieldType.Kind() == reflect.Ptr
- )
- if !isPtr {
- fieldType = reflect.PtrTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
- rv := rel.Field.ReflectValueOf(obj) // relation reflect value
- objs = append(objs, obj)
- if isPtr {
- elems = reflect.Append(elems, rv)
- } else {
- elems = reflect.Append(elems, rv.Addr())
- }
- }
- } else {
- break
- }
- }
- if elems.Len() > 0 {
- if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
- for i := 0; i < elems.Len(); i++ {
- setupReferences(objs[i], elems.Index(i))
- }
- }
- }
- case reflect.Struct:
- if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
- rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
- if rv.Kind() != reflect.Ptr {
- rv = rv.Addr()
- }
- if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
- setupReferences(db.Statement.ReflectValue, rv)
- }
- }
- }
- }
- }
- }
- func SaveAfterAssociations(db *gorm.DB) {
- if db.Error == nil && db.Statement.Schema != nil {
- selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
- // Save Has One associations
- for _, rel := range db.Statement.Schema.Relationships.HasOne {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- var (
- fieldType = rel.Field.FieldType
- isPtr = fieldType.Kind() == reflect.Ptr
- )
- if !isPtr {
- fieldType = reflect.PtrTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- if _, zero := rel.Field.ValueOf(obj); !zero {
- rv := rel.Field.ReflectValueOf(obj)
- if rv.Kind() != reflect.Ptr {
- rv = rv.Addr()
- }
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(obj)
- db.AddError(ref.ForeignKey.Set(rv, fv))
- } else if ref.PrimaryValue != "" {
- db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
- }
- }
- elems = reflect.Append(elems, rv)
- }
- }
- }
- if elems.Len() > 0 {
- assignmentColumns := []string{}
- for _, ref := range rel.References {
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
- }
- case reflect.Struct:
- if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
- f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
- if f.Kind() != reflect.Ptr {
- f = f.Addr()
- }
- assignmentColumns := []string{}
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
- ref.ForeignKey.Set(f, fv)
- } else if ref.PrimaryValue != "" {
- ref.ForeignKey.Set(f, ref.PrimaryValue)
- }
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
- }
- }
- }
- // Save Has Many associations
- for _, rel := range db.Statement.Schema.Relationships.HasMany {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- fieldType := rel.Field.IndirectFieldType.Elem()
- isPtr := fieldType.Kind() == reflect.Ptr
- if !isPtr {
- fieldType = reflect.PtrTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- appendToElems := func(v reflect.Value) {
- if _, zero := rel.Field.ValueOf(v); !zero {
- f := reflect.Indirect(rel.Field.ReflectValueOf(v))
- for i := 0; i < f.Len(); i++ {
- elem := f.Index(i)
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- pv, _ := ref.PrimaryKey.ValueOf(v)
- ref.ForeignKey.Set(elem, pv)
- } else if ref.PrimaryValue != "" {
- ref.ForeignKey.Set(elem, ref.PrimaryValue)
- }
- }
- if isPtr {
- elems = reflect.Append(elems, elem)
- } else {
- elems = reflect.Append(elems, elem.Addr())
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- appendToElems(obj)
- }
- }
- case reflect.Struct:
- appendToElems(db.Statement.ReflectValue)
- }
- if elems.Len() > 0 {
- assignmentColumns := []string{}
- for _, ref := range rel.References {
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
- }
- }
- // Save Many2Many associations
- for _, rel := range db.Statement.Schema.Relationships.Many2Many {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- fieldType := rel.Field.IndirectFieldType.Elem()
- isPtr := fieldType.Kind() == reflect.Ptr
- if !isPtr {
- fieldType = reflect.PtrTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
- objs := []reflect.Value{}
- appendToJoins := func(obj reflect.Value, elem reflect.Value) {
- joinValue := reflect.New(rel.JoinTable.ModelType)
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(obj)
- ref.ForeignKey.Set(joinValue, fv)
- } else if ref.PrimaryValue != "" {
- ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
- } else {
- fv, _ := ref.PrimaryKey.ValueOf(elem)
- ref.ForeignKey.Set(joinValue, fv)
- }
- }
- joins = reflect.Append(joins, joinValue)
- }
- appendToElems := func(v reflect.Value) {
- if _, zero := rel.Field.ValueOf(v); !zero {
- f := reflect.Indirect(rel.Field.ReflectValueOf(v))
- for i := 0; i < f.Len(); i++ {
- elem := f.Index(i)
- objs = append(objs, v)
- if isPtr {
- elems = reflect.Append(elems, elem)
- } else {
- elems = reflect.Append(elems, elem.Addr())
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- appendToElems(obj)
- }
- }
- case reflect.Struct:
- appendToElems(db.Statement.ReflectValue)
- }
- if elems.Len() > 0 {
- if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
- saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
- }
- for i := 0; i < elems.Len(); i++ {
- appendToJoins(objs[i], elems.Index(i))
- }
- }
- if joins.Len() > 0 {
- db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
- SkipHooks: db.Statement.SkipHooks,
- DisableNestedTransaction: true,
- }).Create(joins.Interface()).Error)
- }
- }
- }
- }
- func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
- if stmt.DB.FullSaveAssociations {
- defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
- for _, dbName := range s.DBNames {
- if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
- continue
- }
- if !s.LookUpField(dbName).PrimaryKey {
- defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
- }
- }
- }
- if len(defaultUpdatingColumns) > 0 {
- var columns []clause.Column
- for _, dbName := range s.PrimaryFieldDBNames {
- columns = append(columns, clause.Column{Name: dbName})
- }
- return clause.OnConflict{
- Columns: columns,
- DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
- }
- }
- return clause.OnConflict{DoNothing: true}
- }
- func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
- var (
- selects, omits []string
- onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
- refName = rel.Name + "."
- )
- for name, ok := range selectColumns {
- columnName := ""
- if strings.HasPrefix(name, refName) {
- columnName = strings.TrimPrefix(name, refName)
- } else if strings.HasPrefix(name, clause.Associations) {
- columnName = name
- }
- if columnName != "" {
- if ok {
- selects = append(selects, columnName)
- } else {
- omits = append(omits, columnName)
- }
- }
- }
- tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
- SkipHooks: db.Statement.SkipHooks,
- DisableNestedTransaction: true,
- })
- db.Statement.Settings.Range(func(k, v interface{}) bool {
- tx.Statement.Settings.Store(k, v)
- return true
- })
- if len(selects) > 0 {
- tx = tx.Select(selects)
- }
- if len(omits) > 0 {
- tx = tx.Omit(omits...)
- }
- return db.AddError(tx.Create(values).Error)
- }
|