message_reflect.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/detrand"
  9. "google.golang.org/protobuf/internal/pragma"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. type reflectMessageInfo struct {
  13. fields map[pref.FieldNumber]*fieldInfo
  14. oneofs map[pref.Name]*oneofInfo
  15. // fieldTypes contains the zero value of an enum or message field.
  16. // For lists, it contains the element type.
  17. // For maps, it contains the entry value type.
  18. fieldTypes map[pref.FieldNumber]interface{}
  19. // denseFields is a subset of fields where:
  20. // 0 < fieldDesc.Number() < len(denseFields)
  21. // It provides faster access to the fieldInfo, but may be incomplete.
  22. denseFields []*fieldInfo
  23. // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
  24. rangeInfos []interface{} // either *fieldInfo or *oneofInfo
  25. getUnknown func(pointer) pref.RawFields
  26. setUnknown func(pointer, pref.RawFields)
  27. extensionMap func(pointer) *extensionMap
  28. nilMessage atomicNilMessage
  29. }
  30. // makeReflectFuncs generates the set of functions to support reflection.
  31. func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
  32. mi.makeKnownFieldsFunc(si)
  33. mi.makeUnknownFieldsFunc(t, si)
  34. mi.makeExtensionFieldsFunc(t, si)
  35. mi.makeFieldTypes(si)
  36. }
  37. // makeKnownFieldsFunc generates functions for operations that can be performed
  38. // on each protobuf message field. It takes in a reflect.Type representing the
  39. // Go struct and matches message fields with struct fields.
  40. //
  41. // This code assumes that the struct is well-formed and panics if there are
  42. // any discrepancies.
  43. func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
  44. mi.fields = map[pref.FieldNumber]*fieldInfo{}
  45. md := mi.Desc
  46. fds := md.Fields()
  47. for i := 0; i < fds.Len(); i++ {
  48. fd := fds.Get(i)
  49. fs := si.fieldsByNumber[fd.Number()]
  50. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  51. if isOneof {
  52. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  53. }
  54. var fi fieldInfo
  55. switch {
  56. case fs.Type == nil:
  57. fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
  58. case isOneof:
  59. fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
  60. case fd.IsMap():
  61. fi = fieldInfoForMap(fd, fs, mi.Exporter)
  62. case fd.IsList():
  63. fi = fieldInfoForList(fd, fs, mi.Exporter)
  64. case fd.IsWeak():
  65. fi = fieldInfoForWeakMessage(fd, si.weakOffset)
  66. case fd.Message() != nil:
  67. fi = fieldInfoForMessage(fd, fs, mi.Exporter)
  68. default:
  69. fi = fieldInfoForScalar(fd, fs, mi.Exporter)
  70. }
  71. mi.fields[fd.Number()] = &fi
  72. }
  73. mi.oneofs = map[pref.Name]*oneofInfo{}
  74. for i := 0; i < md.Oneofs().Len(); i++ {
  75. od := md.Oneofs().Get(i)
  76. mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
  77. }
  78. mi.denseFields = make([]*fieldInfo, fds.Len()*2)
  79. for i := 0; i < fds.Len(); i++ {
  80. if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
  81. mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
  82. }
  83. }
  84. for i := 0; i < fds.Len(); {
  85. fd := fds.Get(i)
  86. if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
  87. mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
  88. i += od.Fields().Len()
  89. } else {
  90. mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
  91. i++
  92. }
  93. }
  94. // Introduce instability to iteration order, but keep it deterministic.
  95. if len(mi.rangeInfos) > 1 && detrand.Bool() {
  96. i := detrand.Intn(len(mi.rangeInfos) - 1)
  97. mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
  98. }
  99. }
  100. func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
  101. switch {
  102. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
  103. // Handle as []byte.
  104. mi.getUnknown = func(p pointer) pref.RawFields {
  105. if p.IsNil() {
  106. return nil
  107. }
  108. return *p.Apply(mi.unknownOffset).Bytes()
  109. }
  110. mi.setUnknown = func(p pointer, b pref.RawFields) {
  111. if p.IsNil() {
  112. panic("invalid SetUnknown on nil Message")
  113. }
  114. *p.Apply(mi.unknownOffset).Bytes() = b
  115. }
  116. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
  117. // Handle as *[]byte.
  118. mi.getUnknown = func(p pointer) pref.RawFields {
  119. if p.IsNil() {
  120. return nil
  121. }
  122. bp := p.Apply(mi.unknownOffset).BytesPtr()
  123. if *bp == nil {
  124. return nil
  125. }
  126. return **bp
  127. }
  128. mi.setUnknown = func(p pointer, b pref.RawFields) {
  129. if p.IsNil() {
  130. panic("invalid SetUnknown on nil Message")
  131. }
  132. bp := p.Apply(mi.unknownOffset).BytesPtr()
  133. if *bp == nil {
  134. *bp = new([]byte)
  135. }
  136. **bp = b
  137. }
  138. default:
  139. mi.getUnknown = func(pointer) pref.RawFields {
  140. return nil
  141. }
  142. mi.setUnknown = func(p pointer, _ pref.RawFields) {
  143. if p.IsNil() {
  144. panic("invalid SetUnknown on nil Message")
  145. }
  146. }
  147. }
  148. }
  149. func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
  150. if si.extensionOffset.IsValid() {
  151. mi.extensionMap = func(p pointer) *extensionMap {
  152. if p.IsNil() {
  153. return (*extensionMap)(nil)
  154. }
  155. v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
  156. return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
  157. }
  158. } else {
  159. mi.extensionMap = func(pointer) *extensionMap {
  160. return (*extensionMap)(nil)
  161. }
  162. }
  163. }
  164. func (mi *MessageInfo) makeFieldTypes(si structInfo) {
  165. md := mi.Desc
  166. fds := md.Fields()
  167. for i := 0; i < fds.Len(); i++ {
  168. var ft reflect.Type
  169. fd := fds.Get(i)
  170. fs := si.fieldsByNumber[fd.Number()]
  171. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  172. if isOneof {
  173. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  174. }
  175. var isMessage bool
  176. switch {
  177. case fs.Type == nil:
  178. continue // never occurs for officially generated message types
  179. case isOneof:
  180. if fd.Enum() != nil || fd.Message() != nil {
  181. ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
  182. }
  183. case fd.IsMap():
  184. if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
  185. ft = fs.Type.Elem()
  186. }
  187. isMessage = fd.MapValue().Message() != nil
  188. case fd.IsList():
  189. if fd.Enum() != nil || fd.Message() != nil {
  190. ft = fs.Type.Elem()
  191. }
  192. isMessage = fd.Message() != nil
  193. case fd.Enum() != nil:
  194. ft = fs.Type
  195. if fd.HasPresence() && ft.Kind() == reflect.Ptr {
  196. ft = ft.Elem()
  197. }
  198. case fd.Message() != nil:
  199. ft = fs.Type
  200. if fd.IsWeak() {
  201. ft = nil
  202. }
  203. isMessage = true
  204. }
  205. if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
  206. ft = reflect.PtrTo(ft) // never occurs for officially generated message types
  207. }
  208. if ft != nil {
  209. if mi.fieldTypes == nil {
  210. mi.fieldTypes = make(map[pref.FieldNumber]interface{})
  211. }
  212. mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
  213. }
  214. }
  215. }
  216. type extensionMap map[int32]ExtensionField
  217. func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
  218. if m != nil {
  219. for _, x := range *m {
  220. xd := x.Type().TypeDescriptor()
  221. v := x.Value()
  222. if xd.IsList() && v.List().Len() == 0 {
  223. continue
  224. }
  225. if !f(xd, v) {
  226. return
  227. }
  228. }
  229. }
  230. }
  231. func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
  232. if m == nil {
  233. return false
  234. }
  235. xd := xt.TypeDescriptor()
  236. x, ok := (*m)[int32(xd.Number())]
  237. if !ok {
  238. return false
  239. }
  240. switch {
  241. case xd.IsList():
  242. return x.Value().List().Len() > 0
  243. case xd.IsMap():
  244. return x.Value().Map().Len() > 0
  245. case xd.Message() != nil:
  246. return x.Value().Message().IsValid()
  247. }
  248. return true
  249. }
  250. func (m *extensionMap) Clear(xt pref.ExtensionType) {
  251. delete(*m, int32(xt.TypeDescriptor().Number()))
  252. }
  253. func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
  254. xd := xt.TypeDescriptor()
  255. if m != nil {
  256. if x, ok := (*m)[int32(xd.Number())]; ok {
  257. return x.Value()
  258. }
  259. }
  260. return xt.Zero()
  261. }
  262. func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
  263. xd := xt.TypeDescriptor()
  264. isValid := true
  265. switch {
  266. case !xt.IsValidValue(v):
  267. isValid = false
  268. case xd.IsList():
  269. isValid = v.List().IsValid()
  270. case xd.IsMap():
  271. isValid = v.Map().IsValid()
  272. case xd.Message() != nil:
  273. isValid = v.Message().IsValid()
  274. }
  275. if !isValid {
  276. panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
  277. }
  278. if *m == nil {
  279. *m = make(map[int32]ExtensionField)
  280. }
  281. var x ExtensionField
  282. x.Set(xt, v)
  283. (*m)[int32(xd.Number())] = x
  284. }
  285. func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
  286. xd := xt.TypeDescriptor()
  287. if xd.Kind() != pref.MessageKind && xd.Kind() != pref.GroupKind && !xd.IsList() && !xd.IsMap() {
  288. panic("invalid Mutable on field with non-composite type")
  289. }
  290. if x, ok := (*m)[int32(xd.Number())]; ok {
  291. return x.Value()
  292. }
  293. v := xt.New()
  294. m.Set(xt, v)
  295. return v
  296. }
  297. // MessageState is a data structure that is nested as the first field in a
  298. // concrete message. It provides a way to implement the ProtoReflect method
  299. // in an allocation-free way without needing to have a shadow Go type generated
  300. // for every message type. This technique only works using unsafe.
  301. //
  302. //
  303. // Example generated code:
  304. //
  305. // type M struct {
  306. // state protoimpl.MessageState
  307. //
  308. // Field1 int32
  309. // Field2 string
  310. // Field3 *BarMessage
  311. // ...
  312. // }
  313. //
  314. // func (m *M) ProtoReflect() protoreflect.Message {
  315. // mi := &file_fizz_buzz_proto_msgInfos[5]
  316. // if protoimpl.UnsafeEnabled && m != nil {
  317. // ms := protoimpl.X.MessageStateOf(Pointer(m))
  318. // if ms.LoadMessageInfo() == nil {
  319. // ms.StoreMessageInfo(mi)
  320. // }
  321. // return ms
  322. // }
  323. // return mi.MessageOf(m)
  324. // }
  325. //
  326. // The MessageState type holds a *MessageInfo, which must be atomically set to
  327. // the message info associated with a given message instance.
  328. // By unsafely converting a *M into a *MessageState, the MessageState object
  329. // has access to all the information needed to implement protobuf reflection.
  330. // It has access to the message info as its first field, and a pointer to the
  331. // MessageState is identical to a pointer to the concrete message value.
  332. //
  333. //
  334. // Requirements:
  335. // • The type M must implement protoreflect.ProtoMessage.
  336. // • The address of m must not be nil.
  337. // • The address of m and the address of m.state must be equal,
  338. // even though they are different Go types.
  339. type MessageState struct {
  340. pragma.NoUnkeyedLiterals
  341. pragma.DoNotCompare
  342. pragma.DoNotCopy
  343. atomicMessageInfo *MessageInfo
  344. }
  345. type messageState MessageState
  346. var (
  347. _ pref.Message = (*messageState)(nil)
  348. _ unwrapper = (*messageState)(nil)
  349. )
  350. // messageDataType is a tuple of a pointer to the message data and
  351. // a pointer to the message type. It is a generalized way of providing a
  352. // reflective view over a message instance. The disadvantage of this approach
  353. // is the need to allocate this tuple of 16B.
  354. type messageDataType struct {
  355. p pointer
  356. mi *MessageInfo
  357. }
  358. type (
  359. messageReflectWrapper messageDataType
  360. messageIfaceWrapper messageDataType
  361. )
  362. var (
  363. _ pref.Message = (*messageReflectWrapper)(nil)
  364. _ unwrapper = (*messageReflectWrapper)(nil)
  365. _ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
  366. _ unwrapper = (*messageIfaceWrapper)(nil)
  367. )
  368. // MessageOf returns a reflective view over a message. The input must be a
  369. // pointer to a named Go struct. If the provided type has a ProtoReflect method,
  370. // it must be implemented by calling this method.
  371. func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
  372. if reflect.TypeOf(m) != mi.GoReflectType {
  373. panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
  374. }
  375. p := pointerOfIface(m)
  376. if p.IsNil() {
  377. return mi.nilMessage.Init(mi)
  378. }
  379. return &messageReflectWrapper{p, mi}
  380. }
  381. func (m *messageReflectWrapper) pointer() pointer { return m.p }
  382. func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
  383. // Reset implements the v1 proto.Message.Reset method.
  384. func (m *messageIfaceWrapper) Reset() {
  385. if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
  386. mr.Reset()
  387. return
  388. }
  389. rv := reflect.ValueOf(m.protoUnwrap())
  390. if rv.Kind() == reflect.Ptr && !rv.IsNil() {
  391. rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
  392. }
  393. }
  394. func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
  395. return (*messageReflectWrapper)(m)
  396. }
  397. func (m *messageIfaceWrapper) protoUnwrap() interface{} {
  398. return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
  399. }
  400. // checkField verifies that the provided field descriptor is valid.
  401. // Exactly one of the returned values is populated.
  402. func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
  403. var fi *fieldInfo
  404. if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
  405. fi = mi.denseFields[n]
  406. } else {
  407. fi = mi.fields[n]
  408. }
  409. if fi != nil {
  410. if fi.fieldDesc != fd {
  411. if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
  412. panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
  413. }
  414. panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
  415. }
  416. return fi, nil
  417. }
  418. if fd.IsExtension() {
  419. if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
  420. // TODO: Should this be exact containing message descriptor match?
  421. panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
  422. }
  423. if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
  424. panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
  425. }
  426. xtd, ok := fd.(pref.ExtensionTypeDescriptor)
  427. if !ok {
  428. panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
  429. }
  430. return nil, xtd.Type()
  431. }
  432. panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
  433. }