unmarshal.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package rest
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/aws/awserr"
  14. "github.com/aws/aws-sdk-go/aws/request"
  15. )
  16. // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
  17. var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
  18. // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
  19. var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
  20. // Unmarshal unmarshals the REST component of a response in a REST service.
  21. func Unmarshal(r *request.Request) {
  22. if r.DataFilled() {
  23. v := reflect.Indirect(reflect.ValueOf(r.Data))
  24. unmarshalBody(r, v)
  25. }
  26. }
  27. // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
  28. func UnmarshalMeta(r *request.Request) {
  29. r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
  30. if r.RequestID == "" {
  31. // Alternative version of request id in the header
  32. r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
  33. }
  34. if r.DataFilled() {
  35. v := reflect.Indirect(reflect.ValueOf(r.Data))
  36. unmarshalLocationElements(r, v)
  37. }
  38. }
  39. func unmarshalBody(r *request.Request, v reflect.Value) {
  40. if field, ok := v.Type().FieldByName("_"); ok {
  41. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  42. pfield, _ := v.Type().FieldByName(payloadName)
  43. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  44. payload := v.FieldByName(payloadName)
  45. if payload.IsValid() {
  46. switch payload.Interface().(type) {
  47. case []byte:
  48. defer r.HTTPResponse.Body.Close()
  49. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  50. if err != nil {
  51. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  52. } else {
  53. payload.Set(reflect.ValueOf(b))
  54. }
  55. case *string:
  56. defer r.HTTPResponse.Body.Close()
  57. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  58. if err != nil {
  59. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  60. } else {
  61. str := string(b)
  62. payload.Set(reflect.ValueOf(&str))
  63. }
  64. default:
  65. switch payload.Type().String() {
  66. case "io.ReadSeeker":
  67. payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
  68. case "aws.ReadSeekCloser", "io.ReadCloser":
  69. payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
  70. default:
  71. io.Copy(ioutil.Discard, r.HTTPResponse.Body)
  72. defer r.HTTPResponse.Body.Close()
  73. r.Error = awserr.New("SerializationError",
  74. "failed to decode REST response",
  75. fmt.Errorf("unknown payload type %s", payload.Type()))
  76. }
  77. }
  78. }
  79. }
  80. }
  81. }
  82. }
  83. func unmarshalLocationElements(r *request.Request, v reflect.Value) {
  84. for i := 0; i < v.NumField(); i++ {
  85. m, field := v.Field(i), v.Type().Field(i)
  86. if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
  87. continue
  88. }
  89. if m.IsValid() {
  90. name := field.Tag.Get("locationName")
  91. if name == "" {
  92. name = field.Name
  93. }
  94. switch field.Tag.Get("location") {
  95. case "statusCode":
  96. unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
  97. case "header":
  98. err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
  99. if err != nil {
  100. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  101. break
  102. }
  103. case "headers":
  104. prefix := field.Tag.Get("locationName")
  105. err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
  106. if err != nil {
  107. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  108. break
  109. }
  110. }
  111. }
  112. if r.Error != nil {
  113. return
  114. }
  115. }
  116. }
  117. func unmarshalStatusCode(v reflect.Value, statusCode int) {
  118. if !v.IsValid() {
  119. return
  120. }
  121. switch v.Interface().(type) {
  122. case *int64:
  123. s := int64(statusCode)
  124. v.Set(reflect.ValueOf(&s))
  125. }
  126. }
  127. func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
  128. switch r.Interface().(type) {
  129. case map[string]*string: // we only support string map value types
  130. out := map[string]*string{}
  131. for k, v := range headers {
  132. k = http.CanonicalHeaderKey(k)
  133. if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
  134. out[k[len(prefix):]] = &v[0]
  135. }
  136. }
  137. r.Set(reflect.ValueOf(out))
  138. }
  139. return nil
  140. }
  141. func unmarshalHeader(v reflect.Value, header string) error {
  142. if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
  143. return nil
  144. }
  145. switch v.Interface().(type) {
  146. case *string:
  147. v.Set(reflect.ValueOf(&header))
  148. case []byte:
  149. b, err := base64.StdEncoding.DecodeString(header)
  150. if err != nil {
  151. return err
  152. }
  153. v.Set(reflect.ValueOf(&b))
  154. case *bool:
  155. b, err := strconv.ParseBool(header)
  156. if err != nil {
  157. return err
  158. }
  159. v.Set(reflect.ValueOf(&b))
  160. case *int64:
  161. i, err := strconv.ParseInt(header, 10, 64)
  162. if err != nil {
  163. return err
  164. }
  165. v.Set(reflect.ValueOf(&i))
  166. case *float64:
  167. f, err := strconv.ParseFloat(header, 64)
  168. if err != nil {
  169. return err
  170. }
  171. v.Set(reflect.ValueOf(&f))
  172. case *time.Time:
  173. t, err := time.Parse(RFC822, header)
  174. if err != nil {
  175. return err
  176. }
  177. v.Set(reflect.ValueOf(&t))
  178. default:
  179. err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
  180. return err
  181. }
  182. return nil
  183. }