sqlite.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package sqlite
  2. import (
  3. "database/sql"
  4. "strconv"
  5. "strings"
  6. _ "github.com/mattn/go-sqlite3"
  7. "gorm.io/gorm"
  8. "gorm.io/gorm/callbacks"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/migrator"
  12. "gorm.io/gorm/schema"
  13. )
  14. // DriverName is the default driver name for SQLite.
  15. const DriverName = "sqlite3"
  16. type Dialector struct {
  17. DriverName string
  18. DSN string
  19. Conn gorm.ConnPool
  20. }
  21. func Open(dsn string) gorm.Dialector {
  22. return &Dialector{DSN: dsn}
  23. }
  24. func (dialector Dialector) Name() string {
  25. return "sqlite"
  26. }
  27. func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
  28. if dialector.DriverName == "" {
  29. dialector.DriverName = DriverName
  30. }
  31. // register callbacks
  32. callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
  33. LastInsertIDReversed: true,
  34. })
  35. if dialector.Conn != nil {
  36. db.ConnPool = dialector.Conn
  37. } else {
  38. db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
  39. if err != nil {
  40. return err
  41. }
  42. }
  43. for k, v := range dialector.ClauseBuilders() {
  44. db.ClauseBuilders[k] = v
  45. }
  46. return
  47. }
  48. func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
  49. return map[string]clause.ClauseBuilder{
  50. "INSERT": func(c clause.Clause, builder clause.Builder) {
  51. if insert, ok := c.Expression.(clause.Insert); ok {
  52. if stmt, ok := builder.(*gorm.Statement); ok {
  53. stmt.WriteString("INSERT ")
  54. if insert.Modifier != "" {
  55. stmt.WriteString(insert.Modifier)
  56. stmt.WriteByte(' ')
  57. }
  58. stmt.WriteString("INTO ")
  59. if insert.Table.Name == "" {
  60. stmt.WriteQuoted(stmt.Table)
  61. } else {
  62. stmt.WriteQuoted(insert.Table)
  63. }
  64. return
  65. }
  66. }
  67. c.Build(builder)
  68. },
  69. "LIMIT": func(c clause.Clause, builder clause.Builder) {
  70. if limit, ok := c.Expression.(clause.Limit); ok {
  71. if limit.Limit > 0 || limit.Offset > 0 {
  72. if limit.Limit <= 0 {
  73. limit.Limit = -1
  74. }
  75. builder.WriteString("LIMIT " + strconv.Itoa(limit.Limit))
  76. }
  77. if limit.Offset > 0 {
  78. builder.WriteString(" OFFSET " + strconv.Itoa(limit.Offset))
  79. }
  80. }
  81. },
  82. "FOR": func(c clause.Clause, builder clause.Builder) {
  83. if _, ok := c.Expression.(clause.Locking); ok {
  84. // SQLite3 does not support row-level locking.
  85. return
  86. }
  87. c.Build(builder)
  88. },
  89. }
  90. }
  91. func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
  92. if field.AutoIncrement {
  93. return clause.Expr{SQL: "NULL"}
  94. }
  95. // doesn't work, will raise error
  96. return clause.Expr{SQL: "DEFAULT"}
  97. }
  98. func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
  99. return Migrator{migrator.Migrator{Config: migrator.Config{
  100. DB: db,
  101. Dialector: dialector,
  102. CreateIndexAfterCreateTable: true,
  103. }}}
  104. }
  105. func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
  106. writer.WriteByte('?')
  107. }
  108. func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
  109. writer.WriteByte('`')
  110. if strings.Contains(str, ".") {
  111. for idx, str := range strings.Split(str, ".") {
  112. if idx > 0 {
  113. writer.WriteString(".`")
  114. }
  115. writer.WriteString(str)
  116. writer.WriteByte('`')
  117. }
  118. } else {
  119. writer.WriteString(str)
  120. writer.WriteByte('`')
  121. }
  122. }
  123. func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
  124. return logger.ExplainSQL(sql, nil, `"`, vars...)
  125. }
  126. func (dialector Dialector) DataTypeOf(field *schema.Field) string {
  127. switch field.DataType {
  128. case schema.Bool:
  129. return "numeric"
  130. case schema.Int, schema.Uint:
  131. if field.AutoIncrement && !field.PrimaryKey {
  132. // https://www.sqlite.org/autoinc.html
  133. return "integer PRIMARY KEY AUTOINCREMENT"
  134. } else {
  135. return "integer"
  136. }
  137. case schema.Float:
  138. return "real"
  139. case schema.String:
  140. return "text"
  141. case schema.Time:
  142. return "datetime"
  143. case schema.Bytes:
  144. return "blob"
  145. }
  146. return string(field.DataType)
  147. }
  148. func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
  149. tx.Exec("SAVEPOINT " + name)
  150. return nil
  151. }
  152. func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
  153. tx.Exec("ROLLBACK TO SAVEPOINT " + name)
  154. return nil
  155. }