gorm.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "time"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/schema"
  12. )
  13. // Config GORM config
  14. type Config struct {
  15. // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
  16. // You can disable it by setting `SkipDefaultTransaction` to true
  17. SkipDefaultTransaction bool
  18. // NamingStrategy tables, columns naming strategy
  19. NamingStrategy schema.Namer
  20. // FullSaveAssociations full save associations
  21. FullSaveAssociations bool
  22. // Logger
  23. Logger logger.Interface
  24. // NowFunc the function to be used when creating a new timestamp
  25. NowFunc func() time.Time
  26. // DryRun generate sql without execute
  27. DryRun bool
  28. // PrepareStmt executes the given query in cached statement
  29. PrepareStmt bool
  30. // DisableAutomaticPing
  31. DisableAutomaticPing bool
  32. // DisableForeignKeyConstraintWhenMigrating
  33. DisableForeignKeyConstraintWhenMigrating bool
  34. // DisableNestedTransaction disable nested transaction
  35. DisableNestedTransaction bool
  36. // AllowGlobalUpdate allow global update
  37. AllowGlobalUpdate bool
  38. // QueryFields executes the SQL query with all fields of the table
  39. QueryFields bool
  40. // CreateBatchSize default create batch size
  41. CreateBatchSize int
  42. // ClauseBuilders clause builder
  43. ClauseBuilders map[string]clause.ClauseBuilder
  44. // ConnPool db conn pool
  45. ConnPool ConnPool
  46. // Dialector database dialector
  47. Dialector
  48. // Plugins registered plugins
  49. Plugins map[string]Plugin
  50. callbacks *callbacks
  51. cacheStore *sync.Map
  52. }
  53. // DB GORM DB definition
  54. type DB struct {
  55. *Config
  56. Error error
  57. RowsAffected int64
  58. Statement *Statement
  59. clone int
  60. }
  61. // Session session config when create session with Session() method
  62. type Session struct {
  63. DryRun bool
  64. PrepareStmt bool
  65. NewDB bool
  66. SkipHooks bool
  67. SkipDefaultTransaction bool
  68. DisableNestedTransaction bool
  69. AllowGlobalUpdate bool
  70. FullSaveAssociations bool
  71. QueryFields bool
  72. Context context.Context
  73. Logger logger.Interface
  74. NowFunc func() time.Time
  75. CreateBatchSize int
  76. }
  77. // Open initialize db session based on dialector
  78. func Open(dialector Dialector, config *Config) (db *DB, err error) {
  79. if config == nil {
  80. config = &Config{}
  81. }
  82. if config.NamingStrategy == nil {
  83. config.NamingStrategy = schema.NamingStrategy{}
  84. }
  85. if config.Logger == nil {
  86. config.Logger = logger.Default
  87. }
  88. if config.NowFunc == nil {
  89. config.NowFunc = func() time.Time { return time.Now().Local() }
  90. }
  91. if dialector != nil {
  92. config.Dialector = dialector
  93. }
  94. if config.Plugins == nil {
  95. config.Plugins = map[string]Plugin{}
  96. }
  97. if config.cacheStore == nil {
  98. config.cacheStore = &sync.Map{}
  99. }
  100. db = &DB{Config: config, clone: 1}
  101. db.callbacks = initializeCallbacks(db)
  102. if config.ClauseBuilders == nil {
  103. config.ClauseBuilders = map[string]clause.ClauseBuilder{}
  104. }
  105. if config.Dialector != nil {
  106. err = config.Dialector.Initialize(db)
  107. }
  108. preparedStmt := &PreparedStmtDB{
  109. ConnPool: db.ConnPool,
  110. Stmts: map[string]Stmt{},
  111. Mux: &sync.RWMutex{},
  112. PreparedSQL: make([]string, 0, 100),
  113. }
  114. db.cacheStore.Store("preparedStmt", preparedStmt)
  115. if config.PrepareStmt {
  116. db.ConnPool = preparedStmt
  117. }
  118. db.Statement = &Statement{
  119. DB: db,
  120. ConnPool: db.ConnPool,
  121. Context: context.Background(),
  122. Clauses: map[string]clause.Clause{},
  123. }
  124. if err == nil && !config.DisableAutomaticPing {
  125. if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
  126. err = pinger.Ping()
  127. }
  128. }
  129. if err != nil {
  130. config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
  131. }
  132. return
  133. }
  134. // Session create new db session
  135. func (db *DB) Session(config *Session) *DB {
  136. var (
  137. txConfig = *db.Config
  138. tx = &DB{
  139. Config: &txConfig,
  140. Statement: db.Statement,
  141. Error: db.Error,
  142. clone: 1,
  143. }
  144. )
  145. if config.CreateBatchSize > 0 {
  146. tx.Config.CreateBatchSize = config.CreateBatchSize
  147. }
  148. if config.SkipDefaultTransaction {
  149. tx.Config.SkipDefaultTransaction = true
  150. }
  151. if config.AllowGlobalUpdate {
  152. txConfig.AllowGlobalUpdate = true
  153. }
  154. if config.FullSaveAssociations {
  155. txConfig.FullSaveAssociations = true
  156. }
  157. if config.Context != nil || config.PrepareStmt || config.SkipHooks {
  158. tx.Statement = tx.Statement.clone()
  159. tx.Statement.DB = tx
  160. }
  161. if config.Context != nil {
  162. tx.Statement.Context = config.Context
  163. }
  164. if config.PrepareStmt {
  165. if v, ok := db.cacheStore.Load("preparedStmt"); ok {
  166. preparedStmt := v.(*PreparedStmtDB)
  167. tx.Statement.ConnPool = &PreparedStmtDB{
  168. ConnPool: db.Config.ConnPool,
  169. Mux: preparedStmt.Mux,
  170. Stmts: preparedStmt.Stmts,
  171. }
  172. txConfig.ConnPool = tx.Statement.ConnPool
  173. txConfig.PrepareStmt = true
  174. }
  175. }
  176. if config.SkipHooks {
  177. tx.Statement.SkipHooks = true
  178. }
  179. if config.DisableNestedTransaction {
  180. txConfig.DisableNestedTransaction = true
  181. }
  182. if !config.NewDB {
  183. tx.clone = 2
  184. }
  185. if config.DryRun {
  186. tx.Config.DryRun = true
  187. }
  188. if config.QueryFields {
  189. tx.Config.QueryFields = true
  190. }
  191. if config.Logger != nil {
  192. tx.Config.Logger = config.Logger
  193. }
  194. if config.NowFunc != nil {
  195. tx.Config.NowFunc = config.NowFunc
  196. }
  197. return tx
  198. }
  199. // WithContext change current instance db's context to ctx
  200. func (db *DB) WithContext(ctx context.Context) *DB {
  201. return db.Session(&Session{Context: ctx})
  202. }
  203. // Debug start debug mode
  204. func (db *DB) Debug() (tx *DB) {
  205. return db.Session(&Session{
  206. Logger: db.Logger.LogMode(logger.Info),
  207. })
  208. }
  209. // Set store value with key into current db instance's context
  210. func (db *DB) Set(key string, value interface{}) *DB {
  211. tx := db.getInstance()
  212. tx.Statement.Settings.Store(key, value)
  213. return tx
  214. }
  215. // Get get value with key from current db instance's context
  216. func (db *DB) Get(key string) (interface{}, bool) {
  217. return db.Statement.Settings.Load(key)
  218. }
  219. // InstanceSet store value with key into current db instance's context
  220. func (db *DB) InstanceSet(key string, value interface{}) *DB {
  221. tx := db.getInstance()
  222. tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
  223. return tx
  224. }
  225. // InstanceGet get value with key from current db instance's context
  226. func (db *DB) InstanceGet(key string) (interface{}, bool) {
  227. return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
  228. }
  229. // Callback returns callback manager
  230. func (db *DB) Callback() *callbacks {
  231. return db.callbacks
  232. }
  233. // AddError add error to db
  234. func (db *DB) AddError(err error) error {
  235. if db.Error == nil {
  236. db.Error = err
  237. } else if err != nil {
  238. db.Error = fmt.Errorf("%v; %w", db.Error, err)
  239. }
  240. return db.Error
  241. }
  242. // DB returns `*sql.DB`
  243. func (db *DB) DB() (*sql.DB, error) {
  244. connPool := db.ConnPool
  245. if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
  246. connPool = stmtDB.ConnPool
  247. }
  248. if sqldb, ok := connPool.(*sql.DB); ok {
  249. return sqldb, nil
  250. }
  251. return nil, errors.New("invalid db")
  252. }
  253. func (db *DB) getInstance() *DB {
  254. if db.clone > 0 {
  255. tx := &DB{Config: db.Config}
  256. if db.clone == 1 {
  257. // clone with new statement
  258. tx.Statement = &Statement{
  259. DB: tx,
  260. ConnPool: db.Statement.ConnPool,
  261. Context: db.Statement.Context,
  262. Clauses: map[string]clause.Clause{},
  263. Vars: make([]interface{}, 0, 8),
  264. }
  265. } else {
  266. // with clone statement
  267. tx.Statement = db.Statement.clone()
  268. tx.Statement.DB = tx
  269. }
  270. return tx
  271. }
  272. return db
  273. }
  274. func Expr(expr string, args ...interface{}) clause.Expr {
  275. return clause.Expr{SQL: expr, Vars: args}
  276. }
  277. func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
  278. var (
  279. tx = db.getInstance()
  280. stmt = tx.Statement
  281. modelSchema, joinSchema *schema.Schema
  282. )
  283. if err := stmt.Parse(model); err == nil {
  284. modelSchema = stmt.Schema
  285. } else {
  286. return err
  287. }
  288. if err := stmt.Parse(joinTable); err == nil {
  289. joinSchema = stmt.Schema
  290. } else {
  291. return err
  292. }
  293. if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
  294. for _, ref := range relation.References {
  295. if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
  296. f.DataType = ref.ForeignKey.DataType
  297. f.GORMDataType = ref.ForeignKey.GORMDataType
  298. if f.Size == 0 {
  299. f.Size = ref.ForeignKey.Size
  300. }
  301. ref.ForeignKey = f
  302. } else {
  303. return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
  304. }
  305. }
  306. for name, rel := range relation.JoinTable.Relationships.Relations {
  307. if _, ok := joinSchema.Relationships.Relations[name]; !ok {
  308. rel.Schema = joinSchema
  309. joinSchema.Relationships.Relations[name] = rel
  310. }
  311. }
  312. relation.JoinTable = joinSchema
  313. } else {
  314. return fmt.Errorf("failed to found relation: %v", field)
  315. }
  316. return nil
  317. }
  318. func (db *DB) Use(plugin Plugin) error {
  319. name := plugin.Name()
  320. if _, ok := db.Plugins[name]; ok {
  321. return ErrRegistered
  322. }
  323. if err := plugin.Initialize(db); err != nil {
  324. return err
  325. }
  326. db.Plugins[name] = plugin
  327. return nil
  328. }