From c9f4177de4999261f7d52b075d547489b6a0d9e8 Mon Sep 17 00:00:00 2001 From: Will McCutchen <will@mccutch.org> Date: Mon, 10 Jul 2023 12:58:36 -0400 Subject: [PATCH] Consistently parse and validate user-provided status codes (#137) In testing out error handling after #135, I happened to stumble across an unexpected panic for requests like `/status/1024` where the user-provided status code is outside the legal bounds. So, here we take a quick pass to ensure we're parsing and validating status codes the same way everywhere. --- httpbin/handlers.go | 27 ++++++++++----------------- httpbin/handlers_test.go | 33 +++++++++++++++++++++++++++++++++ httpbin/helpers.go | 15 +++++++++++++++ 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/httpbin/handlers.go b/httpbin/handlers.go index d812ba0..f7cfea7 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -254,9 +254,9 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, nil) return } - code, err := strconv.Atoi(parts[2]) + code, err := parseStatusCode(parts[2]) if err != nil { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status %q: %w", parts[2], err)) + writeError(w, http.StatusBadRequest, err) return } @@ -297,7 +297,7 @@ func (h *HTTPBin) Unstable(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %w", err)) return } else if failureRate < 0 || failureRate > 1 { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in interval [0, 1]", err)) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in range [0, 1]", err)) return } } @@ -414,14 +414,10 @@ func (h *HTTPBin) RedirectTo(w http.ResponseWriter, r *http.Request) { } statusCode := http.StatusFound - rawStatusCode := q.Get("status_code") - if rawStatusCode != "" { - statusCode, err = strconv.Atoi(q.Get("status_code")) + if userStatusCode := q.Get("status_code"); userStatusCode != "" { + statusCode, err = parseBoundedStatusCode(userStatusCode, 300, 399) if err != nil { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status code: %w", err)) - return - } else if statusCode < 300 || statusCode > 399 { - writeError(w, http.StatusBadRequest, errors.New("invalid status code: must be in range [300, 399]")) + writeError(w, http.StatusBadRequest, err) return } } @@ -617,18 +613,15 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %w", err)) return } else if numBytes < 1 || numBytes > h.MaxBodySize { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in interval [1, %d]", numBytes, h.MaxBodySize)) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in range [1, %d]", numBytes, h.MaxBodySize)) return } } if userCode := q.Get("code"); userCode != "" { - code, err = strconv.Atoi(userCode) + code, err = parseStatusCode(userCode) if err != nil { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %w", err)) - return - } else if code < 100 || code >= 600 { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %d not in interval [100, 599]", code)) + writeError(w, http.StatusBadRequest, err) return } } @@ -713,7 +706,7 @@ func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) { w.Header().Add("Accept-Ranges", "bytes") if numBytes <= 0 || numBytes > h.MaxBodySize { - writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in interval [1, %d]", numBytes, h.MaxBodySize)) + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in range [1, %d]", numBytes, h.MaxBodySize)) return } diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index ef37998..56e3b35 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -799,6 +799,7 @@ func TestStatus(t *testing.T) { headers map[string]string body string }{ + // 100 is tested as a special case below {200, nil, ""}, {300, map[string]string{"Location": "/image/jpeg"}, `<!doctype html> <head> @@ -822,6 +823,8 @@ func TestStatus(t *testing.T) { </html>`}, {401, unauthorizedHeaders, ""}, {418, nil, "I'm a teapot!"}, + {500, nil, ""}, // maximum allowed status code + {599, nil, ""}, // maximum allowed status code } for _, test := range tests { @@ -848,6 +851,8 @@ func TestStatus(t *testing.T) { {"/status/200/foo", http.StatusNotFound}, {"/status/3.14", http.StatusBadRequest}, {"/status/foo", http.StatusBadRequest}, + {"/status/600", http.StatusBadRequest}, + {"/status/1024", http.StatusBadRequest}, } for _, test := range errorTests { @@ -860,6 +865,34 @@ func TestStatus(t *testing.T) { assert.StatusCode(t, resp, test.status) }) } + + t.Run("HTTP 100 Continue status code supported", func(t *testing.T) { + // The stdlib http client automagically handles 100 Continue responses + // by continuing the request until a "final" 200 OK response is + // received, which prevents us from confirming that a 100 Continue + // response is sent when using the http client directly. + // + // So, here we instead manally write the request to the wire and read + // the initial response, which will give us access to the 100 Continue + // indication we need. + t.Parallel() + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + req := newTestRequest(t, "GET", "/status/100") + reqBytes, err := httputil.DumpRequestOut(req, false) + assert.NilError(t, err) + + n, err := conn.Write(reqBytes) + assert.NilError(t, err) + assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusContinue) + }) } func TestUnstable(t *testing.T) { diff --git a/httpbin/helpers.go b/httpbin/helpers.go index aa41c79..76db70b 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -234,6 +234,21 @@ func encodeData(body []byte, contentType string) string { return string("data:" + contentType + ";base64," + data) } +func parseStatusCode(input string) (int, error) { + return parseBoundedStatusCode(input, 100, 599) +} + +func parseBoundedStatusCode(input string, min, max int) (int, error) { + code, err := strconv.Atoi(input) + if err != nil { + return 0, fmt.Errorf("invalid status code: %q: %w", input, err) + } + if code < min || code > max { + return 0, fmt.Errorf("invalid status code: %d not in range [%d, %d]", code, min, max) + } + return code, nil +} + // parseDuration takes a user's input as a string and attempts to convert it // into a time.Duration. If not given as a go-style duration string, the input // is assumed to be seconds as a float. -- GitLab