123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- package gorm
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "sync"
- "time"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/logger"
- "gorm.io/gorm/schema"
- )
- // Config GORM config
- type Config struct {
- // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
- // You can disable it by setting `SkipDefaultTransaction` to true
- SkipDefaultTransaction bool
- // NamingStrategy tables, columns naming strategy
- NamingStrategy schema.Namer
- // FullSaveAssociations full save associations
- FullSaveAssociations bool
- // Logger
- Logger logger.Interface
- // NowFunc the function to be used when creating a new timestamp
- NowFunc func() time.Time
- // DryRun generate sql without execute
- DryRun bool
- // PrepareStmt executes the given query in cached statement
- PrepareStmt bool
- // DisableAutomaticPing
- DisableAutomaticPing bool
- // DisableForeignKeyConstraintWhenMigrating
- DisableForeignKeyConstraintWhenMigrating bool
- // DisableNestedTransaction disable nested transaction
- DisableNestedTransaction bool
- // AllowGlobalUpdate allow global update
- AllowGlobalUpdate bool
- // QueryFields executes the SQL query with all fields of the table
- QueryFields bool
- // CreateBatchSize default create batch size
- CreateBatchSize int
- // ClauseBuilders clause builder
- ClauseBuilders map[string]clause.ClauseBuilder
- // ConnPool db conn pool
- ConnPool ConnPool
- // Dialector database dialector
- Dialector
- // Plugins registered plugins
- Plugins map[string]Plugin
- callbacks *callbacks
- cacheStore *sync.Map
- }
- // DB GORM DB definition
- type DB struct {
- *Config
- Error error
- RowsAffected int64
- Statement *Statement
- clone int
- }
- // Session session config when create session with Session() method
- type Session struct {
- DryRun bool
- PrepareStmt bool
- NewDB bool
- SkipHooks bool
- SkipDefaultTransaction bool
- DisableNestedTransaction bool
- AllowGlobalUpdate bool
- FullSaveAssociations bool
- QueryFields bool
- Context context.Context
- Logger logger.Interface
- NowFunc func() time.Time
- CreateBatchSize int
- }
- // Open initialize db session based on dialector
- func Open(dialector Dialector, config *Config) (db *DB, err error) {
- if config == nil {
- config = &Config{}
- }
- if config.NamingStrategy == nil {
- config.NamingStrategy = schema.NamingStrategy{}
- }
- if config.Logger == nil {
- config.Logger = logger.Default
- }
- if config.NowFunc == nil {
- config.NowFunc = func() time.Time { return time.Now().Local() }
- }
- if dialector != nil {
- config.Dialector = dialector
- }
- if config.Plugins == nil {
- config.Plugins = map[string]Plugin{}
- }
- if config.cacheStore == nil {
- config.cacheStore = &sync.Map{}
- }
- db = &DB{Config: config, clone: 1}
- db.callbacks = initializeCallbacks(db)
- if config.ClauseBuilders == nil {
- config.ClauseBuilders = map[string]clause.ClauseBuilder{}
- }
- if config.Dialector != nil {
- err = config.Dialector.Initialize(db)
- }
- preparedStmt := &PreparedStmtDB{
- ConnPool: db.ConnPool,
- Stmts: map[string]Stmt{},
- Mux: &sync.RWMutex{},
- PreparedSQL: make([]string, 0, 100),
- }
- db.cacheStore.Store("preparedStmt", preparedStmt)
- if config.PrepareStmt {
- db.ConnPool = preparedStmt
- }
- db.Statement = &Statement{
- DB: db,
- ConnPool: db.ConnPool,
- Context: context.Background(),
- Clauses: map[string]clause.Clause{},
- }
- if err == nil && !config.DisableAutomaticPing {
- if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
- err = pinger.Ping()
- }
- }
- if err != nil {
- config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
- }
- return
- }
- // Session create new db session
- func (db *DB) Session(config *Session) *DB {
- var (
- txConfig = *db.Config
- tx = &DB{
- Config: &txConfig,
- Statement: db.Statement,
- Error: db.Error,
- clone: 1,
- }
- )
- if config.CreateBatchSize > 0 {
- tx.Config.CreateBatchSize = config.CreateBatchSize
- }
- if config.SkipDefaultTransaction {
- tx.Config.SkipDefaultTransaction = true
- }
- if config.AllowGlobalUpdate {
- txConfig.AllowGlobalUpdate = true
- }
- if config.FullSaveAssociations {
- txConfig.FullSaveAssociations = true
- }
- if config.Context != nil || config.PrepareStmt || config.SkipHooks {
- tx.Statement = tx.Statement.clone()
- tx.Statement.DB = tx
- }
- if config.Context != nil {
- tx.Statement.Context = config.Context
- }
- if config.PrepareStmt {
- if v, ok := db.cacheStore.Load("preparedStmt"); ok {
- preparedStmt := v.(*PreparedStmtDB)
- tx.Statement.ConnPool = &PreparedStmtDB{
- ConnPool: db.Config.ConnPool,
- Mux: preparedStmt.Mux,
- Stmts: preparedStmt.Stmts,
- }
- txConfig.ConnPool = tx.Statement.ConnPool
- txConfig.PrepareStmt = true
- }
- }
- if config.SkipHooks {
- tx.Statement.SkipHooks = true
- }
- if config.DisableNestedTransaction {
- txConfig.DisableNestedTransaction = true
- }
- if !config.NewDB {
- tx.clone = 2
- }
- if config.DryRun {
- tx.Config.DryRun = true
- }
- if config.QueryFields {
- tx.Config.QueryFields = true
- }
- if config.Logger != nil {
- tx.Config.Logger = config.Logger
- }
- if config.NowFunc != nil {
- tx.Config.NowFunc = config.NowFunc
- }
- return tx
- }
- // WithContext change current instance db's context to ctx
- func (db *DB) WithContext(ctx context.Context) *DB {
- return db.Session(&Session{Context: ctx})
- }
- // Debug start debug mode
- func (db *DB) Debug() (tx *DB) {
- return db.Session(&Session{
- Logger: db.Logger.LogMode(logger.Info),
- })
- }
- // Set store value with key into current db instance's context
- func (db *DB) Set(key string, value interface{}) *DB {
- tx := db.getInstance()
- tx.Statement.Settings.Store(key, value)
- return tx
- }
- // Get get value with key from current db instance's context
- func (db *DB) Get(key string) (interface{}, bool) {
- return db.Statement.Settings.Load(key)
- }
- // InstanceSet store value with key into current db instance's context
- func (db *DB) InstanceSet(key string, value interface{}) *DB {
- tx := db.getInstance()
- tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
- return tx
- }
- // InstanceGet get value with key from current db instance's context
- func (db *DB) InstanceGet(key string) (interface{}, bool) {
- return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
- }
- // Callback returns callback manager
- func (db *DB) Callback() *callbacks {
- return db.callbacks
- }
- // AddError add error to db
- func (db *DB) AddError(err error) error {
- if db.Error == nil {
- db.Error = err
- } else if err != nil {
- db.Error = fmt.Errorf("%v; %w", db.Error, err)
- }
- return db.Error
- }
- // DB returns `*sql.DB`
- func (db *DB) DB() (*sql.DB, error) {
- connPool := db.ConnPool
- if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
- connPool = stmtDB.ConnPool
- }
- if sqldb, ok := connPool.(*sql.DB); ok {
- return sqldb, nil
- }
- return nil, errors.New("invalid db")
- }
- func (db *DB) getInstance() *DB {
- if db.clone > 0 {
- tx := &DB{Config: db.Config}
- if db.clone == 1 {
- // clone with new statement
- tx.Statement = &Statement{
- DB: tx,
- ConnPool: db.Statement.ConnPool,
- Context: db.Statement.Context,
- Clauses: map[string]clause.Clause{},
- Vars: make([]interface{}, 0, 8),
- }
- } else {
- // with clone statement
- tx.Statement = db.Statement.clone()
- tx.Statement.DB = tx
- }
- return tx
- }
- return db
- }
- func Expr(expr string, args ...interface{}) clause.Expr {
- return clause.Expr{SQL: expr, Vars: args}
- }
- func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
- var (
- tx = db.getInstance()
- stmt = tx.Statement
- modelSchema, joinSchema *schema.Schema
- )
- if err := stmt.Parse(model); err == nil {
- modelSchema = stmt.Schema
- } else {
- return err
- }
- if err := stmt.Parse(joinTable); err == nil {
- joinSchema = stmt.Schema
- } else {
- return err
- }
- if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
- for _, ref := range relation.References {
- if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
- f.DataType = ref.ForeignKey.DataType
- f.GORMDataType = ref.ForeignKey.GORMDataType
- if f.Size == 0 {
- f.Size = ref.ForeignKey.Size
- }
- ref.ForeignKey = f
- } else {
- return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
- }
- }
- for name, rel := range relation.JoinTable.Relationships.Relations {
- if _, ok := joinSchema.Relationships.Relations[name]; !ok {
- rel.Schema = joinSchema
- joinSchema.Relationships.Relations[name] = rel
- }
- }
- relation.JoinTable = joinSchema
- } else {
- return fmt.Errorf("failed to found relation: %v", field)
- }
- return nil
- }
- func (db *DB) Use(plugin Plugin) error {
- name := plugin.Name()
- if _, ok := db.Plugins[name]; ok {
- return ErrRegistered
- }
- if err := plugin.Initialize(db); err != nil {
- return err
- }
- db.Plugins[name] = plugin
- return nil
- }
|