finisher_api.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. package gorm
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "strings"
  8. "gorm.io/gorm/clause"
  9. "gorm.io/gorm/logger"
  10. "gorm.io/gorm/schema"
  11. "gorm.io/gorm/utils"
  12. )
  13. // Create insert the value into database
  14. func (db *DB) Create(value interface{}) (tx *DB) {
  15. if db.CreateBatchSize > 0 {
  16. return db.CreateInBatches(value, db.CreateBatchSize)
  17. }
  18. tx = db.getInstance()
  19. tx.Statement.Dest = value
  20. return tx.callbacks.Create().Execute(tx)
  21. }
  22. // CreateInBatches insert the value in batches into database
  23. func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
  24. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  25. switch reflectValue.Kind() {
  26. case reflect.Slice, reflect.Array:
  27. var rowsAffected int64
  28. tx = db.getInstance()
  29. callFc := func(tx *DB) error {
  30. // the reflection length judgment of the optimized value
  31. reflectLen := reflectValue.Len()
  32. for i := 0; i < reflectLen; i += batchSize {
  33. ends := i + batchSize
  34. if ends > reflectLen {
  35. ends = reflectLen
  36. }
  37. subtx := tx.getInstance()
  38. subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
  39. subtx.callbacks.Create().Execute(subtx)
  40. if subtx.Error != nil {
  41. return subtx.Error
  42. }
  43. rowsAffected += subtx.RowsAffected
  44. }
  45. return nil
  46. }
  47. if tx.SkipDefaultTransaction {
  48. tx.AddError(callFc(tx.Session(&Session{})))
  49. } else {
  50. tx.AddError(tx.Transaction(callFc))
  51. }
  52. tx.RowsAffected = rowsAffected
  53. default:
  54. tx = db.getInstance()
  55. tx.Statement.Dest = value
  56. tx = tx.callbacks.Create().Execute(tx)
  57. }
  58. return
  59. }
  60. // Save update value in database, if the value doesn't have primary key, will insert it
  61. func (db *DB) Save(value interface{}) (tx *DB) {
  62. tx = db.getInstance()
  63. tx.Statement.Dest = value
  64. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  65. switch reflectValue.Kind() {
  66. case reflect.Slice, reflect.Array:
  67. if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
  68. tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
  69. }
  70. tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
  71. case reflect.Struct:
  72. if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
  73. for _, pf := range tx.Statement.Schema.PrimaryFields {
  74. if _, isZero := pf.ValueOf(reflectValue); isZero {
  75. return tx.callbacks.Create().Execute(tx)
  76. }
  77. }
  78. }
  79. fallthrough
  80. default:
  81. selectedUpdate := len(tx.Statement.Selects) != 0
  82. // when updating, use all fields including those zero-value fields
  83. if !selectedUpdate {
  84. tx.Statement.Selects = append(tx.Statement.Selects, "*")
  85. }
  86. tx = tx.callbacks.Update().Execute(tx)
  87. if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
  88. result := reflect.New(tx.Statement.Schema.ModelType).Interface()
  89. if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
  90. return tx.Create(value)
  91. }
  92. }
  93. }
  94. return
  95. }
  96. // First find first record that match given conditions, order by primary key
  97. func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
  98. tx = db.Limit(1).Order(clause.OrderByColumn{
  99. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  100. })
  101. if len(conds) > 0 {
  102. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  103. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  104. }
  105. }
  106. tx.Statement.RaiseErrorOnNotFound = true
  107. tx.Statement.Dest = dest
  108. return tx.callbacks.Query().Execute(tx)
  109. }
  110. // Take return a record that match given conditions, the order will depend on the database implementation
  111. func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
  112. tx = db.Limit(1)
  113. if len(conds) > 0 {
  114. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  115. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  116. }
  117. }
  118. tx.Statement.RaiseErrorOnNotFound = true
  119. tx.Statement.Dest = dest
  120. return tx.callbacks.Query().Execute(tx)
  121. }
  122. // Last find last record that match given conditions, order by primary key
  123. func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
  124. tx = db.Limit(1).Order(clause.OrderByColumn{
  125. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  126. Desc: true,
  127. })
  128. if len(conds) > 0 {
  129. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  130. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  131. }
  132. }
  133. tx.Statement.RaiseErrorOnNotFound = true
  134. tx.Statement.Dest = dest
  135. return tx.callbacks.Query().Execute(tx)
  136. }
  137. // Find find records that match given conditions
  138. func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
  139. tx = db.getInstance()
  140. if len(conds) > 0 {
  141. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  142. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  143. }
  144. }
  145. tx.Statement.Dest = dest
  146. return tx.callbacks.Query().Execute(tx)
  147. }
  148. // FindInBatches find records in batches
  149. func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
  150. var (
  151. tx = db.Order(clause.OrderByColumn{
  152. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  153. }).Session(&Session{})
  154. queryDB = tx
  155. rowsAffected int64
  156. batch int
  157. )
  158. for {
  159. result := queryDB.Limit(batchSize).Find(dest)
  160. rowsAffected += result.RowsAffected
  161. batch++
  162. if result.Error == nil && result.RowsAffected != 0 {
  163. tx.AddError(fc(result, batch))
  164. } else if result.Error != nil {
  165. tx.AddError(result.Error)
  166. }
  167. if tx.Error != nil || int(result.RowsAffected) < batchSize {
  168. break
  169. }
  170. // Optimize for-break
  171. resultsValue := reflect.Indirect(reflect.ValueOf(dest))
  172. if result.Statement.Schema.PrioritizedPrimaryField == nil {
  173. tx.AddError(ErrPrimaryKeyRequired)
  174. break
  175. }
  176. primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
  177. queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
  178. }
  179. tx.RowsAffected = rowsAffected
  180. return tx
  181. }
  182. func (tx *DB) assignInterfacesToValue(values ...interface{}) {
  183. for _, value := range values {
  184. switch v := value.(type) {
  185. case []clause.Expression:
  186. for _, expr := range v {
  187. if eq, ok := expr.(clause.Eq); ok {
  188. switch column := eq.Column.(type) {
  189. case string:
  190. if field := tx.Statement.Schema.LookUpField(column); field != nil {
  191. tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
  192. }
  193. case clause.Column:
  194. if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
  195. tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
  196. }
  197. }
  198. } else if andCond, ok := expr.(clause.AndConditions); ok {
  199. tx.assignInterfacesToValue(andCond.Exprs)
  200. }
  201. }
  202. case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
  203. if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
  204. tx.assignInterfacesToValue(exprs)
  205. }
  206. default:
  207. if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
  208. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  209. switch reflectValue.Kind() {
  210. case reflect.Struct:
  211. for _, f := range s.Fields {
  212. if f.Readable {
  213. if v, isZero := f.ValueOf(reflectValue); !isZero {
  214. if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
  215. tx.AddError(field.Set(tx.Statement.ReflectValue, v))
  216. }
  217. }
  218. }
  219. }
  220. }
  221. } else if len(values) > 0 {
  222. if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
  223. tx.assignInterfacesToValue(exprs)
  224. }
  225. return
  226. }
  227. }
  228. }
  229. }
  230. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
  231. queryTx := db.Limit(1).Order(clause.OrderByColumn{
  232. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  233. })
  234. if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
  235. if c, ok := tx.Statement.Clauses["WHERE"]; ok {
  236. if where, ok := c.Expression.(clause.Where); ok {
  237. tx.assignInterfacesToValue(where.Exprs)
  238. }
  239. }
  240. // initialize with attrs, conds
  241. if len(tx.Statement.attrs) > 0 {
  242. tx.assignInterfacesToValue(tx.Statement.attrs...)
  243. }
  244. }
  245. // initialize with attrs, conds
  246. if len(tx.Statement.assigns) > 0 {
  247. tx.assignInterfacesToValue(tx.Statement.assigns...)
  248. }
  249. return
  250. }
  251. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
  252. queryTx := db.Limit(1).Order(clause.OrderByColumn{
  253. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  254. })
  255. if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
  256. if c, ok := tx.Statement.Clauses["WHERE"]; ok {
  257. if where, ok := c.Expression.(clause.Where); ok {
  258. tx.assignInterfacesToValue(where.Exprs)
  259. }
  260. }
  261. // initialize with attrs, conds
  262. if len(tx.Statement.attrs) > 0 {
  263. tx.assignInterfacesToValue(tx.Statement.attrs...)
  264. }
  265. // initialize with attrs, conds
  266. if len(tx.Statement.assigns) > 0 {
  267. tx.assignInterfacesToValue(tx.Statement.assigns...)
  268. }
  269. return tx.Create(dest)
  270. } else if len(db.Statement.assigns) > 0 {
  271. exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
  272. assigns := map[string]interface{}{}
  273. for _, expr := range exprs {
  274. if eq, ok := expr.(clause.Eq); ok {
  275. switch column := eq.Column.(type) {
  276. case string:
  277. assigns[column] = eq.Value
  278. case clause.Column:
  279. assigns[column.Name] = eq.Value
  280. default:
  281. }
  282. }
  283. }
  284. return tx.Model(dest).Updates(assigns)
  285. }
  286. return db
  287. }
  288. // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
  289. func (db *DB) Update(column string, value interface{}) (tx *DB) {
  290. tx = db.getInstance()
  291. tx.Statement.Dest = map[string]interface{}{column: value}
  292. return tx.callbacks.Update().Execute(tx)
  293. }
  294. // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
  295. func (db *DB) Updates(values interface{}) (tx *DB) {
  296. tx = db.getInstance()
  297. tx.Statement.Dest = values
  298. return tx.callbacks.Update().Execute(tx)
  299. }
  300. func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
  301. tx = db.getInstance()
  302. tx.Statement.Dest = map[string]interface{}{column: value}
  303. tx.Statement.SkipHooks = true
  304. return tx.callbacks.Update().Execute(tx)
  305. }
  306. func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
  307. tx = db.getInstance()
  308. tx.Statement.Dest = values
  309. tx.Statement.SkipHooks = true
  310. return tx.callbacks.Update().Execute(tx)
  311. }
  312. // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
  313. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
  314. tx = db.getInstance()
  315. if len(conds) > 0 {
  316. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  317. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  318. }
  319. }
  320. tx.Statement.Dest = value
  321. return tx.callbacks.Delete().Execute(tx)
  322. }
  323. func (db *DB) Count(count *int64) (tx *DB) {
  324. tx = db.getInstance()
  325. if tx.Statement.Model == nil {
  326. tx.Statement.Model = tx.Statement.Dest
  327. defer func() {
  328. tx.Statement.Model = nil
  329. }()
  330. }
  331. if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
  332. defer func() {
  333. tx.Statement.Clauses["SELECT"] = selectClause
  334. }()
  335. } else {
  336. defer delete(tx.Statement.Clauses, "SELECT")
  337. }
  338. if len(tx.Statement.Selects) == 0 {
  339. tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
  340. } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
  341. expr := clause.Expr{SQL: "count(*)"}
  342. if len(tx.Statement.Selects) == 1 {
  343. dbName := tx.Statement.Selects[0]
  344. fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
  345. if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
  346. if tx.Statement.Parse(tx.Statement.Model) == nil {
  347. if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
  348. dbName = f.DBName
  349. }
  350. }
  351. if tx.Statement.Distinct {
  352. expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
  353. } else {
  354. expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
  355. }
  356. }
  357. }
  358. tx.Statement.AddClause(clause.Select{Expression: expr})
  359. }
  360. if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
  361. if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
  362. delete(tx.Statement.Clauses, "ORDER BY")
  363. defer func() {
  364. tx.Statement.Clauses["ORDER BY"] = orderByClause
  365. }()
  366. }
  367. }
  368. tx.Statement.Dest = count
  369. tx = tx.callbacks.Query().Execute(tx)
  370. if tx.RowsAffected != 1 {
  371. *count = tx.RowsAffected
  372. }
  373. return
  374. }
  375. func (db *DB) Row() *sql.Row {
  376. tx := db.getInstance().Set("rows", false)
  377. tx = tx.callbacks.Row().Execute(tx)
  378. row, ok := tx.Statement.Dest.(*sql.Row)
  379. if !ok && tx.DryRun {
  380. db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
  381. }
  382. return row
  383. }
  384. func (db *DB) Rows() (*sql.Rows, error) {
  385. tx := db.getInstance().Set("rows", true)
  386. tx = tx.callbacks.Row().Execute(tx)
  387. rows, ok := tx.Statement.Dest.(*sql.Rows)
  388. if !ok && tx.DryRun && tx.Error == nil {
  389. tx.Error = ErrDryRunModeUnsupported
  390. }
  391. return rows, tx.Error
  392. }
  393. // Scan scan value to a struct
  394. func (db *DB) Scan(dest interface{}) (tx *DB) {
  395. config := *db.Config
  396. currentLogger, newLogger := config.Logger, logger.Recorder.New()
  397. config.Logger = newLogger
  398. tx = db.getInstance()
  399. tx.Config = &config
  400. if rows, err := tx.Rows(); err != nil {
  401. tx.AddError(err)
  402. } else {
  403. defer rows.Close()
  404. if rows.Next() {
  405. tx.ScanRows(rows, dest)
  406. } else {
  407. tx.RowsAffected = 0
  408. }
  409. }
  410. currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
  411. return newLogger.SQL, tx.RowsAffected
  412. }, tx.Error)
  413. tx.Logger = currentLogger
  414. return
  415. }
  416. // Pluck used to query single column from a model as a map
  417. // var ages []int64
  418. // db.Model(&users).Pluck("age", &ages)
  419. func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
  420. tx = db.getInstance()
  421. if tx.Statement.Model != nil {
  422. if tx.Statement.Parse(tx.Statement.Model) == nil {
  423. if f := tx.Statement.Schema.LookUpField(column); f != nil {
  424. column = f.DBName
  425. }
  426. }
  427. } else if tx.Statement.Table == "" {
  428. tx.AddError(ErrModelValueRequired)
  429. }
  430. if len(tx.Statement.Selects) != 1 {
  431. fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
  432. tx.Statement.AddClauseIfNotExists(clause.Select{
  433. Distinct: tx.Statement.Distinct,
  434. Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
  435. })
  436. }
  437. tx.Statement.Dest = dest
  438. return tx.callbacks.Query().Execute(tx)
  439. }
  440. func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
  441. tx := db.getInstance()
  442. if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
  443. tx.AddError(err)
  444. }
  445. tx.Statement.Dest = dest
  446. tx.Statement.ReflectValue = reflect.ValueOf(dest)
  447. for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
  448. tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
  449. }
  450. Scan(rows, tx, true)
  451. return tx.Error
  452. }
  453. // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
  454. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
  455. panicked := true
  456. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  457. // nested transaction
  458. if !db.DisableNestedTransaction {
  459. err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
  460. defer func() {
  461. // Make sure to rollback when panic, Block error or Commit error
  462. if panicked || err != nil {
  463. db.RollbackTo(fmt.Sprintf("sp%p", fc))
  464. }
  465. }()
  466. }
  467. if err == nil {
  468. err = fc(db.Session(&Session{}))
  469. }
  470. } else {
  471. tx := db.Begin(opts...)
  472. defer func() {
  473. // Make sure to rollback when panic, Block error or Commit error
  474. if panicked || err != nil {
  475. tx.Rollback()
  476. }
  477. }()
  478. if err = tx.Error; err == nil {
  479. err = fc(tx)
  480. }
  481. if err == nil {
  482. err = tx.Commit().Error
  483. }
  484. }
  485. panicked = false
  486. return
  487. }
  488. // Begin begins a transaction
  489. func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
  490. var (
  491. // clone statement
  492. tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
  493. opt *sql.TxOptions
  494. err error
  495. )
  496. if len(opts) > 0 {
  497. opt = opts[0]
  498. }
  499. if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
  500. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  501. } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
  502. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  503. } else {
  504. err = ErrInvalidTransaction
  505. }
  506. if err != nil {
  507. tx.AddError(err)
  508. }
  509. return tx
  510. }
  511. // Commit commit a transaction
  512. func (db *DB) Commit() *DB {
  513. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
  514. db.AddError(committer.Commit())
  515. } else {
  516. db.AddError(ErrInvalidTransaction)
  517. }
  518. return db
  519. }
  520. // Rollback rollback a transaction
  521. func (db *DB) Rollback() *DB {
  522. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  523. if !reflect.ValueOf(committer).IsNil() {
  524. db.AddError(committer.Rollback())
  525. }
  526. } else {
  527. db.AddError(ErrInvalidTransaction)
  528. }
  529. return db
  530. }
  531. func (db *DB) SavePoint(name string) *DB {
  532. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  533. db.AddError(savePointer.SavePoint(db, name))
  534. } else {
  535. db.AddError(ErrUnsupportedDriver)
  536. }
  537. return db
  538. }
  539. func (db *DB) RollbackTo(name string) *DB {
  540. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  541. db.AddError(savePointer.RollbackTo(db, name))
  542. } else {
  543. db.AddError(ErrUnsupportedDriver)
  544. }
  545. return db
  546. }
  547. // Exec execute raw sql
  548. func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
  549. tx = db.getInstance()
  550. tx.Statement.SQL = strings.Builder{}
  551. if strings.Contains(sql, "@") {
  552. clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
  553. } else {
  554. clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
  555. }
  556. return tx.callbacks.Raw().Execute(tx)
  557. }