equal.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "math"
  8. "reflect"
  9. "google.golang.org/protobuf/encoding/protowire"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. // Equal reports whether two messages are equal.
  13. // If two messages marshal to the same bytes under deterministic serialization,
  14. // then Equal is guaranteed to report true.
  15. //
  16. // Two messages are equal if they belong to the same message descriptor,
  17. // have the same set of populated known and extension field values,
  18. // and the same set of unknown fields values. If either of the top-level
  19. // messages are invalid, then Equal reports true only if both are invalid.
  20. //
  21. // Scalar values are compared with the equivalent of the == operator in Go,
  22. // except bytes values which are compared using bytes.Equal and
  23. // floating point values which specially treat NaNs as equal.
  24. // Message values are compared by recursively calling Equal.
  25. // Lists are equal if each element value is also equal.
  26. // Maps are equal if they have the same set of keys, where the pair of values
  27. // for each key is also equal.
  28. func Equal(x, y Message) bool {
  29. if x == nil || y == nil {
  30. return x == nil && y == nil
  31. }
  32. mx := x.ProtoReflect()
  33. my := y.ProtoReflect()
  34. if mx.IsValid() != my.IsValid() {
  35. return false
  36. }
  37. return equalMessage(mx, my)
  38. }
  39. // equalMessage compares two messages.
  40. func equalMessage(mx, my pref.Message) bool {
  41. if mx.Descriptor() != my.Descriptor() {
  42. return false
  43. }
  44. nx := 0
  45. equal := true
  46. mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  47. nx++
  48. vy := my.Get(fd)
  49. equal = my.Has(fd) && equalField(fd, vx, vy)
  50. return equal
  51. })
  52. if !equal {
  53. return false
  54. }
  55. ny := 0
  56. my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  57. ny++
  58. return true
  59. })
  60. if nx != ny {
  61. return false
  62. }
  63. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  64. }
  65. // equalField compares two fields.
  66. func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
  67. switch {
  68. case fd.IsList():
  69. return equalList(fd, x.List(), y.List())
  70. case fd.IsMap():
  71. return equalMap(fd, x.Map(), y.Map())
  72. default:
  73. return equalValue(fd, x, y)
  74. }
  75. }
  76. // equalMap compares two maps.
  77. func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
  78. if x.Len() != y.Len() {
  79. return false
  80. }
  81. equal := true
  82. x.Range(func(k pref.MapKey, vx pref.Value) bool {
  83. vy := y.Get(k)
  84. equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
  85. return equal
  86. })
  87. return equal
  88. }
  89. // equalList compares two lists.
  90. func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
  91. if x.Len() != y.Len() {
  92. return false
  93. }
  94. for i := x.Len() - 1; i >= 0; i-- {
  95. if !equalValue(fd, x.Get(i), y.Get(i)) {
  96. return false
  97. }
  98. }
  99. return true
  100. }
  101. // equalValue compares two singular values.
  102. func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
  103. switch fd.Kind() {
  104. case pref.BoolKind:
  105. return x.Bool() == y.Bool()
  106. case pref.EnumKind:
  107. return x.Enum() == y.Enum()
  108. case pref.Int32Kind, pref.Sint32Kind,
  109. pref.Int64Kind, pref.Sint64Kind,
  110. pref.Sfixed32Kind, pref.Sfixed64Kind:
  111. return x.Int() == y.Int()
  112. case pref.Uint32Kind, pref.Uint64Kind,
  113. pref.Fixed32Kind, pref.Fixed64Kind:
  114. return x.Uint() == y.Uint()
  115. case pref.FloatKind, pref.DoubleKind:
  116. fx := x.Float()
  117. fy := y.Float()
  118. if math.IsNaN(fx) || math.IsNaN(fy) {
  119. return math.IsNaN(fx) && math.IsNaN(fy)
  120. }
  121. return fx == fy
  122. case pref.StringKind:
  123. return x.String() == y.String()
  124. case pref.BytesKind:
  125. return bytes.Equal(x.Bytes(), y.Bytes())
  126. case pref.MessageKind, pref.GroupKind:
  127. return equalMessage(x.Message(), y.Message())
  128. default:
  129. return x.Interface() == y.Interface()
  130. }
  131. }
  132. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  133. // of each individual field number.
  134. func equalUnknown(x, y pref.RawFields) bool {
  135. if len(x) != len(y) {
  136. return false
  137. }
  138. if bytes.Equal([]byte(x), []byte(y)) {
  139. return true
  140. }
  141. mx := make(map[pref.FieldNumber]pref.RawFields)
  142. my := make(map[pref.FieldNumber]pref.RawFields)
  143. for len(x) > 0 {
  144. fnum, _, n := protowire.ConsumeField(x)
  145. mx[fnum] = append(mx[fnum], x[:n]...)
  146. x = x[n:]
  147. }
  148. for len(y) > 0 {
  149. fnum, _, n := protowire.ConsumeField(y)
  150. my[fnum] = append(my[fnum], y[:n]...)
  151. y = y[n:]
  152. }
  153. return reflect.DeepEqual(mx, my)
  154. }