schema.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package schema
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "go/ast"
  7. "reflect"
  8. "sync"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. )
  12. // ErrUnsupportedDataType unsupported data type
  13. var ErrUnsupportedDataType = errors.New("unsupported data type")
  14. type Schema struct {
  15. Name string
  16. ModelType reflect.Type
  17. Table string
  18. PrioritizedPrimaryField *Field
  19. DBNames []string
  20. PrimaryFields []*Field
  21. PrimaryFieldDBNames []string
  22. Fields []*Field
  23. FieldsByName map[string]*Field
  24. FieldsByDBName map[string]*Field
  25. FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
  26. Relationships Relationships
  27. CreateClauses []clause.Interface
  28. QueryClauses []clause.Interface
  29. UpdateClauses []clause.Interface
  30. DeleteClauses []clause.Interface
  31. BeforeCreate, AfterCreate bool
  32. BeforeUpdate, AfterUpdate bool
  33. BeforeDelete, AfterDelete bool
  34. BeforeSave, AfterSave bool
  35. AfterFind bool
  36. err error
  37. initialized chan struct{}
  38. namer Namer
  39. cacheStore *sync.Map
  40. }
  41. func (schema Schema) String() string {
  42. if schema.ModelType.Name() == "" {
  43. return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
  44. }
  45. return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
  46. }
  47. func (schema Schema) MakeSlice() reflect.Value {
  48. slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
  49. results := reflect.New(slice.Type())
  50. results.Elem().Set(slice)
  51. return results
  52. }
  53. func (schema Schema) LookUpField(name string) *Field {
  54. if field, ok := schema.FieldsByDBName[name]; ok {
  55. return field
  56. }
  57. if field, ok := schema.FieldsByName[name]; ok {
  58. return field
  59. }
  60. return nil
  61. }
  62. type Tabler interface {
  63. TableName() string
  64. }
  65. // Parse get data type from dialector
  66. func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  67. if dest == nil {
  68. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  69. }
  70. modelType := reflect.ValueOf(dest).Type()
  71. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
  72. modelType = modelType.Elem()
  73. }
  74. if modelType.Kind() != reflect.Struct {
  75. if modelType.PkgPath() == "" {
  76. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  77. }
  78. return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  79. }
  80. if v, ok := cacheStore.Load(modelType); ok {
  81. s := v.(*Schema)
  82. // Wait for the initialization of other goroutines to complete
  83. <-s.initialized
  84. return s, s.err
  85. }
  86. modelValue := reflect.New(modelType)
  87. tableName := namer.TableName(modelType.Name())
  88. if tabler, ok := modelValue.Interface().(Tabler); ok {
  89. tableName = tabler.TableName()
  90. }
  91. if en, ok := namer.(embeddedNamer); ok {
  92. tableName = en.Table
  93. }
  94. schema := &Schema{
  95. Name: modelType.Name(),
  96. ModelType: modelType,
  97. Table: tableName,
  98. FieldsByName: map[string]*Field{},
  99. FieldsByDBName: map[string]*Field{},
  100. Relationships: Relationships{Relations: map[string]*Relationship{}},
  101. cacheStore: cacheStore,
  102. namer: namer,
  103. initialized: make(chan struct{}),
  104. }
  105. // When the schema initialization is completed, the channel will be closed
  106. defer close(schema.initialized)
  107. if v, loaded := cacheStore.Load(modelType); loaded {
  108. s := v.(*Schema)
  109. // Wait for the initialization of other goroutines to complete
  110. <-s.initialized
  111. return s, s.err
  112. }
  113. for i := 0; i < modelType.NumField(); i++ {
  114. if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
  115. if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
  116. schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
  117. } else {
  118. schema.Fields = append(schema.Fields, field)
  119. }
  120. }
  121. }
  122. for _, field := range schema.Fields {
  123. if field.DBName == "" && field.DataType != "" {
  124. field.DBName = namer.ColumnName(schema.Table, field.Name)
  125. }
  126. if field.DBName != "" {
  127. // nonexistence or shortest path or first appear prioritized if has permission
  128. if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
  129. if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
  130. schema.DBNames = append(schema.DBNames, field.DBName)
  131. }
  132. schema.FieldsByDBName[field.DBName] = field
  133. schema.FieldsByName[field.Name] = field
  134. if v != nil && v.PrimaryKey {
  135. for idx, f := range schema.PrimaryFields {
  136. if f == v {
  137. schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
  138. }
  139. }
  140. }
  141. if field.PrimaryKey {
  142. schema.PrimaryFields = append(schema.PrimaryFields, field)
  143. }
  144. }
  145. }
  146. if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
  147. schema.FieldsByName[field.Name] = field
  148. }
  149. field.setupValuerAndSetter()
  150. }
  151. prioritizedPrimaryField := schema.LookUpField("id")
  152. if prioritizedPrimaryField == nil {
  153. prioritizedPrimaryField = schema.LookUpField("ID")
  154. }
  155. if prioritizedPrimaryField != nil {
  156. if prioritizedPrimaryField.PrimaryKey {
  157. schema.PrioritizedPrimaryField = prioritizedPrimaryField
  158. } else if len(schema.PrimaryFields) == 0 {
  159. prioritizedPrimaryField.PrimaryKey = true
  160. schema.PrioritizedPrimaryField = prioritizedPrimaryField
  161. schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
  162. }
  163. }
  164. if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
  165. schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
  166. }
  167. for _, field := range schema.PrimaryFields {
  168. schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
  169. }
  170. for _, field := range schema.FieldsByDBName {
  171. if field.HasDefaultValue && field.DefaultValueInterface == nil {
  172. schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
  173. }
  174. }
  175. if field := schema.PrioritizedPrimaryField; field != nil {
  176. switch field.GORMDataType {
  177. case Int, Uint:
  178. if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
  179. if !field.HasDefaultValue || field.DefaultValueInterface != nil {
  180. schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
  181. }
  182. field.HasDefaultValue = true
  183. field.AutoIncrement = true
  184. }
  185. }
  186. }
  187. callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
  188. for _, name := range callbacks {
  189. if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
  190. switch methodValue.Type().String() {
  191. case "func(*gorm.DB) error": // TODO hack
  192. reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
  193. default:
  194. logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
  195. }
  196. }
  197. }
  198. if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
  199. s := v.(*Schema)
  200. // Wait for the initialization of other goroutines to complete
  201. <-s.initialized
  202. return s, s.err
  203. }
  204. defer func() {
  205. if schema.err != nil {
  206. logger.Default.Error(context.Background(), schema.err.Error())
  207. cacheStore.Delete(modelType)
  208. }
  209. }()
  210. if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
  211. for _, field := range schema.Fields {
  212. if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
  213. if schema.parseRelation(field); schema.err != nil {
  214. return schema, schema.err
  215. } else {
  216. schema.FieldsByName[field.Name] = field
  217. }
  218. }
  219. fieldValue := reflect.New(field.IndirectFieldType)
  220. fieldInterface := fieldValue.Interface()
  221. if fc, ok := fieldInterface.(CreateClausesInterface); ok {
  222. field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
  223. }
  224. if fc, ok := fieldInterface.(QueryClausesInterface); ok {
  225. field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
  226. }
  227. if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
  228. field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
  229. }
  230. if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
  231. field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
  232. }
  233. }
  234. }
  235. return schema, schema.err
  236. }
  237. func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  238. modelType := reflect.ValueOf(dest).Type()
  239. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
  240. modelType = modelType.Elem()
  241. }
  242. if modelType.Kind() != reflect.Struct {
  243. if modelType.PkgPath() == "" {
  244. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  245. }
  246. return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  247. }
  248. if v, ok := cacheStore.Load(modelType); ok {
  249. return v.(*Schema), nil
  250. }
  251. return Parse(dest, cacheStore, namer)
  252. }