query.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package callbacks
  2. import (
  3. "fmt"
  4. "reflect"
  5. "sort"
  6. "strings"
  7. "gorm.io/gorm"
  8. "gorm.io/gorm/clause"
  9. )
  10. func Query(db *gorm.DB) {
  11. if db.Error == nil {
  12. BuildQuerySQL(db)
  13. if !db.DryRun && db.Error == nil {
  14. rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
  15. if err != nil {
  16. db.AddError(err)
  17. return
  18. }
  19. defer rows.Close()
  20. gorm.Scan(rows, db, false)
  21. }
  22. }
  23. }
  24. func BuildQuerySQL(db *gorm.DB) {
  25. if db.Statement.Schema != nil && !db.Statement.Unscoped {
  26. for _, c := range db.Statement.Schema.QueryClauses {
  27. db.Statement.AddClause(c)
  28. }
  29. }
  30. if db.Statement.SQL.String() == "" {
  31. db.Statement.SQL.Grow(100)
  32. clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
  33. if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
  34. var conds []clause.Expression
  35. for _, primaryField := range db.Statement.Schema.PrimaryFields {
  36. if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
  37. conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
  38. }
  39. }
  40. if len(conds) > 0 {
  41. db.Statement.AddClause(clause.Where{Exprs: conds})
  42. }
  43. }
  44. if len(db.Statement.Selects) > 0 {
  45. clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
  46. for idx, name := range db.Statement.Selects {
  47. if db.Statement.Schema == nil {
  48. clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
  49. } else if f := db.Statement.Schema.LookUpField(name); f != nil {
  50. clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
  51. } else {
  52. clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
  53. }
  54. }
  55. } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
  56. selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
  57. clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
  58. for _, dbName := range db.Statement.Schema.DBNames {
  59. if v, ok := selectColumns[dbName]; (ok && v) || !ok {
  60. clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
  61. }
  62. }
  63. } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
  64. queryFields := db.QueryFields
  65. if !queryFields {
  66. switch db.Statement.ReflectValue.Kind() {
  67. case reflect.Struct:
  68. queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
  69. case reflect.Slice:
  70. queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
  71. }
  72. }
  73. if queryFields {
  74. stmt := gorm.Statement{DB: db}
  75. // smaller struct
  76. if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
  77. clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
  78. for idx, dbName := range stmt.Schema.DBNames {
  79. clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
  80. }
  81. }
  82. }
  83. }
  84. // inline joins
  85. if len(db.Statement.Joins) != 0 {
  86. if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
  87. clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
  88. for idx, dbName := range db.Statement.Schema.DBNames {
  89. clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
  90. }
  91. }
  92. joins := []clause.Join{}
  93. if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
  94. joins = fromClause.Joins
  95. }
  96. for _, join := range db.Statement.Joins {
  97. if db.Statement.Schema == nil {
  98. joins = append(joins, clause.Join{
  99. Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
  100. })
  101. } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
  102. tableAliasName := relation.Name
  103. for _, s := range relation.FieldSchema.DBNames {
  104. clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
  105. Table: tableAliasName,
  106. Name: s,
  107. Alias: tableAliasName + "__" + s,
  108. })
  109. }
  110. exprs := make([]clause.Expression, len(relation.References))
  111. for idx, ref := range relation.References {
  112. if ref.OwnPrimaryKey {
  113. exprs[idx] = clause.Eq{
  114. Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
  115. Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
  116. }
  117. } else {
  118. if ref.PrimaryValue == "" {
  119. exprs[idx] = clause.Eq{
  120. Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
  121. Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
  122. }
  123. } else {
  124. exprs[idx] = clause.Eq{
  125. Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
  126. Value: ref.PrimaryValue,
  127. }
  128. }
  129. }
  130. }
  131. if join.On != nil {
  132. onStmt := gorm.Statement{Table: tableAliasName, DB: db}
  133. join.On.Build(&onStmt)
  134. onSQL := onStmt.SQL.String()
  135. vars := onStmt.Vars
  136. for idx, v := range onStmt.Vars {
  137. bindvar := strings.Builder{}
  138. onStmt.Vars = vars[0 : idx+1]
  139. db.Dialector.BindVarTo(&bindvar, &onStmt, v)
  140. onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
  141. }
  142. exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
  143. }
  144. joins = append(joins, clause.Join{
  145. Type: clause.LeftJoin,
  146. Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
  147. ON: clause.Where{Exprs: exprs},
  148. })
  149. } else {
  150. joins = append(joins, clause.Join{
  151. Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
  152. })
  153. }
  154. }
  155. db.Statement.Joins = nil
  156. db.Statement.AddClause(clause.From{Joins: joins})
  157. } else {
  158. db.Statement.AddClauseIfNotExists(clause.From{})
  159. }
  160. db.Statement.AddClauseIfNotExists(clauseSelect)
  161. db.Statement.Build(db.Statement.BuildClauses...)
  162. }
  163. }
  164. func Preload(db *gorm.DB) {
  165. if db.Error == nil && len(db.Statement.Preloads) > 0 {
  166. preloadMap := map[string]map[string][]interface{}{}
  167. for name := range db.Statement.Preloads {
  168. preloadFields := strings.Split(name, ".")
  169. if preloadFields[0] == clause.Associations {
  170. for _, rel := range db.Statement.Schema.Relationships.Relations {
  171. if rel.Schema == db.Statement.Schema {
  172. if _, ok := preloadMap[rel.Name]; !ok {
  173. preloadMap[rel.Name] = map[string][]interface{}{}
  174. }
  175. if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
  176. preloadMap[rel.Name][value] = db.Statement.Preloads[name]
  177. }
  178. }
  179. }
  180. } else {
  181. if _, ok := preloadMap[preloadFields[0]]; !ok {
  182. preloadMap[preloadFields[0]] = map[string][]interface{}{}
  183. }
  184. if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
  185. preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
  186. }
  187. }
  188. }
  189. preloadNames := make([]string, 0, len(preloadMap))
  190. for key := range preloadMap {
  191. preloadNames = append(preloadNames, key)
  192. }
  193. sort.Strings(preloadNames)
  194. for _, name := range preloadNames {
  195. if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
  196. preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
  197. } else {
  198. db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
  199. }
  200. }
  201. }
  202. }
  203. func AfterQuery(db *gorm.DB) {
  204. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
  205. callMethod(db, func(value interface{}, tx *gorm.DB) bool {
  206. if i, ok := value.(AfterFindInterface); ok {
  207. db.AddError(i.AfterFind(tx))
  208. return true
  209. }
  210. return false
  211. })
  212. }
  213. }