chainable_api.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package gorm
  2. import (
  3. "fmt"
  4. "regexp"
  5. "strings"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/utils"
  8. )
  9. // Model specify the model you would like to run db operations
  10. // // update all users's name to `hello`
  11. // db.Model(&User{}).Update("name", "hello")
  12. // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
  13. // db.Model(&user).Update("name", "hello")
  14. func (db *DB) Model(value interface{}) (tx *DB) {
  15. tx = db.getInstance()
  16. tx.Statement.Model = value
  17. return
  18. }
  19. // Clauses Add clauses
  20. func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
  21. tx = db.getInstance()
  22. var whereConds []interface{}
  23. for _, cond := range conds {
  24. if c, ok := cond.(clause.Interface); ok {
  25. tx.Statement.AddClause(c)
  26. } else if optimizer, ok := cond.(StatementModifier); ok {
  27. optimizer.ModifyStatement(tx.Statement)
  28. } else {
  29. whereConds = append(whereConds, cond)
  30. }
  31. }
  32. if len(whereConds) > 0 {
  33. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
  34. }
  35. return
  36. }
  37. var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
  38. // Table specify the table you would like to run db operations
  39. func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
  40. tx = db.getInstance()
  41. if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
  42. tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
  43. if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
  44. tx.Statement.Table = results[1]
  45. }
  46. } else if tables := strings.Split(name, "."); len(tables) == 2 {
  47. tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
  48. tx.Statement.Table = tables[1]
  49. } else {
  50. tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
  51. tx.Statement.Table = name
  52. }
  53. return
  54. }
  55. // Distinct specify distinct fields that you want querying
  56. func (db *DB) Distinct(args ...interface{}) (tx *DB) {
  57. tx = db.getInstance()
  58. tx.Statement.Distinct = true
  59. if len(args) > 0 {
  60. tx = tx.Select(args[0], args[1:]...)
  61. }
  62. return
  63. }
  64. // Select specify fields that you want when querying, creating, updating
  65. func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
  66. tx = db.getInstance()
  67. switch v := query.(type) {
  68. case []string:
  69. tx.Statement.Selects = v
  70. for _, arg := range args {
  71. switch arg := arg.(type) {
  72. case string:
  73. tx.Statement.Selects = append(tx.Statement.Selects, arg)
  74. case []string:
  75. tx.Statement.Selects = append(tx.Statement.Selects, arg...)
  76. default:
  77. tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
  78. return
  79. }
  80. }
  81. delete(tx.Statement.Clauses, "SELECT")
  82. case string:
  83. if strings.Count(v, "?") >= len(args) && len(args) > 0 {
  84. tx.Statement.AddClause(clause.Select{
  85. Distinct: db.Statement.Distinct,
  86. Expression: clause.Expr{SQL: v, Vars: args},
  87. })
  88. } else if strings.Count(v, "@") > 0 && len(args) > 0 {
  89. tx.Statement.AddClause(clause.Select{
  90. Distinct: db.Statement.Distinct,
  91. Expression: clause.NamedExpr{SQL: v, Vars: args},
  92. })
  93. } else {
  94. tx.Statement.Selects = []string{v}
  95. for _, arg := range args {
  96. switch arg := arg.(type) {
  97. case string:
  98. tx.Statement.Selects = append(tx.Statement.Selects, arg)
  99. case []string:
  100. tx.Statement.Selects = append(tx.Statement.Selects, arg...)
  101. default:
  102. tx.Statement.AddClause(clause.Select{
  103. Distinct: db.Statement.Distinct,
  104. Expression: clause.Expr{SQL: v, Vars: args},
  105. })
  106. return
  107. }
  108. }
  109. delete(tx.Statement.Clauses, "SELECT")
  110. }
  111. default:
  112. tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
  113. }
  114. return
  115. }
  116. // Omit specify fields that you want to ignore when creating, updating and querying
  117. func (db *DB) Omit(columns ...string) (tx *DB) {
  118. tx = db.getInstance()
  119. if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
  120. tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
  121. } else {
  122. tx.Statement.Omits = columns
  123. }
  124. return
  125. }
  126. // Where add conditions
  127. func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
  128. tx = db.getInstance()
  129. if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
  130. tx.Statement.AddClause(clause.Where{Exprs: conds})
  131. }
  132. return
  133. }
  134. // Not add NOT conditions
  135. func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
  136. tx = db.getInstance()
  137. if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
  138. tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
  139. }
  140. return
  141. }
  142. // Or add OR conditions
  143. func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
  144. tx = db.getInstance()
  145. if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
  146. tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
  147. }
  148. return
  149. }
  150. // Joins specify Joins conditions
  151. // db.Joins("Account").Find(&user)
  152. // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
  153. // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
  154. func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
  155. tx = db.getInstance()
  156. if len(args) > 0 {
  157. if db, ok := args[0].(*DB); ok {
  158. if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
  159. tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where})
  160. }
  161. return
  162. }
  163. }
  164. tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
  165. return
  166. }
  167. // Group specify the group method on the find
  168. func (db *DB) Group(name string) (tx *DB) {
  169. tx = db.getInstance()
  170. fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
  171. tx.Statement.AddClause(clause.GroupBy{
  172. Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
  173. })
  174. return
  175. }
  176. // Having specify HAVING conditions for GROUP BY
  177. func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
  178. tx = db.getInstance()
  179. tx.Statement.AddClause(clause.GroupBy{
  180. Having: tx.Statement.BuildCondition(query, args...),
  181. })
  182. return
  183. }
  184. // Order specify order when retrieve records from database
  185. // db.Order("name DESC")
  186. // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
  187. func (db *DB) Order(value interface{}) (tx *DB) {
  188. tx = db.getInstance()
  189. switch v := value.(type) {
  190. case clause.OrderByColumn:
  191. tx.Statement.AddClause(clause.OrderBy{
  192. Columns: []clause.OrderByColumn{v},
  193. })
  194. case string:
  195. if v != "" {
  196. tx.Statement.AddClause(clause.OrderBy{
  197. Columns: []clause.OrderByColumn{{
  198. Column: clause.Column{Name: v, Raw: true},
  199. }},
  200. })
  201. }
  202. }
  203. return
  204. }
  205. // Limit specify the number of records to be retrieved
  206. func (db *DB) Limit(limit int) (tx *DB) {
  207. tx = db.getInstance()
  208. tx.Statement.AddClause(clause.Limit{Limit: limit})
  209. return
  210. }
  211. // Offset specify the number of records to skip before starting to return the records
  212. func (db *DB) Offset(offset int) (tx *DB) {
  213. tx = db.getInstance()
  214. tx.Statement.AddClause(clause.Limit{Offset: offset})
  215. return
  216. }
  217. // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
  218. // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
  219. // return db.Where("amount > ?", 1000)
  220. // }
  221. //
  222. // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
  223. // return func (db *gorm.DB) *gorm.DB {
  224. // return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
  225. // }
  226. // }
  227. //
  228. // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
  229. func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
  230. tx = db.getInstance()
  231. tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
  232. return tx
  233. }
  234. // Preload preload associations with given conditions
  235. // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
  236. func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
  237. tx = db.getInstance()
  238. if tx.Statement.Preloads == nil {
  239. tx.Statement.Preloads = map[string][]interface{}{}
  240. }
  241. tx.Statement.Preloads[query] = args
  242. return
  243. }
  244. func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
  245. tx = db.getInstance()
  246. tx.Statement.attrs = attrs
  247. return
  248. }
  249. func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
  250. tx = db.getInstance()
  251. tx.Statement.assigns = attrs
  252. return
  253. }
  254. func (db *DB) Unscoped() (tx *DB) {
  255. tx = db.getInstance()
  256. tx.Statement.Unscoped = true
  257. return
  258. }
  259. func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
  260. tx = db.getInstance()
  261. tx.Statement.SQL = strings.Builder{}
  262. if strings.Contains(sql, "@") {
  263. clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
  264. } else {
  265. clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
  266. }
  267. return
  268. }