diff --git a/go.mod b/go.mod index 91b9af29f6e1bfb7674972084e618c574f5eee08..4327b57f0abaac259345fb34ed846082cbeb7d8f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/mccutchen/go-httpbin/v2 -go 1.16 +go 1.18 diff --git a/httpbin/handlers.go b/httpbin/handlers.go index a15a8a6b9ef9bdb7bf63ae39cbe6030dbc8d45f6..9dc7888cd93968f09e0238614128ea592ab7f0cb 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -17,6 +17,8 @@ import ( "github.com/mccutchen/go-httpbin/v2/httpbin/digest" ) +var nilValues = url.Values{} + func notImplementedHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not implemented", http.StatusNotImplemented) } @@ -71,6 +73,8 @@ func (h *HTTPBin) Anything(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { resp := &bodyResponse{ Args: r.URL.Query(), + Files: nilValues, + Form: nilValues, Headers: getRequestHeaders(r), Method: r.Method, Origin: getClientIP(r), @@ -628,25 +632,38 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) { return } - pause := duration / time.Duration(numBytes) - flusher := w.(http.Flusher) + pause := duration + if numBytes > 1 { + // compensate for lack of pause after final write (i.e. if we're + // writing 10 bytes, we will only pause 9 times) + pause = duration / time.Duration(numBytes-1) + } w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Length", fmt.Sprintf("%d", numBytes)) w.WriteHeader(code) + + flusher := w.(http.Flusher) flusher.Flush() + // wait for initial delay before writing response body select { case <-r.Context().Done(): return case <-time.After(delay): } + // write response body byte-by-byte, pausing between each write b := []byte{'*'} for i := int64(0); i < numBytes; i++ { w.Write(b) flusher.Flush() + // don't pause after last byte + if i == numBytes-1 { + break + } + select { case <-r.Context().Done(): return @@ -829,9 +846,10 @@ func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) { } }() } else { + // if not streaming, we will write the whole response at once chunkSize = numBytes + w.Header().Set("Content-Length", strconv.Itoa(numBytes)) write = func(chunk []byte) { - w.Header().Set("Content-Length", strconv.Itoa(len(chunk))) w.Write(chunk) } } @@ -878,6 +896,7 @@ func (h *HTTPBin) Links(w http.ResponseWriter, r *http.Request) { offset, err := strconv.Atoi(parts[3]) if err != nil { http.Error(w, "Invalid offset", http.StatusBadRequest) + return } doLinksPage(w, r, n, offset) return @@ -937,6 +956,7 @@ func doImage(w http.ResponseWriter, kind string) { img, err := staticAsset("image." + kind) if err != nil { http.Error(w, "Not Found", http.StatusNotFound) + return } contentType := "image/" + kind if kind == "svg" { diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index e687ef62602313cce3c173b2e669c3b92cccc5e6..ebdce3e811efffae8bcce2d774bc1a420ba5916a 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -12,11 +12,13 @@ import ( "fmt" "io" "log" - "math/rand" "mime/multipart" + "net" "net/http" "net/http/httptest" + "net/http/httputil" "net/url" + "os" "reflect" "regexp" "runtime" @@ -24,140 +26,117 @@ import ( "strings" "testing" "time" + + "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" + "github.com/mccutchen/go-httpbin/v2/internal/testing/must" ) const ( - maxBodySize int64 = 1024 - maxDuration time.Duration = 1 * time.Second - alphanumLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + maxBodySize int64 = 1024 + maxDuration time.Duration = 1 * time.Second ) -var testDefaultParams = DefaultParams{ - DripDelay: 0, - DripDuration: 100 * time.Millisecond, - DripNumBytes: 10, -} - -var app = New( - WithDefaultParams(testDefaultParams), - WithMaxBodySize(maxBodySize), - WithMaxDuration(maxDuration), - WithObserver(StdLogObserver(log.New(io.Discard, "", 0))), +// "Global" test app, server, & client to be reused across test cases. +// Initialized in TestMain. +var ( + app *HTTPBin + srv *httptest.Server + client *http.Client ) -func assertStatusCode(t *testing.T, w *httptest.ResponseRecorder, code int) { - t.Helper() - if w.Code != code { - t.Fatalf("expected status code %d, got %d", code, w.Code) - } -} - -// assertHeader asserts that a header key has a specific value in a -// response-like object. x must be *httptest.ResponseRecorder or *http.Response -func assertHeader(t *testing.T, x interface{}, key, want string) { - t.Helper() - - var got string - switch r := x.(type) { - case *httptest.ResponseRecorder: - got = r.Header().Get(key) - case *http.Response: - got = r.Header.Get(key) - default: - t.Fatalf("expected *httptest.ResponseRecorder or *http.Response, got %t", x) - } - if want != got { - t.Fatalf("expected header %s=%#v, got %#v", key, want, got) - } -} - -func assertContentType(t *testing.T, w *httptest.ResponseRecorder, contentType string) { - t.Helper() - assertHeader(t, w, "Content-Type", contentType) -} - -func assertBodyContains(t *testing.T, w *httptest.ResponseRecorder, needle string) { - t.Helper() - if !strings.Contains(w.Body.String(), needle) { - t.Fatalf("expected string %q in body %q", needle, w.Body.String()) - } -} - -func assertBodyEquals(t *testing.T, w *httptest.ResponseRecorder, want string) { - t.Helper() - have := w.Body.String() - if want != have { - t.Fatalf("expected body = %q, got %q", want, have) - } -} - -func randStringBytes(n int) string { - rand.New(rand.NewSource(time.Now().UnixNano())) - b := make([]byte, n) - for i := range b { - b[i] = alphanumLetters[rand.Intn(len(alphanumLetters))] - } - return string(b) +func TestMain(m *testing.M) { + app = New( + WithAllowedRedirectDomains([]string{ + "httpbingo.org", + "example.org", + "www.example.com", + }), + WithDefaultParams(DefaultParams{ + DripDelay: 0, + DripDuration: 100 * time.Millisecond, + DripNumBytes: 10, + }), + WithMaxBodySize(maxBodySize), + WithMaxDuration(maxDuration), + WithObserver(StdLogObserver(log.New(io.Discard, "", 0))), + ) + srv, client = newTestServer(app) + defer srv.Close() + os.Exit(m.Run()) } func TestIndex(t *testing.T) { - t.Parallel() - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + t.Run("ok", func(t *testing.T) { + t.Parallel() - assertContentType(t, w, htmlContentType) - assertHeader(t, w, "Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' camo.githubusercontent.com") - assertBodyContains(t, w, "go-httpbin") -} + req := newTestRequest(t, "GET", "/") + resp := must.DoReq(t, client, req) -func TestIndex__NotFound(t *testing.T) { - t.Parallel() - r, _ := http.NewRequest("GET", "/foo", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusNotFound) - assertBodyContains(t, w, "/foo") + assert.ContentType(t, resp, htmlContentType) + assert.Header(t, resp, "Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' camo.githubusercontent.com") + assert.BodyContains(t, resp, "go-httpbin") + }) + + t.Run("not found", func(t *testing.T) { + t.Parallel() + req := newTestRequest(t, "GET", "/foo") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusNotFound) + assert.BodyContains(t, resp, "/foo") + }) } func TestFormsPost(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/forms/post", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertContentType(t, w, htmlContentType) - assertBodyContains(t, w, `<form method="post" action="/post">`) + req := newTestRequest(t, "GET", "/forms/post") + resp := must.DoReq(t, client, req) + + assert.ContentType(t, resp, htmlContentType) + assert.BodyContains(t, resp, `<form method="post" action="/post">`) } func TestUTF8(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/encoding/utf8", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertContentType(t, w, htmlContentType) - assertBodyContains(t, w, `Hello world, Καλημέρα κόσμε, コンニチハ`) + req := newTestRequest(t, "GET", "/encoding/utf8") + resp := must.DoReq(t, client, req) + + assert.ContentType(t, resp, htmlContentType) + assert.BodyContains(t, resp, `Hello world, Καλημέρα κόσμε, コンニチハ`) } func TestGet(t *testing.T) { - t.Parallel() + doGetRequest := func(t *testing.T, path string, params url.Values, headers *http.Header) noBodyResponse { + t.Helper() + + if params != nil { + path = fmt.Sprintf("%s?%s", path, params.Encode()) + } + req := newTestRequest(t, "GET", path) + req.Header.Set("User-Agent", "test") + if headers != nil { + for k, vs := range *headers { + for _, v := range vs { + req.Header.Set(k, v) + } + } + } + + resp := must.DoReq(t, client, req) + return mustParseResponse[noBodyResponse](t, resp) + } t.Run("basic", func(t *testing.T) { t.Parallel() - resp, _ := testRequestWithoutBody(t, "/get", nil, nil, http.StatusOK) - if resp.Args.Encode() != "" { - t.Fatalf("expected empty args, got %s", resp.Args.Encode()) - } - if resp.Method != "GET" { - t.Fatalf("expected method to be GET, got %s", resp.Method) - } - if resp.Origin != "" { - t.Fatalf("expected empty origin, got %#v", resp.Origin) - } - if resp.URL != "http://localhost/get" { - t.Fatalf("unexpected url: %#v", resp.URL) + result := doGetRequest(t, "/get", nil, nil) + assert.Equal(t, result.Method, "GET", "method mismatch") + assert.Equal(t, result.Args.Encode(), "", "expected empty args") + assert.Equal(t, result.URL, srv.URL+"/get", "url mismatch") + + if !strings.HasPrefix(result.Origin, "127.0.0.1") { + t.Fatalf("expected 127.0.0.1 origin, got %q", result.Origin) } wantHeaders := map[string]string{ @@ -165,36 +144,31 @@ func TestGet(t *testing.T) { "User-Agent": "test", } for key, val := range wantHeaders { - if resp.Headers.Get(key) != val { - t.Fatalf("expected %s = %#v, got %#v", key, val, resp.Headers.Get(key)) - } + assert.Equal(t, result.Headers.Get(key), val, "header mismatch for key %q", key) } }) t.Run("with_query_params", func(t *testing.T) { t.Parallel() - params := &url.Values{} + + params := url.Values{} params.Set("foo", "foo") params.Add("bar", "bar1") params.Add("bar", "bar2") - resp, _ := testRequestWithoutBody(t, "/get", params, nil, http.StatusOK) - if resp.Args.Encode() != params.Encode() { - t.Fatalf("args mismatch: %s != %s", resp.Args.Encode(), params.Encode()) - } - if resp.Method != "GET" { - t.Fatalf("expected method to be GET, got %s", resp.Method) - } + result := doGetRequest(t, "/get", params, nil) + assert.Equal(t, result.Args.Encode(), params.Encode(), "args mismatch") + assert.Equal(t, result.Method, "GET", "method mismatch") }) t.Run("only_allows_gets", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("POST", "/get", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusMethodNotAllowed) - assertContentType(t, w, "text/plain; charset=utf-8") + req := newTestRequest(t, "POST", "/get") + resp := must.DoReq(t, client, req) + + assert.StatusCode(t, resp, http.StatusMethodNotAllowed) + assert.ContentType(t, resp, "text/plain; charset=utf-8") }) protoTests := []struct { @@ -211,47 +185,15 @@ func TestGet(t *testing.T) { t.Parallel() headers := &http.Header{} headers.Set(test.key, test.value) - resp, _ := testRequestWithoutBody(t, "/get", nil, headers, http.StatusOK) - if !strings.HasPrefix(resp.URL, "https://") { + result := doGetRequest(t, "/get", nil, headers) + if !strings.HasPrefix(result.URL, "https://") { t.Fatalf("%s=%s should result in https URL", test.key, test.value) } }) } } -func testRequestWithoutBody(t *testing.T, path string, params *url.Values, headers *http.Header, expectedStatus int) (*noBodyResponse, *httptest.ResponseRecorder) { - t.Helper() - - urlStr := path - if params != nil { - urlStr = fmt.Sprintf("%s?%s", urlStr, params.Encode()) - } - r, _ := http.NewRequest("GET", urlStr, nil) - r.Host = "localhost" - r.Header.Set("User-Agent", "test") - if headers != nil { - for k, vs := range *headers { - for _, v := range vs { - r.Header.Set(k, v) - } - } - } - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, expectedStatus) - - var resp *noBodyResponse - if expectedStatus == http.StatusOK { - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - } - return resp, w -} - func TestHead(t *testing.T) { - t.Parallel() testCases := []struct { verb string path string @@ -267,53 +209,44 @@ func TestHead(t *testing.T) { tc := tc t.Run(fmt.Sprintf("%s %s", tc.verb, tc.path), func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest(tc.verb, tc.path, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, tc.wantCode) + req := newTestRequest(t, tc.verb, tc.path) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, tc.wantCode) // we only do further validation when we get an OK response if tc.wantCode != http.StatusOK { return } - assertStatusCode(t, w, http.StatusOK) - assertBodyEquals(t, w, "") - - if contentLength := w.Header().Get("Content-Length"); contentLength != "" { - t.Fatalf("did not expect Content-Length in response to HEAD request") - } + assert.StatusCode(t, resp, http.StatusOK) + assert.BodyEquals(t, resp, "") + assert.Header(t, resp, "Content-Length", "") // content-length should be empty }) } } func TestCORS(t *testing.T) { - t.Parallel() t.Run("CORS/no_request_origin", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/get", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertHeader(t, w, "Access-Control-Allow-Origin", "*") + req := newTestRequest(t, "GET", "/get") + resp := must.DoReq(t, client, req) + assert.Header(t, resp, "Access-Control-Allow-Origin", "*") }) t.Run("CORS/with_request_origin", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/get", nil) - r.Header.Set("Origin", "origin") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertHeader(t, w, "Access-Control-Allow-Origin", "origin") + req := newTestRequest(t, "GET", "/get") + req.Header.Set("Origin", "origin") + resp := must.DoReq(t, client, req) + assert.Header(t, resp, "Access-Control-Allow-Origin", "origin") }) t.Run("CORS/options_request", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("OPTIONS", "/get", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, 200) + req := newTestRequest(t, "OPTIONS", "/get") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, 200) headerTests := []struct { key string @@ -326,18 +259,17 @@ func TestCORS(t *testing.T) { {"Access-Control-Allow-Headers", ""}, } for _, test := range headerTests { - assertHeader(t, w, test.key, test.expected) + assert.Header(t, resp, test.key, test.expected) } }) t.Run("CORS/allow_headers", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("OPTIONS", "/get", nil) - r.Header.Set("Access-Control-Request-Headers", "X-Test-Header") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, 200) + req := newTestRequest(t, "OPTIONS", "/get") + req.Header.Set("Access-Control-Request-Headers", "X-Test-Header") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, 200) headerTests := []struct { key string @@ -346,13 +278,12 @@ func TestCORS(t *testing.T) { {"Access-Control-Allow-Headers", "X-Test-Header"}, } for _, test := range headerTests { - assertHeader(t, w, test.key, test.expected) + assert.Header(t, resp, test.key, test.expected) } }) } func TestIP(t *testing.T) { - t.Parallel() testCases := map[string]struct { remoteAddr string headers map[string]string @@ -383,113 +314,87 @@ func TestIP(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/ip", nil) - r.RemoteAddr = tc.remoteAddr + + req, _ := http.NewRequest("GET", "/ip", nil) + req.RemoteAddr = tc.remoteAddr for k, v := range tc.headers { - r.Header.Set(k, v) + req.Header.Set(k, v) } - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + // this test does not use a real server, because we need to control + // the RemoteAddr field on the request object to make the test + // deterministic. + w := httptest.NewRecorder() + app.ServeHTTP(w, req) - var resp *ipResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + if w.Code != http.StatusOK { + t.Errorf("wanted status code %d, got %d", http.StatusOK, w.Code) } - if resp.Origin != tc.wantOrigin { - t.Fatalf("got %q, want %q", resp.Origin, tc.wantOrigin) + if ct := w.Header().Get("Content-Type"); ct != jsonContentType { + t.Errorf("expected content type %q, got %q", jsonContentType, ct) } + + result := must.Unmarshal[ipResponse](t, w.Body) + assert.Equal(t, result.Origin, tc.wantOrigin, "incorrect origin") }) } } func TestUserAgent(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/user-agent", nil) - r.Header.Set("User-Agent", "test") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequest(t, "GET", "/user-agent") + req.Header.Set("User-Agent", "test") - var resp *userAgentResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - - if resp.UserAgent != "test" { - t.Fatalf("%#v != \"test\"", resp.UserAgent) - } + resp := must.DoReq(t, client, req) + result := mustParseResponse[userAgentResponse](t, resp) + assert.Equal(t, "test", result.UserAgent, "incorrect user agent") } func TestHeaders(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/headers", nil) - r.Host = "test-host" - r.Header.Set("User-Agent", "test") - r.Header.Set("Foo-Header", "foo") - r.Header.Add("Bar-Header", "bar1") - r.Header.Add("Bar-Header", "bar2") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *headersResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - // Host header requires special treatment, because its an attribute of the + req := newTestRequest(t, "GET", "/headers") + req.Host = "test-host" + req.Header.Set("User-Agent", "test") + req.Header.Set("Foo-Header", "foo") + req.Header.Add("Bar-Header", "bar1") + req.Header.Add("Bar-Header", "bar2") + + resp := must.DoReq(t, client, req) + result := mustParseResponse[headersResponse](t, resp) + + // Host header requires special treatment, because it's a field on the // http.Request struct itself, not part of its headers map - host := resp.Headers[http.CanonicalHeaderKey("Host")] - if host == nil || host[0] != "test-host" { - t.Fatalf("expected Host header \"test-host\", got %#v", host) - } + host := result.Headers.Get("Host") + assert.Equal(t, req.Host, host, "missing or incorrect Host header") - for k, expectedValues := range r.Header { - values, ok := resp.Headers[http.CanonicalHeaderKey(k)] - if !ok { - t.Fatalf("expected header %#v in response", k) - } - if !reflect.DeepEqual(expectedValues, values) { - t.Fatalf("header %s value mismatch: %#v != %#v", k, values, expectedValues) - } + for k, expectedValues := range req.Header { + values := result.Headers.Values(k) + assert.DeepEqual(t, expectedValues, values, "missing or incorrect header for key %q", k) } } func TestPost(t *testing.T) { - t.Parallel() testRequestWithBody(t, "POST", "/post") } func TestPut(t *testing.T) { - t.Parallel() testRequestWithBody(t, "PUT", "/put") } func TestDelete(t *testing.T) { - t.Parallel() testRequestWithBody(t, "DELETE", "/delete") } func TestPatch(t *testing.T) { - t.Parallel() testRequestWithBody(t, "PATCH", "/patch") } func TestAnything(t *testing.T) { - t.Parallel() var ( - verbsWithReqBodies = []string{ + verbs = []string{ "GET", "DELETE", "PATCH", @@ -502,62 +407,58 @@ func TestAnything(t *testing.T) { } ) for _, path := range paths { - for _, verb := range verbsWithReqBodies { + for _, verb := range verbs { testRequestWithBody(t, verb, path) } } t.Run("HEAD", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("HEAD", "/anything", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertBodyEquals(t, w, "") - if contentLength := w.Header().Get("Content-Length"); contentLength != "" { - t.Fatalf("did not expect Content-Length in response to HEAD request") - } + req := newTestRequest(t, "HEAD", "/anything") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.BodyEquals(t, resp, "") + assert.Header(t, resp, "Content-Length", "") // responses to HEAD requests should not have a Content-Length header }) } -// getFuncName uses runtime type reflection to get the name of the given -// function. -// -// Cribbed from https://stackoverflow.com/a/70535822/151221 -func getFuncName(f interface{}) string { - parts := strings.Split((runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()), ".") - return parts[len(parts)-1] -} +func testRequestWithBody(t *testing.T, verb, path string) { + // getFuncName uses runtime type reflection to get the name of the given + // function. + // + // Cribbed from https://stackoverflow.com/a/70535822/151221 + getFuncName := func(f interface{}) string { + parts := strings.Split((runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()), ".") + return parts[len(parts)-1] + } -// getTestName expects a function named like testRequestWithBody__BodyTooBig -// and returns only the trailing BodyTooBig part. -func getTestName(prefix string, f interface{}) string { - name := strings.TrimPrefix(getFuncName(f), "testRequestWithBody") - return fmt.Sprintf("%s/%s", prefix, name) -} + // getTestName expects a function named like testRequestWithBody__BodyTooBig + // and returns only the trailing BodyTooBig part. + getTestName := func(prefix string, f interface{}) string { + name := strings.TrimPrefix(getFuncName(f), "testRequestWithBody") + return fmt.Sprintf("%s/%s", prefix, name) + } -func testRequestWithBody(t *testing.T, verb, path string) { type testFunc func(t *testing.T, verb, path string) testFuncs := []testFunc{ + testRequestWithBodyBinaryBody, testRequestWithBodyBodyTooBig, testRequestWithBodyEmptyBody, testRequestWithBodyFormEncodedBody, - testRequestWithBodyMultiPartBodyFiles, testRequestWithBodyFormEncodedBodyNoContentType, + testRequestWithBodyHTML, testRequestWithBodyInvalidFormEncodedBody, testRequestWithBodyInvalidJSON, testRequestWithBodyInvalidMultiPartBody, - testRequestWithBodyHTML, testRequestWithBodyJSON, testRequestWithBodyMultiPartBody, + testRequestWithBodyMultiPartBodyFiles, testRequestWithBodyQueryParams, testRequestWithBodyQueryParamsAndBody, - testRequestWithBodyBinaryBody, testRequestWithBodyTransferEncoding, } for _, testFunc := range testFuncs { testFunc := testFunc - t.Run(getTestName(verb, testFunc), func(t *testing.T) { t.Parallel() testFunc(t, verb, path) @@ -581,40 +482,20 @@ func testRequestWithBodyBinaryBody(t *testing.T, verb string, path string) { t.Run("content type/"+test.contentType, func(t *testing.T) { t.Parallel() - testBody := bytes.NewReader([]byte(test.requestBody)) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader([]byte(test.requestBody))) + req.Header.Set("Content-Type", test.contentType) - r, _ := http.NewRequest(verb, path, testBody) - r.Header.Set("Content-Type", test.contentType) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, nilValues, "expected empty form") + assert.DeepEqual(t, result.JSON, nil, "expected nil json") expected := "data:" + test.contentType + ";base64," + base64.StdEncoding.EncodeToString([]byte(test.requestBody)) - - if resp.Data != expected { - t.Fatalf("expected binary encoded response data: %#v got %#v", expected, resp.Data) - } - if resp.JSON != nil { - t.Fatalf("expected nil response json, got %#v", resp.JSON) - } - - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) > 0 { - t.Fatalf("expected no form data, got %#v", resp.Form) - } + assert.Equal(t, expected, result.Data, "expected binary encoded response data") }) } } @@ -632,36 +513,19 @@ func testRequestWithBodyEmptyBody(t *testing.T, verb string, path string) { test := test t.Run("content type/"+test.contentType, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest(verb, path, nil) - r.Header.Set("Content-Type", test.contentType) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequest(t, verb, path) + req.Header.Set("Content-Type", test.contentType) - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - if resp.Data != "" { - t.Fatalf("expected empty response data, got %#v", resp.Data) - } - if resp.JSON != nil { - t.Fatalf("expected nil response json, got %#v", resp.JSON) - } - - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) > 0 { - t.Fatalf("expected no form data, got %#v", resp.Form) - } + assert.Equal(t, result.Data, "", "expected empty response data") + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, nilValues, "expected empty form") + assert.DeepEqual(t, result.JSON, nil, "expected nil JSON") }) } } @@ -672,58 +536,29 @@ func testRequestWithBodyFormEncodedBody(t *testing.T, verb, path string) { params.Add("bar", "bar1") params.Add("bar", "bar2") - r, _ := http.NewRequest(verb, path, strings.NewReader(params.Encode())) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequestWithBody(t, verb, path, strings.NewReader(params.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) - } + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) != len(params) { - t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form)) - } - for k, expectedValues := range params { - values, ok := resp.Form[k] - if !ok { - t.Fatalf("expected form field %#v in response", k) - } - if !reflect.DeepEqual(expectedValues, values) { - t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues) - } - } + assert.DeepEqual(t, params, result.Form, "form data mismatch") + assert.Equal(t, verb, result.Method, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.JSON, nil, "expected nil json") } func testRequestWithBodyHTML(t *testing.T, verb, path string) { data := "<html><body><h1>hello world</h1></body></html>" - r, _ := http.NewRequest(verb, path, strings.NewReader(data)) - r.Header.Set("Content-Type", htmlContentType) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequestWithBody(t, verb, path, strings.NewReader(data)) + req.Header.Set("Content-Type", htmlContentType) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - // We do not use json.Unmarshal here which would unescape any escaped characters. - // For httpbin compatibility, we need to verify the data is returned as-is without - // escaping. - respBody := w.Body.String() - if !strings.Contains(respBody, data) { - t.Fatalf("response data mismatch, %#v != %#v", respBody, data) - } + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, jsonContentType) + assert.BodyContains(t, resp, data) } func testRequestWithBodyFormEncodedBodyNoContentType(t *testing.T, verb, path string) { @@ -732,37 +567,23 @@ func testRequestWithBodyFormEncodedBodyNoContentType(t *testing.T, verb, path st params.Add("bar", "bar1") params.Add("bar", "bar2") - r, _ := http.NewRequest(verb, path, strings.NewReader(params.Encode())) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequestWithBody(t, verb, path, strings.NewReader(params.Encode())) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, nilValues, "expected empty form") + assert.DeepEqual(t, result.JSON, nil, "expected nil JSON") - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) != 0 { - t.Fatalf("expected no form values, got %d", len(resp.Form)) - } // Because we did not set an content type, httpbin will return the base64 encoded data. expectedBody := "data:application/octet-stream;base64," + base64.StdEncoding.EncodeToString([]byte(params.Encode())) - if string(resp.Data) != expectedBody { - t.Fatalf("response data mismatch, %#v != %#v", string(resp.Data), expectedBody) - } + assert.Equal(t, expectedBody, result.Data, "response data mismatch") } func testRequestWithBodyMultiPartBody(t *testing.T, verb, path string) { - params := map[string][]string{ + params := url.Values{ "foo": {"foo"}, "bar": {"bar1", "bar2"}, } @@ -784,38 +605,17 @@ func testRequestWithBodyMultiPartBody(t *testing.T, verb, path string) { } mw.Close() - r, _ := http.NewRequest(verb, path, bytes.NewReader(body.Bytes())) - r.Header.Set("Content-Type", mw.FormDataContentType()) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", mw.FormDataContentType()) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) - } - - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) != len(params) { - t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form)) - } - for k, expectedValues := range params { - values, ok := resp.Form[k] - if !ok { - t.Fatalf("expected form field %#v in response", k) - } - if !reflect.DeepEqual(expectedValues, values) { - t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues) - } - } + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, params, "form values mismatch") + assert.DeepEqual(t, result.JSON, nil, "expected nil JSON") } func testRequestWithBodyMultiPartBodyFiles(t *testing.T, verb, path string) { @@ -827,52 +627,37 @@ func testRequestWithBodyMultiPartBodyFiles(t *testing.T, verb, path string) { part.Write([]byte("hello world")) mw.Close() - r, _ := http.NewRequest(verb, path, bytes.NewReader(body.Bytes())) - r.Header.Set("Content-Type", mw.FormDataContentType()) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", mw.FormDataContentType()) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) - } - - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Form, nilValues, "expected empty form") + assert.DeepEqual(t, result.JSON, nil, "expected nil JSON") // verify that the file we added is present in the `files` attribute of the // response, with the field as key and content as value - wantFiles := map[string][]string{ + wantFiles := url.Values{ "fieldname": {"hello world"}, } - if !reflect.DeepEqual(resp.Files, wantFiles) { - t.Fatalf("want resp.Files = %#v, got %#v", wantFiles, resp.Files) - } - - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } + assert.DeepEqual(t, result.Files, wantFiles, "files mismatch") } func testRequestWithBodyInvalidFormEncodedBody(t *testing.T, verb, path string) { - r, _ := http.NewRequest(verb, path, strings.NewReader("%ZZ")) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusBadRequest) + req := newTestRequestWithBody(t, verb, path, strings.NewReader("%ZZ")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) } func testRequestWithBodyInvalidMultiPartBody(t *testing.T, verb, path string) { - r, _ := http.NewRequest(verb, path, strings.NewReader("%ZZ")) - r.Header.Set("Content-Type", "multipart/form-data; etc") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusBadRequest) + req := newTestRequestWithBody(t, verb, path, strings.NewReader("%ZZ")) + req.Header.Set("Content-Type", "multipart/form-data; etc") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) } func testRequestWithBodyJSON(t *testing.T, verb, path string) { @@ -882,7 +667,7 @@ func testRequestWithBodyJSON(t *testing.T, verb, path string) { Baz []float64 Quux map[int]string } - input := &testInput{ + input := testInput{ Foo: "foo", Bar: 123, Baz: []float64{1.0, 1.1, 1.2}, @@ -890,63 +675,42 @@ func testRequestWithBodyJSON(t *testing.T, verb, path string) { } inputBody, _ := json.Marshal(input) - r, _ := http.NewRequest(verb, path, bytes.NewReader(inputBody)) - r.Header.Set("Content-Type", "application/json; charset=utf-8") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader(inputBody)) + req.Header.Set("Content-Type", "application/json; charset=utf-8") - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - if resp.Data != string(inputBody) { - t.Fatalf("expected data == %#v, got %#v", string(inputBody), resp.Data) - } - if len(resp.Args) > 0 { - t.Fatalf("expected no query params, got %#v", resp.Args) - } - if resp.Method != verb { - t.Fatalf("expected method to be %s, got %s", verb, resp.Method) - } - if len(resp.Form) != 0 { - t.Fatalf("expected no form values, got %d", len(resp.Form)) - } + assert.Equal(t, result.Data, string(inputBody), "response data mismatch") + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Args, nilValues, "expected empty args") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, nilValues, "form values mismatch") // Need to re-marshall just the JSON field from the response in order to // re-unmarshall it into our expected type - outputBodyBytes, _ := json.Marshal(resp.JSON) - output := &testInput{} - err = json.Unmarshal(outputBodyBytes, output) - if err != nil { - t.Fatalf("failed to round-trip JSON: coult not re-unmarshal JSON: %s", err) - } + roundTrippedInputBytes, err := json.Marshal(result.JSON) + assert.NilError(t, err) - if !reflect.DeepEqual(input, output) { - t.Fatalf("failed to round-trip JSON: %#v != %#v", output, input) + var roundTrippedInput testInput + if err := json.Unmarshal(roundTrippedInputBytes, &roundTrippedInput); err != nil { + t.Fatalf("failed to round-trip JSON: coult not re-unmarshal JSON: %s", err) } + assert.DeepEqual(t, input, roundTrippedInput, "round-tripped JSON mismatch") } func testRequestWithBodyInvalidJSON(t *testing.T, verb, path string) { - r, _ := http.NewRequest("POST", "/post", bytes.NewReader([]byte("foo"))) - r.Header.Set("Content-Type", "application/json; charset=utf-8") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusBadRequest) + req := newTestRequestWithBody(t, verb, path, strings.NewReader("foo")) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) } func testRequestWithBodyBodyTooBig(t *testing.T, verb, path string) { body := make([]byte, maxBodySize+1) - - r, _ := http.NewRequest("POST", "/post", bytes.NewReader(body)) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusBadRequest) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader(body)) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) } func testRequestWithBodyQueryParams(t *testing.T, verb, path string) { @@ -955,30 +719,17 @@ func testRequestWithBodyQueryParams(t *testing.T, verb, path string) { params.Add("bar", "bar1") params.Add("bar", "bar2") - r, _ := http.NewRequest("POST", fmt.Sprintf("/post?%s", params.Encode()), nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequest(t, verb, fmt.Sprintf("%s?%s", path, params.Encode())) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) - } + assert.DeepEqual(t, result.Args, params, "args mismatch") - if resp.Args.Encode() != params.Encode() { - t.Fatalf("expected args = %#v in response, got %#v", params.Encode(), resp.Args.Encode()) - } - - if resp.Method != "POST" { - t.Fatalf("expected method to be POST, got %s", resp.Method) - } - - if len(resp.Form) > 0 { - t.Fatalf("expected form data, got %#v", resp.Form) - } + // extra validation + assert.Equal(t, result.Method, verb, "method mismatch") + assert.DeepEqual(t, result.Files, nilValues, "expected empty files") + assert.DeepEqual(t, result.Form, nilValues, "form values mismatch") + assert.DeepEqual(t, result.JSON, nil, "expected nil JSON") } func testRequestWithBodyQueryParamsAndBody(t *testing.T, verb, path string) { @@ -992,46 +743,18 @@ func testRequestWithBodyQueryParamsAndBody(t *testing.T, verb, path string) { form.Add("form2", "bar1") form.Add("form2", "bar2") - url := fmt.Sprintf("/post?%s", args.Encode()) - body := strings.NewReader(form.Encode()) - - r, _ := http.NewRequest("POST", url, body) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) - } + url := fmt.Sprintf("%s?%s", path, args.Encode()) + req := newTestRequestWithBody(t, verb, url, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp := must.DoReq(t, client, req) - if resp.Args.Encode() != args.Encode() { - t.Fatalf("expected args = %#v in response, got %#v", args.Encode(), resp.Args.Encode()) - } - - if resp.Method != "POST" { - t.Fatalf("expected method to be POST, got %s", resp.Method) - } - - if len(resp.Form) != len(form) { - t.Fatalf("expected %d form values, got %d", len(form), len(resp.Form)) - } - for k, expectedValues := range form { - values, ok := resp.Form[k] - if !ok { - t.Fatalf("expected form field %#v in response", k) - } - if !reflect.DeepEqual(expectedValues, values) { - t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues) - } - } + result := mustParseResponse[bodyResponse](t, resp) + assert.Equal(t, verb, result.Method, "method mismatch") + assert.Equal(t, args.Encode(), result.Args.Encode(), "args mismatch") + assert.Equal(t, form.Encode(), result.Form.Encode(), "form mismatch") } -func testRequestWithBodyTransferEncoding(t *testing.T, verb string, path string) { +func testRequestWithBodyTransferEncoding(t *testing.T, verb, path string) { testCases := []struct { given string want string @@ -1045,34 +768,22 @@ func testRequestWithBodyTransferEncoding(t *testing.T, verb string, path string) t.Run("transfer-encoding/"+tc.given, func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - - r, _ := http.NewRequest(verb, srv.URL+path, bytes.NewReader([]byte("{}"))) + req := newTestRequestWithBody(t, verb, path, bytes.NewReader([]byte("{}"))) if tc.given != "" { - r.TransferEncoding = []string{tc.given} + req.TransferEncoding = []string{tc.given} } - httpResp, err := srv.Client().Do(r) - assertNilError(t, err) - assertIntEqual(t, httpResp.StatusCode, http.StatusOK) - - var resp *bodyResponse - if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil { - t.Fatalf("failed to unmarshal body from JSON: %s", err) - } + resp := must.DoReq(t, client, req) + result := mustParseResponse[bodyResponse](t, resp) - got := resp.Headers.Get("Transfer-Encoding") - if got != tc.want { - t.Errorf("expected Transfer-Encoding %#v, got %#v", tc.want, got) - } + got := result.Headers.Get("Transfer-Encoding") + assert.Equal(t, got, tc.want, "Transfer-Encoding header mismatch") }) } } // TODO: implement and test more complex /status endpoint func TestStatus(t *testing.T) { - t.Parallel() redirectHeaders := map[string]string{ "Location": "/redirect/1", } @@ -1113,22 +824,12 @@ func TestStatus(t *testing.T) { test := test t.Run(fmt.Sprintf("ok/status/%d", test.code), func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", fmt.Sprintf("/status/%d", test.code), nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.code) - - if test.headers != nil { - for key, val := range test.headers { - assertHeader(t, w, key, val) - } - } - - if test.body != "" { - if w.Body.String() != test.body { - t.Fatalf("expected body %#v, got %#v", test.body, w.Body.String()) - } + req, _ := http.NewRequest("GET", srv.URL+fmt.Sprintf("/status/%d", test.code), nil) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) + assert.BodyEquals(t, resp, test.body) + for key, val := range test.headers { + assert.Header(t, resp, key, val) } }) } @@ -1148,31 +849,28 @@ func TestStatus(t *testing.T) { test := test t.Run("error"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.status) }) } } func TestUnstable(t *testing.T) { - t.Parallel() t.Run("ok_no_seed", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/unstable", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - if w.Code != 200 && w.Code != 500 { - t.Fatalf("expected status code 200 or 500, got %d", w.Code) + req := newTestRequest(t, "GET", "/unstable") + resp := must.DoReq(t, client, req) + if resp.StatusCode != 200 && resp.StatusCode != 500 { + t.Fatalf("expected status code 200 or 500, got %d", resp.StatusCode) } }) - // rand.NewSource(1234567890).Float64() => 0.08 tests := []struct { url string status int }{ + // rand.NewSource(1234567890).Float64() => 0.08 {"/unstable?seed=1234567890", 500}, {"/unstable?seed=1234567890&failure_rate=0.07", 200}, } @@ -1180,10 +878,9 @@ func TestUnstable(t *testing.T) { test := test t.Run("ok_"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.status) }) } @@ -1195,11 +892,10 @@ func TestUnstable(t *testing.T) { test := test t.Run("bad"+test, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - if w.Code != 200 && w.Code != 500 { - t.Fatalf("expected status code 200 or 500, got %d", w.Code) + req := newTestRequest(t, "GET", test) + resp := must.DoReq(t, client, req) + if resp.StatusCode != 200 && resp.StatusCode != 500 { + t.Fatalf("expected status code 200 or 500, got %d", resp.StatusCode) } }) } @@ -1217,79 +913,54 @@ func TestUnstable(t *testing.T) { test := test t.Run("bad"+test, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusBadRequest) + req := newTestRequest(t, "GET", test) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) }) } } -func TestResponseHeaders__OK(t *testing.T) { - t.Parallel() - headers := map[string][]string{ - "Foo": {"foo"}, - "Bar": {"bar1, bar2"}, - } +func TestResponseHeaders(t *testing.T) { + t.Run("ok", func(t *testing.T) { + t.Parallel() - params := url.Values{} - for k, vs := range headers { - for _, v := range vs { - params.Add(k, v) + wantHeaders := url.Values{ + "Foo": {"foo"}, + "Bar": {"bar1", "bar2"}, } - } - r, _ := http.NewRequest("GET", fmt.Sprintf("/response-headers?%s", params.Encode()), nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/response-headers?%s", srv.URL, wantHeaders.Encode()), nil) + resp := must.DoReq(t, client, req) + result := mustParseResponse[http.Header](t, resp) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + for k, expectedValues := range wantHeaders { + // expected headers should be present in the HTTP response itself + respValues := resp.Header[k] + assert.DeepEqual(t, expectedValues, respValues, "HTTP response headers mismatch") - for k, expectedValues := range headers { - values, ok := w.Header()[k] - if !ok { - t.Fatalf("expected header %s in response headers", k) + // they should also be reflected in the decoded JSON resposne + resultValues := result[k] + assert.DeepEqual(t, expectedValues, resultValues, "JSON response headers mismatch") } - if !reflect.DeepEqual(values, expectedValues) { - t.Fatalf("expected key values %#v for header %s, got %#v", expectedValues, k, values) - } - } - - resp := &http.Header{} - err := json.Unmarshal(w.Body.Bytes(), resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + }) - for k, expectedValues := range headers { - values, ok := (*resp)[k] - if !ok { - t.Fatalf("expected header %s in response body", k) - } - if !reflect.DeepEqual(values, expectedValues) { - t.Fatalf("expected key values %#v for header %s, got %#v", expectedValues, k, values) - } - } -} + t.Run("override content-type", func(t *testing.T) { + t.Parallel() -func TestResponseHeaders__OverrideContentType(t *testing.T) { - t.Parallel() - contentType := "text/test" + contentType := "text/test" - params := url.Values{} - params.Set("Content-Type", contentType) + params := url.Values{} + params.Set("Content-Type", contentType) - r, _ := http.NewRequest("GET", fmt.Sprintf("/response-headers?%s", params.Encode()), nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/response-headers?%s", srv.URL, params.Encode()), nil) + resp := must.DoReq(t, client, req) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, contentType) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, contentType) + }) } func TestRedirects(t *testing.T) { - t.Parallel() tests := []struct { requestURL string expectedLocation string @@ -1319,13 +990,13 @@ func TestRedirects(t *testing.T) { test := test t.Run("ok"+test.requestURL, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.requestURL, nil) - r.Host = "host" - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", test.expectedLocation) + req := newTestRequest(t, "GET", test.requestURL) + req.Host = "host" + resp := must.DoReq(t, client, req) + + assert.StatusCode(t, resp, http.StatusFound) + assert.Header(t, resp, "Location", test.expectedLocation) }) } @@ -1356,17 +1027,14 @@ func TestRedirects(t *testing.T) { test := test t.Run("error"+test.requestURL, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.requestURL, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", test.requestURL) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } } func TestRedirectTo(t *testing.T) { - t.Parallel() okTests := []struct { url string expectedLocation string @@ -1385,12 +1053,10 @@ func TestRedirectTo(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.expectedStatus) - assertHeader(t, w, "Location", test.expectedLocation) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) + assert.Header(t, resp, "Location", test.expectedLocation) }) } @@ -1408,24 +1074,19 @@ func TestRedirectTo(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } - allowListHandler := New( - WithAllowedRedirectDomains([]string{"httpbingo.org", "example.org"}), - WithObserver(StdLogObserver(log.New(io.Discard, "", 0))), - ).Handler() - + // error message matches redirect configuration in global shared test app allowedDomainsError := `Forbidden redirect URL. Please be careful with this link. Allowed redirect destinations: - example.org - httpbingo.org +- www.example.com ` allowListTests := []struct { @@ -1442,188 +1103,169 @@ Allowed redirect destinations: test := test t.Run("allowlist"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - allowListHandler.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) if test.expectedStatus >= 400 { - assertBodyEquals(t, w, allowedDomainsError) + assert.BodyEquals(t, resp, allowedDomainsError) } }) } } func TestCookies(t *testing.T) { - t.Parallel() - testCookies := func(t *testing.T, cookies cookiesResponse) { - r, _ := http.NewRequest("GET", "/cookies", nil) - for k, v := range cookies { - r.AddCookie(&http.Cookie{ - Name: k, - Value: v, - }) - } - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - resp := cookiesResponse{} - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + t.Run("get", func(t *testing.T) { + testCases := map[string]struct { + cookies cookiesResponse + }{ + "ok/no cookies": { + cookies: cookiesResponse{}, + }, + "ok/one cookie": { + cookies: cookiesResponse{ + "k1": "v1", + }, + }, + "ok/many cookies": { + cookies: cookiesResponse{ + "k1": "v1", + "k2": "v2", + "k3": "v3", + }, + }, + } + + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + req := newTestRequest(t, "GET", "/cookies") + for k, v := range tc.cookies { + req.AddCookie(&http.Cookie{ + Name: k, + Value: v, + }) + } - if !reflect.DeepEqual(cookies, resp) { - t.Fatalf("expected cookies %#v, got %#v", cookies, resp) + resp := must.DoReq(t, client, req) + result := mustParseResponse[cookiesResponse](t, resp) + assert.DeepEqual(t, result, tc.cookies, "cookies mismatch") + }) } - } - - t.Run("ok/no cookies", func(t *testing.T) { - t.Parallel() - testCookies(t, cookiesResponse{}) }) - t.Run("ok/cookies", func(t *testing.T) { + t.Run("set", func(t *testing.T) { t.Parallel() - testCookies(t, cookiesResponse{ - "k1": "v1", - "k2": "v2", - }) - }) -} - -func TestSetCookies(t *testing.T) { - t.Parallel() - cookies := cookiesResponse{ - "k1": "v1", - "k2": "v2", - } - params := &url.Values{} - for k, v := range cookies { - params.Set(k, v) - } + cookies := cookiesResponse{ + "k1": "v1", + "k2": "v2", + } + params := &url.Values{} + for k, v := range cookies { + params.Set(k, v) + } - r, _ := http.NewRequest("GET", fmt.Sprintf("/cookies/set?%s", params.Encode()), nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequest(t, "GET", "/cookies/set?"+params.Encode()) + resp := must.DoReq(t, client, req) - assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", "/cookies") + assert.StatusCode(t, resp, http.StatusFound) + assert.Header(t, resp, "Location", "/cookies") - for _, c := range w.Result().Cookies() { - v, ok := cookies[c.Name] - if !ok { - t.Fatalf("got unexpected cookie %s=%s", c.Name, c.Value) - } - if v != c.Value { - t.Fatalf("got cookie %s=%s, expected value in %#v", c.Name, c.Value, v) + for _, c := range resp.Cookies() { + v, ok := cookies[c.Name] + if !ok { + t.Fatalf("got unexpected cookie %s=%s", c.Name, c.Value) + } + assert.Equal(t, c.Value, v, "value mismatch for cookie %q", c.Name) } - } -} + }) -func TestDeleteCookies(t *testing.T) { - t.Parallel() - cookies := cookiesResponse{ - "k1": "v1", - "k2": "v2", - } + t.Run("delete", func(t *testing.T) { + t.Parallel() + + cookies := cookiesResponse{ + "k1": "v1", + "k2": "v2", + } - toDelete := "k2" - params := &url.Values{} - params.Set(toDelete, "") + toDelete := "k2" + params := &url.Values{} + params.Set(toDelete, "") - r, _ := http.NewRequest("GET", fmt.Sprintf("/cookies/delete?%s", params.Encode()), nil) - for k, v := range cookies { - r.AddCookie(&http.Cookie{ - Name: k, - Value: v, - }) - } - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequest(t, "GET", "/cookies/delete?"+params.Encode()) + for k, v := range cookies { + req.AddCookie(&http.Cookie{ + Name: k, + Value: v, + }) + } - assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", "/cookies") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusFound) + assert.Header(t, resp, "Location", "/cookies") - for _, c := range w.Result().Cookies() { - if c.Name == toDelete { - if time.Since(c.Expires) < (24*365-1)*time.Hour { - t.Fatalf("expected cookie %s to be deleted; got %#v", toDelete, c) + for _, c := range resp.Cookies() { + if c.Name == toDelete { + if time.Since(c.Expires) < (24*365-1)*time.Hour { + t.Fatalf("expected cookie %s to be deleted; got %#v", toDelete, c) + } } } - } + }) } func TestBasicAuth(t *testing.T) { - t.Parallel() t.Run("ok", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/basic-auth/user/pass", nil) - r.SetBasicAuth("user", "pass") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequest(t, "GET", "/basic-auth/user/pass") + req.SetBasicAuth("user", "pass") - resp := &authResponse{} - json.Unmarshal(w.Body.Bytes(), resp) - - expectedResp := &authResponse{ + resp := must.DoReq(t, client, req) + result := mustParseResponse[authResponse](t, resp) + expectedResult := authResponse{ Authorized: true, User: "user", } - if !reflect.DeepEqual(resp, expectedResp) { - t.Fatalf("expected response %#v, got %#v", expectedResp, resp) - } + assert.DeepEqual(t, result, expectedResult, "expected authorized user") }) t.Run("error/no auth", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/basic-auth/user/pass", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusUnauthorized) - assertContentType(t, w, jsonContentType) - assertHeader(t, w, "WWW-Authenticate", `Basic realm="Fake Realm"`) - resp := &authResponse{} - json.Unmarshal(w.Body.Bytes(), resp) + req := newTestRequest(t, "GET", "/basic-auth/user/pass") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusUnauthorized) + assert.ContentType(t, resp, jsonContentType) + assert.Header(t, resp, "WWW-Authenticate", `Basic realm="Fake Realm"`) - expectedResp := &authResponse{ + result := must.Unmarshal[authResponse](t, resp.Body) + expectedResult := authResponse{ Authorized: false, User: "", } - if !reflect.DeepEqual(resp, expectedResp) { - t.Fatalf("expected response %#v, got %#v", expectedResp, resp) - } + assert.DeepEqual(t, result, expectedResult, "expected unauthorized user") }) t.Run("error/bad auth", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/basic-auth/user/pass", nil) - r.SetBasicAuth("bad", "auth") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusUnauthorized) - assertContentType(t, w, jsonContentType) - assertHeader(t, w, "WWW-Authenticate", `Basic realm="Fake Realm"`) + req := newTestRequest(t, "GET", "/basic-auth/user/pass") + req.SetBasicAuth("bad", "auth") - resp := &authResponse{} - json.Unmarshal(w.Body.Bytes(), resp) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusUnauthorized) + assert.ContentType(t, resp, jsonContentType) + assert.Header(t, resp, "WWW-Authenticate", `Basic realm="Fake Realm"`) - expectedResp := &authResponse{ + result := must.Unmarshal[authResponse](t, resp.Body) + expectedResult := authResponse{ Authorized: false, User: "bad", } - if !reflect.DeepEqual(resp, expectedResp) { - t.Fatalf("expected response %#v, got %#v", expectedResp, resp) - } + assert.DeepEqual(t, result, expectedResult, "expected unauthorized user") }) errorTests := []struct { @@ -1638,62 +1280,45 @@ func TestBasicAuth(t *testing.T) { test := test t.Run("error"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - r.SetBasicAuth("foo", "bar") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + req := newTestRequest(t, "GET", test.url) + req.SetBasicAuth("foo", "bar") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.status) }) } } func TestHiddenBasicAuth(t *testing.T) { - t.Parallel() t.Run("ok", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/hidden-basic-auth/user/pass", nil) - r.SetBasicAuth("user", "pass") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + req := newTestRequest(t, "GET", "/hidden-basic-auth/user/pass") + req.SetBasicAuth("user", "pass") - resp := &authResponse{} - json.Unmarshal(w.Body.Bytes(), resp) - - expectedResp := &authResponse{ + resp := must.DoReq(t, client, req) + result := mustParseResponse[authResponse](t, resp) + expectedResult := authResponse{ Authorized: true, User: "user", } - if !reflect.DeepEqual(resp, expectedResp) { - t.Fatalf("expected response %#v, got %#v", expectedResp, resp) - } + assert.DeepEqual(t, result, expectedResult, "expected authorized user") }) t.Run("error/no auth", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/hidden-basic-auth/user/pass", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusNotFound) - if w.Header().Get("WWW-Authenticate") != "" { - t.Fatal("did not expect WWW-Authenticate header") - } + req := newTestRequest(t, "GET", "/hidden-basic-auth/user/pass") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusNotFound) + assert.Header(t, resp, "WWW-Authenticate", "") }) t.Run("error/bad auth", func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/hidden-basic-auth/user/pass", nil) - r.SetBasicAuth("bad", "auth") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusNotFound) - if w.Header().Get("WWW-Authenticate") != "" { - t.Fatal("did not expect WWW-Authenticate header") - } + req := newTestRequest(t, "GET", "/hidden-basic-auth/user/pass") + req.SetBasicAuth("bad", "auth") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusNotFound) + assert.Header(t, resp, "WWW-Authenticate", "") }) errorTests := []struct { @@ -1708,17 +1333,16 @@ func TestHiddenBasicAuth(t *testing.T) { test := test t.Run("error"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - r.SetBasicAuth("foo", "bar") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + + req := newTestRequest(t, "GET", test.url) + req.SetBasicAuth("foo", "bar") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.status) }) } } func TestDigestAuth(t *testing.T) { - t.Parallel() tests := []struct { url string status int @@ -1741,60 +1365,54 @@ func TestDigestAuth(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.status) }) } t.Run("ok", func(t *testing.T) { t.Parallel() - // Example captured from a successful login in a browser - authorization := `Digest username="user", - realm="go-httpbin", - nonce="6fb213c6593975c877bb1247370527ad", - uri="/digest-auth/auth/user/pass/MD5", - algorithm=MD5, - response="9b7a05d78051b4f668356eedf32f55d6", - opaque="fd1c386a015a2bb7c41585f54329ce91", - qop=auth, - nc=00000001, - cnonce="aaab705226af5bd4"` - - url := "/digest-auth/auth/user/pass/MD5" - r, _ := http.NewRequest("GET", url, nil) - r.RequestURI = url - r.Header.Set("Authorization", authorization) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - - resp := &authResponse{} - json.Unmarshal(w.Body.Bytes(), resp) - - expectedResp := &authResponse{ + // Example captured from a successful login in a browser + authorization := strings.Join([]string{ + `Digest username="user"`, + `realm="go-httpbin"`, + `nonce="6fb213c6593975c877bb1247370527ad"`, + `uri="/digest-auth/auth/user/pass/MD5"`, + `algorithm=MD5`, + `response="9b7a05d78051b4f668356eedf32f55d6"`, + `opaque="fd1c386a015a2bb7c41585f54329ce91"`, + `qop=auth`, + `nc=00000001`, + `cnonce="aaab705226af5bd4"`, + }, ", ") + + req := newTestRequest(t, "GET", "/digest-auth/auth/user/pass/MD5") + req.Header.Set("Authorization", authorization) + + resp := must.DoReq(t, client, req) + result := mustParseResponse[authResponse](t, resp) + expectedResult := authResponse{ Authorized: true, User: "user", } - if !reflect.DeepEqual(resp, expectedResp) { - t.Fatalf("expected response %#v, got %#v", expectedResp, resp) - } + assert.DeepEqual(t, result, expectedResult, "expected authorized user") }) } func TestGzip(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/gzip", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertContentType(t, w, "application/json; encoding=utf-8") - assertHeader(t, w, "Content-Encoding", "gzip") - assertStatusCode(t, w, http.StatusOK) + req := newTestRequest(t, "GET", "/gzip") + req.Header.Set("Accept-Encoding", "none") // disable automagic gzip decompression in default http client - zippedContentLengthStr := w.Header().Get("Content-Length") + resp := must.DoReq(t, client, req) + assert.Header(t, resp, "Content-Encoding", "gzip") + assert.ContentType(t, resp, jsonContentType) + assert.StatusCode(t, resp, http.StatusOK) + + zippedContentLengthStr := resp.Header.Get("Content-Length") if zippedContentLengthStr == "" { t.Fatalf("missing Content-Length header in response") } @@ -1804,7 +1422,7 @@ func TestGzip(t *testing.T) { t.Fatalf("error converting Content-Lengh %v to integer: %s", zippedContentLengthStr, err) } - gzipReader, err := gzip.NewReader(w.Body) + gzipReader, err := gzip.NewReader(resp.Body) if err != nil { t.Fatalf("error creating gzip reader: %s", err) } @@ -1813,12 +1431,10 @@ func TestGzip(t *testing.T) { if err != nil { t.Fatalf("error reading gzipped body: %s", err) } - var resp *noBodyResponse - if err := json.Unmarshal(unzippedBody, &resp); err != nil { - t.Fatalf("error unmarshalling response: %s", err) - } - if resp.Gzipped != true { + result := must.Unmarshal[noBodyResponse](t, bytes.NewBuffer(unzippedBody)) + + if result.Gzipped != true { t.Fatalf("expected resp.Gzipped == true") } @@ -1829,15 +1445,15 @@ func TestGzip(t *testing.T) { func TestDeflate(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/deflate", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertContentType(t, w, "application/json; encoding=utf-8") - assertHeader(t, w, "Content-Encoding", "deflate") - assertStatusCode(t, w, http.StatusOK) + req := newTestRequest(t, "GET", "/deflate") + resp := must.DoReq(t, client, req) + + assert.ContentType(t, resp, "application/json; encoding=utf-8") + assert.Header(t, resp, "Content-Encoding", "deflate") + assert.StatusCode(t, resp, http.StatusOK) - contentLengthHeader := w.Header().Get("Content-Length") + contentLengthHeader := resp.Header.Get("Content-Length") if contentLengthHeader == "" { t.Fatalf("missing Content-Length header in response") } @@ -1847,7 +1463,7 @@ func TestDeflate(t *testing.T) { t.Fatal(err) } - reader, err := zlib.NewReader(w.Body) + reader, err := zlib.NewReader(resp.Body) if err != nil { t.Fatal(err) } @@ -1856,13 +1472,9 @@ func TestDeflate(t *testing.T) { t.Fatal(err) } - var resp *noBodyResponse - err = json.Unmarshal(body, &resp) - if err != nil { - t.Fatalf("error unmarshalling response: %s", err) - } + result := must.Unmarshal[noBodyResponse](t, bytes.NewBuffer(body)) - if resp.Deflated != true { + if result.Deflated != true { t.Fatalf("expected resp.Deflated == true") } @@ -1873,6 +1485,7 @@ func TestDeflate(t *testing.T) { func TestStream(t *testing.T) { t.Parallel() + okTests := []struct { url string expectedLines int @@ -1887,19 +1500,14 @@ func TestStream(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - resp, err := http.Get(srv.URL + test.url) - assertNil(t, err) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) defer resp.Body.Close() // Expect empty content-length due to streaming response - assertHeader(t, resp, "Content-Length", "") - - if len(resp.TransferEncoding) != 1 || resp.TransferEncoding[0] != "chunked" { - t.Fatalf("expected Transfer-Encoding: chunked, got %#v", resp.TransferEncoding) - } + assert.Header(t, resp, "Content-Length", "") + assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "expected Transfer-Encoding: chunked") var sr *streamResponse @@ -1934,16 +1542,17 @@ func TestStream(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.code) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) }) } } func TestDelay(t *testing.T) { t.Parallel() + okTests := []struct { url string expectedDelay time.Duration @@ -1961,22 +1570,13 @@ func TestDelay(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - start := time.Now() - - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + start := time.Now() + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) elapsed := time.Since(start) - assertStatusCode(t, w, http.StatusOK) - assertHeader(t, w, "Content-Type", jsonContentType) - - var resp *bodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("error unmarshalling response: %s", err) - } + _ = mustParseResponse[bodyResponse](t, resp) if elapsed < test.expectedDelay { t.Fatalf("expected delay of %s, got %s", test.expectedDelay, elapsed) @@ -1986,29 +1586,30 @@ func TestDelay(t *testing.T) { t.Run("handle cancelation", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - client := http.Client{ - Timeout: time.Duration(10 * time.Millisecond), - } - _, err := client.Get(srv.URL + "/delay/1") - if err == nil { - t.Fatal("expected timeout error") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + req := newTestRequest(t, "GET", "/delay/1").WithContext(ctx) + _, err := client.Do(req) + if !os.IsTimeout(err) { + t.Errorf("expected timeout error, got %v", err) } }) t.Run("cancelation causes 499", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() - r, _ := http.NewRequestWithContext(ctx, "GET", "/delay/1s", nil) + // use httptest.NewRecorder rather than a live httptest.NewServer + // because only the former will let us inspect the status code. w := httptest.NewRecorder() - app.ServeHTTP(w, r) - + req, _ := http.NewRequestWithContext(ctx, "GET", "/delay/1s", nil) + app.ServeHTTP(w, req) if w.Code != 499 { - t.Errorf("expected 499 response, got %d", w.Code) + t.Errorf("expected 499, got %d", w.Code) } }) @@ -2031,16 +1632,17 @@ func TestDelay(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.code) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) }) } } func TestDrip(t *testing.T) { t.Parallel() + okTests := []struct { params *url.Values duration time.Duration @@ -2067,8 +1669,7 @@ func TestDrip(t *testing.T) { {&url.Values{"numbytes": {"101"}}, 0, 101, http.StatusOK}, {&url.Values{"numbytes": {fmt.Sprintf("%d", maxBodySize)}}, 0, int(maxBodySize), http.StatusOK}, - {&url.Values{"code": {"100"}}, 0, 10, 100}, - {&url.Values{"code": {"404"}}, 0, 10, 404}, + {&url.Values{"code": {"404"}}, 0, 10, http.StatusNotFound}, {&url.Values{"code": {"599"}}, 0, 10, 599}, {&url.Values{"code": {"567"}}, 0, 10, 567}, @@ -2079,20 +1680,21 @@ func TestDrip(t *testing.T) { test := test t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) { t.Parallel() - url := "/drip?" + test.params.Encode() - start := time.Now() - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + url := "/drip?" + test.params.Encode() + start := time.Now() + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) + body := must.ReadAll(t, resp.Body) // must read body before measuring elapsed time elapsed := time.Since(start) - assertStatusCode(t, w, test.code) - assertHeader(t, w, "Content-Type", "application/octet-stream") - assertHeader(t, w, "Content-Length", strconv.Itoa(test.numbytes)) - if len(w.Body.Bytes()) != test.numbytes { - t.Fatalf("expected %d bytes, got %d", test.numbytes, len(w.Body.Bytes())) + assert.StatusCode(t, resp, test.code) + assert.Header(t, resp, "Content-Type", "application/octet-stream") + assert.Header(t, resp, "Content-Length", strconv.Itoa(test.numbytes)) + + if len(body) != test.numbytes { + t.Fatalf("expected %d bytes, got %d", test.numbytes, len(body)) } if elapsed < test.duration { @@ -2101,23 +1703,45 @@ func TestDrip(t *testing.T) { }) } - t.Run("writes are actually incremmental", func(t *testing.T) { + t.Run("HTTP 100 Continue status code supported", func(t *testing.T) { + // The stdlib http client automagically handles 100 Continue responses + // by continuing the request until a "final" 200 OK response is + // received, which prevents us from confirming that a 100 Continue + // response is sent when using the http client directly. + // + // So, here we instead manally write the request to the wire and read + // the initial response, which will give us access to the 100 Continue + // indication we need. t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() + req := newTestRequest(t, "GET", "/drip?code=100") + reqBytes, err := httputil.DumpRequestOut(req, false) + assert.NilError(t, err) + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + n, err := conn.Write(append(reqBytes, []byte("\r\n\r\n")...)) + assert.NilError(t, err) + assert.Equal(t, len(reqBytes)+4, n, "incorrect number of bytes written") + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + assert.NilError(t, err) + assert.StatusCode(t, resp, 100) + }) + + t.Run("writes are actually incremmental", func(t *testing.T) { + t.Parallel() var ( duration = 100 * time.Millisecond numBytes = 3 wantDelay = duration / time.Duration(numBytes) - wantBytes = []byte{'*'} + endpoint = fmt.Sprintf("/drip?duration=%s&delay=%s&numbytes=%d", duration, wantDelay, numBytes) ) - resp, err := http.Get(srv.URL + fmt.Sprintf("/drip?duration=%s&delay=%s&numbytes=%d", duration, wantDelay, numBytes)) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - defer resp.Body.Close() + req := newTestRequest(t, "GET", endpoint) + resp := must.DoReq(t, client, req) // Here we read from the response one byte at a time, and ensure that // at least the expected delay occurs for each read. @@ -2125,7 +1749,7 @@ func TestDrip(t *testing.T) { // The request above includes an initial delay equal to the expected // wait between writes so that even the first iteration of this loop // expects to wait the same amount of time for a read. - buf := make([]byte, 1) + buf := make([]byte, 1024) for { start := time.Now() n, err := resp.Body.Read(buf) @@ -2134,12 +1758,10 @@ func TestDrip(t *testing.T) { if err == io.EOF { break } - assertNil(t, err) - assertIntEqual(t, n, 1) - if !reflect.DeepEqual(buf, wantBytes) { - t.Fatalf("unexpected bytes read: got %v, want %v", buf, wantBytes) - } + assert.NilError(t, err) + assert.Equal(t, n, 1, "incorrect number of bytes read") + assert.DeepEqual(t, buf[:n], []byte{'*'}, "unexpected bytes read") if gotDelay < wantDelay { t.Fatalf("to wait at least %s between reads, waited %s", wantDelay, gotDelay) } @@ -2148,8 +1770,6 @@ func TestDrip(t *testing.T) { t.Run("handle cancelation during initial delay", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() // For this test, we expect the client to time out and cancel the // request after 10ms. The handler should immediately write a 200 OK @@ -2158,46 +1778,44 @@ func TestDrip(t *testing.T) { // // So, we're testing that a) the client got an immediate 200 OK but // that b) the response body was empty. - client := http.Client{ - Timeout: time.Duration(10 * time.Millisecond), - } - resp, err := client.Get(srv.URL + "/drip?duration=500ms&delay=500ms") - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - defer resp.Body.Close() + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() - body, _ := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading response body: %s", err) - } + req := newTestRequest(t, "GET", "/drip?duration=500ms&delay=500ms") + req = req.WithContext(ctx) + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) - if len(body) != 0 { + body, err := io.ReadAll(resp.Body) + if !os.IsTimeout(err) { + t.Fatalf("expected client timeout while reading body, bot %s", err) + } + if len(body) > 0 { t.Fatalf("expected client timeout before body was written, got body %q", string(body)) } }) t.Run("handle cancelation during drip", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - client := http.Client{ - Timeout: time.Duration(250 * time.Millisecond), - } - resp, err := client.Get(srv.URL + "/drip?duration=900ms&delay=100ms") - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + req := newTestRequest(t, "GET", "/drip?duration=900ms&delay=100ms") + req = req.WithContext(ctx) + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) // in this case, the timeout happens while trying to read the body body, err := io.ReadAll(resp.Body) - if err == nil { - t.Fatal("expected timeout reading body") + if !os.IsTimeout(err) { + t.Fatalf("expected timeout reading body, got %s", err) } // but we should have received a partial response - assertBytesEqual(t, body, []byte("**")) + assert.DeepEqual(t, body, []byte("**"), "incorrect partial body") }) badTests := []struct { @@ -2234,137 +1852,125 @@ func TestDrip(t *testing.T) { test := test t.Run(fmt.Sprintf("bad/%s", test.params.Encode()), func(t *testing.T) { t.Parallel() - url := "/drip?" + test.params.Encode() - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.code) + url := "/drip?" + test.params.Encode() + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) }) } t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - resp, err := http.Head(srv.URL + "/drip?duration=900ms&delay=100ms") - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading response body: %s", err) - } + req := newTestRequest(t, "HEAD", "/drip?duration=900ms&delay=100ms") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected HTTP 200 OK rsponse, got %d", resp.StatusCode) - } + body := must.ReadAll(t, resp.Body) if bodySize := len(body); bodySize > 0 { - t.Fatalf("expected empty body from HEAD request, bot: %s", string(body)) + t.Fatalf("expected empty body from HEAD request, got: %s", string(body)) } }) } func TestRange(t *testing.T) { - t.Parallel() t.Run("ok_no_range", func(t *testing.T) { t.Parallel() + wantBytes := maxBodySize - 1 url := fmt.Sprintf("/range/%d", wantBytes) - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequest(t, "GET", url) - assertStatusCode(t, w, http.StatusOK) - assertHeader(t, w, "ETag", fmt.Sprintf("range%d", wantBytes)) - assertHeader(t, w, "Accept-Ranges", "bytes") - assertHeader(t, w, "Content-Length", strconv.Itoa(int(wantBytes))) - assertContentType(t, w, "text/plain; charset=utf-8") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.Header(t, resp, "ETag", fmt.Sprintf("range%d", wantBytes)) + assert.Header(t, resp, "Accept-Ranges", "bytes") + assert.Header(t, resp, "Content-Length", strconv.Itoa(int(wantBytes))) + assert.ContentType(t, resp, "text/plain; charset=utf-8") - if len(w.Body.String()) != int(wantBytes) { - t.Errorf("expected content length %d, got %d", wantBytes, len(w.Body.String())) + body := must.ReadAll(t, resp.Body) + if len(body) != int(wantBytes) { + t.Errorf("expected content length %d, got %d", wantBytes, len(body)) } }) t.Run("ok_range", func(t *testing.T) { t.Parallel() + url := "/range/100" - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add("Range", "bytes=10-24") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusPartialContent) - assertHeader(t, w, "ETag", "range100") - assertHeader(t, w, "Accept-Ranges", "bytes") - assertHeader(t, w, "Content-Length", "15") - assertHeader(t, w, "Content-Range", "bytes 10-24/100") - assertBodyEquals(t, w, "klmnopqrstuvwxy") + req := newTestRequest(t, "GET", url) + req.Header.Add("Range", "bytes=10-24") + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range100") + assert.Header(t, resp, "Accept-Ranges", "bytes") + assert.Header(t, resp, "Content-Length", "15") + assert.Header(t, resp, "Content-Range", "bytes 10-24/100") + assert.BodyEquals(t, resp, "klmnopqrstuvwxy") }) t.Run("ok_range_first_16_bytes", func(t *testing.T) { t.Parallel() + url := "/range/1000" - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add("Range", "bytes=0-15") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusPartialContent) - assertHeader(t, w, "ETag", "range1000") - assertHeader(t, w, "Accept-Ranges", "bytes") - assertHeader(t, w, "Content-Length", "16") - assertHeader(t, w, "Content-Range", "bytes 0-15/1000") - assertBodyEquals(t, w, "abcdefghijklmnop") + req := newTestRequest(t, "GET", url) + req.Header.Add("Range", "bytes=0-15") + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range1000") + assert.Header(t, resp, "Accept-Ranges", "bytes") + assert.Header(t, resp, "Content-Length", "16") + assert.Header(t, resp, "Content-Range", "bytes 0-15/1000") + assert.BodyEquals(t, resp, "abcdefghijklmnop") }) t.Run("ok_range_open_ended_last_6_bytes", func(t *testing.T) { t.Parallel() - url := "/range/26" - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add("Range", "bytes=20-") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusPartialContent) - assertHeader(t, w, "ETag", "range26") - assertHeader(t, w, "Content-Length", "6") - assertHeader(t, w, "Content-Range", "bytes 20-25/26") - assertBodyEquals(t, w, "uvwxyz") + url := "/range/26" + req := newTestRequest(t, "GET", url) + req.Header.Add("Range", "bytes=20-") + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range26") + assert.Header(t, resp, "Content-Length", "6") + assert.Header(t, resp, "Content-Range", "bytes 20-25/26") + assert.BodyEquals(t, resp, "uvwxyz") }) t.Run("ok_range_suffix", func(t *testing.T) { t.Parallel() + url := "/range/26" - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add("Range", "bytes=-5") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - t.Logf("headers = %v", w.Header()) - assertStatusCode(t, w, http.StatusPartialContent) - assertHeader(t, w, "ETag", "range26") - assertHeader(t, w, "Content-Length", "5") - assertHeader(t, w, "Content-Range", "bytes 21-25/26") - assertBodyEquals(t, w, "vwxyz") + req := newTestRequest(t, "GET", url) + req.Header.Add("Range", "bytes=-5") + + resp := must.DoReq(t, client, req) + t.Logf("headers = %v", resp.Header) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range26") + assert.Header(t, resp, "Content-Length", "5") + assert.Header(t, resp, "Content-Range", "bytes 21-25/26") + assert.BodyEquals(t, resp, "vwxyz") }) t.Run("err_range_out_of_bounds", func(t *testing.T) { t.Parallel() - url := "/range/26" - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add("Range", "bytes=-5") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusPartialContent) - assertHeader(t, w, "ETag", "range26") - assertHeader(t, w, "Content-Length", "5") - assertHeader(t, w, "Content-Range", "bytes 21-25/26") - assertBodyEquals(t, w, "vwxyz") + url := "/range/26" + req := newTestRequest(t, "GET", url) + req.Header.Add("Range", "bytes=-5") + + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range26") + assert.Header(t, resp, "Content-Length", "5") + assert.Header(t, resp, "Content-Range", "bytes 21-25/26") + assert.BodyEquals(t, resp, "vwxyz") }) // Note: httpbin rejects these requests with invalid range headers, but the @@ -2381,11 +1987,11 @@ func TestRange(t *testing.T) { test := test t.Run(fmt.Sprintf("ok_bad_range_header/%s", test.rangeHeader), func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertBodyEquals(t, w, "abcdefghijklmnopqrstuvwxyz") + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.BodyEquals(t, resp, "abcdefghijklmnopqrstuvwxyz") }) } @@ -2405,71 +2011,52 @@ func TestRange(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.code) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) }) } } func TestHTML(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/html", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertContentType(t, w, htmlContentType) - assertBodyContains(t, w, `<h1>Herman Melville - Moby-Dick</h1>`) + req := newTestRequest(t, "GET", "/html") + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, htmlContentType) + assert.BodyContains(t, resp, `<h1>Herman Melville - Moby-Dick</h1>`) } func TestRobots(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/robots.txt", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertContentType(t, w, "text/plain") - assertBodyContains(t, w, `Disallow: /deny`) + req := newTestRequest(t, "GET", "/robots.txt") + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, "text/plain") + assert.BodyContains(t, resp, `Disallow: /deny`) } func TestDeny(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/deny", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertContentType(t, w, "text/plain") - assertBodyContains(t, w, `YOU SHOULDN'T BE HERE`) + req := newTestRequest(t, "GET", "/deny") + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, "text/plain") + assert.BodyContains(t, resp, `YOU SHOULDN'T BE HERE`) } func TestCache(t *testing.T) { - t.Parallel() t.Run("ok_no_cache", func(t *testing.T) { t.Parallel() - url := "/cache" - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) + url := "/cache" + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) - lastModified := w.Header().Get("Last-Modified") + _ = mustParseResponse[noBodyResponse](t, resp) + lastModified := resp.Header.Get("Last-Modified") if lastModified == "" { - t.Fatalf("did get Last-Modified header") - } - - etag := w.Header().Get("ETag") - if etag != sha1hash(lastModified) { - t.Fatalf("expected ETag header %v, got %v", sha1hash(lastModified), etag) - } - - var resp *noBodyResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + t.Fatalf("expected Last-Modified header") } + assert.Header(t, resp, "ETag", sha1hash(lastModified)) }) tests := []struct { @@ -2483,27 +2070,25 @@ func TestCache(t *testing.T) { test := test t.Run(fmt.Sprintf("ok_cache/%s", test.headerKey), func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/cache", nil) - r.Header.Add(test.headerKey, test.headerVal) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusNotModified) + + req := newTestRequest(t, "GET", "/cache") + req.Header.Add(test.headerKey, test.headerVal) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusNotModified) }) } } func TestCacheControl(t *testing.T) { - t.Parallel() t.Run("ok_cache_control", func(t *testing.T) { t.Parallel() - url := "/cache/60" - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - assertHeader(t, w, "Cache-Control", "public, max-age=60") + url := "/cache/60" + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, jsonContentType) + assert.Header(t, resp, "Cache-Control", "public, max-age=60") }) badTests := []struct { @@ -2518,24 +2103,23 @@ func TestCacheControl(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } } func TestETag(t *testing.T) { - t.Parallel() t.Run("ok_no_headers", func(t *testing.T) { t.Parallel() + url := "/etag/abc" - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertHeader(t, w, "ETag", `"abc"`) + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.Header(t, resp, "ETag", `"abc"`) }) tests := []struct { @@ -2560,12 +2144,12 @@ func TestETag(t *testing.T) { test := test t.Run("ok_"+test.name, func(t *testing.T) { t.Parallel() + url := "/etag/" + test.etag - r, _ := http.NewRequest("GET", url, nil) - r.Header.Add(test.headerKey, test.headerVal) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", url) + req.Header.Add(test.headerKey, test.headerVal) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } @@ -2579,45 +2163,42 @@ func TestETag(t *testing.T) { test := test t.Run(fmt.Sprintf("bad/%s", test.url), func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } } func TestBytes(t *testing.T) { - t.Parallel() t.Run("ok_no_seed", func(t *testing.T) { t.Parallel() + url := "/bytes/1024" - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequest(t, "GET", url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, "application/octet-stream") - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, "application/octet-stream") - if len(w.Body.String()) != 1024 { - t.Errorf("expected content length 1024, got %d", len(w.Body.String())) + body := must.ReadAll(t, resp.Body) + if len(body) != 1024 { + t.Errorf("expected content length 1024, got %d", len(body)) } }) t.Run("ok_seed", func(t *testing.T) { t.Parallel() + url := "/bytes/16?seed=1234567890" - r, _ := http.NewRequest("GET", url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) + req := newTestRequest(t, "GET", url) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, "application/octet-stream") + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, "application/octet-stream") - bodyHex := fmt.Sprintf("%x", w.Body.Bytes()) - wantHex := "bfcd2afa15a2b372c707985a22024a8e" - if bodyHex != wantHex { - t.Errorf("expected body in hexadecimal = %v, got %v", wantHex, bodyHex) - } + want := "\xbf\xcd*\xfa\x15\xa2\xb3r\xc7\a\x98Z\"\x02J\x8e" + assert.BodyEquals(t, resp, want) }) edgeCaseTests := []struct { @@ -2635,13 +2216,17 @@ func TestBytes(t *testing.T) { test := test t.Run("edge"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertHeader(t, w, "Content-Length", fmt.Sprintf("%d", test.expectedContentLength)) - if len(w.Body.Bytes()) != test.expectedContentLength { - t.Errorf("expected body of length %d, got %d", test.expectedContentLength, len(w.Body.Bytes())) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + t.Logf("status: %q", resp.Status) + t.Logf("headers: %v", resp.Header) + assert.Header(t, resp, "Content-Length", strconv.Itoa(test.expectedContentLength)) + + bodyLen := len(must.ReadAll(t, resp.Body)) + if bodyLen != test.expectedContentLength { + t.Errorf("expected body of length %d, got %d", test.expectedContentLength, bodyLen) } }) } @@ -2666,16 +2251,15 @@ func TestBytes(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } } func TestStreamBytes(t *testing.T) { - t.Parallel() okTests := []struct { url string expectedContentLength int @@ -2696,24 +2280,15 @@ func TestStreamBytes(t *testing.T) { t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(app) - defer srv.Close() - - resp, err := http.Get(srv.URL + test.url) - assertNil(t, err) - defer resp.Body.Close() - - if len(resp.TransferEncoding) != 1 || resp.TransferEncoding[0] != "chunked" { - t.Fatalf("expected Transfer-Encoding: chunked, got %#v", resp.TransferEncoding) - } + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) // Expect empty content-length due to streaming response - assertHeader(t, resp, "Content-Length", "") + assert.Header(t, resp, "Content-Length", "") + assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "incorrect Transfer-Encoding header") - body, err := io.ReadAll(resp.Body) - assertNil(t, err) - if len(body) != test.expectedContentLength { - t.Fatalf("expected body of length %d, got %d", test.expectedContentLength, len(body)) + if bodySize := len(must.ReadAll(t, resp.Body)); bodySize != test.expectedContentLength { + t.Fatalf("expected body of length %d, got %d", test.expectedContentLength, bodySize) } }) } @@ -2735,16 +2310,14 @@ func TestStreamBytes(t *testing.T) { test := test t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.code) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.code) }) } } func TestLinks(t *testing.T) { - t.Parallel() redirectTests := []struct { url string expectedLocation string @@ -2757,12 +2330,10 @@ func TestLinks(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", test.expectedLocation) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusFound) + assert.Header(t, resp, "Location", test.expectedLocation) }) } @@ -2786,11 +2357,9 @@ func TestLinks(t *testing.T) { test := test t.Run("error"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } @@ -2810,19 +2379,16 @@ func TestLinks(t *testing.T) { test := test t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, htmlContentType) - assertBodyEquals(t, w, test.expectedContent) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, htmlContentType) + assert.BodyEquals(t, resp, test.expectedContent) }) } } func TestImage(t *testing.T) { - t.Parallel() acceptTests := []struct { acceptHeader string expectedContentType string @@ -2844,14 +2410,12 @@ func TestImage(t *testing.T) { test := test t.Run("ok/accept="+test.acceptHeader, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/image", nil) - r.Header.Set("Accept", test.acceptHeader) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", "/image") + req.Header.Set("Accept", test.acceptHeader) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) if test.expectedContentType != "" { - assertContentType(t, w, test.expectedContentType) + assert.ContentType(t, resp, test.expectedContentType) } }) } @@ -2874,30 +2438,27 @@ func TestImage(t *testing.T) { test := test t.Run("error"+test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, test.expectedStatus) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, test.expectedStatus) }) } } func TestXML(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/xml", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertContentType(t, w, "application/xml") - assertBodyContains(t, w, `<?xml version='1.0' encoding='us-ascii'?>`) + req := newTestRequest(t, "GET", "/xml") + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, "application/xml") + assert.BodyContains(t, resp, `<?xml version='1.0' encoding='us-ascii'?>`) } func isValidUUIDv4(uuid string) error { if len(uuid) != 36 { return fmt.Errorf("uuid length: %d != 36", len(uuid)) } - r := regexp.MustCompile("^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[8|9|a|b][a-f0-9]{3}-[a-f0-9]{12}$") - if !r.MatchString(uuid) { + req := regexp.MustCompile("^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[8|9|a|b][a-f0-9]{3}-[a-f0-9]{12}$") + if !req.MatchString(uuid) { return errors.New("Failed to match against uuidv4 regex") } return nil @@ -2905,27 +2466,15 @@ func isValidUUIDv4(uuid string) error { func TestUUID(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/uuid", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - - // Test response unmarshalling - var resp *uuidResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } - - // Test if the value is an actual UUID - if err := isValidUUIDv4(resp.UUID); err != nil { - t.Fatalf("Invalid uuid %s: %s", resp.UUID, err) + req := newTestRequest(t, "GET", "/uuid") + resp := must.DoReq(t, client, req) + result := mustParseResponse[uuidResponse](t, resp) + if err := isValidUUIDv4(result.UUID); err != nil { + t.Fatalf("Invalid uuid %s: %s", result.UUID, err) } } func TestBase64(t *testing.T) { - t.Parallel() okTests := []struct { requestURL string want string @@ -2966,12 +2515,11 @@ func TestBase64(t *testing.T) { test := test t.Run("ok"+test.requestURL, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.requestURL, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, "text/plain") - assertBodyEquals(t, w, test.want) + req := newTestRequest(t, "GET", test.requestURL) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, "text/plain") + assert.BodyEquals(t, resp, test.want) }) } @@ -2992,11 +2540,7 @@ func TestBase64(t *testing.T) { "decode failed", }, { - "/base64/decode/" + randStringBytes(Base64MaxLen+1), - "Cannot handle input", - }, - { - "/base64/decode/" + randStringBytes(Base64MaxLen+1), + "/base64/decode/" + strings.Repeat("X", Base64MaxLen+1), "Cannot handle input", }, { @@ -3027,66 +2571,52 @@ func TestBase64(t *testing.T) { test := test t.Run("error"+test.requestURL, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.requestURL, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusBadRequest) - assertBodyContains(t, w, test.expectedBodyContains) + req := newTestRequest(t, "GET", test.requestURL) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusBadRequest) + assert.BodyContains(t, resp, test.expectedBodyContains) }) } } func TestDumpRequest(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/dump/request?foo=bar", nil) - r.Host = "test-host" - r.Header.Set("x-test-header2", "Test-Value2") - r.Header.Set("x-test-header1", "Test-Value1") - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertContentType(t, w, "text/plain; charset=utf-8") - assertBodyEquals(t, w, "GET /dump/request?foo=bar HTTP/1.1\r\nHost: test-host\r\nX-Test-Header1: Test-Value1\r\nX-Test-Header2: Test-Value2\r\n\r\n") + req := newTestRequest(t, "GET", "/dump/request?foo=bar") + req.Host = "test-host" + req.Header.Set("x-test-header2", "Test-Value2") + req.Header.Set("x-test-header1", "Test-Value1") + + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, "text/plain; charset=utf-8") + assert.BodyEquals(t, resp, "GET /dump/request?foo=bar HTTP/1.1\r\nHost: test-host\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\nX-Test-Header1: Test-Value1\r\nX-Test-Header2: Test-Value2\r\n\r\n") } func TestJSON(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", "/json", nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertContentType(t, w, jsonContentType) - assertBodyContains(t, w, `Wake up to WonderWidgets!`) + req := newTestRequest(t, "GET", "/json") + resp := must.DoReq(t, client, req) + assert.ContentType(t, resp, jsonContentType) + assert.BodyContains(t, resp, `Wake up to WonderWidgets!`) } func TestBearer(t *testing.T) { - t.Parallel() requestURL := "/bearer" t.Run("valid_token", func(t *testing.T) { t.Parallel() - token := "valid_token" - r, _ := http.NewRequest("GET", requestURL, nil) - r.Header.Set("Authorization", "Bearer "+token) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusOK) - var resp *bearerResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) - } + token := "valid_token" + req := newTestRequest(t, "GET", requestURL) + req.Header.Set("Authorization", "Bearer "+token) - if resp.Authenticated != true { - t.Fatalf("expected response key %s=%#v, got %#v", - "Authenticated", true, resp.Authenticated) - } - if resp.Token != token { - t.Fatalf("expected response key %s=%#v, got %#v", - "token", token, resp.Authenticated) + resp := must.DoReq(t, client, req) + result := mustParseResponse[bearerResponse](t, resp) + want := bearerResponse{ + Authenticated: true, + Token: token, } + assert.DeepEqual(t, result, want, "auth response mismatch") }) errorTests := []struct { @@ -3115,20 +2645,19 @@ func TestBearer(t *testing.T) { test := test t.Run("error"+test.authorizationHeader, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", requestURL, nil) + + req := newTestRequest(t, "GET", requestURL) if test.authorizationHeader != "" { - r.Header.Set("Authorization", test.authorizationHeader) + req.Header.Set("Authorization", test.authorizationHeader) } - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertHeader(t, w, "WWW-Authenticate", "Bearer") - assertStatusCode(t, w, http.StatusUnauthorized) + resp := must.DoReq(t, client, req) + assert.Header(t, resp, "WWW-Authenticate", "Bearer") + assert.StatusCode(t, resp, http.StatusUnauthorized) }) } } func TestNotImplemented(t *testing.T) { - t.Parallel() tests := []struct { url string }{ @@ -3138,53 +2667,68 @@ func TestNotImplemented(t *testing.T) { test := test t.Run(test.url, func(t *testing.T) { t.Parallel() - r, _ := http.NewRequest("GET", test.url, nil) - w := httptest.NewRecorder() - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusNotImplemented) + req := newTestRequest(t, "GET", test.url) + resp := must.DoReq(t, client, req) + assert.StatusCode(t, resp, http.StatusNotImplemented) }) } } func TestHostname(t *testing.T) { - t.Parallel() - loadResponse := func(t *testing.T, bodyBytes []byte) hostnameResponse { - var resp hostnameResponse - err := json.Unmarshal(bodyBytes, &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %q from JSON: %s", string(bodyBytes), err) - } - return resp - } - t.Run("default hostname", func(t *testing.T) { t.Parallel() - var ( - app = New() - r, _ = http.NewRequest("GET", "/hostname", nil) - w = httptest.NewRecorder() - ) - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - resp := loadResponse(t, w.Body.Bytes()) - if resp.Hostname != DefaultHostname { - t.Errorf("expected hostname %q, got %q", DefaultHostname, resp.Hostname) - } + req := newTestRequest(t, "GET", "/hostname") + resp := must.DoReq(t, client, req) + result := mustParseResponse[hostnameResponse](t, resp) + assert.Equal(t, result.Hostname, DefaultHostname, "hostname mismatch") }) t.Run("real hostname", func(t *testing.T) { t.Parallel() - var ( - realHostname = "real-hostname" - app = New(WithHostname(realHostname)) - r, _ = http.NewRequest("GET", "/hostname", nil) - w = httptest.NewRecorder() - ) - app.ServeHTTP(w, r) - assertStatusCode(t, w, http.StatusOK) - resp := loadResponse(t, w.Body.Bytes()) - if resp.Hostname != realHostname { - t.Errorf("expected hostname %q, got %q", realHostname, resp.Hostname) - } + + realHostname := "real-hostname" + app := New(WithHostname(realHostname)) + srv, client := newTestServer(app) + defer srv.Close() + + req, err := http.NewRequest("GET", srv.URL+"/hostname", nil) + assert.NilError(t, err) + + resp, err := client.Do(req) + assert.NilError(t, err) + + result := mustParseResponse[hostnameResponse](t, resp) + assert.Equal(t, result.Hostname, realHostname, "hostname mismatch") }) } + +func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) { + srv := httptest.NewServer(handler) + client := srv.Client() + client.Timeout = 5 * time.Second + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + return srv, client +} + +func newTestRequest(t *testing.T, verb, path string) *http.Request { + t.Helper() + return newTestRequestWithBody(t, verb, path, nil) +} + +func newTestRequestWithBody(t *testing.T, verb, path string, body io.Reader) *http.Request { + t.Helper() + req, err := http.NewRequest(verb, srv.URL+path, body) + if err != nil { + t.Fatalf("failed to create request: %s", err) + } + return req +} + +func mustParseResponse[T any](t *testing.T, resp *http.Response) T { + t.Helper() + assert.StatusCode(t, resp, http.StatusOK) + assert.ContentType(t, resp, jsonContentType) + return must.Unmarshal[T](t, resp.Body) +} diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 989bb3c08911b5c05170546cfa9c96a6074de3ad..9eca14532bc02422a79ef85be7a318d4db144263 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -143,10 +143,6 @@ func parseFiles(fileHeaders map[string][]*multipart.FileHeader) (map[string][]st // Note: this function expects callers to limit the the maximum size of the // request body. See, e.g., the limitRequestSize middleware. func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse) error { - if r.Body == nil { - return nil - } - // Always set resp.Data to the incoming request body, in case we don't know // how to handle the content type body, err := io.ReadAll(r.Body) diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go index 1f5cd43e2a1957c8eeba25c8f3fe080af60f2e7a..a3f7bce40226a31d147d10c283b9ae91adae60db 100644 --- a/httpbin/helpers_test.go +++ b/httpbin/helpers_test.go @@ -8,42 +8,11 @@ import ( "mime/multipart" "net/http" "net/url" - "reflect" "testing" "time" -) - -func assertNil(t *testing.T, v interface{}) { - t.Helper() - if v != nil { - t.Fatalf("expected nil, got %#v", v) - } -} - -func assertNilError(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Fatalf("expected nil error, got %s (%T)", err, err) - } -} - -func assertIntEqual(t *testing.T, a, b int) { - if a != b { - t.Errorf("expected %v == %v", a, b) - } -} - -func assertBytesEqual(t *testing.T, a, b []byte) { - if !reflect.DeepEqual(a, b) { - t.Errorf("expected %v == %v", a, b) - } -} -func assertError(t *testing.T, got, expected error) { - if got != expected { - t.Errorf("expected error %v, got %v", expected, got) - } -} + "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" +) func mustParse(s string) *url.URL { u, e := url.Parse(s) @@ -54,7 +23,7 @@ func mustParse(s string) *url.URL { } func TestGetURL(t *testing.T) { - baseUrl, _ := url.Parse("http://example.com/something?foo=bar") + baseURL := mustParse("http://example.com/something?foo=bar") tests := []struct { name string input *http.Request @@ -63,7 +32,7 @@ func TestGetURL(t *testing.T) { { "basic test", &http.Request{ - URL: baseUrl, + URL: baseURL, Header: http.Header{}, }, mustParse("http://example.com/something?foo=bar"), @@ -71,7 +40,7 @@ func TestGetURL(t *testing.T) { { "if TLS is not nil, scheme is https", &http.Request{ - URL: baseUrl, + URL: baseURL, TLS: &tls.ConnectionState{}, Header: http.Header{}, }, @@ -80,7 +49,7 @@ func TestGetURL(t *testing.T) { { "if X-Forwarded-Proto is present, scheme is that value", &http.Request{ - URL: baseUrl, + URL: baseURL, Header: http.Header{"X-Forwarded-Proto": {"https"}}, }, mustParse("https://example.com/something?foo=bar"), @@ -88,7 +57,7 @@ func TestGetURL(t *testing.T) { { "if X-Forwarded-Proto is present, scheme is that value (2)", &http.Request{ - URL: baseUrl, + URL: baseURL, Header: http.Header{"X-Forwarded-Proto": {"bananas"}}, }, mustParse("bananas://example.com/something?foo=bar"), @@ -96,7 +65,7 @@ func TestGetURL(t *testing.T) { { "if X-Forwarded-Ssl is 'on', scheme is https", &http.Request{ - URL: baseUrl, + URL: baseURL, Header: http.Header{"X-Forwarded-Ssl": {"on"}}, }, mustParse("https://example.com/something?foo=bar"), @@ -114,9 +83,7 @@ func TestGetURL(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { res := getURL(test.input) - if res.String() != test.expected.String() { - t.Fatalf("expected %s, got %s", test.expected, res) - } + assert.Equal(t, res.String(), test.expected.String(), "URL mismatch") }) } } @@ -143,12 +110,8 @@ func TestParseDuration(t *testing.T) { t.Run(fmt.Sprintf("ok/%s", test.input), func(t *testing.T) { t.Parallel() result, err := parseDuration(test.input) - if err != nil { - t.Fatalf("unexpected error parsing duration %v: %s", test.input, err) - } - if result != test.expected { - t.Fatalf("expected %s, got %s", test.expected, result) - } + assert.NilError(t, err) + assert.Equal(t, result, test.expected, "incorrect duration") }) } @@ -186,23 +149,23 @@ func TestSyntheticByteStream(t *testing.T) { // read first half p := make([]byte, 5) count, err := s.Read(p) - assertNil(t, err) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{0, 1, 2, 3, 4}) + assert.NilError(t, err) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{0, 1, 2, 3, 4}, "incorrect bytes read") // read second half p = make([]byte, 5) count, err = s.Read(p) - assertError(t, err, io.EOF) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{5, 6, 7, 8, 9}) + assert.Error(t, err, io.EOF) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{5, 6, 7, 8, 9}, "incorrect bytes read") // can't read any more p = make([]byte, 5) count, err = s.Read(p) - assertError(t, err, io.EOF) - assertIntEqual(t, count, 0) - assertBytesEqual(t, p, []byte{0, 0, 0, 0, 0}) + assert.Error(t, err, io.EOF) + assert.Equal(t, count, 0, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{0, 0, 0, 0, 0}, "incorrect bytes read") }) t.Run("read into too-large buffer", func(t *testing.T) { @@ -210,9 +173,9 @@ func TestSyntheticByteStream(t *testing.T) { s := newSyntheticByteStream(5, factory) p := make([]byte, 10) count, err := s.Read(p) - assertError(t, err, io.EOF) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{0, 1, 2, 3, 4, 0, 0, 0, 0, 0}) + assert.Error(t, err, io.EOF) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{0, 1, 2, 3, 4, 0, 0, 0, 0, 0}, "incorrect bytes read") }) t.Run("seek", func(t *testing.T) { @@ -222,37 +185,31 @@ func TestSyntheticByteStream(t *testing.T) { p := make([]byte, 5) s.Seek(10, io.SeekStart) count, err := s.Read(p) - assertNil(t, err) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{10, 11, 12, 13, 14}) + assert.NilError(t, err) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{10, 11, 12, 13, 14}, "incorrect bytes read") s.Seek(10, io.SeekCurrent) count, err = s.Read(p) - assertNil(t, err) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{25, 26, 27, 28, 29}) + assert.NilError(t, err) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{25, 26, 27, 28, 29}, "incorrect bytes read") s.Seek(10, io.SeekEnd) count, err = s.Read(p) - assertNil(t, err) - assertIntEqual(t, count, 5) - assertBytesEqual(t, p, []byte{90, 91, 92, 93, 94}) + assert.NilError(t, err) + assert.Equal(t, count, 5, "incorrect number of bytes read") + assert.DeepEqual(t, p, []byte{90, 91, 92, 93, 94}, "incorrect bytes read") - // invalid whence _, err = s.Seek(10, 666) - if err.Error() != "Seek: invalid whence" { - t.Errorf("Expected \"Seek: invalid whence\", got %#v", err.Error()) - } + assert.Equal(t, err.Error(), "Seek: invalid whence", "incorrect error for invalid whence") - // invalid offset _, err = s.Seek(-10, io.SeekStart) - if err.Error() != "Seek: invalid offset" { - t.Errorf("Expected \"Seek: invalid offset\", got %#v", err.Error()) - } + assert.Equal(t, err.Error(), "Seek: invalid offset", "incorrect error for invalid offset") }) } -func Test_getClientIP(t *testing.T) { +func TestGetClientIP(t *testing.T) { t.Parallel() makeHeaders := func(m map[string]string) http.Header { @@ -297,9 +254,7 @@ func Test_getClientIP(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - if got := getClientIP(tc.given); got != tc.want { - t.Errorf("getClientIP() = %v, want %v", got, tc.want) - } + assert.Equal(t, getClientIP(tc.given), tc.want, "incorrect client ip") }) } } diff --git a/httpbin/responses.go b/httpbin/responses.go index ad08fcadb154e2c81a0f5ecf7be94c3886f19436..f8daade6819034ebb350eb457a8b40b3d8dc7f46 100644 --- a/httpbin/responses.go +++ b/httpbin/responses.go @@ -44,10 +44,10 @@ type bodyResponse struct { Origin string `json:"origin"` URL string `json:"url"` - Data string `json:"data"` - Files map[string][]string `json:"files"` - Form map[string][]string `json:"form"` - JSON interface{} `json:"json"` + Data string `json:"data"` + Files url.Values `json:"files"` + Form url.Values `json:"form"` + JSON interface{} `json:"json"` } type cookiesResponse map[string]string diff --git a/internal/testing/assert/assert.go b/internal/testing/assert/assert.go new file mode 100644 index 0000000000000000000000000000000000000000..112be3aebad3432184fdd7b234d318e383469c96 --- /dev/null +++ b/internal/testing/assert/assert.go @@ -0,0 +1,91 @@ +package assert + +import ( + "fmt" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/mccutchen/go-httpbin/v2/internal/testing/must" +) + +// Equal asserts that two values are equal. +func Equal[T comparable](t *testing.T, want, got T, msg string, arg ...any) { + t.Helper() + if want != got { + if msg == "" { + msg = "expected values to match" + } + msg = fmt.Sprintf(msg, arg...) + t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got) + } +} + +// DeepEqual asserts that two values are deeply equal. +func DeepEqual[T any](t *testing.T, want, got T, msg string, arg ...any) { + t.Helper() + if !reflect.DeepEqual(want, got) { + if msg == "" { + msg = "expected values to match" + } + msg = fmt.Sprintf(msg, arg...) + t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got) + } +} + +// NilError asserts that an error is nil. +func NilError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("expected nil error, got %s (%T)", err, err) + } +} + +// Error asserts that an error is not nil. +func Error(t *testing.T, got, expected error) { + t.Helper() + if got != expected { + t.Fatalf("expected error %v, got %v", expected, got) + } +} + +// StatusCode asserts that a response has a specific status code. +func StatusCode(t *testing.T, resp *http.Response, code int) { + t.Helper() + if resp.StatusCode != code { + t.Fatalf("expected status code %d, got %d", code, resp.StatusCode) + } +} + +// Header asserts that a header key has a specific value in a response. +func Header(t *testing.T, resp *http.Response, key, want string) { + t.Helper() + got := resp.Header.Get(key) + if want != got { + t.Fatalf("expected header %s=%#v, got %#v", key, want, got) + } +} + +// ContentType asserts that a response has a specific Content-Type header +// value. +func ContentType(t *testing.T, resp *http.Response, contentType string) { + t.Helper() + Header(t, resp, "Content-Type", contentType) +} + +// BodyContains asserts that a response body contains a specific substring. +func BodyContains(t *testing.T, resp *http.Response, needle string) { + t.Helper() + body := must.ReadAll(t, resp.Body) + if !strings.Contains(body, needle) { + t.Fatalf("expected string %q in body %q", needle, body) + } +} + +// BodyEquals asserts that a response body is equal to a specific string. +func BodyEquals(t *testing.T, resp *http.Response, want string) { + t.Helper() + got := must.ReadAll(t, resp.Body) + Equal(t, got, want, "incorrect response body") +} diff --git a/internal/testing/must/must.go b/internal/testing/must/must.go new file mode 100644 index 0000000000000000000000000000000000000000..8737cd9d5058f3ecd9e77f22386241493456a58c --- /dev/null +++ b/internal/testing/must/must.go @@ -0,0 +1,46 @@ +package must + +import ( + "encoding/json" + "io" + "net/http" + "testing" + "time" +) + +// DoReq makes an HTTP request and fails the test if there is an error. +func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response { + t.Helper() + start := time.Now() + resp, err := client.Do(req) + if err != nil { + t.Fatalf("error making HTTP request: %s %s: %s", req.Method, req.URL, err) + } + t.Logf("HTTP request: %s %s => %s (%s)", req.Method, req.URL, resp.Status, time.Since(start)) + return resp +} + +// ReadAll reads all bytes from an io.Reader and fails the test if there is an +// error. +func ReadAll(t *testing.T, r io.Reader) string { + t.Helper() + body, err := io.ReadAll(r) + if err != nil { + t.Fatalf("error reading: %s", err) + } + if rc, ok := r.(io.ReadCloser); ok { + rc.Close() + } + return string(body) +} + +// Unmarshal unmarshals JSON from an io.Reader into a value and fails the test +// if there is an error. +func Unmarshal[T any](t *testing.T, r io.Reader) T { + t.Helper() + var v T + if err := json.NewDecoder(r).Decode(&v); err != nil { + t.Fatal(err) + } + return v +}