From c9f4177de4999261f7d52b075d547489b6a0d9e8 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Mon, 10 Jul 2023 12:58:36 -0400
Subject: [PATCH] Consistently parse and validate user-provided status codes
 (#137)

In testing out error handling after #135, I happened to stumble across
an unexpected panic for requests like `/status/1024` where the
user-provided status code is outside the legal bounds. So, here we take
a quick pass to ensure we're parsing and validating status codes the
same way everywhere.
---
 httpbin/handlers.go      | 27 ++++++++++-----------------
 httpbin/handlers_test.go | 33 +++++++++++++++++++++++++++++++++
 httpbin/helpers.go       | 15 +++++++++++++++
 3 files changed, 58 insertions(+), 17 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index d812ba0..f7cfea7 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -254,9 +254,9 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) {
 		writeError(w, http.StatusNotFound, nil)
 		return
 	}
-	code, err := strconv.Atoi(parts[2])
+	code, err := parseStatusCode(parts[2])
 	if err != nil {
-		writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status %q: %w", parts[2], err))
+		writeError(w, http.StatusBadRequest, err)
 		return
 	}
 
@@ -297,7 +297,7 @@ func (h *HTTPBin) Unstable(w http.ResponseWriter, r *http.Request) {
 			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %w", err))
 			return
 		} else if failureRate < 0 || failureRate > 1 {
-			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in interval [0, 1]", err))
+			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in range [0, 1]", err))
 			return
 		}
 	}
@@ -414,14 +414,10 @@ func (h *HTTPBin) RedirectTo(w http.ResponseWriter, r *http.Request) {
 	}
 
 	statusCode := http.StatusFound
-	rawStatusCode := q.Get("status_code")
-	if rawStatusCode != "" {
-		statusCode, err = strconv.Atoi(q.Get("status_code"))
+	if userStatusCode := q.Get("status_code"); userStatusCode != "" {
+		statusCode, err = parseBoundedStatusCode(userStatusCode, 300, 399)
 		if err != nil {
-			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status code: %w", err))
-			return
-		} else if statusCode < 300 || statusCode > 399 {
-			writeError(w, http.StatusBadRequest, errors.New("invalid status code: must be in range [300, 399]"))
+			writeError(w, http.StatusBadRequest, err)
 			return
 		}
 	}
@@ -617,18 +613,15 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %w", err))
 			return
 		} else if numBytes < 1 || numBytes > h.MaxBodySize {
-			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
+			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in range [1, %d]", numBytes, h.MaxBodySize))
 			return
 		}
 	}
 
 	if userCode := q.Get("code"); userCode != "" {
-		code, err = strconv.Atoi(userCode)
+		code, err = parseStatusCode(userCode)
 		if err != nil {
-			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %w", err))
-			return
-		} else if code < 100 || code >= 600 {
-			writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %d not in interval [100, 599]", code))
+			writeError(w, http.StatusBadRequest, err)
 			return
 		}
 	}
@@ -713,7 +706,7 @@ func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) {
 	w.Header().Add("Accept-Ranges", "bytes")
 
 	if numBytes <= 0 || numBytes > h.MaxBodySize {
-		writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
+		writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in range [1, %d]", numBytes, h.MaxBodySize))
 		return
 	}
 
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index ef37998..56e3b35 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -799,6 +799,7 @@ func TestStatus(t *testing.T) {
 		headers map[string]string
 		body    string
 	}{
+		// 100 is tested as a special case below
 		{200, nil, ""},
 		{300, map[string]string{"Location": "/image/jpeg"}, `<!doctype html>
 <head>
@@ -822,6 +823,8 @@ func TestStatus(t *testing.T) {
 </html>`},
 		{401, unauthorizedHeaders, ""},
 		{418, nil, "I'm a teapot!"},
+		{500, nil, ""}, // maximum allowed status code
+		{599, nil, ""}, // maximum allowed status code
 	}
 
 	for _, test := range tests {
@@ -848,6 +851,8 @@ func TestStatus(t *testing.T) {
 		{"/status/200/foo", http.StatusNotFound},
 		{"/status/3.14", http.StatusBadRequest},
 		{"/status/foo", http.StatusBadRequest},
+		{"/status/600", http.StatusBadRequest},
+		{"/status/1024", http.StatusBadRequest},
 	}
 
 	for _, test := range errorTests {
@@ -860,6 +865,34 @@ func TestStatus(t *testing.T) {
 			assert.StatusCode(t, resp, test.status)
 		})
 	}
+
+	t.Run("HTTP 100 Continue status code supported", func(t *testing.T) {
+		// The stdlib http client automagically handles 100 Continue responses
+		// by continuing the request until a "final" 200 OK response is
+		// received, which prevents us from confirming that a 100 Continue
+		// response is sent when using the http client directly.
+		//
+		// So, here we instead manally write the request to the wire and read
+		// the initial response, which will give us access to the 100 Continue
+		// indication we need.
+		t.Parallel()
+
+		conn, err := net.Dial("tcp", srv.Listener.Addr().String())
+		assert.NilError(t, err)
+		defer conn.Close()
+
+		req := newTestRequest(t, "GET", "/status/100")
+		reqBytes, err := httputil.DumpRequestOut(req, false)
+		assert.NilError(t, err)
+
+		n, err := conn.Write(reqBytes)
+		assert.NilError(t, err)
+		assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")
+
+		resp, err := http.ReadResponse(bufio.NewReader(conn), req)
+		assert.NilError(t, err)
+		assert.StatusCode(t, resp, http.StatusContinue)
+	})
 }
 
 func TestUnstable(t *testing.T) {
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index aa41c79..76db70b 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -234,6 +234,21 @@ func encodeData(body []byte, contentType string) string {
 	return string("data:" + contentType + ";base64," + data)
 }
 
+func parseStatusCode(input string) (int, error) {
+	return parseBoundedStatusCode(input, 100, 599)
+}
+
+func parseBoundedStatusCode(input string, min, max int) (int, error) {
+	code, err := strconv.Atoi(input)
+	if err != nil {
+		return 0, fmt.Errorf("invalid status code: %q: %w", input, err)
+	}
+	if code < min || code > max {
+		return 0, fmt.Errorf("invalid status code: %d not in range [%d, %d]", code, min, max)
+	}
+	return code, nil
+}
+
 // parseDuration takes a user's input as a string and attempts to convert it
 // into a time.Duration. If not given as a go-style duration string, the input
 // is assumed to be seconds as a float.
-- 
GitLab