relationship.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. package schema
  2. import (
  3. "fmt"
  4. "reflect"
  5. "regexp"
  6. "strings"
  7. "github.com/jinzhu/inflection"
  8. "gorm.io/gorm/clause"
  9. )
  10. // RelationshipType relationship type
  11. type RelationshipType string
  12. const (
  13. HasOne RelationshipType = "has_one" // HasOneRel has one relationship
  14. HasMany RelationshipType = "has_many" // HasManyRel has many relationship
  15. BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
  16. Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
  17. )
  18. type Relationships struct {
  19. HasOne []*Relationship
  20. BelongsTo []*Relationship
  21. HasMany []*Relationship
  22. Many2Many []*Relationship
  23. Relations map[string]*Relationship
  24. }
  25. type Relationship struct {
  26. Name string
  27. Type RelationshipType
  28. Field *Field
  29. Polymorphic *Polymorphic
  30. References []*Reference
  31. Schema *Schema
  32. FieldSchema *Schema
  33. JoinTable *Schema
  34. foreignKeys, primaryKeys []string
  35. }
  36. type Polymorphic struct {
  37. PolymorphicID *Field
  38. PolymorphicType *Field
  39. Value string
  40. }
  41. type Reference struct {
  42. PrimaryKey *Field
  43. PrimaryValue string
  44. ForeignKey *Field
  45. OwnPrimaryKey bool
  46. }
  47. func (schema *Schema) parseRelation(field *Field) *Relationship {
  48. var (
  49. err error
  50. fieldValue = reflect.New(field.IndirectFieldType).Interface()
  51. relation = &Relationship{
  52. Name: field.Name,
  53. Field: field,
  54. Schema: schema,
  55. foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
  56. primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
  57. }
  58. )
  59. cacheStore := schema.cacheStore
  60. if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
  61. schema.err = err
  62. return nil
  63. }
  64. if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
  65. schema.buildPolymorphicRelation(relation, field, polymorphic)
  66. } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
  67. schema.buildMany2ManyRelation(relation, field, many2many)
  68. } else {
  69. switch field.IndirectFieldType.Kind() {
  70. case reflect.Struct:
  71. schema.guessRelation(relation, field, guessBelongs)
  72. case reflect.Slice:
  73. schema.guessRelation(relation, field, guessHas)
  74. default:
  75. schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
  76. }
  77. }
  78. if relation.Type == "has" {
  79. // don't add relations to embeded schema, which might be shared
  80. if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
  81. relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
  82. }
  83. switch field.IndirectFieldType.Kind() {
  84. case reflect.Struct:
  85. relation.Type = HasOne
  86. case reflect.Slice:
  87. relation.Type = HasMany
  88. }
  89. }
  90. if schema.err == nil {
  91. schema.Relationships.Relations[relation.Name] = relation
  92. switch relation.Type {
  93. case HasOne:
  94. schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
  95. case HasMany:
  96. schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
  97. case BelongsTo:
  98. schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
  99. case Many2Many:
  100. schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
  101. }
  102. }
  103. return relation
  104. }
  105. // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
  106. // type User struct {
  107. // Toys []Toy `gorm:"polymorphic:Owner;"`
  108. // }
  109. // type Pet struct {
  110. // Toy Toy `gorm:"polymorphic:Owner;"`
  111. // }
  112. // type Toy struct {
  113. // OwnerID int
  114. // OwnerType string
  115. // }
  116. func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
  117. relation.Polymorphic = &Polymorphic{
  118. Value: schema.Table,
  119. PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
  120. PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
  121. }
  122. if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
  123. relation.Polymorphic.Value = strings.TrimSpace(value)
  124. }
  125. if relation.Polymorphic.PolymorphicType == nil {
  126. schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
  127. }
  128. if relation.Polymorphic.PolymorphicID == nil {
  129. schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
  130. }
  131. if schema.err == nil {
  132. relation.References = append(relation.References, &Reference{
  133. PrimaryValue: relation.Polymorphic.Value,
  134. ForeignKey: relation.Polymorphic.PolymorphicType,
  135. })
  136. primaryKeyField := schema.PrioritizedPrimaryField
  137. if len(relation.foreignKeys) > 0 {
  138. if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
  139. schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
  140. }
  141. }
  142. // use same data type for foreign keys
  143. relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
  144. relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
  145. if relation.Polymorphic.PolymorphicID.Size == 0 {
  146. relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
  147. }
  148. relation.References = append(relation.References, &Reference{
  149. PrimaryKey: primaryKeyField,
  150. ForeignKey: relation.Polymorphic.PolymorphicID,
  151. OwnPrimaryKey: true,
  152. })
  153. }
  154. relation.Type = "has"
  155. }
  156. func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
  157. relation.Type = Many2Many
  158. var (
  159. err error
  160. joinTableFields []reflect.StructField
  161. fieldsMap = map[string]*Field{}
  162. ownFieldsMap = map[string]bool{} // fix self join many2many
  163. joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
  164. joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
  165. )
  166. ownForeignFields := schema.PrimaryFields
  167. refForeignFields := relation.FieldSchema.PrimaryFields
  168. if len(relation.foreignKeys) > 0 {
  169. ownForeignFields = []*Field{}
  170. for _, foreignKey := range relation.foreignKeys {
  171. if field := schema.LookUpField(foreignKey); field != nil {
  172. ownForeignFields = append(ownForeignFields, field)
  173. } else {
  174. schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
  175. return
  176. }
  177. }
  178. }
  179. if len(relation.primaryKeys) > 0 {
  180. refForeignFields = []*Field{}
  181. for _, foreignKey := range relation.primaryKeys {
  182. if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
  183. refForeignFields = append(refForeignFields, field)
  184. } else {
  185. schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
  186. return
  187. }
  188. }
  189. }
  190. for idx, ownField := range ownForeignFields {
  191. joinFieldName := strings.Title(schema.Name) + ownField.Name
  192. if len(joinForeignKeys) > idx {
  193. joinFieldName = strings.Title(joinForeignKeys[idx])
  194. }
  195. ownFieldsMap[joinFieldName] = true
  196. fieldsMap[joinFieldName] = ownField
  197. joinTableFields = append(joinTableFields, reflect.StructField{
  198. Name: joinFieldName,
  199. PkgPath: ownField.StructField.PkgPath,
  200. Type: ownField.StructField.Type,
  201. Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
  202. })
  203. }
  204. for idx, relField := range refForeignFields {
  205. joinFieldName := relation.FieldSchema.Name + relField.Name
  206. if len(joinReferences) > idx {
  207. joinFieldName = strings.Title(joinReferences[idx])
  208. }
  209. if _, ok := ownFieldsMap[joinFieldName]; ok {
  210. if field.Name != relation.FieldSchema.Name {
  211. joinFieldName = inflection.Singular(field.Name) + relField.Name
  212. } else {
  213. joinFieldName += "Reference"
  214. }
  215. }
  216. fieldsMap[joinFieldName] = relField
  217. joinTableFields = append(joinTableFields, reflect.StructField{
  218. Name: joinFieldName,
  219. PkgPath: relField.StructField.PkgPath,
  220. Type: relField.StructField.Type,
  221. Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
  222. })
  223. }
  224. joinTableFields = append(joinTableFields, reflect.StructField{
  225. Name: strings.Title(schema.Name) + field.Name,
  226. Type: schema.ModelType,
  227. Tag: `gorm:"-"`,
  228. })
  229. if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
  230. schema.err = err
  231. }
  232. relation.JoinTable.Name = many2many
  233. relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
  234. relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
  235. relName := relation.Schema.Name
  236. relRefName := relation.FieldSchema.Name
  237. if relName == relRefName {
  238. relRefName = relation.Field.Name
  239. }
  240. if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
  241. relation.JoinTable.Relationships.Relations[relName] = &Relationship{
  242. Name: relName,
  243. Type: BelongsTo,
  244. Schema: relation.JoinTable,
  245. FieldSchema: relation.Schema,
  246. }
  247. } else {
  248. relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
  249. }
  250. if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
  251. relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
  252. Name: relRefName,
  253. Type: BelongsTo,
  254. Schema: relation.JoinTable,
  255. FieldSchema: relation.FieldSchema,
  256. }
  257. } else {
  258. relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
  259. }
  260. // build references
  261. for _, f := range relation.JoinTable.Fields {
  262. if f.Creatable || f.Readable || f.Updatable {
  263. // use same data type for foreign keys
  264. f.DataType = fieldsMap[f.Name].DataType
  265. f.GORMDataType = fieldsMap[f.Name].GORMDataType
  266. if f.Size == 0 {
  267. f.Size = fieldsMap[f.Name].Size
  268. }
  269. relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
  270. ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
  271. if ownPriamryField {
  272. joinRel := relation.JoinTable.Relationships.Relations[relName]
  273. joinRel.Field = relation.Field
  274. joinRel.References = append(joinRel.References, &Reference{
  275. PrimaryKey: fieldsMap[f.Name],
  276. ForeignKey: f,
  277. })
  278. } else {
  279. joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
  280. if joinRefRel.Field == nil {
  281. joinRefRel.Field = relation.Field
  282. }
  283. joinRefRel.References = append(joinRefRel.References, &Reference{
  284. PrimaryKey: fieldsMap[f.Name],
  285. ForeignKey: f,
  286. })
  287. }
  288. relation.References = append(relation.References, &Reference{
  289. PrimaryKey: fieldsMap[f.Name],
  290. ForeignKey: f,
  291. OwnPrimaryKey: ownPriamryField,
  292. })
  293. }
  294. }
  295. }
  296. type guessLevel int
  297. const (
  298. guessBelongs guessLevel = iota
  299. guessEmbeddedBelongs
  300. guessHas
  301. guessEmbeddedHas
  302. )
  303. func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
  304. var (
  305. primaryFields, foreignFields []*Field
  306. primarySchema, foreignSchema = schema, relation.FieldSchema
  307. )
  308. reguessOrErr := func() {
  309. switch gl {
  310. case guessBelongs:
  311. schema.guessRelation(relation, field, guessEmbeddedBelongs)
  312. case guessEmbeddedBelongs:
  313. schema.guessRelation(relation, field, guessHas)
  314. case guessHas:
  315. schema.guessRelation(relation, field, guessEmbeddedHas)
  316. // case guessEmbeddedHas:
  317. default:
  318. schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
  319. }
  320. }
  321. switch gl {
  322. case guessBelongs:
  323. primarySchema, foreignSchema = relation.FieldSchema, schema
  324. case guessEmbeddedBelongs:
  325. if field.OwnerSchema != nil {
  326. primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
  327. } else {
  328. reguessOrErr()
  329. return
  330. }
  331. case guessHas:
  332. case guessEmbeddedHas:
  333. if field.OwnerSchema != nil {
  334. primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
  335. } else {
  336. reguessOrErr()
  337. return
  338. }
  339. }
  340. if len(relation.foreignKeys) > 0 {
  341. for _, foreignKey := range relation.foreignKeys {
  342. if f := foreignSchema.LookUpField(foreignKey); f != nil {
  343. foreignFields = append(foreignFields, f)
  344. } else {
  345. reguessOrErr()
  346. return
  347. }
  348. }
  349. } else {
  350. var primaryFields []*Field
  351. if len(relation.primaryKeys) > 0 {
  352. for _, primaryKey := range relation.primaryKeys {
  353. if f := primarySchema.LookUpField(primaryKey); f != nil {
  354. primaryFields = append(primaryFields, f)
  355. }
  356. }
  357. } else {
  358. primaryFields = primarySchema.PrimaryFields
  359. }
  360. for _, primaryField := range primaryFields {
  361. lookUpName := primarySchema.Name + primaryField.Name
  362. if gl == guessBelongs {
  363. lookUpName = field.Name + primaryField.Name
  364. }
  365. lookUpNames := []string{lookUpName}
  366. if len(primaryFields) == 1 {
  367. lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")
  368. lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"Id")
  369. lookUpNames = append(lookUpNames, schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
  370. }
  371. for _, name := range lookUpNames {
  372. if f := foreignSchema.LookUpField(name); f != nil {
  373. foreignFields = append(foreignFields, f)
  374. primaryFields = append(primaryFields, primaryField)
  375. break
  376. }
  377. }
  378. }
  379. }
  380. if len(foreignFields) == 0 {
  381. reguessOrErr()
  382. return
  383. } else if len(relation.primaryKeys) > 0 {
  384. for idx, primaryKey := range relation.primaryKeys {
  385. if f := primarySchema.LookUpField(primaryKey); f != nil {
  386. if len(primaryFields) < idx+1 {
  387. primaryFields = append(primaryFields, f)
  388. } else if f != primaryFields[idx] {
  389. reguessOrErr()
  390. return
  391. }
  392. } else {
  393. reguessOrErr()
  394. return
  395. }
  396. }
  397. } else if len(primaryFields) == 0 {
  398. if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
  399. primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
  400. } else if len(primarySchema.PrimaryFields) == len(foreignFields) {
  401. primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
  402. } else {
  403. reguessOrErr()
  404. return
  405. }
  406. }
  407. // build references
  408. for idx, foreignField := range foreignFields {
  409. // use same data type for foreign keys
  410. foreignField.DataType = primaryFields[idx].DataType
  411. foreignField.GORMDataType = primaryFields[idx].GORMDataType
  412. if foreignField.Size == 0 {
  413. foreignField.Size = primaryFields[idx].Size
  414. }
  415. relation.References = append(relation.References, &Reference{
  416. PrimaryKey: primaryFields[idx],
  417. ForeignKey: foreignField,
  418. OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
  419. })
  420. }
  421. if gl == guessHas || gl == guessEmbeddedHas {
  422. relation.Type = "has"
  423. } else {
  424. relation.Type = BelongsTo
  425. }
  426. }
  427. type Constraint struct {
  428. Name string
  429. Field *Field
  430. Schema *Schema
  431. ForeignKeys []*Field
  432. ReferenceSchema *Schema
  433. References []*Field
  434. OnDelete string
  435. OnUpdate string
  436. }
  437. func (rel *Relationship) ParseConstraint() *Constraint {
  438. str := rel.Field.TagSettings["CONSTRAINT"]
  439. if str == "-" {
  440. return nil
  441. }
  442. var (
  443. name string
  444. idx = strings.Index(str, ",")
  445. settings = ParseTagSetting(str, ",")
  446. )
  447. if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
  448. name = str[0:idx]
  449. } else {
  450. name = rel.Schema.namer.RelationshipFKName(*rel)
  451. }
  452. constraint := Constraint{
  453. Name: name,
  454. Field: rel.Field,
  455. OnUpdate: settings["ONUPDATE"],
  456. OnDelete: settings["ONDELETE"],
  457. }
  458. for _, ref := range rel.References {
  459. if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
  460. constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
  461. constraint.References = append(constraint.References, ref.PrimaryKey)
  462. if ref.OwnPrimaryKey {
  463. constraint.Schema = ref.ForeignKey.Schema
  464. constraint.ReferenceSchema = rel.Schema
  465. } else {
  466. constraint.Schema = rel.Schema
  467. constraint.ReferenceSchema = ref.PrimaryKey.Schema
  468. }
  469. }
  470. }
  471. return &constraint
  472. }
  473. func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
  474. table := rel.FieldSchema.Table
  475. foreignFields := []*Field{}
  476. relForeignKeys := []string{}
  477. if rel.JoinTable != nil {
  478. table = rel.JoinTable.Table
  479. for _, ref := range rel.References {
  480. if ref.OwnPrimaryKey {
  481. foreignFields = append(foreignFields, ref.PrimaryKey)
  482. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  483. } else if ref.PrimaryValue != "" {
  484. conds = append(conds, clause.Eq{
  485. Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
  486. Value: ref.PrimaryValue,
  487. })
  488. } else {
  489. conds = append(conds, clause.Eq{
  490. Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
  491. Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
  492. })
  493. }
  494. }
  495. } else {
  496. for _, ref := range rel.References {
  497. if ref.OwnPrimaryKey {
  498. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  499. foreignFields = append(foreignFields, ref.PrimaryKey)
  500. } else if ref.PrimaryValue != "" {
  501. conds = append(conds, clause.Eq{
  502. Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
  503. Value: ref.PrimaryValue,
  504. })
  505. } else {
  506. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  507. foreignFields = append(foreignFields, ref.ForeignKey)
  508. }
  509. }
  510. }
  511. _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
  512. column, values := ToQueryValues(table, relForeignKeys, foreignValues)
  513. conds = append(conds, clause.IN{Column: column, Values: values})
  514. return
  515. }