association.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. package gorm
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/schema"
  8. "gorm.io/gorm/utils"
  9. )
  10. // Association Mode contains some helper methods to handle relationship things easily.
  11. type Association struct {
  12. DB *DB
  13. Relationship *schema.Relationship
  14. Error error
  15. }
  16. func (db *DB) Association(column string) *Association {
  17. association := &Association{DB: db}
  18. table := db.Statement.Table
  19. if err := db.Statement.Parse(db.Statement.Model); err == nil {
  20. db.Statement.Table = table
  21. association.Relationship = db.Statement.Schema.Relationships.Relations[column]
  22. if association.Relationship == nil {
  23. association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
  24. }
  25. db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
  26. for db.Statement.ReflectValue.Kind() == reflect.Ptr {
  27. db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
  28. }
  29. } else {
  30. association.Error = err
  31. }
  32. return association
  33. }
  34. func (association *Association) Find(out interface{}, conds ...interface{}) error {
  35. if association.Error == nil {
  36. association.Error = association.buildCondition().Find(out, conds...).Error
  37. }
  38. return association.Error
  39. }
  40. func (association *Association) Append(values ...interface{}) error {
  41. if association.Error == nil {
  42. switch association.Relationship.Type {
  43. case schema.HasOne, schema.BelongsTo:
  44. if len(values) > 0 {
  45. association.Error = association.Replace(values...)
  46. }
  47. default:
  48. association.saveAssociation( /*clear*/ false, values...)
  49. }
  50. }
  51. return association.Error
  52. }
  53. func (association *Association) Replace(values ...interface{}) error {
  54. if association.Error == nil {
  55. // save associations
  56. if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
  57. return association.Error
  58. }
  59. // set old associations's foreign key to null
  60. reflectValue := association.DB.Statement.ReflectValue
  61. rel := association.Relationship
  62. switch rel.Type {
  63. case schema.BelongsTo:
  64. if len(values) == 0 {
  65. updateMap := map[string]interface{}{}
  66. switch reflectValue.Kind() {
  67. case reflect.Slice, reflect.Array:
  68. for i := 0; i < reflectValue.Len(); i++ {
  69. association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
  70. }
  71. case reflect.Struct:
  72. association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
  73. }
  74. for _, ref := range rel.References {
  75. updateMap[ref.ForeignKey.DBName] = nil
  76. }
  77. association.Error = association.DB.UpdateColumns(updateMap).Error
  78. }
  79. case schema.HasOne, schema.HasMany:
  80. var (
  81. primaryFields []*schema.Field
  82. foreignKeys []string
  83. updateMap = map[string]interface{}{}
  84. relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
  85. modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
  86. tx = association.DB.Model(modelValue)
  87. )
  88. if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
  89. if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
  90. tx.Not(clause.IN{Column: column, Values: values})
  91. }
  92. }
  93. for _, ref := range rel.References {
  94. if ref.OwnPrimaryKey {
  95. primaryFields = append(primaryFields, ref.PrimaryKey)
  96. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  97. updateMap[ref.ForeignKey.DBName] = nil
  98. } else if ref.PrimaryValue != "" {
  99. tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  100. }
  101. }
  102. if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
  103. column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  104. association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
  105. }
  106. case schema.Many2Many:
  107. var (
  108. primaryFields, relPrimaryFields []*schema.Field
  109. joinPrimaryKeys, joinRelPrimaryKeys []string
  110. modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
  111. tx = association.DB.Model(modelValue)
  112. )
  113. for _, ref := range rel.References {
  114. if ref.PrimaryValue == "" {
  115. if ref.OwnPrimaryKey {
  116. primaryFields = append(primaryFields, ref.PrimaryKey)
  117. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  118. } else {
  119. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  120. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  121. }
  122. } else {
  123. tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  124. }
  125. }
  126. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  127. if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
  128. tx.Where(clause.IN{Column: column, Values: values})
  129. } else {
  130. return ErrPrimaryKeyRequired
  131. }
  132. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  133. if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
  134. tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
  135. }
  136. association.Error = tx.Delete(modelValue).Error
  137. }
  138. }
  139. return association.Error
  140. }
  141. func (association *Association) Delete(values ...interface{}) error {
  142. if association.Error == nil {
  143. var (
  144. reflectValue = association.DB.Statement.ReflectValue
  145. rel = association.Relationship
  146. primaryFields []*schema.Field
  147. foreignKeys []string
  148. updateAttrs = map[string]interface{}{}
  149. conds []clause.Expression
  150. )
  151. for _, ref := range rel.References {
  152. if ref.PrimaryValue == "" {
  153. primaryFields = append(primaryFields, ref.PrimaryKey)
  154. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  155. updateAttrs[ref.ForeignKey.DBName] = nil
  156. } else {
  157. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  158. }
  159. }
  160. switch rel.Type {
  161. case schema.BelongsTo:
  162. tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
  163. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
  164. pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
  165. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  166. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
  167. relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
  168. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  169. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  170. case schema.HasOne, schema.HasMany:
  171. tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
  172. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  173. pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  174. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  175. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  176. relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
  177. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  178. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  179. case schema.Many2Many:
  180. var (
  181. primaryFields, relPrimaryFields []*schema.Field
  182. joinPrimaryKeys, joinRelPrimaryKeys []string
  183. joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
  184. )
  185. for _, ref := range rel.References {
  186. if ref.PrimaryValue == "" {
  187. if ref.OwnPrimaryKey {
  188. primaryFields = append(primaryFields, ref.PrimaryKey)
  189. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  190. } else {
  191. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  192. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  193. }
  194. } else {
  195. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  196. }
  197. }
  198. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  199. pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
  200. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  201. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  202. relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
  203. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  204. association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
  205. }
  206. if association.Error == nil {
  207. // clean up deleted values's foreign key
  208. relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  209. cleanUpDeletedRelations := func(data reflect.Value) {
  210. if _, zero := rel.Field.ValueOf(data); !zero {
  211. fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
  212. primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
  213. switch fieldValue.Kind() {
  214. case reflect.Slice, reflect.Array:
  215. validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
  216. for i := 0; i < fieldValue.Len(); i++ {
  217. for idx, field := range rel.FieldSchema.PrimaryFields {
  218. primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
  219. }
  220. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
  221. validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
  222. }
  223. }
  224. association.Error = rel.Field.Set(data, validFieldValues.Interface())
  225. case reflect.Struct:
  226. for idx, field := range rel.FieldSchema.PrimaryFields {
  227. primaryValues[idx], _ = field.ValueOf(fieldValue)
  228. }
  229. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
  230. if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
  231. break
  232. }
  233. if rel.JoinTable == nil {
  234. for _, ref := range rel.References {
  235. if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
  236. association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  237. } else {
  238. association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  239. }
  240. }
  241. }
  242. }
  243. }
  244. }
  245. }
  246. switch reflectValue.Kind() {
  247. case reflect.Slice, reflect.Array:
  248. for i := 0; i < reflectValue.Len(); i++ {
  249. cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
  250. }
  251. case reflect.Struct:
  252. cleanUpDeletedRelations(reflectValue)
  253. }
  254. }
  255. }
  256. return association.Error
  257. }
  258. func (association *Association) Clear() error {
  259. return association.Replace()
  260. }
  261. func (association *Association) Count() (count int64) {
  262. if association.Error == nil {
  263. association.Error = association.buildCondition().Count(&count).Error
  264. }
  265. return
  266. }
  267. type assignBack struct {
  268. Source reflect.Value
  269. Index int
  270. Dest reflect.Value
  271. }
  272. func (association *Association) saveAssociation(clear bool, values ...interface{}) {
  273. var (
  274. reflectValue = association.DB.Statement.ReflectValue
  275. assignBacks []assignBack // assign association values back to arguments after save
  276. )
  277. appendToRelations := func(source, rv reflect.Value, clear bool) {
  278. switch association.Relationship.Type {
  279. case schema.HasOne, schema.BelongsTo:
  280. switch rv.Kind() {
  281. case reflect.Slice, reflect.Array:
  282. if rv.Len() > 0 {
  283. association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
  284. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  285. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
  286. }
  287. }
  288. case reflect.Struct:
  289. association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
  290. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  291. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
  292. }
  293. }
  294. case schema.HasMany, schema.Many2Many:
  295. elemType := association.Relationship.Field.IndirectFieldType.Elem()
  296. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
  297. if clear {
  298. fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
  299. }
  300. appendToFieldValues := func(ev reflect.Value) {
  301. if ev.Type().AssignableTo(elemType) {
  302. fieldValue = reflect.Append(fieldValue, ev)
  303. } else if ev.Type().Elem().AssignableTo(elemType) {
  304. fieldValue = reflect.Append(fieldValue, ev.Elem())
  305. } else {
  306. association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
  307. }
  308. if elemType.Kind() == reflect.Struct {
  309. assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
  310. }
  311. }
  312. switch rv.Kind() {
  313. case reflect.Slice, reflect.Array:
  314. for i := 0; i < rv.Len(); i++ {
  315. appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
  316. }
  317. case reflect.Struct:
  318. appendToFieldValues(rv.Addr())
  319. }
  320. if association.Error == nil {
  321. association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
  322. }
  323. }
  324. }
  325. selectedSaveColumns := []string{association.Relationship.Name}
  326. omitColumns := []string{}
  327. selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
  328. for name, ok := range selectColumns {
  329. columnName := ""
  330. if strings.HasPrefix(name, association.Relationship.Name) {
  331. if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
  332. columnName = name
  333. }
  334. } else if strings.HasPrefix(name, clause.Associations) {
  335. columnName = name
  336. }
  337. if columnName != "" {
  338. if ok {
  339. selectedSaveColumns = append(selectedSaveColumns, columnName)
  340. } else {
  341. omitColumns = append(omitColumns, columnName)
  342. }
  343. }
  344. }
  345. for _, ref := range association.Relationship.References {
  346. if !ref.OwnPrimaryKey {
  347. selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
  348. }
  349. }
  350. associationDB := association.DB.Session(&Session{}).Model(nil)
  351. if !association.DB.FullSaveAssociations {
  352. associationDB.Select(selectedSaveColumns)
  353. }
  354. if len(omitColumns) > 0 {
  355. associationDB.Omit(omitColumns...)
  356. }
  357. associationDB = associationDB.Session(&Session{})
  358. switch reflectValue.Kind() {
  359. case reflect.Slice, reflect.Array:
  360. if len(values) != reflectValue.Len() {
  361. // clear old data
  362. if clear && len(values) == 0 {
  363. for i := 0; i < reflectValue.Len(); i++ {
  364. if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
  365. association.Error = err
  366. break
  367. }
  368. if association.Relationship.JoinTable == nil {
  369. for _, ref := range association.Relationship.References {
  370. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  371. if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
  372. association.Error = err
  373. break
  374. }
  375. }
  376. }
  377. }
  378. }
  379. break
  380. }
  381. association.Error = ErrInvalidValueOfLength
  382. return
  383. }
  384. for i := 0; i < reflectValue.Len(); i++ {
  385. appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
  386. // TODO support save slice data, sql with case?
  387. association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
  388. }
  389. case reflect.Struct:
  390. // clear old data
  391. if clear && len(values) == 0 {
  392. association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
  393. if association.Relationship.JoinTable == nil && association.Error == nil {
  394. for _, ref := range association.Relationship.References {
  395. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  396. association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  397. }
  398. }
  399. }
  400. }
  401. for idx, value := range values {
  402. rv := reflect.Indirect(reflect.ValueOf(value))
  403. appendToRelations(reflectValue, rv, clear && idx == 0)
  404. }
  405. if len(values) > 0 {
  406. association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
  407. }
  408. }
  409. for _, assignBack := range assignBacks {
  410. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
  411. if assignBack.Index > 0 {
  412. reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
  413. } else {
  414. reflect.Indirect(assignBack.Dest).Set(fieldValue)
  415. }
  416. }
  417. }
  418. func (association *Association) buildCondition() *DB {
  419. var (
  420. queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
  421. modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
  422. tx = association.DB.Model(modelValue)
  423. )
  424. if association.Relationship.JoinTable != nil {
  425. if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
  426. joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
  427. for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
  428. joinStmt.AddClause(queryClause)
  429. }
  430. joinStmt.Build("WHERE")
  431. tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
  432. }
  433. tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
  434. Table: clause.Table{Name: association.Relationship.JoinTable.Table},
  435. ON: clause.Where{Exprs: queryConds},
  436. }}})
  437. } else {
  438. tx.Clauses(clause.Where{Exprs: queryConds})
  439. }
  440. return tx
  441. }