migrator.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. package sqlite
  2. import (
  3. "fmt"
  4. "regexp"
  5. "strings"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/migrator"
  9. "gorm.io/gorm/schema"
  10. )
  11. type Migrator struct {
  12. migrator.Migrator
  13. }
  14. func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
  15. var enabled int
  16. m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
  17. if enabled == 1 {
  18. m.DB.Exec("PRAGMA foreign_keys = OFF")
  19. defer m.DB.Exec("PRAGMA foreign_keys = ON")
  20. }
  21. return fc()
  22. }
  23. func (m Migrator) HasTable(value interface{}) bool {
  24. var count int
  25. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  26. return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
  27. })
  28. return count > 0
  29. }
  30. func (m Migrator) DropTable(values ...interface{}) error {
  31. return m.RunWithoutForeignKey(func() error {
  32. values = m.ReorderModels(values, false)
  33. tx := m.DB.Session(&gorm.Session{})
  34. for i := len(values) - 1; i >= 0; i-- {
  35. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  36. return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
  37. }); err != nil {
  38. return err
  39. }
  40. }
  41. return nil
  42. })
  43. }
  44. func (m Migrator) HasColumn(value interface{}, name string) bool {
  45. var count int
  46. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  47. if stmt.Schema != nil {
  48. if field := stmt.Schema.LookUpField(name); field != nil {
  49. name = field.DBName
  50. }
  51. }
  52. if name != "" {
  53. m.DB.Raw(
  54. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  55. "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
  56. ).Row().Scan(&count)
  57. }
  58. return nil
  59. })
  60. return count > 0
  61. }
  62. func (m Migrator) AlterColumn(value interface{}, name string) error {
  63. return m.RunWithoutForeignKey(func() error {
  64. return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
  65. if field := stmt.Schema.LookUpField(name); field != nil {
  66. reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?,")
  67. if err != nil {
  68. return "", nil, err
  69. }
  70. createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?,", field.DBName))
  71. return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
  72. } else {
  73. return "", nil, fmt.Errorf("failed to alter field with name %v", name)
  74. }
  75. })
  76. })
  77. }
  78. func (m Migrator) DropColumn(value interface{}, name string) error {
  79. return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
  80. if field := stmt.Schema.LookUpField(name); field != nil {
  81. name = field.DBName
  82. }
  83. reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,")
  84. if err != nil {
  85. return "", nil, err
  86. }
  87. createSQL := reg.ReplaceAllString(rawDDL, "")
  88. return createSQL, nil, nil
  89. })
  90. }
  91. func (m Migrator) CreateConstraint(value interface{}, name string) error {
  92. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  93. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  94. return m.recreateTable(value, &table,
  95. func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
  96. var (
  97. constraintName string
  98. constraintSql string
  99. constraintValues []interface{}
  100. )
  101. if constraint != nil {
  102. constraintName = constraint.Name
  103. constraintSql, constraintValues = buildConstraint(constraint)
  104. } else if chk != nil {
  105. constraintName = chk.Name
  106. constraintSql = "CONSTRAINT ? CHECK (?)"
  107. constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
  108. } else {
  109. return "", nil, nil
  110. }
  111. createDDL, err := parseDDL(rawDDL)
  112. if err != nil {
  113. return "", nil, err
  114. }
  115. createDDL.addConstraint(constraintName, constraintSql)
  116. createSQL := createDDL.compile()
  117. return createSQL, constraintValues, nil
  118. })
  119. })
  120. }
  121. func (m Migrator) DropConstraint(value interface{}, name string) error {
  122. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  123. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  124. if constraint != nil {
  125. name = constraint.Name
  126. } else if chk != nil {
  127. name = chk.Name
  128. }
  129. return m.recreateTable(value, &table,
  130. func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
  131. createDDL, err := parseDDL(rawDDL)
  132. if err != nil {
  133. return "", nil, err
  134. }
  135. createDDL.removeConstraint(name)
  136. createSQL := createDDL.compile()
  137. return createSQL, nil, nil
  138. })
  139. })
  140. }
  141. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  142. var count int64
  143. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  144. constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
  145. if constraint != nil {
  146. name = constraint.Name
  147. } else if chk != nil {
  148. name = chk.Name
  149. }
  150. m.DB.Raw(
  151. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  152. "table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%",
  153. ).Row().Scan(&count)
  154. return nil
  155. })
  156. return count > 0
  157. }
  158. func (m Migrator) CurrentDatabase() (name string) {
  159. var null interface{}
  160. m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
  161. return
  162. }
  163. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  164. for _, opt := range opts {
  165. str := stmt.Quote(opt.DBName)
  166. if opt.Expression != "" {
  167. str = opt.Expression
  168. }
  169. if opt.Collate != "" {
  170. str += " COLLATE " + opt.Collate
  171. }
  172. if opt.Sort != "" {
  173. str += " " + opt.Sort
  174. }
  175. results = append(results, clause.Expr{SQL: str})
  176. }
  177. return
  178. }
  179. func (m Migrator) CreateIndex(value interface{}, name string) error {
  180. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  181. if idx := stmt.Schema.LookIndex(name); idx != nil {
  182. opts := m.BuildIndexOptions(idx.Fields, stmt)
  183. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  184. createIndexSQL := "CREATE "
  185. if idx.Class != "" {
  186. createIndexSQL += idx.Class + " "
  187. }
  188. createIndexSQL += "INDEX ?"
  189. if idx.Type != "" {
  190. createIndexSQL += " USING " + idx.Type
  191. }
  192. createIndexSQL += " ON ??"
  193. if idx.Where != "" {
  194. createIndexSQL += " WHERE " + idx.Where
  195. }
  196. return m.DB.Exec(createIndexSQL, values...).Error
  197. }
  198. return fmt.Errorf("failed to create index with name %v", name)
  199. })
  200. }
  201. func (m Migrator) HasIndex(value interface{}, name string) bool {
  202. var count int
  203. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  204. if idx := stmt.Schema.LookIndex(name); idx != nil {
  205. name = idx.Name
  206. }
  207. if name != "" {
  208. m.DB.Raw(
  209. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
  210. ).Row().Scan(&count)
  211. }
  212. return nil
  213. })
  214. return count > 0
  215. }
  216. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  217. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  218. var sql string
  219. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
  220. if sql != "" {
  221. return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
  222. }
  223. return fmt.Errorf("failed to find index with name %v", oldName)
  224. })
  225. }
  226. func (m Migrator) DropIndex(value interface{}, name string) error {
  227. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  228. if idx := stmt.Schema.LookIndex(name); idx != nil {
  229. name = idx.Name
  230. }
  231. return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
  232. })
  233. }
  234. func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
  235. sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
  236. if constraint.OnDelete != "" {
  237. sql += " ON DELETE " + constraint.OnDelete
  238. }
  239. if constraint.OnUpdate != "" {
  240. sql += " ON UPDATE " + constraint.OnUpdate
  241. }
  242. var foreignKeys, references []interface{}
  243. for _, field := range constraint.ForeignKeys {
  244. foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
  245. }
  246. for _, field := range constraint.References {
  247. references = append(references, clause.Column{Name: field.DBName})
  248. }
  249. results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
  250. return
  251. }
  252. func (m Migrator) getRawDDL(table string) (string, error) {
  253. var createSQL string
  254. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
  255. if m.DB.Error != nil {
  256. return "", m.DB.Error
  257. }
  258. return createSQL, nil
  259. }
  260. func (m Migrator) recreateTable(value interface{}, tablePtr *string,
  261. getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
  262. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  263. table := stmt.Table
  264. if tablePtr != nil {
  265. table = *tablePtr
  266. }
  267. rawDDL, err := m.getRawDDL(table)
  268. if err != nil {
  269. return err
  270. }
  271. newTableName := table + "__temp"
  272. createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
  273. if err != nil {
  274. return err
  275. }
  276. if createSQL == "" {
  277. return nil
  278. }
  279. tableReg, err := regexp.Compile(" ('|`|\"| )" + table + "('|`|\"| ) ")
  280. if err != nil {
  281. return err
  282. }
  283. createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
  284. createDDL, err := parseDDL(createSQL)
  285. if err != nil {
  286. return err
  287. }
  288. columns := createDDL.getColumns()
  289. return m.DB.Transaction(func(tx *gorm.DB) error {
  290. queries := []string{
  291. createSQL,
  292. fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
  293. fmt.Sprintf("DROP TABLE `%v`", table),
  294. fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
  295. }
  296. for _, query := range queries {
  297. if err := tx.Exec(query, sqlArgs...).Error; err != nil {
  298. return err
  299. }
  300. }
  301. return nil
  302. })
  303. })
  304. }