diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 43e681e5b76b6bd74519d912493219ec387b45c2..fca583b1637c23cfbb7b05e9348b69b8155aacae 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -85,135 +85,86 @@ func TestUTF8(t *testing.T) { assertBodyContains(t, w, `Hello world, Καλημέρα κόσμε, コンニチハ`) } -func TestGet__Basic(t *testing.T) { - r, _ := http.NewRequest("GET", "/get", nil) - r.Host = "localhost" - r.Header.Set("User-Agent", "test") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *getResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - - if resp.Args.Encode() != "" { - t.Fatalf("expected empty args, got %s", resp.Args.Encode()) - } - if resp.Origin != "" { - t.Fatalf("expected empty origin, got %#v", resp.Origin) - } - if resp.URL != "http://localhost/get" { - t.Fatalf("unexpected url: %#v", resp.URL) - } - - var headerTests = []struct { - key string - expected string - }{ - {"Content-Type", ""}, - {"User-Agent", "test"}, - } - for _, test := range headerTests { - if resp.Headers.Get(test.key) != test.expected { - t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, resp.Headers.Get(test.key)) +func TestGet(t *testing.T) { + makeGetRequest := func(params *url.Values, headers *http.Header, expectedStatus int) (*getResponse, *httptest.ResponseRecorder) { + urlStr := "/get" + if params != nil { + urlStr = fmt.Sprintf("%s?%s", urlStr, params.Encode()) } - } -} - -func TestGet__WithParams(t *testing.T) { - params := url.Values{} - params.Set("foo", "foo") - params.Add("bar", "bar1") - params.Add("bar", "bar2") - - r, _ := http.NewRequest("GET", fmt.Sprintf("/get?%s", params.Encode()), nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + r, _ := http.NewRequest("GET", urlStr, nil) + r.Host = "localhost" + r.Header.Set("User-Agent", "test") + if headers != nil { + for k, vs := range *headers { + for _, v := range vs { + r.Header.Set(k, v) + } + } + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) - var resp *getResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + assertStatusCode(t, w, expectedStatus) - if resp.Args.Encode() != params.Encode() { - t.Fatalf("args mismatch: %s != %s", resp.Args.Encode(), params.Encode()) + var resp *getResponse + if expectedStatus == http.StatusOK { + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + } + } + return resp, w } -} - -func TestGet__OnlyAllowsGets(t *testing.T) { - r, _ := http.NewRequest("POST", "/get", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusMethodNotAllowed) - assertContentType(t, w, "text/plain; charset=utf-8") -} - -func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) { - r, _ := http.NewRequest("GET", "/get", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - assertHeader(t, w, "Access-Control-Allow-Origin", "*") -} + t.Run("basic", func(t *testing.T) { + resp, _ := makeGetRequest(nil, nil, http.StatusOK) -func TestGet__CORSHeadersWithRequestOrigin(t *testing.T) { - r, _ := http.NewRequest("GET", "/get", nil) - r.Header.Set("Origin", "origin") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + if resp.Args.Encode() != "" { + t.Fatalf("expected empty args, got %s", resp.Args.Encode()) + } + if resp.Origin != "" { + t.Fatalf("expected empty origin, got %#v", resp.Origin) + } + if resp.URL != "http://localhost/get" { + t.Fatalf("unexpected url: %#v", resp.URL) + } - assertHeader(t, w, "Access-Control-Allow-Origin", "origin") -} + var headerTests = []struct { + key string + expected string + }{ + {"Content-Type", ""}, + {"User-Agent", "test"}, + } + for _, test := range headerTests { + if resp.Headers.Get(test.key) != test.expected { + t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, resp.Headers.Get(test.key)) + } + } + }) -func TestGet__CORSHeadersWithOptionsVerb(t *testing.T) { - r, _ := http.NewRequest("OPTIONS", "/get", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + t.Run("with_query_params", func(t *testing.T) { + params := &url.Values{} + params.Set("foo", "foo") + params.Add("bar", "bar1") + params.Add("bar", "bar2") - var headerTests = []struct { - key string - expected string - }{ - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Credentials", "true"}, - {"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"}, - {"Access-Control-Max-Age", "3600"}, - {"Access-Control-Allow-Headers", ""}, - } - for _, test := range headerTests { - assertHeader(t, w, test.key, test.expected) - } -} + resp, _ := makeGetRequest(params, nil, http.StatusOK) + if resp.Args.Encode() != params.Encode() { + t.Fatalf("args mismatch: %s != %s", resp.Args.Encode(), params.Encode()) + } + }) -func TestGet__CORSAllowHeaders(t *testing.T) { - r, _ := http.NewRequest("OPTIONS", "/get", nil) - r.Header.Set("Access-Control-Request-Headers", "X-Test-Header") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + t.Run("only_allows_gets", func(t *testing.T) { + r, _ := http.NewRequest("POST", "/get", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) - var headerTests = []struct { - key string - expected string - }{ - {"Access-Control-Allow-Headers", "X-Test-Header"}, - } - for _, test := range headerTests { - assertHeader(t, w, test.key, test.expected) - } -} + assertStatusCode(t, w, http.StatusMethodNotAllowed) + assertContentType(t, w, "text/plain; charset=utf-8") + }) -func TestGet__XForwardedProto(t *testing.T) { - var tests = []struct { + var protoTests = []struct { key string value string }{ @@ -221,23 +172,70 @@ func TestGet__XForwardedProto(t *testing.T) { {"X-Forwarded-Protocol", "https"}, {"X-Forwarded-Ssl", "on"}, } + for _, test := range protoTests { + t.Run(test.key, func(t *testing.T) { + headers := &http.Header{} + headers.Set(test.key, test.value) + resp, _ := makeGetRequest(nil, headers, http.StatusOK) + if !strings.HasPrefix(resp.URL, "https://") { + t.Fatalf("%s=%s should result in https URL", test.key, test.value) + } + }) + } +} - for _, test := range tests { +func TestCORS(t *testing.T) { + t.Run("CORS/no_request_origin", func(t *testing.T) { r, _ := http.NewRequest("GET", "/get", nil) - r.Header.Set(test.key, test.value) w := httptest.NewRecorder() handler.ServeHTTP(w, r) + assertHeader(t, w, "Access-Control-Allow-Origin", "*") + }) - var resp *getResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + t.Run("CORS/with_request_origin", func(t *testing.T) { + r, _ := http.NewRequest("GET", "/get", nil) + r.Header.Set("Origin", "origin") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertHeader(t, w, "Access-Control-Allow-Origin", "origin") + }) + + t.Run("CORS/options_request", func(t *testing.T) { + r, _ := http.NewRequest("OPTIONS", "/get", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + var headerTests = []struct { + key string + expected string + }{ + {"Access-Control-Allow-Origin", "*"}, + {"Access-Control-Allow-Credentials", "true"}, + {"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"}, + {"Access-Control-Max-Age", "3600"}, + {"Access-Control-Allow-Headers", ""}, + } + for _, test := range headerTests { + assertHeader(t, w, test.key, test.expected) } + }) - if !strings.HasPrefix(resp.URL, "https://") { - t.Fatalf("%s=%s should result in https URL", test.key, test.value) + t.Run("CORS/allow_headers", func(t *testing.T) { + r, _ := http.NewRequest("OPTIONS", "/get", nil) + r.Header.Set("Access-Control-Request-Headers", "X-Test-Header") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + var headerTests = []struct { + key string + expected string + }{ + {"Access-Control-Allow-Headers", "X-Test-Header"}, } - } + for _, test := range headerTests { + assertHeader(t, w, test.key, test.expected) + } + }) } func TestIP(t *testing.T) {