sql.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package logger
  2. import (
  3. "database/sql/driver"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode"
  11. "gorm.io/gorm/utils"
  12. )
  13. const (
  14. tmFmtWithMS = "2006-01-02 15:04:05.999"
  15. tmFmtZero = "0000-00-00 00:00:00"
  16. nullStr = "NULL"
  17. )
  18. func isPrintable(s []byte) bool {
  19. for _, r := range s {
  20. if !unicode.IsPrint(rune(r)) {
  21. return false
  22. }
  23. }
  24. return true
  25. }
  26. var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
  27. func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
  28. var convertParams func(interface{}, int)
  29. var vars = make([]string, len(avars))
  30. convertParams = func(v interface{}, idx int) {
  31. switch v := v.(type) {
  32. case bool:
  33. vars[idx] = strconv.FormatBool(v)
  34. case time.Time:
  35. if v.IsZero() {
  36. vars[idx] = escaper + tmFmtZero + escaper
  37. } else {
  38. vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
  39. }
  40. case *time.Time:
  41. if v != nil {
  42. if v.IsZero() {
  43. vars[idx] = escaper + tmFmtZero + escaper
  44. } else {
  45. vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
  46. }
  47. } else {
  48. vars[idx] = nullStr
  49. }
  50. case driver.Valuer:
  51. reflectValue := reflect.ValueOf(v)
  52. if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
  53. r, _ := v.Value()
  54. convertParams(r, idx)
  55. } else {
  56. vars[idx] = nullStr
  57. }
  58. case fmt.Stringer:
  59. reflectValue := reflect.ValueOf(v)
  60. if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
  61. vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
  62. } else {
  63. vars[idx] = nullStr
  64. }
  65. case []byte:
  66. if isPrintable(v) {
  67. vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
  68. } else {
  69. vars[idx] = escaper + "<binary>" + escaper
  70. }
  71. case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
  72. vars[idx] = utils.ToString(v)
  73. case float64, float32:
  74. vars[idx] = fmt.Sprintf("%.6f", v)
  75. case string:
  76. vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
  77. default:
  78. rv := reflect.ValueOf(v)
  79. if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
  80. vars[idx] = nullStr
  81. } else if valuer, ok := v.(driver.Valuer); ok {
  82. v, _ = valuer.Value()
  83. convertParams(v, idx)
  84. } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
  85. convertParams(reflect.Indirect(rv).Interface(), idx)
  86. } else {
  87. for _, t := range convertibleTypes {
  88. if rv.Type().ConvertibleTo(t) {
  89. convertParams(rv.Convert(t).Interface(), idx)
  90. return
  91. }
  92. }
  93. vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
  94. }
  95. }
  96. }
  97. for idx, v := range avars {
  98. convertParams(v, idx)
  99. }
  100. if numericPlaceholder == nil {
  101. var idx int
  102. var newSQL strings.Builder
  103. for _, v := range []byte(sql) {
  104. if v == '?' {
  105. if len(vars) > idx {
  106. newSQL.WriteString(vars[idx])
  107. idx++
  108. continue
  109. }
  110. }
  111. newSQL.WriteByte(v)
  112. }
  113. sql = newSQL.String()
  114. } else {
  115. sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
  116. for idx, v := range vars {
  117. sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
  118. }
  119. }
  120. return sql
  121. }