preload.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package callbacks
  2. import (
  3. "fmt"
  4. "reflect"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/schema"
  8. "gorm.io/gorm/utils"
  9. )
  10. func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
  11. var (
  12. reflectValue = db.Statement.ReflectValue
  13. tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
  14. relForeignKeys []string
  15. relForeignFields []*schema.Field
  16. foreignFields []*schema.Field
  17. foreignValues [][]interface{}
  18. identityMap = map[string][]reflect.Value{}
  19. inlineConds []interface{}
  20. )
  21. db.Statement.Settings.Range(func(k, v interface{}) bool {
  22. tx.Statement.Settings.Store(k, v)
  23. return true
  24. })
  25. if rel.JoinTable != nil {
  26. var (
  27. joinForeignFields = make([]*schema.Field, 0, len(rel.References))
  28. joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
  29. joinForeignKeys = make([]string, 0, len(rel.References))
  30. )
  31. for _, ref := range rel.References {
  32. if ref.OwnPrimaryKey {
  33. joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
  34. joinForeignFields = append(joinForeignFields, ref.ForeignKey)
  35. foreignFields = append(foreignFields, ref.PrimaryKey)
  36. } else if ref.PrimaryValue != "" {
  37. tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  38. } else {
  39. joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
  40. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  41. relForeignFields = append(relForeignFields, ref.PrimaryKey)
  42. }
  43. }
  44. joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
  45. if len(joinForeignValues) == 0 {
  46. return
  47. }
  48. joinResults := rel.JoinTable.MakeSlice().Elem()
  49. column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
  50. db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
  51. // convert join identity map to relation identity map
  52. fieldValues := make([]interface{}, len(joinForeignFields))
  53. joinFieldValues := make([]interface{}, len(joinRelForeignFields))
  54. for i := 0; i < joinResults.Len(); i++ {
  55. for idx, field := range joinForeignFields {
  56. fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
  57. }
  58. for idx, field := range joinRelForeignFields {
  59. joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
  60. }
  61. if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
  62. joinKey := utils.ToStringKey(joinFieldValues...)
  63. identityMap[joinKey] = append(identityMap[joinKey], results...)
  64. }
  65. }
  66. _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
  67. } else {
  68. for _, ref := range rel.References {
  69. if ref.OwnPrimaryKey {
  70. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  71. relForeignFields = append(relForeignFields, ref.ForeignKey)
  72. foreignFields = append(foreignFields, ref.PrimaryKey)
  73. } else if ref.PrimaryValue != "" {
  74. tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  75. } else {
  76. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  77. relForeignFields = append(relForeignFields, ref.PrimaryKey)
  78. foreignFields = append(foreignFields, ref.ForeignKey)
  79. }
  80. }
  81. identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
  82. if len(foreignValues) == 0 {
  83. return
  84. }
  85. }
  86. // nested preload
  87. for p, pvs := range preloads {
  88. tx = tx.Preload(p, pvs...)
  89. }
  90. reflectResults := rel.FieldSchema.MakeSlice().Elem()
  91. column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
  92. if len(values) != 0 {
  93. for _, cond := range conds {
  94. if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
  95. tx = fc(tx)
  96. } else {
  97. inlineConds = append(inlineConds, cond)
  98. }
  99. }
  100. db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
  101. }
  102. fieldValues := make([]interface{}, len(relForeignFields))
  103. // clean up old values before preloading
  104. switch reflectValue.Kind() {
  105. case reflect.Struct:
  106. switch rel.Type {
  107. case schema.HasMany, schema.Many2Many:
  108. rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
  109. default:
  110. rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
  111. }
  112. case reflect.Slice, reflect.Array:
  113. for i := 0; i < reflectValue.Len(); i++ {
  114. switch rel.Type {
  115. case schema.HasMany, schema.Many2Many:
  116. rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
  117. default:
  118. rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
  119. }
  120. }
  121. }
  122. for i := 0; i < reflectResults.Len(); i++ {
  123. elem := reflectResults.Index(i)
  124. for idx, field := range relForeignFields {
  125. fieldValues[idx], _ = field.ValueOf(elem)
  126. }
  127. if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok {
  128. for _, data := range datas {
  129. reflectFieldValue := rel.Field.ReflectValueOf(data)
  130. if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
  131. reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
  132. }
  133. reflectFieldValue = reflect.Indirect(reflectFieldValue)
  134. switch reflectFieldValue.Kind() {
  135. case reflect.Struct:
  136. rel.Field.Set(data, reflectResults.Index(i).Interface())
  137. case reflect.Slice, reflect.Array:
  138. if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
  139. rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
  140. } else {
  141. rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
  142. }
  143. }
  144. }
  145. } else {
  146. db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()))
  147. }
  148. }
  149. }