protocol_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package protocol_test
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/aws/aws-sdk-go/aws/client/metadata"
  9. "github.com/aws/aws-sdk-go/aws/request"
  10. "github.com/aws/aws-sdk-go/awstesting"
  11. "github.com/aws/aws-sdk-go/private/protocol"
  12. "github.com/aws/aws-sdk-go/private/protocol/ec2query"
  13. "github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
  14. "github.com/aws/aws-sdk-go/private/protocol/query"
  15. "github.com/aws/aws-sdk-go/private/protocol/rest"
  16. "github.com/aws/aws-sdk-go/private/protocol/restjson"
  17. "github.com/aws/aws-sdk-go/private/protocol/restxml"
  18. )
  19. func xmlData(set bool, b []byte, size, delta int) {
  20. if !set {
  21. copy(b, []byte("<B><A>"))
  22. }
  23. if size == 0 {
  24. copy(b[delta-len("</B></A>"):], []byte("</B></A>"))
  25. }
  26. }
  27. func jsonData(set bool, b []byte, size, delta int) {
  28. if !set {
  29. copy(b, []byte("{\"A\": \""))
  30. }
  31. if size == 0 {
  32. copy(b[delta-len("\"}"):], []byte("\"}"))
  33. }
  34. }
  35. func buildNewRequest(data interface{}) *request.Request {
  36. v := url.Values{}
  37. v.Set("test", "TEST")
  38. v.Add("test1", "TEST1")
  39. req := &request.Request{
  40. HTTPRequest: &http.Request{
  41. Header: make(http.Header),
  42. Body: &awstesting.ReadCloser{Size: 2048},
  43. URL: &url.URL{
  44. RawQuery: v.Encode(),
  45. },
  46. },
  47. Params: &struct {
  48. LocationName string `locationName:"test"`
  49. }{
  50. "Test",
  51. },
  52. ClientInfo: metadata.ClientInfo{
  53. ServiceName: "test",
  54. TargetPrefix: "test",
  55. JSONVersion: "test",
  56. APIVersion: "test",
  57. Endpoint: "test",
  58. SigningName: "test",
  59. SigningRegion: "test",
  60. },
  61. Operation: &request.Operation{
  62. Name: "test",
  63. },
  64. }
  65. req.HTTPResponse = &http.Response{
  66. Body: &awstesting.ReadCloser{Size: 2048},
  67. Header: http.Header{
  68. "X-Amzn-Requestid": []string{"1"},
  69. },
  70. StatusCode: http.StatusOK,
  71. }
  72. if data == nil {
  73. data = &struct {
  74. _ struct{} `type:"structure"`
  75. LocationName *string `locationName:"testName"`
  76. Location *string `location:"statusCode"`
  77. A *string `type:"string"`
  78. }{}
  79. }
  80. req.Data = data
  81. return req
  82. }
  83. type expected struct {
  84. dataType int
  85. closed bool
  86. size int
  87. errExists bool
  88. }
  89. const (
  90. jsonType = iota
  91. xmlType
  92. )
  93. func checkForLeak(data interface{}, build, fn func(*request.Request), t *testing.T, result expected) {
  94. req := buildNewRequest(data)
  95. reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
  96. switch result.dataType {
  97. case jsonType:
  98. reader.FillData = jsonData
  99. case xmlType:
  100. reader.FillData = xmlData
  101. }
  102. build(req)
  103. fn(req)
  104. if result.errExists {
  105. assert.NotNil(t, req.Error)
  106. } else {
  107. fmt.Println(req.Error)
  108. assert.Nil(t, req.Error)
  109. }
  110. assert.Equal(t, reader.Closed, result.closed)
  111. assert.Equal(t, reader.Size, result.size)
  112. }
  113. func TestJSONRpc(t *testing.T) {
  114. checkForLeak(nil, jsonrpc.Build, jsonrpc.Unmarshal, t, expected{jsonType, true, 0, false})
  115. checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  116. checkForLeak(nil, jsonrpc.Build, jsonrpc.UnmarshalError, t, expected{jsonType, true, 0, true})
  117. }
  118. func TestQuery(t *testing.T) {
  119. checkForLeak(nil, query.Build, query.Unmarshal, t, expected{jsonType, true, 0, false})
  120. checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  121. checkForLeak(nil, query.Build, query.UnmarshalError, t, expected{jsonType, true, 0, true})
  122. }
  123. func TestRest(t *testing.T) {
  124. // case 1: Payload io.ReadSeeker
  125. checkForLeak(nil, rest.Build, rest.Unmarshal, t, expected{jsonType, false, 2048, false})
  126. checkForLeak(nil, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  127. // case 2: Payload *string
  128. // should close the body
  129. dataStr := struct {
  130. _ struct{} `type:"structure" payload:"Payload"`
  131. LocationName *string `locationName:"testName"`
  132. Location *string `location:"statusCode"`
  133. A *string `type:"string"`
  134. Payload *string `locationName:"payload" type:"blob" required:"true"`
  135. }{}
  136. checkForLeak(&dataStr, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
  137. checkForLeak(&dataStr, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  138. // case 3: Payload []byte
  139. // should close the body
  140. dataBytes := struct {
  141. _ struct{} `type:"structure" payload:"Payload"`
  142. LocationName *string `locationName:"testName"`
  143. Location *string `location:"statusCode"`
  144. A *string `type:"string"`
  145. Payload []byte `locationName:"payload" type:"blob" required:"true"`
  146. }{}
  147. checkForLeak(&dataBytes, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, false})
  148. checkForLeak(&dataBytes, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  149. // case 4: Payload unsupported type
  150. // should close the body
  151. dataUnsupported := struct {
  152. _ struct{} `type:"structure" payload:"Payload"`
  153. LocationName *string `locationName:"testName"`
  154. Location *string `location:"statusCode"`
  155. A *string `type:"string"`
  156. Payload string `locationName:"payload" type:"blob" required:"true"`
  157. }{}
  158. checkForLeak(&dataUnsupported, rest.Build, rest.Unmarshal, t, expected{jsonType, true, 0, true})
  159. checkForLeak(&dataUnsupported, query.Build, query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  160. }
  161. func TestRestJSON(t *testing.T) {
  162. checkForLeak(nil, restjson.Build, restjson.Unmarshal, t, expected{jsonType, true, 0, false})
  163. checkForLeak(nil, restjson.Build, restjson.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  164. checkForLeak(nil, restjson.Build, restjson.UnmarshalError, t, expected{jsonType, true, 0, true})
  165. }
  166. func TestRestXML(t *testing.T) {
  167. checkForLeak(nil, restxml.Build, restxml.Unmarshal, t, expected{xmlType, true, 0, false})
  168. checkForLeak(nil, restxml.Build, restxml.UnmarshalMeta, t, expected{xmlType, false, 2048, false})
  169. checkForLeak(nil, restxml.Build, restxml.UnmarshalError, t, expected{xmlType, true, 0, true})
  170. }
  171. func TestXML(t *testing.T) {
  172. checkForLeak(nil, ec2query.Build, ec2query.Unmarshal, t, expected{jsonType, true, 0, false})
  173. checkForLeak(nil, ec2query.Build, ec2query.UnmarshalMeta, t, expected{jsonType, false, 2048, false})
  174. checkForLeak(nil, ec2query.Build, ec2query.UnmarshalError, t, expected{jsonType, true, 0, true})
  175. }
  176. func TestProtocol(t *testing.T) {
  177. checkForLeak(nil, restxml.Build, protocol.UnmarshalDiscardBody, t, expected{xmlType, true, 0, false})
  178. }