From 9e306405584c63cf444f8eb1291e4641c2a8043e Mon Sep 17 00:00:00 2001 From: Will McCutchen <will@mccutch.org> Date: Sat, 1 Jul 2023 13:14:16 -0400 Subject: [PATCH] Improve and standardize error handling (#135) Standardize on structured JSON error responses everywhere we can, with only one exception where the error is a warning for humans to read. Fixes #108 by adding a check to every request in the test suite to ensure that errors are never served with a Content-Type that might enable XSS or other vulnerabilities. While we're at it, continue refining the test suite and further adopting some of the testing helpers added in #131. --- httpbin/handlers.go | 199 ++++++++++++++++-------------- httpbin/handlers_test.go | 166 ++++++++----------------- httpbin/helpers.go | 14 ++- httpbin/httpbin.go | 1 + httpbin/middleware.go | 6 + httpbin/middleware_test.go | 39 ++++++ httpbin/options.go | 15 ++- httpbin/responses.go | 6 + internal/testing/assert/assert.go | 18 +++ 9 files changed, 254 insertions(+), 210 deletions(-) create mode 100644 httpbin/middleware_test.go diff --git a/httpbin/handlers.go b/httpbin/handlers.go index 0e63ceb..d812ba0 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -5,11 +5,11 @@ import ( "compress/gzip" "compress/zlib" "encoding/json" + "errors" "fmt" "net/http" "net/http/httputil" "net/url" - "sort" "strconv" "strings" "time" @@ -20,14 +20,13 @@ import ( var nilValues = url.Values{} func notImplementedHandler(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Not implemented", http.StatusNotImplemented) + writeError(w, http.StatusNotImplemented, nil) } // Index renders an HTML index page func (h *HTTPBin) Index(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - msg := fmt.Sprintf("Not Found (go-httpbin does not handle the path %s)", r.URL.Path) - http.Error(w, msg, http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } w.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' camo.githubusercontent.com") @@ -81,9 +80,8 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { URL: getURL(r).String(), } - err := parseBody(w, r, resp) - if err != nil { - http.Error(w, fmt.Sprintf("error parsing request body: %s", err), http.StatusBadRequest) + if err := parseBody(w, r, resp); err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("error parsing request body: %w", err)) return } @@ -253,15 +251,19 @@ var ( func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } code, err := strconv.Atoi(parts[2]) if err != nil { - http.Error(w, "Invalid status", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status %q: %w", parts[2], err)) return } + // default to plain text content type, which may be overriden by headers + // for special cases + w.Header().Set("Content-Type", textContentType) + if specialCase, ok := statusSpecialCases[code]; ok { for key, val := range specialCase.headers { w.Header().Set(key, val) @@ -283,29 +285,29 @@ func (h *HTTPBin) Unstable(w http.ResponseWriter, r *http.Request) { // rng/seed rng, err := parseSeed(r.URL.Query().Get("seed")) if err != nil { - http.Error(w, "invalid seed", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid seed: %w", err)) return } // failure_rate - var failureRate float64 - rawFailureRate := r.URL.Query().Get("failure_rate") - if rawFailureRate != "" { + failureRate := 0.5 + if rawFailureRate := r.URL.Query().Get("failure_rate"); rawFailureRate != "" { failureRate, err = strconv.ParseFloat(rawFailureRate, 64) - if err != nil || failureRate < 0 || failureRate > 1 { - http.Error(w, "invalid failure_rate", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %w", err)) + return + } else if failureRate < 0 || failureRate > 1 { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in interval [0, 1]", err)) return } - } else { - failureRate = 0.5 } - var status int + status := http.StatusOK if rng.Float64() < failureRate { status = http.StatusInternalServerError - } else { - status = http.StatusOK } + + w.Header().Set("Content-Type", textContentType) w.WriteHeader(status) } @@ -350,12 +352,15 @@ func redirectLocation(r *http.Request, relative bool, n int) string { func doRedirect(w http.ResponseWriter, r *http.Request, relative bool) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } n, err := strconv.Atoi(parts[2]) - if err != nil || n < 1 { - http.Error(w, "Invalid redirect", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid redirect count: %w", err)) + return + } else if n < 1 { + writeError(w, http.StatusBadRequest, errors.New("redirect count must be > 0")) return } @@ -389,29 +394,21 @@ func (h *HTTPBin) RedirectTo(w http.ResponseWriter, r *http.Request) { inputURL := q.Get("url") if inputURL == "" { - http.Error(w, "Missing URL", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, errors.New("missing required query parameter: url")) return } u, err := url.Parse(inputURL) if err != nil { - http.Error(w, "Invalid URL", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid url: %w", err)) return } if u.IsAbs() && len(h.AllowedRedirectDomains) > 0 { if _, ok := h.AllowedRedirectDomains[u.Hostname()]; !ok { - domainListItems := make([]string, 0, len(h.AllowedRedirectDomains)) - for domain := range h.AllowedRedirectDomains { - domainListItems = append(domainListItems, fmt.Sprintf("- %s", domain)) - } - sort.Strings(domainListItems) - formattedDomains := strings.Join(domainListItems, "\n") - msg := fmt.Sprintf(`Forbidden redirect URL. Please be careful with this link. - -Allowed redirect destinations: -%s`, formattedDomains) - http.Error(w, msg, http.StatusForbidden) + // for this error message we do not use our standard JSON response + // because we want it to be more obviously human readable. + writeResponse(w, http.StatusForbidden, "text/plain", []byte(h.forbiddenRedirectError)) return } } @@ -420,8 +417,11 @@ Allowed redirect destinations: rawStatusCode := q.Get("status_code") if rawStatusCode != "" { statusCode, err = strconv.Atoi(q.Get("status_code")) - if err != nil || statusCode < 300 || statusCode > 399 { - http.Error(w, "Invalid status code", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status code: %w", err)) + return + } else if statusCode < 300 || statusCode > 399 { + writeError(w, http.StatusBadRequest, errors.New("invalid status code: must be in range [300, 399]")) return } } @@ -475,7 +475,7 @@ func (h *HTTPBin) DeleteCookies(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) BasicAuth(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 4 { - http.Error(w, "Not Found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } expectedUser := parts[2] @@ -501,7 +501,7 @@ func (h *HTTPBin) BasicAuth(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 4 { - http.Error(w, "Not Found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } expectedUser := parts[2] @@ -511,7 +511,7 @@ func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) { authorized := givenUser == expectedUser && givenPass == expectedPass if !authorized { - http.Error(w, "Not Found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } @@ -525,12 +525,12 @@ func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } n, err := strconv.Atoi(parts[2]) if err != nil { - http.Error(w, "Invalid integer", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %w", err)) return } @@ -562,13 +562,13 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Delay(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } delay, err := parseBoundedDuration(parts[2], 0, h.MaxDuration) if err != nil { - http.Error(w, "Invalid duration", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid duration: %w", err)) return } @@ -598,7 +598,7 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) { if userDuration := q.Get("duration"); userDuration != "" { duration, err = parseBoundedDuration(userDuration, 0, h.MaxDuration) if err != nil { - http.Error(w, "Invalid duration", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid duration: %w", err)) return } } @@ -606,23 +606,29 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) { if userDelay := q.Get("delay"); userDelay != "" { delay, err = parseBoundedDuration(userDelay, 0, h.MaxDuration) if err != nil { - http.Error(w, "Invalid delay", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid delay: %w", err)) return } } if userNumBytes := q.Get("numbytes"); userNumBytes != "" { numBytes, err = strconv.ParseInt(userNumBytes, 10, 64) - if err != nil || numBytes <= 0 || numBytes > h.MaxBodySize { - http.Error(w, "Invalid numbytes", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %w", err)) + return + } else if numBytes < 1 || numBytes > h.MaxBodySize { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in interval [1, %d]", numBytes, h.MaxBodySize)) return } } if userCode := q.Get("code"); userCode != "" { code, err = strconv.Atoi(userCode) - if err != nil || code < 100 || code >= 600 { - http.Error(w, "Invalid code", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %w", err)) + return + } else if code < 100 || code >= 600 { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %d not in interval [100, 599]", code)) return } } @@ -693,13 +699,13 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } numBytes, err := strconv.ParseInt(parts[2], 10, 64) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %w", err)) return } @@ -707,7 +713,7 @@ func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) { w.Header().Add("Accept-Ranges", "bytes") if numBytes <= 0 || numBytes > h.MaxBodySize { - http.Error(w, "Invalid number of bytes", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in interval [1, %d]", numBytes, h.MaxBodySize)) return } @@ -754,13 +760,13 @@ func (h *HTTPBin) Cache(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) CacheControl(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } seconds, err := strconv.ParseInt(parts[2], 10, 64) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid seconds: %w", err)) return } @@ -773,12 +779,13 @@ func (h *HTTPBin) CacheControl(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } etag := parts[2] w.Header().Set("ETag", fmt.Sprintf(`"%s"`, etag)) + w.Header().Set("Content-Type", textContentType) var buf bytes.Buffer mustMarshalJSON(&buf, noBodyResponse{ @@ -811,18 +818,25 @@ func (h *HTTPBin) StreamBytes(w http.ResponseWriter, r *http.Request) { func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } numBytes, err := strconv.Atoi(parts[2]) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid byte count: %w", err)) + return + } + + // rng/seed + rng, err := parseSeed(r.URL.Query().Get("seed")) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid seed: %w", err)) return } if numBytes < 0 { - http.Error(w, "Bad Request", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid byte count: %d must be greater than 0", numBytes)) return } @@ -845,7 +859,7 @@ func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) { if r.URL.Query().Get("chunk_size") != "" { chunkSize, err = strconv.Atoi(r.URL.Query().Get("chunk_size")) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid chunk_size: %w", err)) return } } else { @@ -868,13 +882,6 @@ func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) { } } - // rng/seed - rng, err := parseSeed(r.URL.Query().Get("seed")) - if err != nil { - http.Error(w, "invalid seed", http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", binaryContentType) w.WriteHeader(http.StatusOK) @@ -895,13 +902,16 @@ func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) { func (h *HTTPBin) Links(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 && len(parts) != 4 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } n, err := strconv.Atoi(parts[2]) - if err != nil || n < 0 || n > 256 { - http.Error(w, "Invalid link count", http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid link count: %w", err)) + return + } else if n < 0 || n > 256 { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid link count: %d must be in range [0, 256]", n)) return } @@ -909,7 +919,7 @@ func (h *HTTPBin) Links(w http.ResponseWriter, r *http.Request) { if len(parts) == 4 { offset, err := strconv.Atoi(parts[3]) if err != nil { - http.Error(w, "Invalid offset", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid offset: %w", err)) return } doLinksPage(w, r, n, offset) @@ -941,16 +951,21 @@ func doLinksPage(w http.ResponseWriter, r *http.Request, n int, offset int) { // ImageAccept responds with an appropriate image based on the Accept header func (h *HTTPBin) ImageAccept(w http.ResponseWriter, r *http.Request) { accept := r.Header.Get("Accept") - if accept == "" || strings.Contains(accept, "image/png") || strings.Contains(accept, "image/*") { + switch { + case accept == "": + fallthrough // default to png + case strings.Contains(accept, "image/*"): + fallthrough // default to png + case strings.Contains(accept, "image/png"): doImage(w, "png") - } else if strings.Contains(accept, "image/webp") { + case strings.Contains(accept, "image/webp"): doImage(w, "webp") - } else if strings.Contains(accept, "image/svg+xml") { + case strings.Contains(accept, "image/svg+xml"): doImage(w, "svg") - } else if strings.Contains(accept, "image/jpeg") { + case strings.Contains(accept, "image/jpeg"): doImage(w, "jpeg") - } else { - http.Error(w, "Unsupported media type", http.StatusUnsupportedMediaType) + default: + writeError(w, http.StatusUnsupportedMediaType, nil) } } @@ -958,7 +973,7 @@ func (h *HTTPBin) ImageAccept(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Image(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.Path, "/") if len(parts) != 3 { - http.Error(w, "Not found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } doImage(w, parts[2]) @@ -969,7 +984,7 @@ func (h *HTTPBin) Image(w http.ResponseWriter, r *http.Request) { func doImage(w http.ResponseWriter, kind string) { img, err := staticAsset("image." + kind) if err != nil { - http.Error(w, "Not Found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } contentType := "image/" + kind @@ -994,7 +1009,7 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) { count := len(parts) if count != 5 && count != 6 { - http.Error(w, "Not Found", http.StatusNotFound) + writeError(w, http.StatusNotFound, nil) return } @@ -1008,11 +1023,11 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) { } if qop != "auth" { - http.Error(w, "Invalid QOP directive", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid QOP directive: %q != \"auth\"", qop)) return } if algoName != "MD5" && algoName != "SHA-256" { - http.Error(w, "Invalid algorithm", http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid algorithm: %s must be one of MD5 or SHA-256", algoName)) return } @@ -1023,7 +1038,7 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) { if !digest.Check(r, user, password) { w.Header().Set("WWW-Authenticate", digest.Challenge("go-httpbin", algorithm)) - w.WriteHeader(http.StatusUnauthorized) + writeError(w, http.StatusUnauthorized, nil) return } @@ -1044,21 +1059,19 @@ func (h *HTTPBin) UUID(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Base64(w http.ResponseWriter, r *http.Request) { b, err := newBase64Helper(r.URL.Path) if err != nil { - http.Error(w, fmt.Sprintf("%s", err), http.StatusBadRequest) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid base64 data: %w", err)) return } var result []byte - var base64Error error - if b.operation == "decode" { - result, base64Error = b.Decode() + result, err = b.Decode() } else { - result, base64Error = b.Encode() + result, err = b.Encode() } - if base64Error != nil { - http.Error(w, fmt.Sprintf("%s failed: %s", b.operation, base64Error), http.StatusBadRequest) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("%s failed: %w", b.operation, err)) return } writeResponse(w, http.StatusOK, textContentType, result) @@ -1072,7 +1085,7 @@ func (h *HTTPBin) Base64(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) DumpRequest(w http.ResponseWriter, r *http.Request) { dump, err := httputil.DumpRequest(r, true) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, fmt.Errorf("failed to dump request: %w", err)) return } w.Write(dump) @@ -1091,7 +1104,7 @@ func (h *HTTPBin) Bearer(w http.ResponseWriter, r *http.Request) { tokenFields := strings.Fields(reqToken) if len(tokenFields) != 2 || tokenFields[0] != "Bearer" { w.Header().Set("WWW-Authenticate", "Bearer") - w.WriteHeader(http.StatusUnauthorized) + writeError(w, http.StatusUnauthorized, nil) return } writeJSON(http.StatusOK, w, bearerResponse{ diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 6a49c1d..ef37998 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -8,7 +8,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "log" @@ -45,6 +44,9 @@ var ( ) func TestMain(m *testing.M) { + // enable additional safety checks + testMode = true + app = New( WithAllowedRedirectDomains([]string{ "httpbingo.org", @@ -82,7 +84,13 @@ func TestIndex(t *testing.T) { req := newTestRequest(t, "GET", "/foo") resp := must.DoReq(t, client, req) assert.StatusCode(t, resp, http.StatusNotFound) - assert.BodyContains(t, resp, "/foo") + assert.ContentType(t, resp, jsonContentType) + got := must.Unmarshal[errorRespnose](t, resp.Body) + want := errorRespnose{ + StatusCode: http.StatusNotFound, + Error: "Not Found", + } + assert.DeepEqual(t, got, want, "incorrect error response") }) } @@ -107,7 +115,7 @@ func TestUTF8(t *testing.T) { } func TestGet(t *testing.T) { - doGetRequest := func(t *testing.T, path string, params url.Values, headers *http.Header) noBodyResponse { + doGetRequest := func(t *testing.T, path string, params url.Values, headers http.Header) noBodyResponse { t.Helper() if params != nil { @@ -115,11 +123,9 @@ func TestGet(t *testing.T) { } req := newTestRequest(t, "GET", path) req.Header.Set("User-Agent", "test") - if headers != nil { - for k, vs := range *headers { - for _, v := range vs { - req.Header.Set(k, v) - } + for k, vs := range headers { + for _, v := range vs { + req.Header.Add(k, v) } } @@ -183,7 +189,7 @@ func TestGet(t *testing.T) { test := test t.Run(test.key, func(t *testing.T) { t.Parallel() - headers := &http.Header{} + headers := http.Header{} headers.Set(test.key, test.value) result := doGetRequest(t, "/get", nil, headers) if !strings.HasPrefix(result.URL, "https://") { @@ -598,12 +604,9 @@ func testRequestWithBodyMultiPartBody(t *testing.T, verb, path string) { for k, vs := range params { for _, v := range vs { fw, err := mw.CreateFormField(k) - if err != nil { - t.Fatalf("error creating multipart form field %s: %s", k, err) - } - if _, err := fw.Write([]byte(v)); err != nil { - t.Fatalf("error writing multipart form value %#v for key %s: %s", v, k, err) - } + assert.NilError(t, err) + _, err = fw.Write([]byte(v)) + assert.NilError(t, err) } } mw.Close() @@ -695,10 +698,7 @@ func testRequestWithBodyJSON(t *testing.T, verb, path string) { roundTrippedInputBytes, err := json.Marshal(result.JSON) assert.NilError(t, err) - var roundTrippedInput testInput - if err := json.Unmarshal(roundTrippedInputBytes, &roundTrippedInput); err != nil { - t.Fatalf("failed to round-trip JSON: coult not re-unmarshal JSON: %s", err) - } + roundTrippedInput := must.Unmarshal[testInput](t, bytes.NewReader(roundTrippedInputBytes)) assert.DeepEqual(t, roundTrippedInput, input, "round-tripped JSON mismatch") } @@ -1016,18 +1016,21 @@ func TestRedirects(t *testing.T) { }{ {"/redirect", http.StatusNotFound}, {"/redirect/", http.StatusBadRequest}, + {"/redirect/-1", http.StatusBadRequest}, {"/redirect/3.14", http.StatusBadRequest}, {"/redirect/foo", http.StatusBadRequest}, {"/redirect/10/foo", http.StatusNotFound}, {"/relative-redirect", http.StatusNotFound}, {"/relative-redirect/", http.StatusBadRequest}, + {"/relative-redirect/-1", http.StatusBadRequest}, {"/relative-redirect/3.14", http.StatusBadRequest}, {"/relative-redirect/foo", http.StatusBadRequest}, {"/relative-redirect/10/foo", http.StatusNotFound}, {"/absolute-redirect", http.StatusNotFound}, {"/absolute-redirect/", http.StatusBadRequest}, + {"/absolute-redirect/-1", http.StatusBadRequest}, {"/absolute-redirect/3.14", http.StatusBadRequest}, {"/absolute-redirect/foo", http.StatusBadRequest}, {"/absolute-redirect/10/foo", http.StatusNotFound}, @@ -1080,6 +1083,7 @@ func TestRedirectTo(t *testing.T) { {"/redirect-to?status_code=302", http.StatusBadRequest}, // missing url {"/redirect-to?url=foo&status_code=201", http.StatusBadRequest}, // invalid status code {"/redirect-to?url=foo&status_code=418", http.StatusBadRequest}, // invalid status code + {"/redirect-to?url=foo&status_code=foo", http.StatusBadRequest}, // invalid status code {"/redirect-to?url=http%3A%2F%2Ffoo%25%25bar&status_code=418", http.StatusBadRequest}, // invalid URL } for _, test := range badTests { @@ -1093,15 +1097,6 @@ func TestRedirectTo(t *testing.T) { }) } - // error message matches redirect configuration in global shared test app - allowedDomainsError := `Forbidden redirect URL. Please be careful with this link. - -Allowed redirect destinations: -- example.org -- httpbingo.org -- www.example.com -` - allowListTests := []struct { url string expectedStatus int @@ -1121,7 +1116,7 @@ Allowed redirect destinations: defer consumeAndCloseBody(resp) assert.StatusCode(t, resp, test.expectedStatus) if test.expectedStatus >= 400 { - assert.BodyEquals(t, resp, allowedDomainsError) + assert.BodyEquals(t, resp, app.forbiddenRedirectError) } }) } @@ -1436,25 +1431,16 @@ func TestGzip(t *testing.T) { } zippedContentLength, err := strconv.Atoi(zippedContentLengthStr) - if err != nil { - t.Fatalf("error converting Content-Lengh %v to integer: %s", zippedContentLengthStr, err) - } + assert.NilError(t, err) gzipReader, err := gzip.NewReader(resp.Body) - if err != nil { - t.Fatalf("error creating gzip reader: %s", err) - } + assert.NilError(t, err) unzippedBody, err := io.ReadAll(gzipReader) - if err != nil { - t.Fatalf("error reading gzipped body: %s", err) - } + assert.NilError(t, err) result := must.Unmarshal[noBodyResponse](t, bytes.NewBuffer(unzippedBody)) - - if result.Gzipped != true { - t.Fatalf("expected resp.Gzipped == true") - } + assert.Equal(t, result.Gzipped, true, "expected resp.Gzipped == true") if len(unzippedBody) <= zippedContentLength { t.Fatalf("expected compressed body") @@ -1477,24 +1463,16 @@ func TestDeflate(t *testing.T) { } compressedContentLength, err := strconv.Atoi(contentLengthHeader) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) reader, err := zlib.NewReader(resp.Body) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) + body, err := io.ReadAll(reader) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) result := must.Unmarshal[noBodyResponse](t, bytes.NewBuffer(body)) - - if result.Deflated != true { - t.Fatalf("expected resp.Deflated == true") - } + assert.Equal(t, result.Deflated, true, "expected result.Deflated == true") if len(body) <= compressedContentLength { t.Fatalf("expected compressed body") @@ -1527,22 +1505,14 @@ func TestStream(t *testing.T) { assert.Header(t, resp, "Content-Length", "") assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "expected Transfer-Encoding: chunked") - var sr *streamResponse - i := 0 scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { - if err := json.Unmarshal(scanner.Bytes(), &sr); err != nil { - t.Fatalf("error unmarshalling response: %s", err) - } - if sr.ID != i { - t.Fatalf("bad id: %v != %v", sr.ID, i) - } + sr := must.Unmarshal[streamResponse](t, bytes.NewReader(scanner.Bytes())) + assert.Equal(t, sr.ID, i, "bad id") i++ } - if err := scanner.Err(); err != nil { - t.Fatalf("error scanning streaming response: %s", err) - } + assert.NilError(t, scanner.Err()) }) } @@ -1627,9 +1597,7 @@ func TestDelay(t *testing.T) { w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, "GET", "/delay/1s", nil) app.ServeHTTP(w, req) - if w.Code != 499 { - t.Errorf("expected 499, got %d", w.Code) - } + assert.Equal(t, w.Code, 499, "incorrect status code") }) badTests := []struct { @@ -1706,12 +1674,15 @@ func TestDrip(t *testing.T) { req := newTestRequest(t, "GET", url) resp := must.DoReq(t, client, req) defer consumeAndCloseBody(resp) - body := must.ReadAll(t, resp.Body) // must read body before measuring elapsed time + assert.BodySize(t, resp, test.numbytes) // must read body before measuring elapsed time elapsed := time.Since(start) assert.StatusCode(t, resp, test.code) assert.ContentType(t, resp, binaryContentType) assert.Header(t, resp, "Content-Length", strconv.Itoa(test.numbytes)) + if elapsed < test.duration { + t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) + } // Note: while the /drip endpoint seems like an ideal use case for // using chunked transfer encoding to stream data to the client, it @@ -1719,14 +1690,6 @@ func TestDrip(t *testing.T) { // server and client, so it is important to ensure that it writes a // "regular," un-chunked response. assert.DeepEqual(t, resp.TransferEncoding, nil, "unexpected Transfer-Encoding header") - - if len(body) != test.numbytes { - t.Fatalf("expected %d bytes, got %d", test.numbytes, len(body)) - } - - if elapsed < test.duration { - t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) - } }) } @@ -1897,15 +1860,10 @@ func TestDrip(t *testing.T) { t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) { t.Parallel() - req := newTestRequest(t, "HEAD", "/drip?duration=900ms&delay=100ms") resp := must.DoReq(t, client, req) assert.StatusCode(t, resp, http.StatusOK) - - body := must.ReadAll(t, resp.Body) - if bodySize := len(body); bodySize > 0 { - t.Fatalf("expected empty body from HEAD request, got: %s", string(body)) - } + assert.BodySize(t, resp, 0) }) } @@ -1923,11 +1881,7 @@ func TestRange(t *testing.T) { assert.Header(t, resp, "Accept-Ranges", "bytes") assert.Header(t, resp, "Content-Length", strconv.Itoa(int(wantBytes))) assert.ContentType(t, resp, textContentType) - - body := must.ReadAll(t, resp.Body) - if len(body) != int(wantBytes) { - t.Errorf("expected content length %d, got %d", wantBytes, len(body)) - } + assert.BodySize(t, resp, int(wantBytes)) }) t.Run("ok_range", func(t *testing.T) { @@ -2215,11 +2169,7 @@ func TestBytes(t *testing.T) { resp := must.DoReq(t, client, req) assert.StatusCode(t, resp, http.StatusOK) assert.ContentType(t, resp, binaryContentType) - - body := must.ReadAll(t, resp.Body) - if len(body) != 1024 { - t.Errorf("expected content length 1024, got %d", len(body)) - } + assert.BodySize(t, resp, 1024) }) t.Run("ok_seed", func(t *testing.T) { @@ -2258,10 +2208,7 @@ func TestBytes(t *testing.T) { assert.StatusCode(t, resp, http.StatusOK) assert.Header(t, resp, "Content-Length", strconv.Itoa(test.expectedContentLength)) - bodyLen := len(must.ReadAll(t, resp.Body)) - if bodyLen != test.expectedContentLength { - t.Errorf("expected body of length %d, got %d", test.expectedContentLength, bodyLen) - } + assert.BodySize(t, resp, test.expectedContentLength) }) } @@ -2321,10 +2268,7 @@ func TestStreamBytes(t *testing.T) { // Expect empty content-length due to streaming response assert.Header(t, resp, "Content-Length", "") assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "incorrect Transfer-Encoding header") - - if bodySize := len(must.ReadAll(t, resp.Body)); bodySize != test.expectedContentLength { - t.Fatalf("expected body of length %d, got %d", test.expectedContentLength, bodySize) - } + assert.BodySize(t, resp, test.expectedContentLength) }) } @@ -2494,15 +2438,13 @@ func TestXML(t *testing.T) { assert.BodyContains(t, resp, `<?xml version='1.0' encoding='us-ascii'?>`) } -func isValidUUIDv4(uuid string) error { - if len(uuid) != 36 { - return fmt.Errorf("uuid length: %d != 36", len(uuid)) - } +func testValidUUIDv4(t *testing.T, uuid string) { + t.Helper() + assert.Equal(t, len(uuid), 36, "incorrect uuid length") req := regexp.MustCompile("^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[8|9|a|b][a-f0-9]{3}-[a-f0-9]{12}$") if !req.MatchString(uuid) { - return errors.New("Failed to match against uuidv4 regex") + t.Fatalf("invalid uuid %q", uuid) } - return nil } func TestUUID(t *testing.T) { @@ -2510,9 +2452,7 @@ func TestUUID(t *testing.T) { req := newTestRequest(t, "GET", "/uuid") resp := must.DoReq(t, client, req) result := mustParseResponse[uuidResponse](t, resp) - if err := isValidUUIDv4(result.UUID); err != nil { - t.Fatalf("Invalid uuid %s: %s", result.UUID, err) - } + testValidUUIDv4(t, result.UUID) } func TestBase64(t *testing.T) { @@ -2765,9 +2705,7 @@ func newTestRequest(t *testing.T, verb, path string) *http.Request { func newTestRequestWithBody(t *testing.T, verb, path string, body io.Reader) *http.Request { t.Helper() req, err := http.NewRequest(verb, srv.URL+path, body) - if err != nil { - t.Fatalf("failed to create request: %s", err) - } + assert.NilError(t, err) return req } diff --git a/httpbin/helpers.go b/httpbin/helpers.go index bca6813..aa41c79 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -114,6 +114,17 @@ func writeHTML(w http.ResponseWriter, body []byte, status int) { writeResponse(w, status, htmlContentType, body) } +func writeError(w http.ResponseWriter, code int, err error) { + resp := errorRespnose{ + Error: http.StatusText(code), + StatusCode: code, + } + if err != nil { + resp.Detail = err.Error() + } + writeJSON(code, w, resp) +} + // parseFiles handles reading the contents of files in a multipart FileHeader // and returning a map that can be used as the Files attribute of a response func parseFiles(fileHeaders map[string][]*multipart.FileHeader) (map[string][]string, error) { @@ -344,8 +355,7 @@ func sha1hash(input string) string { func uuidv4() string { buff := make([]byte, 16) - _, err := crypto_rand.Read(buff[:]) - if err != nil { + if _, err := crypto_rand.Read(buff[:]); err != nil { panic(err) } buff[6] = (buff[6] & 0x0f) | 0x40 // Version 4 diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index 3f0f4a4..db73b66 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -44,6 +44,7 @@ type HTTPBin struct { // Set of hosts to which the /redirect-to endpoint will allow redirects AllowedRedirectDomains map[string]struct{} + forbiddenRedirectError string // The hostname to expose via /hostname. hostname string diff --git a/httpbin/middleware.go b/httpbin/middleware.go index 0f9e4e9..15d507a 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -78,6 +78,9 @@ func autohead(h http.Handler) http.Handler { }) } +// testMode enables additional safety checks to be enabled in the test suite. +var testMode = false + // metaResponseWriter implements http.ResponseWriter and http.Flusher in order // to record a response's status code and body size for logging purposes. type metaResponseWriter struct { @@ -93,6 +96,9 @@ func (mw *metaResponseWriter) Write(b []byte) (int, error) { } func (mw *metaResponseWriter) WriteHeader(s int) { + if testMode && mw.status != 0 { + panic(fmt.Errorf("HTTP status already set to %d, cannot set to %d", mw.status, s)) + } mw.w.WriteHeader(s) mw.status = s } diff --git a/httpbin/middleware_test.go b/httpbin/middleware_test.go new file mode 100644 index 0000000..21b8673 --- /dev/null +++ b/httpbin/middleware_test.go @@ -0,0 +1,39 @@ +package httpbin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" +) + +func TestTestMode(t *testing.T) { + // This test ensures that we use testMode in our test suite, and ensures + // that it is working as expected. + assert.Equal(t, testMode, true, "expected testMode to be turned on in test suite") + + // We want to ensure that, in testMode, a handler calling WriteHeader twice + // will cause a panic. This happens most often when we forget to return + // early after writing an error response, and has helped identify and fix + // some subtly broken error handling. + observer := func(r Result) {} + handler := observe(observer, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusOK) + })) + + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected to catch panic") + } + err, ok := r.(error) + assert.Equal(t, ok, true, "expected panic to be an error") + assert.Equal(t, err.Error(), "HTTP status already set to 400, cannot set to 200", "incorrectp panic error message") + }() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(w, r) +} diff --git a/httpbin/options.go b/httpbin/options.go index 4a202b5..37fb270 100644 --- a/httpbin/options.go +++ b/httpbin/options.go @@ -1,6 +1,11 @@ package httpbin -import "time" +import ( + "fmt" + "sort" + "strings" + "time" +) // OptionFunc uses the "functional options" pattern to customize an HTTPBin // instance @@ -46,9 +51,17 @@ func WithObserver(o Observer) OptionFunc { func WithAllowedRedirectDomains(hosts []string) OptionFunc { return func(h *HTTPBin) { hostSet := make(map[string]struct{}, len(hosts)) + formattedListItems := make([]string, 0, len(hosts)) for _, host := range hosts { hostSet[host] = struct{}{} + formattedListItems = append(formattedListItems, fmt.Sprintf("- %s", host)) } h.AllowedRedirectDomains = hostSet + + sort.Strings(formattedListItems) + h.forbiddenRedirectError = fmt.Sprintf(`Forbidden redirect URL. Please be careful with this link. + +Allowed redirect destinations: +%s`, strings.Join(formattedListItems, "\n")) } } diff --git a/httpbin/responses.go b/httpbin/responses.go index da39f34..63421ed 100644 --- a/httpbin/responses.go +++ b/httpbin/responses.go @@ -81,3 +81,9 @@ type bearerResponse struct { type hostnameResponse struct { Hostname string `json:"hostname"` } + +type errorRespnose struct { + StatusCode int `json:"status_code"` + Error string `json:"error"` + Detail string `json:"detail,omitempty"` +} diff --git a/internal/testing/assert/assert.go b/internal/testing/assert/assert.go index 96efe65..36b5e89 100644 --- a/internal/testing/assert/assert.go +++ b/internal/testing/assert/assert.go @@ -57,6 +57,17 @@ func StatusCode(t *testing.T, resp *http.Response, code int) { if resp.StatusCode != code { t.Fatalf("expected status code %d, got %d", code, resp.StatusCode) } + if resp.StatusCode >= 400 { + // Ensure our error responses are never served as HTML, so that we do + // not need to worry about XSS or other attacks in error responses. + if ct := resp.Header.Get("Content-Type"); !isSafeContentType(ct) { + t.Errorf("HTTP %s error served with dangerous content type: %s", resp.Status, ct) + } + } +} + +func isSafeContentType(ct string) bool { + return strings.HasPrefix(ct, "application/json") || strings.HasPrefix(ct, "text/plain") || strings.HasPrefix(ct, "application/octet-stream") } // Header asserts that a header key has a specific value in a response. @@ -91,6 +102,13 @@ func BodyEquals(t *testing.T, resp *http.Response, want string) { Equal(t, got, want, "incorrect response body") } +// BodySize asserts that a response body is a specific size. +func BodySize(t *testing.T, resp *http.Response, want int) { + t.Helper() + got := must.ReadAll(t, resp.Body) + Equal(t, len(got), want, "incorrect response body size") +} + // DurationRange asserts that a duration is within a specific range. func DurationRange(t *testing.T, got, min, max time.Duration) { t.Helper() -- GitLab