soft_delete.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package gorm
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "encoding/json"
  6. "reflect"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/schema"
  9. )
  10. type DeletedAt sql.NullTime
  11. // Scan implements the Scanner interface.
  12. func (n *DeletedAt) Scan(value interface{}) error {
  13. return (*sql.NullTime)(n).Scan(value)
  14. }
  15. // Value implements the driver Valuer interface.
  16. func (n DeletedAt) Value() (driver.Value, error) {
  17. if !n.Valid {
  18. return nil, nil
  19. }
  20. return n.Time, nil
  21. }
  22. func (n DeletedAt) MarshalJSON() ([]byte, error) {
  23. if n.Valid {
  24. return json.Marshal(n.Time)
  25. }
  26. return json.Marshal(nil)
  27. }
  28. func (n *DeletedAt) UnmarshalJSON(b []byte) error {
  29. if string(b) == "null" {
  30. n.Valid = false
  31. return nil
  32. }
  33. err := json.Unmarshal(b, &n.Time)
  34. if err == nil {
  35. n.Valid = true
  36. }
  37. return err
  38. }
  39. func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
  40. return []clause.Interface{SoftDeleteQueryClause{Field: f}}
  41. }
  42. type SoftDeleteQueryClause struct {
  43. Field *schema.Field
  44. }
  45. func (sd SoftDeleteQueryClause) Name() string {
  46. return ""
  47. }
  48. func (sd SoftDeleteQueryClause) Build(clause.Builder) {
  49. }
  50. func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
  51. }
  52. func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
  53. if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
  54. if c, ok := stmt.Clauses["WHERE"]; ok {
  55. if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
  56. for _, expr := range where.Exprs {
  57. if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
  58. where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
  59. c.Expression = where
  60. stmt.Clauses["WHERE"] = c
  61. break
  62. }
  63. }
  64. }
  65. }
  66. stmt.AddClause(clause.Where{Exprs: []clause.Expression{
  67. clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
  68. }})
  69. stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
  70. }
  71. }
  72. func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
  73. return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
  74. }
  75. type SoftDeleteUpdateClause struct {
  76. Field *schema.Field
  77. }
  78. func (sd SoftDeleteUpdateClause) Name() string {
  79. return ""
  80. }
  81. func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
  82. }
  83. func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
  84. }
  85. func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
  86. if stmt.SQL.String() == "" {
  87. if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
  88. SoftDeleteQueryClause(sd).ModifyStatement(stmt)
  89. }
  90. }
  91. }
  92. func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
  93. return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
  94. }
  95. type SoftDeleteDeleteClause struct {
  96. Field *schema.Field
  97. }
  98. func (sd SoftDeleteDeleteClause) Name() string {
  99. return ""
  100. }
  101. func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
  102. }
  103. func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
  104. }
  105. func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
  106. if stmt.SQL.String() == "" {
  107. curTime := stmt.DB.NowFunc()
  108. stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
  109. stmt.SetColumn(sd.Field.DBName, curTime, true)
  110. if stmt.Schema != nil {
  111. _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
  112. column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
  113. if len(values) > 0 {
  114. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
  115. }
  116. if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
  117. _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
  118. column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
  119. if len(values) > 0 {
  120. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
  121. }
  122. }
  123. }
  124. if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
  125. stmt.DB.AddError(ErrMissingWhereClause)
  126. } else {
  127. SoftDeleteQueryClause(sd).ModifyStatement(stmt)
  128. }
  129. stmt.AddClauseIfNotExists(clause.Update{})
  130. stmt.Build("UPDATE", "SET", "WHERE")
  131. }
  132. }