migrator.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package sqlite
  2. import (
  3. "fmt"
  4. "regexp"
  5. "strings"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/migrator"
  9. "gorm.io/gorm/schema"
  10. )
  11. type Migrator struct {
  12. migrator.Migrator
  13. }
  14. func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
  15. var enabled int
  16. m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
  17. if enabled == 1 {
  18. m.DB.Exec("PRAGMA foreign_keys = OFF")
  19. defer m.DB.Exec("PRAGMA foreign_keys = ON")
  20. }
  21. return fc()
  22. }
  23. func (m Migrator) HasTable(value interface{}) bool {
  24. var count int
  25. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  26. return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
  27. })
  28. return count > 0
  29. }
  30. func (m Migrator) DropTable(values ...interface{}) error {
  31. return m.RunWithoutForeignKey(func() error {
  32. values = m.ReorderModels(values, false)
  33. tx := m.DB.Session(&gorm.Session{})
  34. for i := len(values) - 1; i >= 0; i-- {
  35. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  36. return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
  37. }); err != nil {
  38. return err
  39. }
  40. }
  41. return nil
  42. })
  43. return nil
  44. }
  45. func (m Migrator) HasColumn(value interface{}, name string) bool {
  46. var count int
  47. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  48. if field := stmt.Schema.LookUpField(name); field != nil {
  49. name = field.DBName
  50. }
  51. if name != "" {
  52. m.DB.Raw(
  53. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  54. "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
  55. ).Row().Scan(&count)
  56. }
  57. return nil
  58. })
  59. return count > 0
  60. }
  61. func (m Migrator) AlterColumn(value interface{}, name string) error {
  62. return m.RunWithoutForeignKey(func() error {
  63. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  64. if field := stmt.Schema.LookUpField(name); field != nil {
  65. var (
  66. createSQL string
  67. newTableName = stmt.Table + "__temp"
  68. )
  69. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
  70. if reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?,"); err == nil {
  71. tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
  72. if err != nil {
  73. return err
  74. }
  75. createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
  76. createSQL = reg.ReplaceAllString(createSQL, fmt.Sprintf("`%v` ?,", field.DBName))
  77. var columns []string
  78. columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
  79. for _, columnType := range columnTypes {
  80. columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
  81. }
  82. return m.DB.Transaction(func(tx *gorm.DB) error {
  83. queries := []string{
  84. createSQL,
  85. fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table),
  86. fmt.Sprintf("DROP TABLE `%v`", stmt.Table),
  87. fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, stmt.Table),
  88. }
  89. for _, query := range queries {
  90. if err := tx.Exec(query, m.FullDataTypeOf(field)).Error; err != nil {
  91. return err
  92. }
  93. }
  94. return nil
  95. })
  96. } else {
  97. return err
  98. }
  99. } else {
  100. return fmt.Errorf("failed to alter field with name %v", name)
  101. }
  102. })
  103. })
  104. }
  105. func (m Migrator) DropColumn(value interface{}, name string) error {
  106. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  107. if field := stmt.Schema.LookUpField(name); field != nil {
  108. name = field.DBName
  109. }
  110. var (
  111. createSQL string
  112. newTableName = stmt.Table + "__temp"
  113. )
  114. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
  115. if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
  116. tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
  117. if err != nil {
  118. return err
  119. }
  120. createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
  121. createSQL = reg.ReplaceAllString(createSQL, "")
  122. var columns []string
  123. columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
  124. for _, columnType := range columnTypes {
  125. if columnType.Name() != name {
  126. columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
  127. }
  128. }
  129. return m.DB.Transaction(func(tx *gorm.DB) error {
  130. queries := []string{
  131. createSQL,
  132. fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table),
  133. fmt.Sprintf("DROP TABLE `%v`", stmt.Table),
  134. fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, stmt.Table),
  135. }
  136. for _, query := range queries {
  137. if err := tx.Exec(query).Error; err != nil {
  138. return err
  139. }
  140. }
  141. return nil
  142. })
  143. } else {
  144. return err
  145. }
  146. })
  147. }
  148. func (m Migrator) CreateConstraint(interface{}, string) error {
  149. return ErrConstraintsNotImplemented
  150. }
  151. func (m Migrator) DropConstraint(interface{}, string) error {
  152. return ErrConstraintsNotImplemented
  153. }
  154. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  155. var count int64
  156. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  157. m.DB.Raw(
  158. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  159. "table", stmt.Table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%",
  160. ).Row().Scan(&count)
  161. return nil
  162. })
  163. return count > 0
  164. }
  165. func (m Migrator) CurrentDatabase() (name string) {
  166. var null interface{}
  167. m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
  168. return
  169. }
  170. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  171. for _, opt := range opts {
  172. str := stmt.Quote(opt.DBName)
  173. if opt.Expression != "" {
  174. str = opt.Expression
  175. }
  176. if opt.Collate != "" {
  177. str += " COLLATE " + opt.Collate
  178. }
  179. if opt.Sort != "" {
  180. str += " " + opt.Sort
  181. }
  182. results = append(results, clause.Expr{SQL: str})
  183. }
  184. return
  185. }
  186. func (m Migrator) CreateIndex(value interface{}, name string) error {
  187. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  188. if idx := stmt.Schema.LookIndex(name); idx != nil {
  189. opts := m.BuildIndexOptions(idx.Fields, stmt)
  190. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  191. createIndexSQL := "CREATE "
  192. if idx.Class != "" {
  193. createIndexSQL += idx.Class + " "
  194. }
  195. createIndexSQL += "INDEX ?"
  196. if idx.Type != "" {
  197. createIndexSQL += " USING " + idx.Type
  198. }
  199. createIndexSQL += " ON ??"
  200. if idx.Where != "" {
  201. createIndexSQL += " WHERE " + idx.Where
  202. }
  203. return m.DB.Exec(createIndexSQL, values...).Error
  204. }
  205. return fmt.Errorf("failed to create index with name %v", name)
  206. })
  207. }
  208. func (m Migrator) HasIndex(value interface{}, name string) bool {
  209. var count int
  210. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  211. if idx := stmt.Schema.LookIndex(name); idx != nil {
  212. name = idx.Name
  213. }
  214. if name != "" {
  215. m.DB.Raw(
  216. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
  217. ).Row().Scan(&count)
  218. }
  219. return nil
  220. })
  221. return count > 0
  222. }
  223. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  224. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  225. var sql string
  226. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
  227. if sql != "" {
  228. return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
  229. }
  230. return fmt.Errorf("failed to find index with name %v", oldName)
  231. })
  232. }
  233. func (m Migrator) DropIndex(value interface{}, name string) error {
  234. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  235. if idx := stmt.Schema.LookIndex(name); idx != nil {
  236. name = idx.Name
  237. }
  238. return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
  239. })
  240. }