Skip to content
Snippets Groups Projects
Unverified Commit c9f4177d authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Consistently parse and validate user-provided status codes (#137)

In testing out error handling after #135, I happened to stumble across
an unexpected panic for requests like `/status/1024` where the
user-provided status code is outside the legal bounds. So, here we take
a quick pass to ensure we're parsing and validating status codes the
same way everywhere.
parent 9e306405
No related branches found
No related tags found
No related merge requests found
......@@ -254,9 +254,9 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, nil)
return
}
code, err := strconv.Atoi(parts[2])
code, err := parseStatusCode(parts[2])
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status %q: %w", parts[2], err))
writeError(w, http.StatusBadRequest, err)
return
}
......@@ -297,7 +297,7 @@ func (h *HTTPBin) Unstable(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %w", err))
return
} else if failureRate < 0 || failureRate > 1 {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in interval [0, 1]", err))
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in range [0, 1]", err))
return
}
}
......@@ -414,14 +414,10 @@ func (h *HTTPBin) RedirectTo(w http.ResponseWriter, r *http.Request) {
}
statusCode := http.StatusFound
rawStatusCode := q.Get("status_code")
if rawStatusCode != "" {
statusCode, err = strconv.Atoi(q.Get("status_code"))
if userStatusCode := q.Get("status_code"); userStatusCode != "" {
statusCode, err = parseBoundedStatusCode(userStatusCode, 300, 399)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status code: %w", err))
return
} else if statusCode < 300 || statusCode > 399 {
writeError(w, http.StatusBadRequest, errors.New("invalid status code: must be in range [300, 399]"))
writeError(w, http.StatusBadRequest, err)
return
}
}
......@@ -617,18 +613,15 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %w", err))
return
} else if numBytes < 1 || numBytes > h.MaxBodySize {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in range [1, %d]", numBytes, h.MaxBodySize))
return
}
}
if userCode := q.Get("code"); userCode != "" {
code, err = strconv.Atoi(userCode)
code, err = parseStatusCode(userCode)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %w", err))
return
} else if code < 100 || code >= 600 {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %d not in interval [100, 599]", code))
writeError(w, http.StatusBadRequest, err)
return
}
}
......@@ -713,7 +706,7 @@ func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Accept-Ranges", "bytes")
if numBytes <= 0 || numBytes > h.MaxBodySize {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in range [1, %d]", numBytes, h.MaxBodySize))
return
}
......
......@@ -799,6 +799,7 @@ func TestStatus(t *testing.T) {
headers map[string]string
body string
}{
// 100 is tested as a special case below
{200, nil, ""},
{300, map[string]string{"Location": "/image/jpeg"}, `<!doctype html>
<head>
......@@ -822,6 +823,8 @@ func TestStatus(t *testing.T) {
</html>`},
{401, unauthorizedHeaders, ""},
{418, nil, "I'm a teapot!"},
{500, nil, ""}, // maximum allowed status code
{599, nil, ""}, // maximum allowed status code
}
for _, test := range tests {
......@@ -848,6 +851,8 @@ func TestStatus(t *testing.T) {
{"/status/200/foo", http.StatusNotFound},
{"/status/3.14", http.StatusBadRequest},
{"/status/foo", http.StatusBadRequest},
{"/status/600", http.StatusBadRequest},
{"/status/1024", http.StatusBadRequest},
}
for _, test := range errorTests {
......@@ -860,6 +865,34 @@ func TestStatus(t *testing.T) {
assert.StatusCode(t, resp, test.status)
})
}
t.Run("HTTP 100 Continue status code supported", func(t *testing.T) {
// The stdlib http client automagically handles 100 Continue responses
// by continuing the request until a "final" 200 OK response is
// received, which prevents us from confirming that a 100 Continue
// response is sent when using the http client directly.
//
// So, here we instead manally write the request to the wire and read
// the initial response, which will give us access to the 100 Continue
// indication we need.
t.Parallel()
conn, err := net.Dial("tcp", srv.Listener.Addr().String())
assert.NilError(t, err)
defer conn.Close()
req := newTestRequest(t, "GET", "/status/100")
reqBytes, err := httputil.DumpRequestOut(req, false)
assert.NilError(t, err)
n, err := conn.Write(reqBytes)
assert.NilError(t, err)
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
assert.NilError(t, err)
assert.StatusCode(t, resp, http.StatusContinue)
})
}
func TestUnstable(t *testing.T) {
......
......@@ -234,6 +234,21 @@ func encodeData(body []byte, contentType string) string {
return string("data:" + contentType + ";base64," + data)
}
func parseStatusCode(input string) (int, error) {
return parseBoundedStatusCode(input, 100, 599)
}
func parseBoundedStatusCode(input string, min, max int) (int, error) {
code, err := strconv.Atoi(input)
if err != nil {
return 0, fmt.Errorf("invalid status code: %q: %w", input, err)
}
if code < min || code > max {
return 0, fmt.Errorf("invalid status code: %d not in range [%d, %d]", code, min, max)
}
return code, nil
}
// parseDuration takes a user's input as a string and attempts to convert it
// into a time.Duration. If not given as a go-style duration string, the input
// is assumed to be seconds as a float.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment