migrator.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package mysql
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/migrator"
  8. "gorm.io/gorm/schema"
  9. )
  10. type Migrator struct {
  11. migrator.Migrator
  12. Dialector
  13. }
  14. type Column struct {
  15. name string
  16. nullable sql.NullString
  17. datatype string
  18. maxLen sql.NullInt64
  19. precision sql.NullInt64
  20. scale sql.NullInt64
  21. datetimePrecision sql.NullInt64
  22. }
  23. func (c Column) Name() string {
  24. return c.name
  25. }
  26. func (c Column) DatabaseTypeName() string {
  27. return c.datatype
  28. }
  29. func (c Column) Length() (int64, bool) {
  30. if c.maxLen.Valid {
  31. return c.maxLen.Int64, c.maxLen.Valid
  32. }
  33. return 0, false
  34. }
  35. func (c Column) Nullable() (bool, bool) {
  36. if c.nullable.Valid {
  37. return c.nullable.String == "YES", true
  38. }
  39. return false, false
  40. }
  41. // DecimalSize return precision int64, scale int64, ok bool
  42. func (c Column) DecimalSize() (int64, int64, bool) {
  43. if c.precision.Valid {
  44. if c.scale.Valid {
  45. return c.precision.Int64, c.scale.Int64, true
  46. }
  47. return c.precision.Int64, 0, true
  48. }
  49. if c.datetimePrecision.Valid {
  50. return c.datetimePrecision.Int64, 0, true
  51. }
  52. return 0, 0, false
  53. }
  54. func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
  55. expr := m.Migrator.FullDataTypeOf(field)
  56. if value, ok := field.TagSettings["COMMENT"]; ok {
  57. expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
  58. }
  59. return expr
  60. }
  61. func (m Migrator) AlterColumn(value interface{}, field string) error {
  62. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  63. if field := stmt.Schema.LookUpField(field); field != nil {
  64. return m.DB.Exec(
  65. "ALTER TABLE ? MODIFY COLUMN ? ?",
  66. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
  67. ).Error
  68. }
  69. return fmt.Errorf("failed to look up field with name: %s", field)
  70. })
  71. }
  72. func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
  73. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  74. if !m.Dialector.DontSupportRenameColumn {
  75. return m.Migrator.RenameColumn(value, oldName, newName)
  76. }
  77. var field *schema.Field
  78. if f := stmt.Schema.LookUpField(oldName); f != nil {
  79. oldName = f.DBName
  80. field = f
  81. }
  82. if f := stmt.Schema.LookUpField(newName); f != nil {
  83. newName = f.DBName
  84. field = f
  85. }
  86. if field != nil {
  87. return m.DB.Exec(
  88. "ALTER TABLE ? CHANGE ? ? ?",
  89. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName},
  90. clause.Column{Name: newName}, m.FullDataTypeOf(field),
  91. ).Error
  92. }
  93. return fmt.Errorf("failed to look up field with name: %s", newName)
  94. })
  95. }
  96. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  97. if !m.Dialector.DontSupportRenameIndex {
  98. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  99. return m.DB.Exec(
  100. "ALTER TABLE ? RENAME INDEX ? TO ?",
  101. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
  102. ).Error
  103. })
  104. }
  105. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  106. err := m.DropIndex(value, oldName)
  107. if err != nil {
  108. return err
  109. }
  110. if idx := stmt.Schema.LookIndex(newName); idx == nil {
  111. if idx = stmt.Schema.LookIndex(oldName); idx != nil {
  112. opts := m.BuildIndexOptions(idx.Fields, stmt)
  113. values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}
  114. createIndexSQL := "CREATE "
  115. if idx.Class != "" {
  116. createIndexSQL += idx.Class + " "
  117. }
  118. createIndexSQL += "INDEX ? ON ??"
  119. if idx.Type != "" {
  120. createIndexSQL += " USING " + idx.Type
  121. }
  122. return m.DB.Exec(createIndexSQL, values...).Error
  123. }
  124. }
  125. return m.CreateIndex(value, newName)
  126. })
  127. }
  128. func (m Migrator) DropTable(values ...interface{}) error {
  129. values = m.ReorderModels(values, false)
  130. tx := m.DB.Session(&gorm.Session{})
  131. tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
  132. for i := len(values) - 1; i >= 0; i-- {
  133. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  134. return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
  135. }); err != nil {
  136. return err
  137. }
  138. }
  139. tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
  140. return nil
  141. }
  142. func (m Migrator) DropConstraint(value interface{}, name string) error {
  143. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  144. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  145. if chk != nil {
  146. return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}).Error
  147. }
  148. if constraint != nil {
  149. name = constraint.Name
  150. }
  151. return m.DB.Exec(
  152. "ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name},
  153. ).Error
  154. })
  155. }
  156. // ColumnTypes column types return columnTypes,error
  157. func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
  158. columnTypes := make([]gorm.ColumnType, 0)
  159. err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  160. var (
  161. currentDatabase = m.DB.Migrator().CurrentDatabase()
  162. columnTypeSQL = "SELECT column_name, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_scale "
  163. )
  164. if !m.DisableDatetimePrecision {
  165. columnTypeSQL += ", datetime_precision "
  166. }
  167. columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ?"
  168. columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, stmt.Table).Rows()
  169. if rowErr != nil {
  170. return rowErr
  171. }
  172. defer columns.Close()
  173. for columns.Next() {
  174. var column Column
  175. var values = []interface{}{&column.name, &column.nullable, &column.datatype,
  176. &column.maxLen, &column.precision, &column.scale}
  177. if !m.DisableDatetimePrecision {
  178. values = append(values, &column.datetimePrecision)
  179. }
  180. if scanErr := columns.Scan(values...); scanErr != nil {
  181. return scanErr
  182. }
  183. columnTypes = append(columnTypes, column)
  184. }
  185. return nil
  186. })
  187. return columnTypes, err
  188. }