mysql.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "math"
  7. "strings"
  8. "time"
  9. _ "github.com/go-sql-driver/mysql"
  10. "gorm.io/gorm"
  11. "gorm.io/gorm/callbacks"
  12. "gorm.io/gorm/clause"
  13. "gorm.io/gorm/logger"
  14. "gorm.io/gorm/migrator"
  15. "gorm.io/gorm/schema"
  16. )
  17. type Config struct {
  18. DriverName string
  19. DSN string
  20. Conn gorm.ConnPool
  21. SkipInitializeWithVersion bool
  22. DefaultStringSize uint
  23. DefaultDatetimePrecision *int
  24. DisableDatetimePrecision bool
  25. DontSupportRenameIndex bool
  26. DontSupportRenameColumn bool
  27. DontSupportForShareClause bool
  28. }
  29. type Dialector struct {
  30. *Config
  31. }
  32. var (
  33. // UpdateClauses update clauses setting
  34. UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
  35. // DeleteClauses delete clauses setting
  36. DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
  37. defaultDatetimePrecision = 3
  38. )
  39. func Open(dsn string) gorm.Dialector {
  40. return &Dialector{Config: &Config{DSN: dsn}}
  41. }
  42. func New(config Config) gorm.Dialector {
  43. return &Dialector{Config: &config}
  44. }
  45. func (dialector Dialector) Name() string {
  46. return "mysql"
  47. }
  48. // NowFunc return now func
  49. func (dialector Dialector) NowFunc(n int) func() time.Time {
  50. return func() time.Time {
  51. round := time.Second / time.Duration(math.Pow10(n))
  52. return time.Now().Local().Round(round)
  53. }
  54. }
  55. func (dialector Dialector) Apply(config *gorm.Config) error {
  56. if config.NowFunc == nil {
  57. if dialector.DefaultDatetimePrecision == nil {
  58. dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
  59. }
  60. // while maintaining the readability of the code, separate the business logic from
  61. // the general part and leave it to the function to do it here.
  62. config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
  63. }
  64. return nil
  65. }
  66. func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
  67. ctx := context.Background()
  68. // register callbacks
  69. callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
  70. UpdateClauses: UpdateClauses,
  71. DeleteClauses: DeleteClauses,
  72. })
  73. if dialector.DriverName == "" {
  74. dialector.DriverName = "mysql"
  75. }
  76. if dialector.DefaultDatetimePrecision == nil {
  77. dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
  78. }
  79. if dialector.Conn != nil {
  80. db.ConnPool = dialector.Conn
  81. } else {
  82. db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
  83. if err != nil {
  84. return err
  85. }
  86. }
  87. if !dialector.Config.SkipInitializeWithVersion {
  88. var version string
  89. err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version)
  90. if err != nil {
  91. return err
  92. }
  93. if strings.Contains(version, "MariaDB") {
  94. dialector.Config.DontSupportRenameIndex = true
  95. dialector.Config.DontSupportRenameColumn = true
  96. dialector.Config.DontSupportForShareClause = true
  97. } else if strings.HasPrefix(version, "5.6.") {
  98. dialector.Config.DontSupportRenameIndex = true
  99. dialector.Config.DontSupportRenameColumn = true
  100. dialector.Config.DontSupportForShareClause = true
  101. } else if strings.HasPrefix(version, "5.7.") {
  102. dialector.Config.DontSupportRenameColumn = true
  103. dialector.Config.DontSupportForShareClause = true
  104. } else if strings.HasPrefix(version, "5.") {
  105. dialector.Config.DisableDatetimePrecision = true
  106. dialector.Config.DontSupportRenameIndex = true
  107. dialector.Config.DontSupportRenameColumn = true
  108. dialector.Config.DontSupportForShareClause = true
  109. }
  110. }
  111. for k, v := range dialector.ClauseBuilders() {
  112. db.ClauseBuilders[k] = v
  113. }
  114. return
  115. }
  116. const (
  117. // ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
  118. ClauseOnConflict = "ON CONFLICT"
  119. // ClauseValues for clause.ClauseBuilder VALUES key
  120. ClauseValues = "VALUES"
  121. // ClauseValues for clause.ClauseBuilder FOR key
  122. ClauseFor = "FOR"
  123. )
  124. func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
  125. clauseBuilders := map[string]clause.ClauseBuilder{
  126. ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
  127. onConflict, ok := c.Expression.(clause.OnConflict)
  128. if !ok {
  129. c.Build(builder)
  130. return
  131. }
  132. builder.WriteString("ON DUPLICATE KEY UPDATE ")
  133. if len(onConflict.DoUpdates) == 0 {
  134. if s := builder.(*gorm.Statement).Schema; s != nil {
  135. var column clause.Column
  136. onConflict.DoNothing = false
  137. if s.PrioritizedPrimaryField != nil {
  138. column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
  139. } else if len(s.DBNames) > 0 {
  140. column = clause.Column{Name: s.DBNames[0]}
  141. }
  142. if column.Name != "" {
  143. onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
  144. }
  145. }
  146. }
  147. for idx, assignment := range onConflict.DoUpdates {
  148. if idx > 0 {
  149. builder.WriteByte(',')
  150. }
  151. builder.WriteQuoted(assignment.Column)
  152. builder.WriteByte('=')
  153. if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
  154. column.Table = ""
  155. builder.WriteString("VALUES(")
  156. builder.WriteQuoted(column)
  157. builder.WriteByte(')')
  158. } else {
  159. builder.AddVar(builder, assignment.Value)
  160. }
  161. }
  162. },
  163. ClauseValues: func(c clause.Clause, builder clause.Builder) {
  164. if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
  165. builder.WriteString("VALUES()")
  166. return
  167. }
  168. c.Build(builder)
  169. },
  170. }
  171. if dialector.Config.DontSupportForShareClause {
  172. clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
  173. if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
  174. builder.WriteString("LOCK IN SHARE MODE")
  175. return
  176. }
  177. c.Build(builder)
  178. }
  179. }
  180. return clauseBuilders
  181. }
  182. func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
  183. return clause.Expr{SQL: "DEFAULT"}
  184. }
  185. func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
  186. return Migrator{
  187. Migrator: migrator.Migrator{
  188. Config: migrator.Config{
  189. DB: db,
  190. Dialector: dialector,
  191. },
  192. },
  193. Dialector: dialector,
  194. }
  195. }
  196. func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
  197. writer.WriteByte('?')
  198. }
  199. func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
  200. var (
  201. underQuoted, selfQuoted bool
  202. continuousBacktick int8
  203. shiftDelimiter int8
  204. )
  205. for _, v := range []byte(str) {
  206. switch v {
  207. case '`':
  208. continuousBacktick++
  209. if continuousBacktick == 2 {
  210. writer.WriteString("``")
  211. continuousBacktick = 0
  212. }
  213. case '.':
  214. if continuousBacktick > 0 || !selfQuoted {
  215. shiftDelimiter = 0
  216. underQuoted = false
  217. continuousBacktick = 0
  218. writer.WriteString("`")
  219. }
  220. writer.WriteByte(v)
  221. continue
  222. default:
  223. if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
  224. writer.WriteByte('`')
  225. underQuoted = true
  226. if selfQuoted = continuousBacktick > 0; selfQuoted {
  227. continuousBacktick -= 1
  228. }
  229. }
  230. for ; continuousBacktick > 0; continuousBacktick -= 1 {
  231. writer.WriteString("``")
  232. }
  233. writer.WriteByte(v)
  234. }
  235. shiftDelimiter++
  236. }
  237. if continuousBacktick > 0 && !selfQuoted {
  238. writer.WriteString("``")
  239. }
  240. writer.WriteString("`")
  241. }
  242. func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
  243. return logger.ExplainSQL(sql, nil, `'`, vars...)
  244. }
  245. func (dialector Dialector) DataTypeOf(field *schema.Field) string {
  246. switch field.DataType {
  247. case schema.Bool:
  248. return "boolean"
  249. case schema.Int, schema.Uint:
  250. return dialector.getSchemaIntAndUnitType(field)
  251. case schema.Float:
  252. return dialector.getSchemaFloatType(field)
  253. case schema.String:
  254. return dialector.getSchemaStringType(field)
  255. case schema.Time:
  256. return dialector.getSchemaTimeType(field)
  257. case schema.Bytes:
  258. return dialector.getSchemaBytesType(field)
  259. }
  260. return string(field.DataType)
  261. }
  262. func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
  263. if field.Precision > 0 {
  264. return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
  265. }
  266. if field.Size <= 32 {
  267. return "float"
  268. }
  269. return "double"
  270. }
  271. func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
  272. size := field.Size
  273. if size == 0 {
  274. if dialector.DefaultStringSize > 0 {
  275. size = int(dialector.DefaultStringSize)
  276. } else {
  277. hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
  278. // TEXT, GEOMETRY or JSON column can't have a default value
  279. if field.PrimaryKey || field.HasDefaultValue || hasIndex {
  280. size = 191 // utf8mb4
  281. }
  282. }
  283. }
  284. if size >= 65536 && size <= int(math.Pow(2, 24)) {
  285. return "mediumtext"
  286. }
  287. if size > int(math.Pow(2, 24)) || size <= 0 {
  288. return "longtext"
  289. }
  290. return fmt.Sprintf("varchar(%d)", size)
  291. }
  292. func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
  293. precision := ""
  294. if !dialector.DisableDatetimePrecision && field.Precision == 0 {
  295. field.Precision = *dialector.DefaultDatetimePrecision
  296. }
  297. if field.Precision > 0 {
  298. precision = fmt.Sprintf("(%d)", field.Precision)
  299. }
  300. if field.NotNull || field.PrimaryKey {
  301. return "datetime" + precision
  302. }
  303. return "datetime" + precision + " NULL"
  304. }
  305. func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
  306. if field.Size > 0 && field.Size < 65536 {
  307. return fmt.Sprintf("varbinary(%d)", field.Size)
  308. }
  309. if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
  310. return "mediumblob"
  311. }
  312. return "longblob"
  313. }
  314. func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
  315. sqlType := "bigint"
  316. switch {
  317. case field.Size <= 8:
  318. sqlType = "tinyint"
  319. case field.Size <= 16:
  320. sqlType = "smallint"
  321. case field.Size <= 24:
  322. sqlType = "mediumint"
  323. case field.Size <= 32:
  324. sqlType = "int"
  325. }
  326. if field.DataType == schema.Uint {
  327. sqlType += " unsigned"
  328. }
  329. if field.AutoIncrement {
  330. sqlType += " AUTO_INCREMENT"
  331. }
  332. return sqlType
  333. }
  334. func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
  335. tx.Exec("SAVEPOINT " + name)
  336. return nil
  337. }
  338. func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
  339. tx.Exec("ROLLBACK TO SAVEPOINT " + name)
  340. return nil
  341. }