validate.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "math"
  8. "math/bits"
  9. "reflect"
  10. "unicode/utf8"
  11. "google.golang.org/protobuf/encoding/protowire"
  12. "google.golang.org/protobuf/internal/encoding/messageset"
  13. "google.golang.org/protobuf/internal/flags"
  14. "google.golang.org/protobuf/internal/strs"
  15. pref "google.golang.org/protobuf/reflect/protoreflect"
  16. preg "google.golang.org/protobuf/reflect/protoregistry"
  17. piface "google.golang.org/protobuf/runtime/protoiface"
  18. )
  19. // ValidationStatus is the result of validating the wire-format encoding of a message.
  20. type ValidationStatus int
  21. const (
  22. // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
  23. // The validator was unable to render a judgement.
  24. //
  25. // The only causes of this status are an aberrant message type appearing somewhere
  26. // in the message or a failure in the extension resolver.
  27. ValidationUnknown ValidationStatus = iota + 1
  28. // ValidationInvalid indicates that unmarshaling the message will fail.
  29. ValidationInvalid
  30. // ValidationValid indicates that unmarshaling the message will succeed.
  31. ValidationValid
  32. )
  33. func (v ValidationStatus) String() string {
  34. switch v {
  35. case ValidationUnknown:
  36. return "ValidationUnknown"
  37. case ValidationInvalid:
  38. return "ValidationInvalid"
  39. case ValidationValid:
  40. return "ValidationValid"
  41. default:
  42. return fmt.Sprintf("ValidationStatus(%d)", int(v))
  43. }
  44. }
  45. // Validate determines whether the contents of the buffer are a valid wire encoding
  46. // of the message type.
  47. //
  48. // This function is exposed for testing.
  49. func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
  50. mi, ok := mt.(*MessageInfo)
  51. if !ok {
  52. return out, ValidationUnknown
  53. }
  54. if in.Resolver == nil {
  55. in.Resolver = preg.GlobalTypes
  56. }
  57. o, st := mi.validate(in.Buf, 0, unmarshalOptions{
  58. flags: in.Flags,
  59. resolver: in.Resolver,
  60. })
  61. if o.initialized {
  62. out.Flags |= piface.UnmarshalInitialized
  63. }
  64. return out, st
  65. }
  66. type validationInfo struct {
  67. mi *MessageInfo
  68. typ validationType
  69. keyType, valType validationType
  70. // For non-required fields, requiredBit is 0.
  71. //
  72. // For required fields, requiredBit's nth bit is set, where n is a
  73. // unique index in the range [0, MessageInfo.numRequiredFields).
  74. //
  75. // If there are more than 64 required fields, requiredBit is 0.
  76. requiredBit uint64
  77. }
  78. type validationType uint8
  79. const (
  80. validationTypeOther validationType = iota
  81. validationTypeMessage
  82. validationTypeGroup
  83. validationTypeMap
  84. validationTypeRepeatedVarint
  85. validationTypeRepeatedFixed32
  86. validationTypeRepeatedFixed64
  87. validationTypeVarint
  88. validationTypeFixed32
  89. validationTypeFixed64
  90. validationTypeBytes
  91. validationTypeUTF8String
  92. validationTypeMessageSetItem
  93. )
  94. func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
  95. var vi validationInfo
  96. switch {
  97. case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
  98. switch fd.Kind() {
  99. case pref.MessageKind:
  100. vi.typ = validationTypeMessage
  101. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  102. vi.mi = getMessageInfo(ot.Field(0).Type)
  103. }
  104. case pref.GroupKind:
  105. vi.typ = validationTypeGroup
  106. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  107. vi.mi = getMessageInfo(ot.Field(0).Type)
  108. }
  109. case pref.StringKind:
  110. if strs.EnforceUTF8(fd) {
  111. vi.typ = validationTypeUTF8String
  112. }
  113. }
  114. default:
  115. vi = newValidationInfo(fd, ft)
  116. }
  117. if fd.Cardinality() == pref.Required {
  118. // Avoid overflow. The required field check is done with a 64-bit mask, with
  119. // any message containing more than 64 required fields always reported as
  120. // potentially uninitialized, so it is not important to get a precise count
  121. // of the required fields past 64.
  122. if mi.numRequiredFields < math.MaxUint8 {
  123. mi.numRequiredFields++
  124. vi.requiredBit = 1 << (mi.numRequiredFields - 1)
  125. }
  126. }
  127. return vi
  128. }
  129. func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
  130. var vi validationInfo
  131. switch {
  132. case fd.IsList():
  133. switch fd.Kind() {
  134. case pref.MessageKind:
  135. vi.typ = validationTypeMessage
  136. if ft.Kind() == reflect.Slice {
  137. vi.mi = getMessageInfo(ft.Elem())
  138. }
  139. case pref.GroupKind:
  140. vi.typ = validationTypeGroup
  141. if ft.Kind() == reflect.Slice {
  142. vi.mi = getMessageInfo(ft.Elem())
  143. }
  144. case pref.StringKind:
  145. vi.typ = validationTypeBytes
  146. if strs.EnforceUTF8(fd) {
  147. vi.typ = validationTypeUTF8String
  148. }
  149. default:
  150. switch wireTypes[fd.Kind()] {
  151. case protowire.VarintType:
  152. vi.typ = validationTypeRepeatedVarint
  153. case protowire.Fixed32Type:
  154. vi.typ = validationTypeRepeatedFixed32
  155. case protowire.Fixed64Type:
  156. vi.typ = validationTypeRepeatedFixed64
  157. }
  158. }
  159. case fd.IsMap():
  160. vi.typ = validationTypeMap
  161. switch fd.MapKey().Kind() {
  162. case pref.StringKind:
  163. if strs.EnforceUTF8(fd) {
  164. vi.keyType = validationTypeUTF8String
  165. }
  166. }
  167. switch fd.MapValue().Kind() {
  168. case pref.MessageKind:
  169. vi.valType = validationTypeMessage
  170. if ft.Kind() == reflect.Map {
  171. vi.mi = getMessageInfo(ft.Elem())
  172. }
  173. case pref.StringKind:
  174. if strs.EnforceUTF8(fd) {
  175. vi.valType = validationTypeUTF8String
  176. }
  177. }
  178. default:
  179. switch fd.Kind() {
  180. case pref.MessageKind:
  181. vi.typ = validationTypeMessage
  182. if !fd.IsWeak() {
  183. vi.mi = getMessageInfo(ft)
  184. }
  185. case pref.GroupKind:
  186. vi.typ = validationTypeGroup
  187. vi.mi = getMessageInfo(ft)
  188. case pref.StringKind:
  189. vi.typ = validationTypeBytes
  190. if strs.EnforceUTF8(fd) {
  191. vi.typ = validationTypeUTF8String
  192. }
  193. default:
  194. switch wireTypes[fd.Kind()] {
  195. case protowire.VarintType:
  196. vi.typ = validationTypeVarint
  197. case protowire.Fixed32Type:
  198. vi.typ = validationTypeFixed32
  199. case protowire.Fixed64Type:
  200. vi.typ = validationTypeFixed64
  201. case protowire.BytesType:
  202. vi.typ = validationTypeBytes
  203. }
  204. }
  205. }
  206. return vi
  207. }
  208. func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
  209. mi.init()
  210. type validationState struct {
  211. typ validationType
  212. keyType, valType validationType
  213. endGroup protowire.Number
  214. mi *MessageInfo
  215. tail []byte
  216. requiredMask uint64
  217. }
  218. // Pre-allocate some slots to avoid repeated slice reallocation.
  219. states := make([]validationState, 0, 16)
  220. states = append(states, validationState{
  221. typ: validationTypeMessage,
  222. mi: mi,
  223. })
  224. if groupTag > 0 {
  225. states[0].typ = validationTypeGroup
  226. states[0].endGroup = groupTag
  227. }
  228. initialized := true
  229. start := len(b)
  230. State:
  231. for len(states) > 0 {
  232. st := &states[len(states)-1]
  233. for len(b) > 0 {
  234. // Parse the tag (field number and wire type).
  235. var tag uint64
  236. if b[0] < 0x80 {
  237. tag = uint64(b[0])
  238. b = b[1:]
  239. } else if len(b) >= 2 && b[1] < 128 {
  240. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  241. b = b[2:]
  242. } else {
  243. var n int
  244. tag, n = protowire.ConsumeVarint(b)
  245. if n < 0 {
  246. return out, ValidationInvalid
  247. }
  248. b = b[n:]
  249. }
  250. var num protowire.Number
  251. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  252. return out, ValidationInvalid
  253. } else {
  254. num = protowire.Number(n)
  255. }
  256. wtyp := protowire.Type(tag & 7)
  257. if wtyp == protowire.EndGroupType {
  258. if st.endGroup == num {
  259. goto PopState
  260. }
  261. return out, ValidationInvalid
  262. }
  263. var vi validationInfo
  264. switch {
  265. case st.typ == validationTypeMap:
  266. switch num {
  267. case 1:
  268. vi.typ = st.keyType
  269. case 2:
  270. vi.typ = st.valType
  271. vi.mi = st.mi
  272. vi.requiredBit = 1
  273. }
  274. case flags.ProtoLegacy && st.mi.isMessageSet:
  275. switch num {
  276. case messageset.FieldItem:
  277. vi.typ = validationTypeMessageSetItem
  278. }
  279. default:
  280. var f *coderFieldInfo
  281. if int(num) < len(st.mi.denseCoderFields) {
  282. f = st.mi.denseCoderFields[num]
  283. } else {
  284. f = st.mi.coderFields[num]
  285. }
  286. if f != nil {
  287. vi = f.validation
  288. if vi.typ == validationTypeMessage && vi.mi == nil {
  289. // Probable weak field.
  290. //
  291. // TODO: Consider storing the results of this lookup somewhere
  292. // rather than recomputing it on every validation.
  293. fd := st.mi.Desc.Fields().ByNumber(num)
  294. if fd == nil || !fd.IsWeak() {
  295. break
  296. }
  297. messageName := fd.Message().FullName()
  298. messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
  299. switch err {
  300. case nil:
  301. vi.mi, _ = messageType.(*MessageInfo)
  302. case preg.NotFound:
  303. vi.typ = validationTypeBytes
  304. default:
  305. return out, ValidationUnknown
  306. }
  307. }
  308. break
  309. }
  310. // Possible extension field.
  311. //
  312. // TODO: We should return ValidationUnknown when:
  313. // 1. The resolver is not frozen. (More extensions may be added to it.)
  314. // 2. The resolver returns preg.NotFound.
  315. // In this case, a type added to the resolver in the future could cause
  316. // unmarshaling to begin failing. Supporting this requires some way to
  317. // determine if the resolver is frozen.
  318. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
  319. if err != nil && err != preg.NotFound {
  320. return out, ValidationUnknown
  321. }
  322. if err == nil {
  323. vi = getExtensionFieldInfo(xt).validation
  324. }
  325. }
  326. if vi.requiredBit != 0 {
  327. // Check that the field has a compatible wire type.
  328. // We only need to consider non-repeated field types,
  329. // since repeated fields (and maps) can never be required.
  330. ok := false
  331. switch vi.typ {
  332. case validationTypeVarint:
  333. ok = wtyp == protowire.VarintType
  334. case validationTypeFixed32:
  335. ok = wtyp == protowire.Fixed32Type
  336. case validationTypeFixed64:
  337. ok = wtyp == protowire.Fixed64Type
  338. case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
  339. ok = wtyp == protowire.BytesType
  340. case validationTypeGroup:
  341. ok = wtyp == protowire.StartGroupType
  342. }
  343. if ok {
  344. st.requiredMask |= vi.requiredBit
  345. }
  346. }
  347. switch wtyp {
  348. case protowire.VarintType:
  349. if len(b) >= 10 {
  350. switch {
  351. case b[0] < 0x80:
  352. b = b[1:]
  353. case b[1] < 0x80:
  354. b = b[2:]
  355. case b[2] < 0x80:
  356. b = b[3:]
  357. case b[3] < 0x80:
  358. b = b[4:]
  359. case b[4] < 0x80:
  360. b = b[5:]
  361. case b[5] < 0x80:
  362. b = b[6:]
  363. case b[6] < 0x80:
  364. b = b[7:]
  365. case b[7] < 0x80:
  366. b = b[8:]
  367. case b[8] < 0x80:
  368. b = b[9:]
  369. case b[9] < 0x80 && b[9] < 2:
  370. b = b[10:]
  371. default:
  372. return out, ValidationInvalid
  373. }
  374. } else {
  375. switch {
  376. case len(b) > 0 && b[0] < 0x80:
  377. b = b[1:]
  378. case len(b) > 1 && b[1] < 0x80:
  379. b = b[2:]
  380. case len(b) > 2 && b[2] < 0x80:
  381. b = b[3:]
  382. case len(b) > 3 && b[3] < 0x80:
  383. b = b[4:]
  384. case len(b) > 4 && b[4] < 0x80:
  385. b = b[5:]
  386. case len(b) > 5 && b[5] < 0x80:
  387. b = b[6:]
  388. case len(b) > 6 && b[6] < 0x80:
  389. b = b[7:]
  390. case len(b) > 7 && b[7] < 0x80:
  391. b = b[8:]
  392. case len(b) > 8 && b[8] < 0x80:
  393. b = b[9:]
  394. case len(b) > 9 && b[9] < 2:
  395. b = b[10:]
  396. default:
  397. return out, ValidationInvalid
  398. }
  399. }
  400. continue State
  401. case protowire.BytesType:
  402. var size uint64
  403. if len(b) >= 1 && b[0] < 0x80 {
  404. size = uint64(b[0])
  405. b = b[1:]
  406. } else if len(b) >= 2 && b[1] < 128 {
  407. size = uint64(b[0]&0x7f) + uint64(b[1])<<7
  408. b = b[2:]
  409. } else {
  410. var n int
  411. size, n = protowire.ConsumeVarint(b)
  412. if n < 0 {
  413. return out, ValidationInvalid
  414. }
  415. b = b[n:]
  416. }
  417. if size > uint64(len(b)) {
  418. return out, ValidationInvalid
  419. }
  420. v := b[:size]
  421. b = b[size:]
  422. switch vi.typ {
  423. case validationTypeMessage:
  424. if vi.mi == nil {
  425. return out, ValidationUnknown
  426. }
  427. vi.mi.init()
  428. fallthrough
  429. case validationTypeMap:
  430. if vi.mi != nil {
  431. vi.mi.init()
  432. }
  433. states = append(states, validationState{
  434. typ: vi.typ,
  435. keyType: vi.keyType,
  436. valType: vi.valType,
  437. mi: vi.mi,
  438. tail: b,
  439. })
  440. b = v
  441. continue State
  442. case validationTypeRepeatedVarint:
  443. // Packed field.
  444. for len(v) > 0 {
  445. _, n := protowire.ConsumeVarint(v)
  446. if n < 0 {
  447. return out, ValidationInvalid
  448. }
  449. v = v[n:]
  450. }
  451. case validationTypeRepeatedFixed32:
  452. // Packed field.
  453. if len(v)%4 != 0 {
  454. return out, ValidationInvalid
  455. }
  456. case validationTypeRepeatedFixed64:
  457. // Packed field.
  458. if len(v)%8 != 0 {
  459. return out, ValidationInvalid
  460. }
  461. case validationTypeUTF8String:
  462. if !utf8.Valid(v) {
  463. return out, ValidationInvalid
  464. }
  465. }
  466. case protowire.Fixed32Type:
  467. if len(b) < 4 {
  468. return out, ValidationInvalid
  469. }
  470. b = b[4:]
  471. case protowire.Fixed64Type:
  472. if len(b) < 8 {
  473. return out, ValidationInvalid
  474. }
  475. b = b[8:]
  476. case protowire.StartGroupType:
  477. switch {
  478. case vi.typ == validationTypeGroup:
  479. if vi.mi == nil {
  480. return out, ValidationUnknown
  481. }
  482. vi.mi.init()
  483. states = append(states, validationState{
  484. typ: validationTypeGroup,
  485. mi: vi.mi,
  486. endGroup: num,
  487. })
  488. continue State
  489. case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
  490. typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
  491. if err != nil {
  492. return out, ValidationInvalid
  493. }
  494. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
  495. switch {
  496. case err == preg.NotFound:
  497. b = b[n:]
  498. case err != nil:
  499. return out, ValidationUnknown
  500. default:
  501. xvi := getExtensionFieldInfo(xt).validation
  502. if xvi.mi != nil {
  503. xvi.mi.init()
  504. }
  505. states = append(states, validationState{
  506. typ: xvi.typ,
  507. mi: xvi.mi,
  508. tail: b[n:],
  509. })
  510. b = v
  511. continue State
  512. }
  513. default:
  514. n := protowire.ConsumeFieldValue(num, wtyp, b)
  515. if n < 0 {
  516. return out, ValidationInvalid
  517. }
  518. b = b[n:]
  519. }
  520. default:
  521. return out, ValidationInvalid
  522. }
  523. }
  524. if st.endGroup != 0 {
  525. return out, ValidationInvalid
  526. }
  527. if len(b) != 0 {
  528. return out, ValidationInvalid
  529. }
  530. b = st.tail
  531. PopState:
  532. numRequiredFields := 0
  533. switch st.typ {
  534. case validationTypeMessage, validationTypeGroup:
  535. numRequiredFields = int(st.mi.numRequiredFields)
  536. case validationTypeMap:
  537. // If this is a map field with a message value that contains
  538. // required fields, require that the value be present.
  539. if st.mi != nil && st.mi.numRequiredFields > 0 {
  540. numRequiredFields = 1
  541. }
  542. }
  543. // If there are more than 64 required fields, this check will
  544. // always fail and we will report that the message is potentially
  545. // uninitialized.
  546. if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
  547. initialized = false
  548. }
  549. states = states[:len(states)-1]
  550. }
  551. out.n = start - len(b)
  552. if initialized {
  553. out.initialized = true
  554. }
  555. return out, ValidationValid
  556. }