http_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package context
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "net/http/httputil"
  6. "net/url"
  7. "reflect"
  8. "testing"
  9. "time"
  10. )
  11. func TestWithRequest(t *testing.T) {
  12. var req http.Request
  13. start := time.Now()
  14. req.Method = "GET"
  15. req.Host = "example.com"
  16. req.RequestURI = "/test-test"
  17. req.Header = make(http.Header)
  18. req.Header.Set("Referer", "foo.com/referer")
  19. req.Header.Set("User-Agent", "test/0.1")
  20. ctx := WithRequest(Background(), &req)
  21. for _, testcase := range []struct {
  22. key string
  23. expected interface{}
  24. }{
  25. {
  26. key: "http.request",
  27. expected: &req,
  28. },
  29. {
  30. key: "http.request.id",
  31. },
  32. {
  33. key: "http.request.method",
  34. expected: req.Method,
  35. },
  36. {
  37. key: "http.request.host",
  38. expected: req.Host,
  39. },
  40. {
  41. key: "http.request.uri",
  42. expected: req.RequestURI,
  43. },
  44. {
  45. key: "http.request.referer",
  46. expected: req.Referer(),
  47. },
  48. {
  49. key: "http.request.useragent",
  50. expected: req.UserAgent(),
  51. },
  52. {
  53. key: "http.request.remoteaddr",
  54. expected: req.RemoteAddr,
  55. },
  56. {
  57. key: "http.request.startedat",
  58. },
  59. } {
  60. v := ctx.Value(testcase.key)
  61. if v == nil {
  62. t.Fatalf("value not found for %q", testcase.key)
  63. }
  64. if testcase.expected != nil && v != testcase.expected {
  65. t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
  66. }
  67. // Key specific checks!
  68. switch testcase.key {
  69. case "http.request.id":
  70. if _, ok := v.(string); !ok {
  71. t.Fatalf("request id not a string: %v", v)
  72. }
  73. case "http.request.startedat":
  74. vt, ok := v.(time.Time)
  75. if !ok {
  76. t.Fatalf("value not a time: %v", v)
  77. }
  78. now := time.Now()
  79. if vt.After(now) {
  80. t.Fatalf("time generated too late: %v > %v", vt, now)
  81. }
  82. if vt.Before(start) {
  83. t.Fatalf("time generated too early: %v < %v", vt, start)
  84. }
  85. }
  86. }
  87. }
  88. type testResponseWriter struct {
  89. flushed bool
  90. status int
  91. written int64
  92. header http.Header
  93. }
  94. func (trw *testResponseWriter) Header() http.Header {
  95. if trw.header == nil {
  96. trw.header = make(http.Header)
  97. }
  98. return trw.header
  99. }
  100. func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
  101. if trw.status == 0 {
  102. trw.status = http.StatusOK
  103. }
  104. n = len(p)
  105. trw.written += int64(n)
  106. return
  107. }
  108. func (trw *testResponseWriter) WriteHeader(status int) {
  109. trw.status = status
  110. }
  111. func (trw *testResponseWriter) Flush() {
  112. trw.flushed = true
  113. }
  114. func TestWithResponseWriter(t *testing.T) {
  115. trw := testResponseWriter{}
  116. ctx, rw := WithResponseWriter(Background(), &trw)
  117. if ctx.Value("http.response") != rw {
  118. t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
  119. }
  120. grw, err := GetResponseWriter(ctx)
  121. if err != nil {
  122. t.Fatalf("error getting response writer: %v", err)
  123. }
  124. if grw != rw {
  125. t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
  126. }
  127. if ctx.Value("http.response.status") != 0 {
  128. t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
  129. }
  130. if n, err := rw.Write(make([]byte, 1024)); err != nil {
  131. t.Fatalf("unexpected error writing: %v", err)
  132. } else if n != 1024 {
  133. t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
  134. }
  135. if ctx.Value("http.response.status") != http.StatusOK {
  136. t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
  137. }
  138. if ctx.Value("http.response.written") != int64(1024) {
  139. t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
  140. }
  141. // Make sure flush propagates
  142. rw.(http.Flusher).Flush()
  143. if !trw.flushed {
  144. t.Fatalf("response writer not flushed")
  145. }
  146. // Write another status and make sure context is correct. This normally
  147. // wouldn't work except for in this contrived testcase.
  148. rw.WriteHeader(http.StatusBadRequest)
  149. if ctx.Value("http.response.status") != http.StatusBadRequest {
  150. t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
  151. }
  152. }
  153. func TestWithVars(t *testing.T) {
  154. var req http.Request
  155. vars := map[string]string{
  156. "foo": "asdf",
  157. "bar": "qwer",
  158. }
  159. getVarsFromRequest = func(r *http.Request) map[string]string {
  160. if r != &req {
  161. t.Fatalf("unexpected request: %v != %v", r, req)
  162. }
  163. return vars
  164. }
  165. ctx := WithVars(Background(), &req)
  166. for _, testcase := range []struct {
  167. key string
  168. expected interface{}
  169. }{
  170. {
  171. key: "vars",
  172. expected: vars,
  173. },
  174. {
  175. key: "vars.foo",
  176. expected: "asdf",
  177. },
  178. {
  179. key: "vars.bar",
  180. expected: "qwer",
  181. },
  182. } {
  183. v := ctx.Value(testcase.key)
  184. if !reflect.DeepEqual(v, testcase.expected) {
  185. t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
  186. }
  187. }
  188. }
  189. // SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
  190. // RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
  191. // at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
  192. // just contains the IP address, it is different enough for testing.
  193. func TestRemoteAddr(t *testing.T) {
  194. var expectedRemote string
  195. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  196. defer r.Body.Close()
  197. if r.RemoteAddr == expectedRemote {
  198. t.Errorf("Unexpected matching remote addresses")
  199. }
  200. actualRemote := RemoteAddr(r)
  201. if expectedRemote != actualRemote {
  202. t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
  203. }
  204. w.WriteHeader(200)
  205. }))
  206. defer backend.Close()
  207. backendURL, err := url.Parse(backend.URL)
  208. if err != nil {
  209. t.Fatal(err)
  210. }
  211. proxy := httputil.NewSingleHostReverseProxy(backendURL)
  212. frontend := httptest.NewServer(proxy)
  213. defer frontend.Close()
  214. // X-Forwarded-For set by proxy
  215. expectedRemote = "127.0.0.1"
  216. proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. _, err = http.DefaultClient.Do(proxyReq)
  221. if err != nil {
  222. t.Fatal(err)
  223. }
  224. // RemoteAddr in X-Real-Ip
  225. getReq, err := http.NewRequest("GET", backend.URL, nil)
  226. if err != nil {
  227. t.Fatal(err)
  228. }
  229. expectedRemote = "1.2.3.4"
  230. getReq.Header["X-Real-ip"] = []string{expectedRemote}
  231. _, err = http.DefaultClient.Do(getReq)
  232. if err != nil {
  233. t.Fatal(err)
  234. }
  235. // Valid X-Real-Ip and invalid X-Forwarded-For
  236. getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
  237. _, err = http.DefaultClient.Do(getReq)
  238. if err != nil {
  239. t.Fatal(err)
  240. }
  241. }