messageset.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package messageset encodes and decodes the obsolete MessageSet wire format.
  5. package messageset
  6. import (
  7. "math"
  8. "google.golang.org/protobuf/encoding/protowire"
  9. "google.golang.org/protobuf/internal/errors"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. preg "google.golang.org/protobuf/reflect/protoregistry"
  12. )
  13. // The MessageSet wire format is equivalent to a message defiend as follows,
  14. // where each Item defines an extension field with a field number of 'type_id'
  15. // and content of 'message'. MessageSet extensions must be non-repeated message
  16. // fields.
  17. //
  18. // message MessageSet {
  19. // repeated group Item = 1 {
  20. // required int32 type_id = 2;
  21. // required string message = 3;
  22. // }
  23. // }
  24. const (
  25. FieldItem = protowire.Number(1)
  26. FieldTypeID = protowire.Number(2)
  27. FieldMessage = protowire.Number(3)
  28. )
  29. // ExtensionName is the field name for extensions of MessageSet.
  30. //
  31. // A valid MessageSet extension must be of the form:
  32. // message MyMessage {
  33. // extend proto2.bridge.MessageSet {
  34. // optional MyMessage message_set_extension = 1234;
  35. // }
  36. // ...
  37. // }
  38. const ExtensionName = "message_set_extension"
  39. // IsMessageSet returns whether the message uses the MessageSet wire format.
  40. func IsMessageSet(md pref.MessageDescriptor) bool {
  41. xmd, ok := md.(interface{ IsMessageSet() bool })
  42. return ok && xmd.IsMessageSet()
  43. }
  44. // IsMessageSetExtension reports this field extends a MessageSet.
  45. func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
  46. if fd.Name() != ExtensionName {
  47. return false
  48. }
  49. if fd.FullName().Parent() != fd.Message().FullName() {
  50. return false
  51. }
  52. return IsMessageSet(fd.ContainingMessage())
  53. }
  54. // FindMessageSetExtension locates a MessageSet extension field by name.
  55. // In text and JSON formats, the extension name used is the message itself.
  56. // The extension field name is derived by appending ExtensionName.
  57. func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
  58. name := s.Append(ExtensionName)
  59. xt, err := r.FindExtensionByName(name)
  60. if err != nil {
  61. if err == preg.NotFound {
  62. return nil, err
  63. }
  64. return nil, errors.Wrap(err, "%q", name)
  65. }
  66. if !IsMessageSetExtension(xt.TypeDescriptor()) {
  67. return nil, preg.NotFound
  68. }
  69. return xt, nil
  70. }
  71. // SizeField returns the size of a MessageSet item field containing an extension
  72. // with the given field number, not counting the contents of the message subfield.
  73. func SizeField(num protowire.Number) int {
  74. return 2*protowire.SizeTag(FieldItem) + protowire.SizeTag(FieldTypeID) + protowire.SizeVarint(uint64(num))
  75. }
  76. // Unmarshal parses a MessageSet.
  77. //
  78. // It calls fn with the type ID and value of each item in the MessageSet.
  79. // Unknown fields are discarded.
  80. //
  81. // If wantLen is true, the item values include the varint length prefix.
  82. // This is ugly, but simplifies the fast-path decoder in internal/impl.
  83. func Unmarshal(b []byte, wantLen bool, fn func(typeID protowire.Number, value []byte) error) error {
  84. for len(b) > 0 {
  85. num, wtyp, n := protowire.ConsumeTag(b)
  86. if n < 0 {
  87. return protowire.ParseError(n)
  88. }
  89. b = b[n:]
  90. if num != FieldItem || wtyp != protowire.StartGroupType {
  91. n := protowire.ConsumeFieldValue(num, wtyp, b)
  92. if n < 0 {
  93. return protowire.ParseError(n)
  94. }
  95. b = b[n:]
  96. continue
  97. }
  98. typeID, value, n, err := ConsumeFieldValue(b, wantLen)
  99. if err != nil {
  100. return err
  101. }
  102. b = b[n:]
  103. if typeID == 0 {
  104. continue
  105. }
  106. if err := fn(typeID, value); err != nil {
  107. return err
  108. }
  109. }
  110. return nil
  111. }
  112. // ConsumeFieldValue parses b as a MessageSet item field value until and including
  113. // the trailing end group marker. It assumes the start group tag has already been parsed.
  114. // It returns the contents of the type_id and message subfields and the total
  115. // item length.
  116. //
  117. // If wantLen is true, the returned message value includes the length prefix.
  118. func ConsumeFieldValue(b []byte, wantLen bool) (typeid protowire.Number, message []byte, n int, err error) {
  119. ilen := len(b)
  120. for {
  121. num, wtyp, n := protowire.ConsumeTag(b)
  122. if n < 0 {
  123. return 0, nil, 0, protowire.ParseError(n)
  124. }
  125. b = b[n:]
  126. switch {
  127. case num == FieldItem && wtyp == protowire.EndGroupType:
  128. if wantLen && len(message) == 0 {
  129. // The message field was missing, which should never happen.
  130. // Be prepared for this case anyway.
  131. message = protowire.AppendVarint(message, 0)
  132. }
  133. return typeid, message, ilen - len(b), nil
  134. case num == FieldTypeID && wtyp == protowire.VarintType:
  135. v, n := protowire.ConsumeVarint(b)
  136. if n < 0 {
  137. return 0, nil, 0, protowire.ParseError(n)
  138. }
  139. b = b[n:]
  140. if v < 1 || v > math.MaxInt32 {
  141. return 0, nil, 0, errors.New("invalid type_id in message set")
  142. }
  143. typeid = protowire.Number(v)
  144. case num == FieldMessage && wtyp == protowire.BytesType:
  145. m, n := protowire.ConsumeBytes(b)
  146. if n < 0 {
  147. return 0, nil, 0, protowire.ParseError(n)
  148. }
  149. if message == nil {
  150. if wantLen {
  151. message = b[:n:n]
  152. } else {
  153. message = m[:len(m):len(m)]
  154. }
  155. } else {
  156. // This case should never happen in practice, but handle it for
  157. // correctness: The MessageSet item contains multiple message
  158. // fields, which need to be merged.
  159. //
  160. // In the case where we're returning the length, this becomes
  161. // quite inefficient since we need to strip the length off
  162. // the existing data and reconstruct it with the combined length.
  163. if wantLen {
  164. _, nn := protowire.ConsumeVarint(message)
  165. m0 := message[nn:]
  166. message = nil
  167. message = protowire.AppendVarint(message, uint64(len(m0)+len(m)))
  168. message = append(message, m0...)
  169. message = append(message, m...)
  170. } else {
  171. message = append(message, m...)
  172. }
  173. }
  174. b = b[n:]
  175. default:
  176. // We have no place to put it, so we just ignore unknown fields.
  177. n := protowire.ConsumeFieldValue(num, wtyp, b)
  178. if n < 0 {
  179. return 0, nil, 0, protowire.ParseError(n)
  180. }
  181. b = b[n:]
  182. }
  183. }
  184. }
  185. // AppendFieldStart appends the start of a MessageSet item field containing
  186. // an extension with the given number. The caller must add the message
  187. // subfield (including the tag).
  188. func AppendFieldStart(b []byte, num protowire.Number) []byte {
  189. b = protowire.AppendTag(b, FieldItem, protowire.StartGroupType)
  190. b = protowire.AppendTag(b, FieldTypeID, protowire.VarintType)
  191. b = protowire.AppendVarint(b, uint64(num))
  192. return b
  193. }
  194. // AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
  195. func AppendFieldEnd(b []byte) []byte {
  196. return protowire.AppendTag(b, FieldItem, protowire.EndGroupType)
  197. }
  198. // SizeUnknown returns the size of an unknown fields section in MessageSet format.
  199. //
  200. // See AppendUnknown.
  201. func SizeUnknown(unknown []byte) (size int) {
  202. for len(unknown) > 0 {
  203. num, typ, n := protowire.ConsumeTag(unknown)
  204. if n < 0 || typ != protowire.BytesType {
  205. return 0
  206. }
  207. unknown = unknown[n:]
  208. _, n = protowire.ConsumeBytes(unknown)
  209. if n < 0 {
  210. return 0
  211. }
  212. unknown = unknown[n:]
  213. size += SizeField(num) + protowire.SizeTag(FieldMessage) + n
  214. }
  215. return size
  216. }
  217. // AppendUnknown appends unknown fields to b in MessageSet format.
  218. //
  219. // For historic reasons, unresolved items in a MessageSet are stored in a
  220. // message's unknown fields section in non-MessageSet format. That is, an
  221. // unknown item with typeID T and value V appears in the unknown fields as
  222. // a field with number T and value V.
  223. //
  224. // This function converts the unknown fields back into MessageSet form.
  225. func AppendUnknown(b, unknown []byte) ([]byte, error) {
  226. for len(unknown) > 0 {
  227. num, typ, n := protowire.ConsumeTag(unknown)
  228. if n < 0 || typ != protowire.BytesType {
  229. return nil, errors.New("invalid data in message set unknown fields")
  230. }
  231. unknown = unknown[n:]
  232. _, n = protowire.ConsumeBytes(unknown)
  233. if n < 0 {
  234. return nil, errors.New("invalid data in message set unknown fields")
  235. }
  236. b = AppendFieldStart(b, num)
  237. b = protowire.AppendTag(b, FieldMessage, protowire.BytesType)
  238. b = append(b, unknown[:n]...)
  239. b = AppendFieldEnd(b)
  240. unknown = unknown[n:]
  241. }
  242. return b, nil
  243. }