unmarshal.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package rest
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "io/ioutil"
  6. "net/http"
  7. "reflect"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/aws/aws-sdk-go/aws"
  12. "github.com/aws/aws-sdk-go/aws/awserr"
  13. "github.com/aws/aws-sdk-go/aws/request"
  14. )
  15. // Unmarshal unmarshals the REST component of a response in a REST service.
  16. func Unmarshal(r *request.Request) {
  17. if r.DataFilled() {
  18. v := reflect.Indirect(reflect.ValueOf(r.Data))
  19. unmarshalBody(r, v)
  20. }
  21. }
  22. // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
  23. func UnmarshalMeta(r *request.Request) {
  24. r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
  25. if r.DataFilled() {
  26. v := reflect.Indirect(reflect.ValueOf(r.Data))
  27. unmarshalLocationElements(r, v)
  28. }
  29. }
  30. func unmarshalBody(r *request.Request, v reflect.Value) {
  31. if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
  32. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  33. pfield, _ := v.Type().FieldByName(payloadName)
  34. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  35. payload := v.FieldByName(payloadName)
  36. if payload.IsValid() {
  37. switch payload.Interface().(type) {
  38. case []byte:
  39. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  40. if err != nil {
  41. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  42. } else {
  43. payload.Set(reflect.ValueOf(b))
  44. }
  45. case *string:
  46. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  47. if err != nil {
  48. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  49. } else {
  50. str := string(b)
  51. payload.Set(reflect.ValueOf(&str))
  52. }
  53. default:
  54. switch payload.Type().String() {
  55. case "io.ReadSeeker":
  56. payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
  57. case "aws.ReadSeekCloser", "io.ReadCloser":
  58. payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
  59. default:
  60. r.Error = awserr.New("SerializationError",
  61. "failed to decode REST response",
  62. fmt.Errorf("unknown payload type %s", payload.Type()))
  63. }
  64. }
  65. }
  66. }
  67. }
  68. }
  69. }
  70. func unmarshalLocationElements(r *request.Request, v reflect.Value) {
  71. for i := 0; i < v.NumField(); i++ {
  72. m, field := v.Field(i), v.Type().Field(i)
  73. if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
  74. continue
  75. }
  76. if m.IsValid() {
  77. name := field.Tag.Get("locationName")
  78. if name == "" {
  79. name = field.Name
  80. }
  81. switch field.Tag.Get("location") {
  82. case "statusCode":
  83. unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
  84. case "header":
  85. err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
  86. if err != nil {
  87. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  88. break
  89. }
  90. case "headers":
  91. prefix := field.Tag.Get("locationName")
  92. err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
  93. if err != nil {
  94. r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
  95. break
  96. }
  97. }
  98. }
  99. if r.Error != nil {
  100. return
  101. }
  102. }
  103. }
  104. func unmarshalStatusCode(v reflect.Value, statusCode int) {
  105. if !v.IsValid() {
  106. return
  107. }
  108. switch v.Interface().(type) {
  109. case *int64:
  110. s := int64(statusCode)
  111. v.Set(reflect.ValueOf(&s))
  112. }
  113. }
  114. func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
  115. switch r.Interface().(type) {
  116. case map[string]*string: // we only support string map value types
  117. out := map[string]*string{}
  118. for k, v := range headers {
  119. k = http.CanonicalHeaderKey(k)
  120. if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
  121. out[k[len(prefix):]] = &v[0]
  122. }
  123. }
  124. r.Set(reflect.ValueOf(out))
  125. }
  126. return nil
  127. }
  128. func unmarshalHeader(v reflect.Value, header string) error {
  129. if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
  130. return nil
  131. }
  132. switch v.Interface().(type) {
  133. case *string:
  134. v.Set(reflect.ValueOf(&header))
  135. case []byte:
  136. b, err := base64.StdEncoding.DecodeString(header)
  137. if err != nil {
  138. return err
  139. }
  140. v.Set(reflect.ValueOf(&b))
  141. case *bool:
  142. b, err := strconv.ParseBool(header)
  143. if err != nil {
  144. return err
  145. }
  146. v.Set(reflect.ValueOf(&b))
  147. case *int64:
  148. i, err := strconv.ParseInt(header, 10, 64)
  149. if err != nil {
  150. return err
  151. }
  152. v.Set(reflect.ValueOf(&i))
  153. case *float64:
  154. f, err := strconv.ParseFloat(header, 64)
  155. if err != nil {
  156. return err
  157. }
  158. v.Set(reflect.ValueOf(&f))
  159. case *time.Time:
  160. t, err := time.Parse(RFC822, header)
  161. if err != nil {
  162. return err
  163. }
  164. v.Set(reflect.ValueOf(&t))
  165. default:
  166. err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
  167. return err
  168. }
  169. return nil
  170. }