scan.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package gorm
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "reflect"
  6. "strings"
  7. "time"
  8. "gorm.io/gorm/schema"
  9. )
  10. func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
  11. if db.Statement.Schema != nil {
  12. for idx, name := range columns {
  13. if field := db.Statement.Schema.LookUpField(name); field != nil {
  14. values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
  15. continue
  16. }
  17. values[idx] = new(interface{})
  18. }
  19. } else if len(columnTypes) > 0 {
  20. for idx, columnType := range columnTypes {
  21. if columnType.ScanType() != nil {
  22. values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
  23. } else {
  24. values[idx] = new(interface{})
  25. }
  26. }
  27. } else {
  28. for idx := range columns {
  29. values[idx] = new(interface{})
  30. }
  31. }
  32. }
  33. func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
  34. for idx, column := range columns {
  35. if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
  36. mapValue[column] = reflectValue.Interface()
  37. if valuer, ok := mapValue[column].(driver.Valuer); ok {
  38. mapValue[column], _ = valuer.Value()
  39. } else if b, ok := mapValue[column].(sql.RawBytes); ok {
  40. mapValue[column] = string(b)
  41. }
  42. } else {
  43. mapValue[column] = nil
  44. }
  45. }
  46. }
  47. func Scan(rows *sql.Rows, db *DB, initialized bool) {
  48. columns, _ := rows.Columns()
  49. values := make([]interface{}, len(columns))
  50. db.RowsAffected = 0
  51. switch dest := db.Statement.Dest.(type) {
  52. case map[string]interface{}, *map[string]interface{}:
  53. if initialized || rows.Next() {
  54. columnTypes, _ := rows.ColumnTypes()
  55. prepareValues(values, db, columnTypes, columns)
  56. db.RowsAffected++
  57. db.AddError(rows.Scan(values...))
  58. mapValue, ok := dest.(map[string]interface{})
  59. if !ok {
  60. if v, ok := dest.(*map[string]interface{}); ok {
  61. mapValue = *v
  62. }
  63. }
  64. scanIntoMap(mapValue, values, columns)
  65. }
  66. case *[]map[string]interface{}:
  67. columnTypes, _ := rows.ColumnTypes()
  68. for initialized || rows.Next() {
  69. prepareValues(values, db, columnTypes, columns)
  70. initialized = false
  71. db.RowsAffected++
  72. db.AddError(rows.Scan(values...))
  73. mapValue := map[string]interface{}{}
  74. scanIntoMap(mapValue, values, columns)
  75. *dest = append(*dest, mapValue)
  76. }
  77. case *int, *int8, *int16, *int32, *int64,
  78. *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
  79. *float32, *float64,
  80. *bool, *string, *time.Time,
  81. *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
  82. *sql.NullBool, *sql.NullString, *sql.NullTime:
  83. for initialized || rows.Next() {
  84. initialized = false
  85. db.RowsAffected++
  86. db.AddError(rows.Scan(dest))
  87. }
  88. default:
  89. Schema := db.Statement.Schema
  90. switch db.Statement.ReflectValue.Kind() {
  91. case reflect.Slice, reflect.Array:
  92. var (
  93. reflectValueType = db.Statement.ReflectValue.Type().Elem()
  94. isPtr = reflectValueType.Kind() == reflect.Ptr
  95. fields = make([]*schema.Field, len(columns))
  96. joinFields [][2]*schema.Field
  97. )
  98. if isPtr {
  99. reflectValueType = reflectValueType.Elem()
  100. }
  101. db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
  102. if Schema != nil {
  103. if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
  104. Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
  105. }
  106. for idx, column := range columns {
  107. if field := Schema.LookUpField(column); field != nil && field.Readable {
  108. fields[idx] = field
  109. } else if names := strings.Split(column, "__"); len(names) > 1 {
  110. if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
  111. if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
  112. fields[idx] = field
  113. if len(joinFields) == 0 {
  114. joinFields = make([][2]*schema.Field, len(columns))
  115. }
  116. joinFields[idx] = [2]*schema.Field{rel.Field, field}
  117. continue
  118. }
  119. }
  120. values[idx] = &sql.RawBytes{}
  121. } else {
  122. values[idx] = &sql.RawBytes{}
  123. }
  124. }
  125. }
  126. // pluck values into slice of data
  127. isPluck := false
  128. if len(fields) == 1 {
  129. if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
  130. reflectValueType.Kind() != reflect.Struct || // is not struct
  131. Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
  132. isPluck = true
  133. }
  134. }
  135. for initialized || rows.Next() {
  136. initialized = false
  137. db.RowsAffected++
  138. elem := reflect.New(reflectValueType)
  139. if isPluck {
  140. db.AddError(rows.Scan(elem.Interface()))
  141. } else {
  142. for idx, field := range fields {
  143. if field != nil {
  144. values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
  145. }
  146. }
  147. db.AddError(rows.Scan(values...))
  148. for idx, field := range fields {
  149. if len(joinFields) != 0 && joinFields[idx][0] != nil {
  150. value := reflect.ValueOf(values[idx]).Elem()
  151. relValue := joinFields[idx][0].ReflectValueOf(elem)
  152. if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
  153. if value.IsNil() {
  154. continue
  155. }
  156. relValue.Set(reflect.New(relValue.Type().Elem()))
  157. }
  158. field.Set(relValue, values[idx])
  159. } else if field != nil {
  160. field.Set(elem, values[idx])
  161. }
  162. }
  163. }
  164. if isPtr {
  165. db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
  166. } else {
  167. db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
  168. }
  169. }
  170. case reflect.Struct, reflect.Ptr:
  171. if db.Statement.ReflectValue.Type() != Schema.ModelType {
  172. Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
  173. }
  174. if initialized || rows.Next() {
  175. for idx, column := range columns {
  176. if field := Schema.LookUpField(column); field != nil && field.Readable {
  177. values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
  178. } else if names := strings.Split(column, "__"); len(names) > 1 {
  179. if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
  180. if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
  181. values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
  182. continue
  183. }
  184. }
  185. values[idx] = &sql.RawBytes{}
  186. } else if len(columns) == 1 {
  187. values[idx] = dest
  188. } else {
  189. values[idx] = &sql.RawBytes{}
  190. }
  191. }
  192. db.RowsAffected++
  193. db.AddError(rows.Scan(values...))
  194. for idx, column := range columns {
  195. if field := Schema.LookUpField(column); field != nil && field.Readable {
  196. field.Set(db.Statement.ReflectValue, values[idx])
  197. } else if names := strings.Split(column, "__"); len(names) > 1 {
  198. if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
  199. if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
  200. relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
  201. value := reflect.ValueOf(values[idx]).Elem()
  202. if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
  203. if value.IsNil() {
  204. continue
  205. }
  206. relValue.Set(reflect.New(relValue.Type().Elem()))
  207. }
  208. field.Set(relValue, values[idx])
  209. }
  210. }
  211. }
  212. }
  213. }
  214. default:
  215. db.AddError(rows.Scan(dest))
  216. }
  217. }
  218. if err := rows.Err(); err != nil && err != db.Error {
  219. db.AddError(err)
  220. }
  221. if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
  222. db.AddError(ErrRecordNotFound)
  223. }
  224. }