association.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/schema"
  9. "gorm.io/gorm/utils"
  10. )
  11. // Association Mode contains some helper methods to handle relationship things easily.
  12. type Association struct {
  13. DB *DB
  14. Relationship *schema.Relationship
  15. Error error
  16. }
  17. func (db *DB) Association(column string) *Association {
  18. association := &Association{DB: db}
  19. table := db.Statement.Table
  20. if err := db.Statement.Parse(db.Statement.Model); err == nil {
  21. db.Statement.Table = table
  22. association.Relationship = db.Statement.Schema.Relationships.Relations[column]
  23. if association.Relationship == nil {
  24. association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
  25. }
  26. db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
  27. for db.Statement.ReflectValue.Kind() == reflect.Ptr {
  28. db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
  29. }
  30. } else {
  31. association.Error = err
  32. }
  33. return association
  34. }
  35. func (association *Association) Find(out interface{}, conds ...interface{}) error {
  36. if association.Error == nil {
  37. association.Error = association.buildCondition().Find(out, conds...).Error
  38. }
  39. return association.Error
  40. }
  41. func (association *Association) Append(values ...interface{}) error {
  42. if association.Error == nil {
  43. switch association.Relationship.Type {
  44. case schema.HasOne, schema.BelongsTo:
  45. if len(values) > 0 {
  46. association.Error = association.Replace(values...)
  47. }
  48. default:
  49. association.saveAssociation( /*clear*/ false, values...)
  50. }
  51. }
  52. return association.Error
  53. }
  54. func (association *Association) Replace(values ...interface{}) error {
  55. if association.Error == nil {
  56. // save associations
  57. if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
  58. return association.Error
  59. }
  60. // set old associations's foreign key to null
  61. reflectValue := association.DB.Statement.ReflectValue
  62. rel := association.Relationship
  63. switch rel.Type {
  64. case schema.BelongsTo:
  65. if len(values) == 0 {
  66. updateMap := map[string]interface{}{}
  67. switch reflectValue.Kind() {
  68. case reflect.Slice, reflect.Array:
  69. for i := 0; i < reflectValue.Len(); i++ {
  70. association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
  71. }
  72. case reflect.Struct:
  73. association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
  74. }
  75. for _, ref := range rel.References {
  76. updateMap[ref.ForeignKey.DBName] = nil
  77. }
  78. association.Error = association.DB.UpdateColumns(updateMap).Error
  79. }
  80. case schema.HasOne, schema.HasMany:
  81. var (
  82. primaryFields []*schema.Field
  83. foreignKeys []string
  84. updateMap = map[string]interface{}{}
  85. relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
  86. modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
  87. tx = association.DB.Model(modelValue)
  88. )
  89. if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
  90. if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
  91. tx.Not(clause.IN{Column: column, Values: values})
  92. }
  93. }
  94. for _, ref := range rel.References {
  95. if ref.OwnPrimaryKey {
  96. primaryFields = append(primaryFields, ref.PrimaryKey)
  97. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  98. updateMap[ref.ForeignKey.DBName] = nil
  99. } else if ref.PrimaryValue != "" {
  100. tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  101. }
  102. }
  103. if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
  104. column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  105. association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
  106. }
  107. case schema.Many2Many:
  108. var (
  109. primaryFields, relPrimaryFields []*schema.Field
  110. joinPrimaryKeys, joinRelPrimaryKeys []string
  111. modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
  112. tx = association.DB.Model(modelValue)
  113. )
  114. for _, ref := range rel.References {
  115. if ref.PrimaryValue == "" {
  116. if ref.OwnPrimaryKey {
  117. primaryFields = append(primaryFields, ref.PrimaryKey)
  118. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  119. } else {
  120. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  121. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  122. }
  123. } else {
  124. tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  125. }
  126. }
  127. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  128. if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
  129. tx.Where(clause.IN{Column: column, Values: values})
  130. } else {
  131. return ErrPrimaryKeyRequired
  132. }
  133. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  134. if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
  135. tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
  136. }
  137. association.Error = tx.Delete(modelValue).Error
  138. }
  139. }
  140. return association.Error
  141. }
  142. func (association *Association) Delete(values ...interface{}) error {
  143. if association.Error == nil {
  144. var (
  145. reflectValue = association.DB.Statement.ReflectValue
  146. rel = association.Relationship
  147. primaryFields []*schema.Field
  148. foreignKeys []string
  149. updateAttrs = map[string]interface{}{}
  150. conds []clause.Expression
  151. )
  152. for _, ref := range rel.References {
  153. if ref.PrimaryValue == "" {
  154. primaryFields = append(primaryFields, ref.PrimaryKey)
  155. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  156. updateAttrs[ref.ForeignKey.DBName] = nil
  157. } else {
  158. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  159. }
  160. }
  161. switch rel.Type {
  162. case schema.BelongsTo:
  163. tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
  164. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
  165. pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
  166. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  167. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
  168. relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
  169. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  170. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  171. case schema.HasOne, schema.HasMany:
  172. tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
  173. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  174. pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  175. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  176. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  177. relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
  178. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  179. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  180. case schema.Many2Many:
  181. var (
  182. primaryFields, relPrimaryFields []*schema.Field
  183. joinPrimaryKeys, joinRelPrimaryKeys []string
  184. joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
  185. )
  186. for _, ref := range rel.References {
  187. if ref.PrimaryValue == "" {
  188. if ref.OwnPrimaryKey {
  189. primaryFields = append(primaryFields, ref.PrimaryKey)
  190. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  191. } else {
  192. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  193. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  194. }
  195. } else {
  196. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  197. }
  198. }
  199. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  200. pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
  201. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  202. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  203. relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
  204. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  205. association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
  206. }
  207. if association.Error == nil {
  208. // clean up deleted values's foreign key
  209. relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  210. cleanUpDeletedRelations := func(data reflect.Value) {
  211. if _, zero := rel.Field.ValueOf(data); !zero {
  212. fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
  213. primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
  214. switch fieldValue.Kind() {
  215. case reflect.Slice, reflect.Array:
  216. validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
  217. for i := 0; i < fieldValue.Len(); i++ {
  218. for idx, field := range rel.FieldSchema.PrimaryFields {
  219. primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
  220. }
  221. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
  222. validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
  223. }
  224. }
  225. association.Error = rel.Field.Set(data, validFieldValues.Interface())
  226. case reflect.Struct:
  227. for idx, field := range rel.FieldSchema.PrimaryFields {
  228. primaryValues[idx], _ = field.ValueOf(fieldValue)
  229. }
  230. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
  231. if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
  232. break
  233. }
  234. if rel.JoinTable == nil {
  235. for _, ref := range rel.References {
  236. if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
  237. association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  238. } else {
  239. association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  240. }
  241. }
  242. }
  243. }
  244. }
  245. }
  246. }
  247. switch reflectValue.Kind() {
  248. case reflect.Slice, reflect.Array:
  249. for i := 0; i < reflectValue.Len(); i++ {
  250. cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
  251. }
  252. case reflect.Struct:
  253. cleanUpDeletedRelations(reflectValue)
  254. }
  255. }
  256. }
  257. return association.Error
  258. }
  259. func (association *Association) Clear() error {
  260. return association.Replace()
  261. }
  262. func (association *Association) Count() (count int64) {
  263. if association.Error == nil {
  264. association.Error = association.buildCondition().Count(&count).Error
  265. }
  266. return
  267. }
  268. type assignBack struct {
  269. Source reflect.Value
  270. Index int
  271. Dest reflect.Value
  272. }
  273. func (association *Association) saveAssociation(clear bool, values ...interface{}) {
  274. var (
  275. reflectValue = association.DB.Statement.ReflectValue
  276. assignBacks []assignBack // assign association values back to arguments after save
  277. )
  278. appendToRelations := func(source, rv reflect.Value, clear bool) {
  279. switch association.Relationship.Type {
  280. case schema.HasOne, schema.BelongsTo:
  281. switch rv.Kind() {
  282. case reflect.Slice, reflect.Array:
  283. if rv.Len() > 0 {
  284. association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
  285. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  286. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
  287. }
  288. }
  289. case reflect.Struct:
  290. association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
  291. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  292. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
  293. }
  294. }
  295. case schema.HasMany, schema.Many2Many:
  296. elemType := association.Relationship.Field.IndirectFieldType.Elem()
  297. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
  298. if clear {
  299. fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
  300. }
  301. appendToFieldValues := func(ev reflect.Value) {
  302. if ev.Type().AssignableTo(elemType) {
  303. fieldValue = reflect.Append(fieldValue, ev)
  304. } else if ev.Type().Elem().AssignableTo(elemType) {
  305. fieldValue = reflect.Append(fieldValue, ev.Elem())
  306. } else {
  307. association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
  308. }
  309. if elemType.Kind() == reflect.Struct {
  310. assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
  311. }
  312. }
  313. switch rv.Kind() {
  314. case reflect.Slice, reflect.Array:
  315. for i := 0; i < rv.Len(); i++ {
  316. appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
  317. }
  318. case reflect.Struct:
  319. appendToFieldValues(rv.Addr())
  320. }
  321. if association.Error == nil {
  322. association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
  323. }
  324. }
  325. }
  326. selectedSaveColumns := []string{association.Relationship.Name}
  327. omitColumns := []string{}
  328. selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
  329. for name, ok := range selectColumns {
  330. columnName := ""
  331. if strings.HasPrefix(name, association.Relationship.Name) {
  332. columnName = strings.TrimPrefix(name, association.Relationship.Name)
  333. } else if strings.HasPrefix(name, clause.Associations) {
  334. columnName = name
  335. }
  336. if columnName != "" {
  337. if ok {
  338. selectedSaveColumns = append(selectedSaveColumns, columnName)
  339. } else {
  340. omitColumns = append(omitColumns, columnName)
  341. }
  342. }
  343. }
  344. for _, ref := range association.Relationship.References {
  345. if !ref.OwnPrimaryKey {
  346. selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
  347. }
  348. }
  349. associationDB := association.DB.Session(&Session{}).Model(nil).Select(selectedSaveColumns).Session(&Session{})
  350. switch reflectValue.Kind() {
  351. case reflect.Slice, reflect.Array:
  352. if len(values) != reflectValue.Len() {
  353. // clear old data
  354. if clear && len(values) == 0 {
  355. for i := 0; i < reflectValue.Len(); i++ {
  356. if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
  357. association.Error = err
  358. break
  359. }
  360. if association.Relationship.JoinTable == nil {
  361. for _, ref := range association.Relationship.References {
  362. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  363. if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
  364. association.Error = err
  365. break
  366. }
  367. }
  368. }
  369. }
  370. }
  371. break
  372. }
  373. association.Error = errors.New("invalid association values, length doesn't match")
  374. return
  375. }
  376. for i := 0; i < reflectValue.Len(); i++ {
  377. appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
  378. // TODO support save slice data, sql with case?
  379. association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
  380. }
  381. case reflect.Struct:
  382. // clear old data
  383. if clear && len(values) == 0 {
  384. association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
  385. if association.Relationship.JoinTable == nil && association.Error == nil {
  386. for _, ref := range association.Relationship.References {
  387. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  388. association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  389. }
  390. }
  391. }
  392. }
  393. for idx, value := range values {
  394. rv := reflect.Indirect(reflect.ValueOf(value))
  395. appendToRelations(reflectValue, rv, clear && idx == 0)
  396. }
  397. if len(values) > 0 {
  398. association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
  399. }
  400. }
  401. for _, assignBack := range assignBacks {
  402. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
  403. if assignBack.Index > 0 {
  404. reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
  405. } else {
  406. reflect.Indirect(assignBack.Dest).Set(fieldValue)
  407. }
  408. }
  409. }
  410. func (association *Association) buildCondition() *DB {
  411. var (
  412. queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
  413. modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
  414. tx = association.DB.Model(modelValue)
  415. )
  416. if association.Relationship.JoinTable != nil {
  417. if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
  418. joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
  419. for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
  420. joinStmt.AddClause(queryClause)
  421. }
  422. joinStmt.Build("WHERE")
  423. tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
  424. }
  425. tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
  426. Table: clause.Table{Name: association.Relationship.JoinTable.Table},
  427. ON: clause.Where{Exprs: queryConds},
  428. }}})
  429. } else {
  430. tx.Clauses(clause.Where{Exprs: queryConds})
  431. }
  432. return tx
  433. }