migrator.go 22 KB


  1. package migrator
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "strings"
  9. "gorm.io/gorm"
  10. "gorm.io/gorm/clause"
  11. "gorm.io/gorm/schema"
  12. )
  13. var (
  14. regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
  15. regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
  16. )
  17. // Migrator m struct
  18. type Migrator struct {
  19. Config
  20. }
  21. // Config schema config
  22. type Config struct {
  23. CreateIndexAfterCreateTable bool
  24. DB *gorm.DB
  25. gorm.Dialector
  26. }
  27. type GormDataTypeInterface interface {
  28. GormDBDataType(*gorm.DB, *schema.Field) string
  29. }
  30. func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
  31. stmt := &gorm.Statement{DB: m.DB}
  32. if m.DB.Statement != nil {
  33. stmt.Table = m.DB.Statement.Table
  34. stmt.TableExpr = m.DB.Statement.TableExpr
  35. }
  36. if table, ok := value.(string); ok {
  37. stmt.Table = table
  38. } else if err := stmt.Parse(value); err != nil {
  39. return err
  40. }
  41. return fc(stmt)
  42. }
  43. func (m Migrator) DataTypeOf(field *schema.Field) string {
  44. fieldValue := reflect.New(field.IndirectFieldType)
  45. if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
  46. if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
  47. return dataType
  48. }
  49. }
  50. return m.Dialector.DataTypeOf(field)
  51. }
  52. func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
  53. expr.SQL = m.DataTypeOf(field)
  54. if field.NotNull {
  55. expr.SQL += " NOT NULL"
  56. }
  57. if field.Unique {
  58. expr.SQL += " UNIQUE"
  59. }
  60. if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
  61. if field.DefaultValueInterface != nil {
  62. defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
  63. m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
  64. expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
  65. } else if field.DefaultValue != "(-)" {
  66. expr.SQL += " DEFAULT " + field.DefaultValue
  67. }
  68. }
  69. return
  70. }
  71. // AutoMigrate
  72. func (m Migrator) AutoMigrate(values ...interface{}) error {
  73. for _, value := range m.ReorderModels(values, true) {
  74. tx := m.DB.Session(&gorm.Session{})
  75. if !tx.Migrator().HasTable(value) {
  76. if err := tx.Migrator().CreateTable(value); err != nil {
  77. return err
  78. }
  79. } else {
  80. if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
  81. columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
  82. for _, field := range stmt.Schema.FieldsByDBName {
  83. var foundColumn gorm.ColumnType
  84. for _, columnType := range columnTypes {
  85. if columnType.Name() == field.DBName {
  86. foundColumn = columnType
  87. break
  88. }
  89. }
  90. if foundColumn == nil {
  91. // not found, add column
  92. if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
  93. return err
  94. }
  95. } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
  96. // found, smart migrate
  97. return err
  98. }
  99. }
  100. for _, rel := range stmt.Schema.Relationships.Relations {
  101. if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
  102. if constraint := rel.ParseConstraint(); constraint != nil &&
  103. constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
  104. if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
  105. return err
  106. }
  107. }
  108. }
  109. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  110. if !tx.Migrator().HasConstraint(value, chk.Name) {
  111. if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
  112. return err
  113. }
  114. }
  115. }
  116. }
  117. for _, idx := range stmt.Schema.ParseIndexes() {
  118. if !tx.Migrator().HasIndex(value, idx.Name) {
  119. if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
  120. return err
  121. }
  122. }
  123. }
  124. return nil
  125. }); err != nil {
  126. return err
  127. }
  128. }
  129. }
  130. return nil
  131. }
  132. func (m Migrator) CreateTable(values ...interface{}) error {
  133. for _, value := range m.ReorderModels(values, false) {
  134. tx := m.DB.Session(&gorm.Session{})
  135. if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
  136. var (
  137. createTableSQL = "CREATE TABLE ? ("
  138. values = []interface{}{m.CurrentTable(stmt)}
  139. hasPrimaryKeyInDataType bool
  140. )
  141. for _, dbName := range stmt.Schema.DBNames {
  142. field := stmt.Schema.FieldsByDBName[dbName]
  143. if !field.IgnoreMigration {
  144. createTableSQL += "? ?"
  145. hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
  146. values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
  147. createTableSQL += ","
  148. }
  149. }
  150. if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
  151. createTableSQL += "PRIMARY KEY ?,"
  152. primaryKeys := []interface{}{}
  153. for _, field := range stmt.Schema.PrimaryFields {
  154. primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
  155. }
  156. values = append(values, primaryKeys)
  157. }
  158. for _, idx := range stmt.Schema.ParseIndexes() {
  159. if m.CreateIndexAfterCreateTable {
  160. defer func(value interface{}, name string) {
  161. if errr == nil {
  162. errr = tx.Migrator().CreateIndex(value, name)
  163. }
  164. }(value, idx.Name)
  165. } else {
  166. if idx.Class != "" {
  167. createTableSQL += idx.Class + " "
  168. }
  169. createTableSQL += "INDEX ? ?"
  170. if idx.Comment != "" {
  171. createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
  172. }
  173. if idx.Option != "" {
  174. createTableSQL += " " + idx.Option
  175. }
  176. createTableSQL += ","
  177. values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
  178. }
  179. }
  180. for _, rel := range stmt.Schema.Relationships.Relations {
  181. if !m.DB.DisableForeignKeyConstraintWhenMigrating {
  182. if constraint := rel.ParseConstraint(); constraint != nil {
  183. if constraint.Schema == stmt.Schema {
  184. sql, vars := buildConstraint(constraint)
  185. createTableSQL += sql + ","
  186. values = append(values, vars...)
  187. }
  188. }
  189. }
  190. }
  191. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  192. createTableSQL += "CONSTRAINT ? CHECK (?),"
  193. values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
  194. }
  195. createTableSQL = strings.TrimSuffix(createTableSQL, ",")
  196. createTableSQL += ")"
  197. if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
  198. createTableSQL += fmt.Sprint(tableOption)
  199. }
  200. errr = tx.Exec(createTableSQL, values...).Error
  201. return errr
  202. }); err != nil {
  203. return err
  204. }
  205. }
  206. return nil
  207. }
  208. func (m Migrator) DropTable(values ...interface{}) error {
  209. values = m.ReorderModels(values, false)
  210. for i := len(values) - 1; i >= 0; i-- {
  211. tx := m.DB.Session(&gorm.Session{})
  212. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  213. return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
  214. }); err != nil {
  215. return err
  216. }
  217. }
  218. return nil
  219. }
  220. func (m Migrator) HasTable(value interface{}) bool {
  221. var count int64
  222. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  223. currentDatabase := m.DB.Migrator().CurrentDatabase()
  224. return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
  225. })
  226. return count > 0
  227. }
  228. func (m Migrator) RenameTable(oldName, newName interface{}) error {
  229. var oldTable, newTable interface{}
  230. if v, ok := oldName.(string); ok {
  231. oldTable = clause.Table{Name: v}
  232. } else {
  233. stmt := &gorm.Statement{DB: m.DB}
  234. if err := stmt.Parse(oldName); err == nil {
  235. oldTable = m.CurrentTable(stmt)
  236. } else {
  237. return err
  238. }
  239. }
  240. if v, ok := newName.(string); ok {
  241. newTable = clause.Table{Name: v}
  242. } else {
  243. stmt := &gorm.Statement{DB: m.DB}
  244. if err := stmt.Parse(newName); err == nil {
  245. newTable = m.CurrentTable(stmt)
  246. } else {
  247. return err
  248. }
  249. }
  250. return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
  251. }
  252. func (m Migrator) AddColumn(value interface{}, field string) error {
  253. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  254. // avoid using the same name field
  255. f := stmt.Schema.LookUpField(field)
  256. if f == nil {
  257. return fmt.Errorf("failed to look up field with name: %s", field)
  258. }
  259. if !f.IgnoreMigration {
  260. return m.DB.Exec(
  261. "ALTER TABLE ? ADD ? ?",
  262. m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
  263. ).Error
  264. }
  265. return nil
  266. })
  267. }
  268. func (m Migrator) DropColumn(value interface{}, name string) error {
  269. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  270. if field := stmt.Schema.LookUpField(name); field != nil {
  271. name = field.DBName
  272. }
  273. return m.DB.Exec(
  274. "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
  275. ).Error
  276. })
  277. }
  278. func (m Migrator) AlterColumn(value interface{}, field string) error {
  279. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  280. if field := stmt.Schema.LookUpField(field); field != nil {
  281. fileType := clause.Expr{SQL: m.DataTypeOf(field)}
  282. return m.DB.Exec(
  283. "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
  284. m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
  285. ).Error
  286. }
  287. return fmt.Errorf("failed to look up field with name: %s", field)
  288. })
  289. }
  290. func (m Migrator) HasColumn(value interface{}, field string) bool {
  291. var count int64
  292. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  293. currentDatabase := m.DB.Migrator().CurrentDatabase()
  294. name := field
  295. if field := stmt.Schema.LookUpField(field); field != nil {
  296. name = field.DBName
  297. }
  298. return m.DB.Raw(
  299. "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
  300. currentDatabase, stmt.Table, name,
  301. ).Row().Scan(&count)
  302. })
  303. return count > 0
  304. }
  305. func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
  306. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  307. if field := stmt.Schema.LookUpField(oldName); field != nil {
  308. oldName = field.DBName
  309. }
  310. if field := stmt.Schema.LookUpField(newName); field != nil {
  311. newName = field.DBName
  312. }
  313. return m.DB.Exec(
  314. "ALTER TABLE ? RENAME COLUMN ? TO ?",
  315. m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
  316. ).Error
  317. })
  318. }
  319. func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
  320. // found, smart migrate
  321. fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
  322. realDataType := strings.ToLower(columnType.DatabaseTypeName())
  323. alterColumn := false
  324. // check size
  325. if length, _ := columnType.Length(); length != int64(field.Size) {
  326. if length > 0 && field.Size > 0 {
  327. alterColumn = true
  328. } else {
  329. // has size in data type and not equal
  330. // Since the following code is frequently called in the for loop, reg optimization is needed here
  331. matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
  332. matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
  333. if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
  334. (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
  335. alterColumn = true
  336. }
  337. }
  338. }
  339. // check precision
  340. if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
  341. if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
  342. alterColumn = true
  343. }
  344. }
  345. // check nullable
  346. if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
  347. // not primary key & database is nullable
  348. if !field.PrimaryKey && nullable {
  349. alterColumn = true
  350. }
  351. }
  352. if alterColumn && !field.IgnoreMigration {
  353. return m.DB.Migrator().AlterColumn(value, field.Name)
  354. }
  355. return nil
  356. }
  357. // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
  358. func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
  359. columnTypes := make([]gorm.ColumnType, 0)
  360. execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  361. rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
  362. if err != nil {
  363. return err
  364. }
  365. defer rows.Close()
  366. var rawColumnTypes []*sql.ColumnType
  367. rawColumnTypes, err = rows.ColumnTypes()
  368. if err != nil {
  369. return err
  370. }
  371. for _, c := range rawColumnTypes {
  372. columnTypes = append(columnTypes, c)
  373. }
  374. return nil
  375. })
  376. return columnTypes, execErr
  377. }
  378. func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
  379. return gorm.ErrNotImplemented
  380. }
  381. func (m Migrator) DropView(name string) error {
  382. return gorm.ErrNotImplemented
  383. }
  384. func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
  385. sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
  386. if constraint.OnDelete != "" {
  387. sql += " ON DELETE " + constraint.OnDelete
  388. }
  389. if constraint.OnUpdate != "" {
  390. sql += " ON UPDATE " + constraint.OnUpdate
  391. }
  392. var foreignKeys, references []interface{}
  393. for _, field := range constraint.ForeignKeys {
  394. foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
  395. }
  396. for _, field := range constraint.References {
  397. references = append(references, clause.Column{Name: field.DBName})
  398. }
  399. results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
  400. return
  401. }
  402. func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
  403. if stmt.Schema == nil {
  404. return nil, nil, stmt.Table
  405. }
  406. checkConstraints := stmt.Schema.ParseCheckConstraints()
  407. if chk, ok := checkConstraints[name]; ok {
  408. return nil, &chk, stmt.Table
  409. }
  410. getTable := func(rel *schema.Relationship) string {
  411. switch rel.Type {
  412. case schema.HasOne, schema.HasMany:
  413. return rel.FieldSchema.Table
  414. case schema.Many2Many:
  415. return rel.JoinTable.Table
  416. }
  417. return stmt.Table
  418. }
  419. for _, rel := range stmt.Schema.Relationships.Relations {
  420. if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
  421. return constraint, nil, getTable(rel)
  422. }
  423. }
  424. if field := stmt.Schema.LookUpField(name); field != nil {
  425. for k := range checkConstraints {
  426. if checkConstraints[k].Field == field {
  427. v := checkConstraints[k]
  428. return nil, &v, stmt.Table
  429. }
  430. }
  431. for _, rel := range stmt.Schema.Relationships.Relations {
  432. if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
  433. return constraint, nil, getTable(rel)
  434. }
  435. }
  436. }
  437. return nil, nil, stmt.Schema.Table
  438. }
  439. func (m Migrator) CreateConstraint(value interface{}, name string) error {
  440. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  441. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  442. if chk != nil {
  443. return m.DB.Exec(
  444. "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
  445. m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
  446. ).Error
  447. }
  448. if constraint != nil {
  449. var vars = []interface{}{clause.Table{Name: table}}
  450. if stmt.TableExpr != nil {
  451. vars[0] = stmt.TableExpr
  452. }
  453. sql, values := buildConstraint(constraint)
  454. return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
  455. }
  456. return nil
  457. })
  458. }
  459. func (m Migrator) DropConstraint(value interface{}, name string) error {
  460. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  461. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  462. if constraint != nil {
  463. name = constraint.Name
  464. } else if chk != nil {
  465. name = chk.Name
  466. }
  467. return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
  468. })
  469. }
  470. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  471. var count int64
  472. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  473. currentDatabase := m.DB.Migrator().CurrentDatabase()
  474. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  475. if constraint != nil {
  476. name = constraint.Name
  477. } else if chk != nil {
  478. name = chk.Name
  479. }
  480. return m.DB.Raw(
  481. "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
  482. currentDatabase, table, name,
  483. ).Row().Scan(&count)
  484. })
  485. return count > 0
  486. }
  487. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  488. for _, opt := range opts {
  489. str := stmt.Quote(opt.DBName)
  490. if opt.Expression != "" {
  491. str = opt.Expression
  492. } else if opt.Length > 0 {
  493. str += fmt.Sprintf("(%d)", opt.Length)
  494. }
  495. if opt.Collate != "" {
  496. str += " COLLATE " + opt.Collate
  497. }
  498. if opt.Sort != "" {
  499. str += " " + opt.Sort
  500. }
  501. results = append(results, clause.Expr{SQL: str})
  502. }
  503. return
  504. }
  505. type BuildIndexOptionsInterface interface {
  506. BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
  507. }
  508. func (m Migrator) CreateIndex(value interface{}, name string) error {
  509. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  510. if idx := stmt.Schema.LookIndex(name); idx != nil {
  511. opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
  512. values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
  513. createIndexSQL := "CREATE "
  514. if idx.Class != "" {
  515. createIndexSQL += idx.Class + " "
  516. }
  517. createIndexSQL += "INDEX ? ON ??"
  518. if idx.Type != "" {
  519. createIndexSQL += " USING " + idx.Type
  520. }
  521. if idx.Comment != "" {
  522. createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
  523. }
  524. if idx.Option != "" {
  525. createIndexSQL += " " + idx.Option
  526. }
  527. return m.DB.Exec(createIndexSQL, values...).Error
  528. }
  529. return fmt.Errorf("failed to create index with name %s", name)
  530. })
  531. }
  532. func (m Migrator) DropIndex(value interface{}, name string) error {
  533. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  534. if idx := stmt.Schema.LookIndex(name); idx != nil {
  535. name = idx.Name
  536. }
  537. return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
  538. })
  539. }
  540. func (m Migrator) HasIndex(value interface{}, name string) bool {
  541. var count int64
  542. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  543. currentDatabase := m.DB.Migrator().CurrentDatabase()
  544. if idx := stmt.Schema.LookIndex(name); idx != nil {
  545. name = idx.Name
  546. }
  547. return m.DB.Raw(
  548. "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
  549. currentDatabase, stmt.Table, name,
  550. ).Row().Scan(&count)
  551. })
  552. return count > 0
  553. }
  554. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  555. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  556. return m.DB.Exec(
  557. "ALTER TABLE ? RENAME INDEX ? TO ?",
  558. m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
  559. ).Error
  560. })
  561. }
  562. func (m Migrator) CurrentDatabase() (name string) {
  563. m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
  564. return
  565. }
  566. // ReorderModels reorder models according to constraint dependencies
  567. func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
  568. type Dependency struct {
  569. *gorm.Statement
  570. Depends []*schema.Schema
  571. }
  572. var (
  573. modelNames, orderedModelNames []string
  574. orderedModelNamesMap = map[string]bool{}
  575. parsedSchemas = map[*schema.Schema]bool{}
  576. valuesMap = map[string]Dependency{}
  577. insertIntoOrderedList func(name string)
  578. parseDependence func(value interface{}, addToList bool)
  579. )
  580. parseDependence = func(value interface{}, addToList bool) {
  581. dep := Dependency{
  582. Statement: &gorm.Statement{DB: m.DB, Dest: value},
  583. }
  584. beDependedOn := map[*schema.Schema]bool{}
  585. if err := dep.Parse(value); err != nil {
  586. m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
  587. }
  588. if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
  589. return
  590. }
  591. parsedSchemas[dep.Statement.Schema] = true
  592. for _, rel := range dep.Schema.Relationships.Relations {
  593. if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
  594. dep.Depends = append(dep.Depends, c.ReferenceSchema)
  595. }
  596. if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
  597. beDependedOn[rel.FieldSchema] = true
  598. }
  599. if rel.JoinTable != nil {
  600. // append join value
  601. defer func(rel *schema.Relationship, joinValue interface{}) {
  602. if !beDependedOn[rel.FieldSchema] {
  603. dep.Depends = append(dep.Depends, rel.FieldSchema)
  604. } else {
  605. fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
  606. parseDependence(fieldValue, autoAdd)
  607. }
  608. parseDependence(joinValue, autoAdd)
  609. }(rel, reflect.New(rel.JoinTable.ModelType).Interface())
  610. }
  611. }
  612. valuesMap[dep.Schema.Table] = dep
  613. if addToList {
  614. modelNames = append(modelNames, dep.Schema.Table)
  615. }
  616. }
  617. insertIntoOrderedList = func(name string) {
  618. if _, ok := orderedModelNamesMap[name]; ok {
  619. return // avoid loop
  620. }
  621. orderedModelNamesMap[name] = true
  622. if autoAdd {
  623. dep := valuesMap[name]
  624. for _, d := range dep.Depends {
  625. if _, ok := valuesMap[d.Table]; ok {
  626. insertIntoOrderedList(d.Table)
  627. } else {
  628. parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
  629. insertIntoOrderedList(d.Table)
  630. }
  631. }
  632. }
  633. orderedModelNames = append(orderedModelNames, name)
  634. }
  635. for _, value := range values {
  636. if v, ok := value.(string); ok {
  637. results = append(results, v)
  638. } else {
  639. parseDependence(value, true)
  640. }
  641. }
  642. for _, name := range modelNames {
  643. insertIntoOrderedList(name)
  644. }
  645. for _, name := range orderedModelNames {
  646. results = append(results, valuesMap[name].Statement.Dest)
  647. }
  648. return
  649. }
  650. func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
  651. if stmt.TableExpr != nil {
  652. return *stmt.TableExpr
  653. }
  654. return clause.Table{Name: stmt.Table}
  655. }