unmarshal.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package rest
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "net/http"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/aws/aws-sdk-go/aws"
  15. "github.com/aws/aws-sdk-go/aws/awserr"
  16. "github.com/aws/aws-sdk-go/aws/request"
  17. )
  18. // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
  19. var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
  20. // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
  21. var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
  22. // Unmarshal unmarshals the REST component of a response in a REST service.
  23. func Unmarshal(r *request.Request) {
  24. if r.DataFilled() {
  25. v := reflect.Indirect(reflect.ValueOf(r.Data))
  26. unmarshalBody(r, v)
  27. }
  28. }
  29. // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
  30. func UnmarshalMeta(r *request.Request) {
  31. r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
  32. if r.RequestID == "" {
  33. // Alternative version of request id in the header
  34. r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
  35. }
  36. if r.DataFilled() {
  37. v := reflect.Indirect(reflect.ValueOf(r.Data))
  38. unmarshalLocationElements(r, v)
  39. }
  40. }
  41. func unmarshalBody(r *request.Request, v reflect.Value) {
  42. if field, ok := v.Type().FieldByName("_"); ok {
  43. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  44. pfield, _ := v.Type().FieldByName(payloadName)
  45. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  46. payload := v.FieldByName(payloadName)
  47. if payload.IsValid() {
  48. switch payload.Interface().(type) {
  49. case []byte:
  50. defer r.HTTPResponse.Body.Close()
  51. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  52. if err != nil {
  53. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  54. } else {
  55. payload.Set(reflect.ValueOf(b))
  56. }
  57. case *string:
  58. defer r.HTTPResponse.Body.Close()
  59. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  60. if err != nil {
  61. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  62. } else {
  63. str := string(b)
  64. payload.Set(reflect.ValueOf(&str))
  65. }
  66. default:
  67. switch payload.Type().String() {
  68. case "io.ReadCloser":
  69. payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
  70. case "io.ReadSeeker":
  71. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  72. if err != nil {
  73. r.Error = awserr.New("SerializationError",
  74. "failed to read response body", err)
  75. return
  76. }
  77. payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
  78. default:
  79. io.Copy(ioutil.Discard, r.HTTPResponse.Body)
  80. defer r.HTTPResponse.Body.Close()
  81. r.Error = awserr.New("SerializationError",
  82. "failed to decode REST response",
  83. fmt.Errorf("unknown payload type %s", payload.Type()))
  84. }
  85. }
  86. }
  87. }
  88. }
  89. }
  90. }
  91. func unmarshalLocationElements(r *request.Request, v reflect.Value) {
  92. for i := 0; i < v.NumField(); i++ {
  93. m, field := v.Field(i), v.Type().Field(i)
  94. if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
  95. continue
  96. }
  97. if m.IsValid() {
  98. name := field.Tag.Get("locationName")
  99. if name == "" {
  100. name = field.Name
  101. }
  102. switch field.Tag.Get("location") {
  103. case "statusCode":
  104. unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
  105. case "header":
  106. err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag)
  107. if err != nil {
  108. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  109. break
  110. }
  111. case "headers":
  112. prefix := field.Tag.Get("locationName")
  113. err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
  114. if err != nil {
  115. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  116. break
  117. }
  118. }
  119. }
  120. if r.Error != nil {
  121. return
  122. }
  123. }
  124. }
  125. func unmarshalStatusCode(v reflect.Value, statusCode int) {
  126. if !v.IsValid() {
  127. return
  128. }
  129. switch v.Interface().(type) {
  130. case *int64:
  131. s := int64(statusCode)
  132. v.Set(reflect.ValueOf(&s))
  133. }
  134. }
  135. func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
  136. switch r.Interface().(type) {
  137. case map[string]*string: // we only support string map value types
  138. out := map[string]*string{}
  139. for k, v := range headers {
  140. k = http.CanonicalHeaderKey(k)
  141. if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
  142. out[k[len(prefix):]] = &v[0]
  143. }
  144. }
  145. r.Set(reflect.ValueOf(out))
  146. }
  147. return nil
  148. }
  149. func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
  150. isJSONValue := tag.Get("type") == "jsonvalue"
  151. if isJSONValue {
  152. if len(header) == 0 {
  153. return nil
  154. }
  155. } else if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
  156. return nil
  157. }
  158. switch v.Interface().(type) {
  159. case *string:
  160. v.Set(reflect.ValueOf(&header))
  161. case []byte:
  162. b, err := base64.StdEncoding.DecodeString(header)
  163. if err != nil {
  164. return err
  165. }
  166. v.Set(reflect.ValueOf(&b))
  167. case *bool:
  168. b, err := strconv.ParseBool(header)
  169. if err != nil {
  170. return err
  171. }
  172. v.Set(reflect.ValueOf(&b))
  173. case *int64:
  174. i, err := strconv.ParseInt(header, 10, 64)
  175. if err != nil {
  176. return err
  177. }
  178. v.Set(reflect.ValueOf(&i))
  179. case *float64:
  180. f, err := strconv.ParseFloat(header, 64)
  181. if err != nil {
  182. return err
  183. }
  184. v.Set(reflect.ValueOf(&f))
  185. case *time.Time:
  186. t, err := time.Parse(RFC822, header)
  187. if err != nil {
  188. return err
  189. }
  190. v.Set(reflect.ValueOf(&t))
  191. case aws.JSONValue:
  192. b := []byte(header)
  193. var err error
  194. if tag.Get("location") == "header" {
  195. b, err = base64.StdEncoding.DecodeString(header)
  196. if err != nil {
  197. return err
  198. }
  199. }
  200. m := aws.JSONValue{}
  201. err = json.Unmarshal(b, &m)
  202. if err != nil {
  203. return err
  204. }
  205. v.Set(reflect.ValueOf(m))
  206. default:
  207. err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
  208. return err
  209. }
  210. return nil
  211. }