handler.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package testutil
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "net/url"
  9. "sort"
  10. "strings"
  11. )
  12. // RequestResponseMap is an ordered mapping from Requests to Responses
  13. type RequestResponseMap []RequestResponseMapping
  14. // RequestResponseMapping defines a Response to be sent in response to a given
  15. // Request
  16. type RequestResponseMapping struct {
  17. Request Request
  18. Response Response
  19. }
  20. // Request is a simplified http.Request object
  21. type Request struct {
  22. // Method is the http method of the request, for example GET
  23. Method string
  24. // Route is the http route of this request
  25. Route string
  26. // QueryParams are the query parameters of this request
  27. QueryParams map[string][]string
  28. // Body is the byte contents of the http request
  29. Body []byte
  30. // Headers are the header for this request
  31. Headers http.Header
  32. }
  33. func (r Request) String() string {
  34. queryString := ""
  35. if len(r.QueryParams) > 0 {
  36. keys := make([]string, 0, len(r.QueryParams))
  37. queryParts := make([]string, 0, len(r.QueryParams))
  38. for k := range r.QueryParams {
  39. keys = append(keys, k)
  40. }
  41. sort.Strings(keys)
  42. for _, k := range keys {
  43. for _, val := range r.QueryParams[k] {
  44. queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val)))
  45. }
  46. }
  47. queryString = "?" + strings.Join(queryParts, "&")
  48. }
  49. var headers []string
  50. if len(r.Headers) > 0 {
  51. var headerKeys []string
  52. for k := range r.Headers {
  53. headerKeys = append(headerKeys, k)
  54. }
  55. sort.Strings(headerKeys)
  56. for _, k := range headerKeys {
  57. for _, val := range r.Headers[k] {
  58. headers = append(headers, fmt.Sprintf("%s:%s", k, val))
  59. }
  60. }
  61. }
  62. return fmt.Sprintf("%s %s%s\n%s\n%s", r.Method, r.Route, queryString, headers, r.Body)
  63. }
  64. // Response is a simplified http.Response object
  65. type Response struct {
  66. // Statuscode is the http status code of the Response
  67. StatusCode int
  68. // Headers are the http headers of this Response
  69. Headers http.Header
  70. // Body is the response body
  71. Body []byte
  72. }
  73. // testHandler is an http.Handler with a defined mapping from Request to an
  74. // ordered list of Response objects
  75. type testHandler struct {
  76. responseMap map[string][]Response
  77. }
  78. // NewHandler returns a new test handler that responds to defined requests
  79. // with specified responses
  80. // Each time a Request is received, the next Response is returned in the
  81. // mapping, until no Responses are defined, at which point a 404 is sent back
  82. func NewHandler(requestResponseMap RequestResponseMap) http.Handler {
  83. responseMap := make(map[string][]Response)
  84. for _, mapping := range requestResponseMap {
  85. responses, ok := responseMap[mapping.Request.String()]
  86. if ok {
  87. responseMap[mapping.Request.String()] = append(responses, mapping.Response)
  88. } else {
  89. responseMap[mapping.Request.String()] = []Response{mapping.Response}
  90. }
  91. }
  92. return &testHandler{responseMap: responseMap}
  93. }
  94. func (app *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  95. defer r.Body.Close()
  96. requestBody, _ := ioutil.ReadAll(r.Body)
  97. request := Request{
  98. Method: r.Method,
  99. Route: r.URL.Path,
  100. QueryParams: r.URL.Query(),
  101. Body: requestBody,
  102. Headers: make(map[string][]string),
  103. }
  104. // Add headers of interest here
  105. for k, v := range r.Header {
  106. if k == "If-None-Match" {
  107. request.Headers[k] = v
  108. }
  109. }
  110. responses, ok := app.responseMap[request.String()]
  111. if !ok || len(responses) == 0 {
  112. http.NotFound(w, r)
  113. return
  114. }
  115. response := responses[0]
  116. app.responseMap[request.String()] = responses[1:]
  117. responseHeader := w.Header()
  118. for k, v := range response.Headers {
  119. responseHeader[k] = v
  120. }
  121. w.WriteHeader(response.StatusCode)
  122. io.Copy(w, bytes.NewReader(response.Body))
  123. }