gorm.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "sort"
  7. "sync"
  8. "time"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/schema"
  12. )
  13. // for Config.cacheStore store PreparedStmtDB key
  14. const preparedStmtDBKey = "preparedStmt"
  15. // Config GORM config
  16. type Config struct {
  17. // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
  18. // You can disable it by setting `SkipDefaultTransaction` to true
  19. SkipDefaultTransaction bool
  20. // NamingStrategy tables, columns naming strategy
  21. NamingStrategy schema.Namer
  22. // FullSaveAssociations full save associations
  23. FullSaveAssociations bool
  24. // Logger
  25. Logger logger.Interface
  26. // NowFunc the function to be used when creating a new timestamp
  27. NowFunc func() time.Time
  28. // DryRun generate sql without execute
  29. DryRun bool
  30. // PrepareStmt executes the given query in cached statement
  31. PrepareStmt bool
  32. // DisableAutomaticPing
  33. DisableAutomaticPing bool
  34. // DisableForeignKeyConstraintWhenMigrating
  35. DisableForeignKeyConstraintWhenMigrating bool
  36. // DisableNestedTransaction disable nested transaction
  37. DisableNestedTransaction bool
  38. // AllowGlobalUpdate allow global update
  39. AllowGlobalUpdate bool
  40. // QueryFields executes the SQL query with all fields of the table
  41. QueryFields bool
  42. // CreateBatchSize default create batch size
  43. CreateBatchSize int
  44. // ClauseBuilders clause builder
  45. ClauseBuilders map[string]clause.ClauseBuilder
  46. // ConnPool db conn pool
  47. ConnPool ConnPool
  48. // Dialector database dialector
  49. Dialector
  50. // Plugins registered plugins
  51. Plugins map[string]Plugin
  52. callbacks *callbacks
  53. cacheStore *sync.Map
  54. }
  55. func (c *Config) Apply(config *Config) error {
  56. if config != c {
  57. *config = *c
  58. }
  59. return nil
  60. }
  61. func (c *Config) AfterInitialize(db *DB) error {
  62. if db != nil {
  63. for _, plugin := range c.Plugins {
  64. if err := plugin.Initialize(db); err != nil {
  65. return err
  66. }
  67. }
  68. }
  69. return nil
  70. }
  71. type Option interface {
  72. Apply(*Config) error
  73. AfterInitialize(*DB) error
  74. }
  75. // DB GORM DB definition
  76. type DB struct {
  77. *Config
  78. Error error
  79. RowsAffected int64
  80. Statement *Statement
  81. clone int
  82. }
  83. // Session session config when create session with Session() method
  84. type Session struct {
  85. DryRun bool
  86. PrepareStmt bool
  87. NewDB bool
  88. SkipHooks bool
  89. SkipDefaultTransaction bool
  90. DisableNestedTransaction bool
  91. AllowGlobalUpdate bool
  92. FullSaveAssociations bool
  93. QueryFields bool
  94. Context context.Context
  95. Logger logger.Interface
  96. NowFunc func() time.Time
  97. CreateBatchSize int
  98. }
  99. // Open initialize db session based on dialector
  100. func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
  101. config := &Config{}
  102. sort.Slice(opts, func(i, j int) bool {
  103. _, isConfig := opts[i].(*Config)
  104. _, isConfig2 := opts[j].(*Config)
  105. return isConfig && !isConfig2
  106. })
  107. for _, opt := range opts {
  108. if opt != nil {
  109. if err := opt.Apply(config); err != nil {
  110. return nil, err
  111. }
  112. defer func(opt Option) {
  113. if errr := opt.AfterInitialize(db); errr != nil {
  114. err = errr
  115. }
  116. }(opt)
  117. }
  118. }
  119. if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
  120. if err = d.Apply(config); err != nil {
  121. return
  122. }
  123. }
  124. if config.NamingStrategy == nil {
  125. config.NamingStrategy = schema.NamingStrategy{}
  126. }
  127. if config.Logger == nil {
  128. config.Logger = logger.Default
  129. }
  130. if config.NowFunc == nil {
  131. config.NowFunc = func() time.Time { return time.Now().Local() }
  132. }
  133. if dialector != nil {
  134. config.Dialector = dialector
  135. }
  136. if config.Plugins == nil {
  137. config.Plugins = map[string]Plugin{}
  138. }
  139. if config.cacheStore == nil {
  140. config.cacheStore = &sync.Map{}
  141. }
  142. db = &DB{Config: config, clone: 1}
  143. db.callbacks = initializeCallbacks(db)
  144. if config.ClauseBuilders == nil {
  145. config.ClauseBuilders = map[string]clause.ClauseBuilder{}
  146. }
  147. if config.Dialector != nil {
  148. err = config.Dialector.Initialize(db)
  149. }
  150. preparedStmt := &PreparedStmtDB{
  151. ConnPool: db.ConnPool,
  152. Stmts: map[string]Stmt{},
  153. Mux: &sync.RWMutex{},
  154. PreparedSQL: make([]string, 0, 100),
  155. }
  156. db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
  157. if config.PrepareStmt {
  158. db.ConnPool = preparedStmt
  159. }
  160. db.Statement = &Statement{
  161. DB: db,
  162. ConnPool: db.ConnPool,
  163. Context: context.Background(),
  164. Clauses: map[string]clause.Clause{},
  165. }
  166. if err == nil && !config.DisableAutomaticPing {
  167. if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
  168. err = pinger.Ping()
  169. }
  170. }
  171. if err != nil {
  172. config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
  173. }
  174. return
  175. }
  176. // Session create new db session
  177. func (db *DB) Session(config *Session) *DB {
  178. var (
  179. txConfig = *db.Config
  180. tx = &DB{
  181. Config: &txConfig,
  182. Statement: db.Statement,
  183. Error: db.Error,
  184. clone: 1,
  185. }
  186. )
  187. if config.CreateBatchSize > 0 {
  188. tx.Config.CreateBatchSize = config.CreateBatchSize
  189. }
  190. if config.SkipDefaultTransaction {
  191. tx.Config.SkipDefaultTransaction = true
  192. }
  193. if config.AllowGlobalUpdate {
  194. txConfig.AllowGlobalUpdate = true
  195. }
  196. if config.FullSaveAssociations {
  197. txConfig.FullSaveAssociations = true
  198. }
  199. if config.Context != nil || config.PrepareStmt || config.SkipHooks {
  200. tx.Statement = tx.Statement.clone()
  201. tx.Statement.DB = tx
  202. }
  203. if config.Context != nil {
  204. tx.Statement.Context = config.Context
  205. }
  206. if config.PrepareStmt {
  207. if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
  208. preparedStmt := v.(*PreparedStmtDB)
  209. tx.Statement.ConnPool = &PreparedStmtDB{
  210. ConnPool: db.Config.ConnPool,
  211. Mux: preparedStmt.Mux,
  212. Stmts: preparedStmt.Stmts,
  213. }
  214. txConfig.ConnPool = tx.Statement.ConnPool
  215. txConfig.PrepareStmt = true
  216. }
  217. }
  218. if config.SkipHooks {
  219. tx.Statement.SkipHooks = true
  220. }
  221. if config.DisableNestedTransaction {
  222. txConfig.DisableNestedTransaction = true
  223. }
  224. if !config.NewDB {
  225. tx.clone = 2
  226. }
  227. if config.DryRun {
  228. tx.Config.DryRun = true
  229. }
  230. if config.QueryFields {
  231. tx.Config.QueryFields = true
  232. }
  233. if config.Logger != nil {
  234. tx.Config.Logger = config.Logger
  235. }
  236. if config.NowFunc != nil {
  237. tx.Config.NowFunc = config.NowFunc
  238. }
  239. return tx
  240. }
  241. // WithContext change current instance db's context to ctx
  242. func (db *DB) WithContext(ctx context.Context) *DB {
  243. return db.Session(&Session{Context: ctx})
  244. }
  245. // Debug start debug mode
  246. func (db *DB) Debug() (tx *DB) {
  247. return db.Session(&Session{
  248. Logger: db.Logger.LogMode(logger.Info),
  249. })
  250. }
  251. // Set store value with key into current db instance's context
  252. func (db *DB) Set(key string, value interface{}) *DB {
  253. tx := db.getInstance()
  254. tx.Statement.Settings.Store(key, value)
  255. return tx
  256. }
  257. // Get get value with key from current db instance's context
  258. func (db *DB) Get(key string) (interface{}, bool) {
  259. return db.Statement.Settings.Load(key)
  260. }
  261. // InstanceSet store value with key into current db instance's context
  262. func (db *DB) InstanceSet(key string, value interface{}) *DB {
  263. tx := db.getInstance()
  264. tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
  265. return tx
  266. }
  267. // InstanceGet get value with key from current db instance's context
  268. func (db *DB) InstanceGet(key string) (interface{}, bool) {
  269. return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
  270. }
  271. // Callback returns callback manager
  272. func (db *DB) Callback() *callbacks {
  273. return db.callbacks
  274. }
  275. // AddError add error to db
  276. func (db *DB) AddError(err error) error {
  277. if db.Error == nil {
  278. db.Error = err
  279. } else if err != nil {
  280. db.Error = fmt.Errorf("%v; %w", db.Error, err)
  281. }
  282. return db.Error
  283. }
  284. // DB returns `*sql.DB`
  285. func (db *DB) DB() (*sql.DB, error) {
  286. connPool := db.ConnPool
  287. if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
  288. return dbConnector.GetDBConn()
  289. }
  290. if sqldb, ok := connPool.(*sql.DB); ok {
  291. return sqldb, nil
  292. }
  293. return nil, ErrInvalidDB
  294. }
  295. func (db *DB) getInstance() *DB {
  296. if db.clone > 0 {
  297. tx := &DB{Config: db.Config, Error: db.Error}
  298. if db.clone == 1 {
  299. // clone with new statement
  300. tx.Statement = &Statement{
  301. DB: tx,
  302. ConnPool: db.Statement.ConnPool,
  303. Context: db.Statement.Context,
  304. Clauses: map[string]clause.Clause{},
  305. Vars: make([]interface{}, 0, 8),
  306. }
  307. } else {
  308. // with clone statement
  309. tx.Statement = db.Statement.clone()
  310. tx.Statement.DB = tx
  311. }
  312. return tx
  313. }
  314. return db
  315. }
  316. func Expr(expr string, args ...interface{}) clause.Expr {
  317. return clause.Expr{SQL: expr, Vars: args}
  318. }
  319. func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
  320. var (
  321. tx = db.getInstance()
  322. stmt = tx.Statement
  323. modelSchema, joinSchema *schema.Schema
  324. )
  325. if err := stmt.Parse(model); err == nil {
  326. modelSchema = stmt.Schema
  327. } else {
  328. return err
  329. }
  330. if err := stmt.Parse(joinTable); err == nil {
  331. joinSchema = stmt.Schema
  332. } else {
  333. return err
  334. }
  335. if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
  336. for _, ref := range relation.References {
  337. if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
  338. f.DataType = ref.ForeignKey.DataType
  339. f.GORMDataType = ref.ForeignKey.GORMDataType
  340. if f.Size == 0 {
  341. f.Size = ref.ForeignKey.Size
  342. }
  343. ref.ForeignKey = f
  344. } else {
  345. return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
  346. }
  347. }
  348. for name, rel := range relation.JoinTable.Relationships.Relations {
  349. if _, ok := joinSchema.Relationships.Relations[name]; !ok {
  350. rel.Schema = joinSchema
  351. joinSchema.Relationships.Relations[name] = rel
  352. }
  353. }
  354. relation.JoinTable = joinSchema
  355. } else {
  356. return fmt.Errorf("failed to found relation: %s", field)
  357. }
  358. return nil
  359. }
  360. func (db *DB) Use(plugin Plugin) error {
  361. name := plugin.Name()
  362. if _, ok := db.Plugins[name]; ok {
  363. return ErrRegistered
  364. }
  365. if err := plugin.Initialize(db); err != nil {
  366. return err
  367. }
  368. db.Plugins[name] = plugin
  369. return nil
  370. }