update.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. package callbacks
  2. import (
  3. "reflect"
  4. "sort"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/schema"
  8. )
  9. func SetupUpdateReflectValue(db *gorm.DB) {
  10. if db.Error == nil && db.Statement.Schema != nil {
  11. if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
  12. db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
  13. for db.Statement.ReflectValue.Kind() == reflect.Ptr {
  14. db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
  15. }
  16. if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
  17. for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
  18. if _, ok := dest[rel.Name]; ok {
  19. rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
  20. }
  21. }
  22. }
  23. }
  24. }
  25. }
  26. func BeforeUpdate(db *gorm.DB) {
  27. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
  28. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  29. if db.Statement.Schema.BeforeSave {
  30. if i, ok := value.(BeforeSaveInterface); ok {
  31. called = true
  32. db.AddError(i.BeforeSave(tx))
  33. }
  34. }
  35. if db.Statement.Schema.BeforeUpdate {
  36. if i, ok := value.(BeforeUpdateInterface); ok {
  37. called = true
  38. db.AddError(i.BeforeUpdate(tx))
  39. }
  40. }
  41. return called
  42. })
  43. }
  44. }
  45. func Update(db *gorm.DB) {
  46. if db.Error != nil {
  47. return
  48. }
  49. if db.Statement.Schema != nil && !db.Statement.Unscoped {
  50. for _, c := range db.Statement.Schema.UpdateClauses {
  51. db.Statement.AddClause(c)
  52. }
  53. }
  54. if db.Statement.SQL.String() == "" {
  55. db.Statement.SQL.Grow(180)
  56. db.Statement.AddClauseIfNotExists(clause.Update{})
  57. if set := ConvertToAssignments(db.Statement); len(set) != 0 {
  58. db.Statement.AddClause(set)
  59. } else {
  60. return
  61. }
  62. db.Statement.Build(db.Statement.BuildClauses...)
  63. }
  64. if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
  65. db.AddError(gorm.ErrMissingWhereClause)
  66. return
  67. }
  68. if !db.DryRun && db.Error == nil {
  69. result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
  70. if err == nil {
  71. db.RowsAffected, _ = result.RowsAffected()
  72. } else {
  73. db.AddError(err)
  74. }
  75. }
  76. }
  77. func AfterUpdate(db *gorm.DB) {
  78. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
  79. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  80. if db.Statement.Schema.AfterSave {
  81. if i, ok := value.(AfterSaveInterface); ok {
  82. called = true
  83. db.AddError(i.AfterSave(tx))
  84. }
  85. }
  86. if db.Statement.Schema.AfterUpdate {
  87. if i, ok := value.(AfterUpdateInterface); ok {
  88. called = true
  89. db.AddError(i.AfterUpdate(tx))
  90. }
  91. }
  92. return called
  93. })
  94. }
  95. }
  96. // ConvertToAssignments convert to update assignments
  97. func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
  98. var (
  99. selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
  100. assignValue func(field *schema.Field, value interface{})
  101. )
  102. switch stmt.ReflectValue.Kind() {
  103. case reflect.Slice, reflect.Array:
  104. assignValue = func(field *schema.Field, value interface{}) {
  105. for i := 0; i < stmt.ReflectValue.Len(); i++ {
  106. field.Set(stmt.ReflectValue.Index(i), value)
  107. }
  108. }
  109. case reflect.Struct:
  110. assignValue = func(field *schema.Field, value interface{}) {
  111. if stmt.ReflectValue.CanAddr() {
  112. field.Set(stmt.ReflectValue, value)
  113. }
  114. }
  115. default:
  116. assignValue = func(field *schema.Field, value interface{}) {
  117. }
  118. }
  119. updatingValue := reflect.ValueOf(stmt.Dest)
  120. for updatingValue.Kind() == reflect.Ptr {
  121. updatingValue = updatingValue.Elem()
  122. }
  123. if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  124. switch stmt.ReflectValue.Kind() {
  125. case reflect.Slice, reflect.Array:
  126. var primaryKeyExprs []clause.Expression
  127. for i := 0; i < stmt.ReflectValue.Len(); i++ {
  128. var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
  129. var notZero bool
  130. for idx, field := range stmt.Schema.PrimaryFields {
  131. value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
  132. exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
  133. notZero = notZero || !isZero
  134. }
  135. if notZero {
  136. primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
  137. }
  138. }
  139. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
  140. case reflect.Struct:
  141. for _, field := range stmt.Schema.PrimaryFields {
  142. if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
  143. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
  144. }
  145. }
  146. }
  147. }
  148. switch value := updatingValue.Interface().(type) {
  149. case map[string]interface{}:
  150. set = make([]clause.Assignment, 0, len(value))
  151. keys := make([]string, 0, len(value))
  152. for k := range value {
  153. keys = append(keys, k)
  154. }
  155. sort.Strings(keys)
  156. for _, k := range keys {
  157. kv := value[k]
  158. if _, ok := kv.(*gorm.DB); ok {
  159. kv = []interface{}{kv}
  160. }
  161. if stmt.Schema != nil {
  162. if field := stmt.Schema.LookUpField(k); field != nil {
  163. if field.DBName != "" {
  164. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  165. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
  166. assignValue(field, value[k])
  167. }
  168. } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
  169. assignValue(field, value[k])
  170. }
  171. continue
  172. }
  173. }
  174. if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
  175. set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
  176. }
  177. }
  178. if !stmt.SkipHooks && stmt.Schema != nil {
  179. for _, dbName := range stmt.Schema.DBNames {
  180. field := stmt.Schema.LookUpField(dbName)
  181. if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
  182. if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
  183. now := stmt.DB.NowFunc()
  184. assignValue(field, now)
  185. if field.AutoUpdateTime == schema.UnixNanosecond {
  186. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
  187. } else if field.AutoUpdateTime == schema.UnixMillisecond {
  188. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
  189. } else if field.GORMDataType == schema.Time {
  190. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
  191. } else {
  192. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
  193. }
  194. }
  195. }
  196. }
  197. }
  198. default:
  199. var updatingSchema = stmt.Schema
  200. if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  201. // different schema
  202. updatingStmt := &gorm.Statement{DB: stmt.DB}
  203. if err := updatingStmt.Parse(stmt.Dest); err == nil {
  204. updatingSchema = updatingStmt.Schema
  205. }
  206. }
  207. switch updatingValue.Kind() {
  208. case reflect.Struct:
  209. set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
  210. for _, dbName := range stmt.Schema.DBNames {
  211. if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable {
  212. if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  213. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
  214. value, isZero := field.ValueOf(updatingValue)
  215. if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
  216. if field.AutoUpdateTime == schema.UnixNanosecond {
  217. value = stmt.DB.NowFunc().UnixNano()
  218. } else if field.AutoUpdateTime == schema.UnixMillisecond {
  219. value = stmt.DB.NowFunc().UnixNano() / 1e6
  220. } else if field.GORMDataType == schema.Time {
  221. value = stmt.DB.NowFunc()
  222. } else {
  223. value = stmt.DB.NowFunc().Unix()
  224. }
  225. isZero = false
  226. }
  227. if ok || !isZero {
  228. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
  229. assignValue(field, value)
  230. }
  231. }
  232. } else {
  233. if value, isZero := field.ValueOf(updatingValue); !isZero {
  234. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
  235. }
  236. }
  237. }
  238. }
  239. default:
  240. stmt.AddError(gorm.ErrInvalidData)
  241. }
  242. }
  243. return
  244. }