decode.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "math/bits"
  7. "google.golang.org/protobuf/encoding/protowire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. preg "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. piface "google.golang.org/protobuf/runtime/protoiface"
  15. )
  16. var errDecode = errors.New("cannot parse invalid wire-format data")
  17. type unmarshalOptions struct {
  18. flags protoiface.UnmarshalInputFlags
  19. resolver interface {
  20. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  21. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  22. }
  23. }
  24. func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  25. return proto.UnmarshalOptions{
  26. Merge: true,
  27. AllowPartial: true,
  28. DiscardUnknown: o.DiscardUnknown(),
  29. Resolver: o.resolver,
  30. }
  31. }
  32. func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
  33. func (o unmarshalOptions) IsDefault() bool {
  34. return o.flags == 0 && o.resolver == preg.GlobalTypes
  35. }
  36. var lazyUnmarshalOptions = unmarshalOptions{
  37. resolver: preg.GlobalTypes,
  38. }
  39. type unmarshalOutput struct {
  40. n int // number of bytes consumed
  41. initialized bool
  42. }
  43. // unmarshal is protoreflect.Methods.Unmarshal.
  44. func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
  45. var p pointer
  46. if ms, ok := in.Message.(*messageState); ok {
  47. p = ms.pointer()
  48. } else {
  49. p = in.Message.(*messageReflectWrapper).pointer()
  50. }
  51. out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  52. flags: in.Flags,
  53. resolver: in.Resolver,
  54. })
  55. var flags piface.UnmarshalOutputFlags
  56. if out.initialized {
  57. flags |= piface.UnmarshalInitialized
  58. }
  59. return piface.UnmarshalOutput{
  60. Flags: flags,
  61. }, err
  62. }
  63. // errUnknown is returned during unmarshaling to indicate a parse error that
  64. // should result in a field being placed in the unknown fields section (for example,
  65. // when the wire type doesn't match) as opposed to the entire unmarshal operation
  66. // failing (for example, when a field extends past the available input).
  67. //
  68. // This is a sentinel error which should never be visible to the user.
  69. var errUnknown = errors.New("unknown")
  70. func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  71. mi.init()
  72. if flags.ProtoLegacy && mi.isMessageSet {
  73. return unmarshalMessageSet(mi, b, p, opts)
  74. }
  75. initialized := true
  76. var requiredMask uint64
  77. var exts *map[int32]ExtensionField
  78. start := len(b)
  79. for len(b) > 0 {
  80. // Parse the tag (field number and wire type).
  81. var tag uint64
  82. if b[0] < 0x80 {
  83. tag = uint64(b[0])
  84. b = b[1:]
  85. } else if len(b) >= 2 && b[1] < 128 {
  86. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  87. b = b[2:]
  88. } else {
  89. var n int
  90. tag, n = protowire.ConsumeVarint(b)
  91. if n < 0 {
  92. return out, errDecode
  93. }
  94. b = b[n:]
  95. }
  96. var num protowire.Number
  97. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  98. return out, errDecode
  99. } else {
  100. num = protowire.Number(n)
  101. }
  102. wtyp := protowire.Type(tag & 7)
  103. if wtyp == protowire.EndGroupType {
  104. if num != groupTag {
  105. return out, errDecode
  106. }
  107. groupTag = 0
  108. break
  109. }
  110. var f *coderFieldInfo
  111. if int(num) < len(mi.denseCoderFields) {
  112. f = mi.denseCoderFields[num]
  113. } else {
  114. f = mi.coderFields[num]
  115. }
  116. var n int
  117. err := errUnknown
  118. switch {
  119. case f != nil:
  120. if f.funcs.unmarshal == nil {
  121. break
  122. }
  123. var o unmarshalOutput
  124. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  125. n = o.n
  126. if err != nil {
  127. break
  128. }
  129. requiredMask |= f.validation.requiredBit
  130. if f.funcs.isInit != nil && !o.initialized {
  131. initialized = false
  132. }
  133. default:
  134. // Possible extension.
  135. if exts == nil && mi.extensionOffset.IsValid() {
  136. exts = p.Apply(mi.extensionOffset).Extensions()
  137. if *exts == nil {
  138. *exts = make(map[int32]ExtensionField)
  139. }
  140. }
  141. if exts == nil {
  142. break
  143. }
  144. var o unmarshalOutput
  145. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  146. if err != nil {
  147. break
  148. }
  149. n = o.n
  150. if !o.initialized {
  151. initialized = false
  152. }
  153. }
  154. if err != nil {
  155. if err != errUnknown {
  156. return out, err
  157. }
  158. n = protowire.ConsumeFieldValue(num, wtyp, b)
  159. if n < 0 {
  160. return out, errDecode
  161. }
  162. if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  163. u := mi.mutableUnknownBytes(p)
  164. *u = protowire.AppendTag(*u, num, wtyp)
  165. *u = append(*u, b[:n]...)
  166. }
  167. }
  168. b = b[n:]
  169. }
  170. if groupTag != 0 {
  171. return out, errDecode
  172. }
  173. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  174. initialized = false
  175. }
  176. if initialized {
  177. out.initialized = true
  178. }
  179. out.n = start - len(b)
  180. return out, nil
  181. }
  182. func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
  183. x := exts[int32(num)]
  184. xt := x.Type()
  185. if xt == nil {
  186. var err error
  187. xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
  188. if err != nil {
  189. if err == preg.NotFound {
  190. return out, errUnknown
  191. }
  192. return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
  193. }
  194. }
  195. xi := getExtensionFieldInfo(xt)
  196. if xi.funcs.unmarshal == nil {
  197. return out, errUnknown
  198. }
  199. if flags.LazyUnmarshalExtensions {
  200. if opts.IsDefault() && x.canLazy(xt) {
  201. out, valid := skipExtension(b, xi, num, wtyp, opts)
  202. switch valid {
  203. case ValidationValid:
  204. if out.initialized {
  205. x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
  206. exts[int32(num)] = x
  207. return out, nil
  208. }
  209. case ValidationInvalid:
  210. return out, errDecode
  211. case ValidationUnknown:
  212. }
  213. }
  214. }
  215. ival := x.Value()
  216. if !ival.IsValid() && xi.unmarshalNeedsValue {
  217. // Create a new message, list, or map value to fill in.
  218. // For enums, create a prototype value to let the unmarshal func know the
  219. // concrete type.
  220. ival = xt.New()
  221. }
  222. v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
  223. if err != nil {
  224. return out, err
  225. }
  226. if xi.funcs.isInit == nil {
  227. out.initialized = true
  228. }
  229. x.Set(xt, v)
  230. exts[int32(num)] = x
  231. return out, nil
  232. }
  233. func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  234. if xi.validation.mi == nil {
  235. return out, ValidationUnknown
  236. }
  237. xi.validation.mi.init()
  238. switch xi.validation.typ {
  239. case validationTypeMessage:
  240. if wtyp != protowire.BytesType {
  241. return out, ValidationUnknown
  242. }
  243. v, n := protowire.ConsumeBytes(b)
  244. if n < 0 {
  245. return out, ValidationUnknown
  246. }
  247. out, st := xi.validation.mi.validate(v, 0, opts)
  248. out.n = n
  249. return out, st
  250. case validationTypeGroup:
  251. if wtyp != protowire.StartGroupType {
  252. return out, ValidationUnknown
  253. }
  254. out, st := xi.validation.mi.validate(b, num, opts)
  255. return out, st
  256. default:
  257. return out, ValidationUnknown
  258. }
  259. }