decode.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/encoding/protowire"
  7. "google.golang.org/protobuf/internal/encoding/messageset"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/internal/pragma"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. // UnmarshalOptions configures the unmarshaler.
  16. //
  17. // Example usage:
  18. // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
  19. type UnmarshalOptions struct {
  20. pragma.NoUnkeyedLiterals
  21. // Merge merges the input into the destination message.
  22. // The default behavior is to always reset the message before unmarshaling,
  23. // unless Merge is specified.
  24. Merge bool
  25. // AllowPartial accepts input for messages that will result in missing
  26. // required fields. If AllowPartial is false (the default), Unmarshal will
  27. // return an error if there are any missing required fields.
  28. AllowPartial bool
  29. // If DiscardUnknown is set, unknown fields are ignored.
  30. DiscardUnknown bool
  31. // Resolver is used for looking up types when unmarshaling extension fields.
  32. // If nil, this defaults to using protoregistry.GlobalTypes.
  33. Resolver interface {
  34. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  35. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  36. }
  37. }
  38. // Unmarshal parses the wire-format message in b and places the result in m.
  39. func Unmarshal(b []byte, m Message) error {
  40. _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
  41. return err
  42. }
  43. // Unmarshal parses the wire-format message in b and places the result in m.
  44. func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
  45. _, err := o.unmarshal(b, m.ProtoReflect())
  46. return err
  47. }
  48. // UnmarshalState parses a wire-format message and places the result in m.
  49. //
  50. // This method permits fine-grained control over the unmarshaler.
  51. // Most users should use Unmarshal instead.
  52. func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  53. return o.unmarshal(in.Buf, in.Message)
  54. }
  55. // unmarshal is a centralized function that all unmarshal operations go through.
  56. // For profiling purposes, avoid changing the name of this function or
  57. // introducing other code paths for unmarshal that do not go through this.
  58. func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
  59. if o.Resolver == nil {
  60. o.Resolver = protoregistry.GlobalTypes
  61. }
  62. if !o.Merge {
  63. Reset(m.Interface())
  64. }
  65. allowPartial := o.AllowPartial
  66. o.Merge = true
  67. o.AllowPartial = true
  68. methods := protoMethods(m)
  69. if methods != nil && methods.Unmarshal != nil &&
  70. !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
  71. in := protoiface.UnmarshalInput{
  72. Message: m,
  73. Buf: b,
  74. Resolver: o.Resolver,
  75. }
  76. if o.DiscardUnknown {
  77. in.Flags |= protoiface.UnmarshalDiscardUnknown
  78. }
  79. out, err = methods.Unmarshal(in)
  80. } else {
  81. err = o.unmarshalMessageSlow(b, m)
  82. }
  83. if err != nil {
  84. return out, err
  85. }
  86. if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
  87. return out, nil
  88. }
  89. return out, checkInitialized(m)
  90. }
  91. func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
  92. _, err := o.unmarshal(b, m)
  93. return err
  94. }
  95. func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
  96. md := m.Descriptor()
  97. if messageset.IsMessageSet(md) {
  98. return o.unmarshalMessageSet(b, m)
  99. }
  100. fields := md.Fields()
  101. for len(b) > 0 {
  102. // Parse the tag (field number and wire type).
  103. num, wtyp, tagLen := protowire.ConsumeTag(b)
  104. if tagLen < 0 {
  105. return protowire.ParseError(tagLen)
  106. }
  107. if num > protowire.MaxValidNumber {
  108. return errors.New("invalid field number")
  109. }
  110. // Find the field descriptor for this field number.
  111. fd := fields.ByNumber(num)
  112. if fd == nil && md.ExtensionRanges().Has(num) {
  113. extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
  114. if err != nil && err != protoregistry.NotFound {
  115. return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
  116. }
  117. if extType != nil {
  118. fd = extType.TypeDescriptor()
  119. }
  120. }
  121. var err error
  122. if fd == nil {
  123. err = errUnknown
  124. } else if flags.ProtoLegacy {
  125. if fd.IsWeak() && fd.Message().IsPlaceholder() {
  126. err = errUnknown // weak referent is not linked in
  127. }
  128. }
  129. // Parse the field value.
  130. var valLen int
  131. switch {
  132. case err != nil:
  133. case fd.IsList():
  134. valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
  135. case fd.IsMap():
  136. valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
  137. default:
  138. valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
  139. }
  140. if err != nil {
  141. if err != errUnknown {
  142. return err
  143. }
  144. valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
  145. if valLen < 0 {
  146. return protowire.ParseError(valLen)
  147. }
  148. if !o.DiscardUnknown {
  149. m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
  150. }
  151. }
  152. b = b[tagLen+valLen:]
  153. }
  154. return nil
  155. }
  156. func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
  157. v, n, err := o.unmarshalScalar(b, wtyp, fd)
  158. if err != nil {
  159. return 0, err
  160. }
  161. switch fd.Kind() {
  162. case protoreflect.GroupKind, protoreflect.MessageKind:
  163. m2 := m.Mutable(fd).Message()
  164. if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
  165. return n, err
  166. }
  167. default:
  168. // Non-message scalars replace the previous value.
  169. m.Set(fd, v)
  170. }
  171. return n, nil
  172. }
  173. func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
  174. if wtyp != protowire.BytesType {
  175. return 0, errUnknown
  176. }
  177. b, n = protowire.ConsumeBytes(b)
  178. if n < 0 {
  179. return 0, protowire.ParseError(n)
  180. }
  181. var (
  182. keyField = fd.MapKey()
  183. valField = fd.MapValue()
  184. key protoreflect.Value
  185. val protoreflect.Value
  186. haveKey bool
  187. haveVal bool
  188. )
  189. switch valField.Kind() {
  190. case protoreflect.GroupKind, protoreflect.MessageKind:
  191. val = mapv.NewValue()
  192. }
  193. // Map entries are represented as a two-element message with fields
  194. // containing the key and value.
  195. for len(b) > 0 {
  196. num, wtyp, n := protowire.ConsumeTag(b)
  197. if n < 0 {
  198. return 0, protowire.ParseError(n)
  199. }
  200. if num > protowire.MaxValidNumber {
  201. return 0, errors.New("invalid field number")
  202. }
  203. b = b[n:]
  204. err = errUnknown
  205. switch num {
  206. case 1:
  207. key, n, err = o.unmarshalScalar(b, wtyp, keyField)
  208. if err != nil {
  209. break
  210. }
  211. haveKey = true
  212. case 2:
  213. var v protoreflect.Value
  214. v, n, err = o.unmarshalScalar(b, wtyp, valField)
  215. if err != nil {
  216. break
  217. }
  218. switch valField.Kind() {
  219. case protoreflect.GroupKind, protoreflect.MessageKind:
  220. if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
  221. return 0, err
  222. }
  223. default:
  224. val = v
  225. }
  226. haveVal = true
  227. }
  228. if err == errUnknown {
  229. n = protowire.ConsumeFieldValue(num, wtyp, b)
  230. if n < 0 {
  231. return 0, protowire.ParseError(n)
  232. }
  233. } else if err != nil {
  234. return 0, err
  235. }
  236. b = b[n:]
  237. }
  238. // Every map entry should have entries for key and value, but this is not strictly required.
  239. if !haveKey {
  240. key = keyField.Default()
  241. }
  242. if !haveVal {
  243. switch valField.Kind() {
  244. case protoreflect.GroupKind, protoreflect.MessageKind:
  245. default:
  246. val = valField.Default()
  247. }
  248. }
  249. mapv.Set(key.MapKey(), val)
  250. return n, nil
  251. }
  252. // errUnknown is used internally to indicate fields which should be added
  253. // to the unknown field set of a message. It is never returned from an exported
  254. // function.
  255. var errUnknown = errors.New("BUG: internal error (unknown)")