Skip to content
Snippets Groups Projects
Commit 355c0a67 authored by Will McCutchen's avatar Will McCutchen
Browse files

Extract duration parsing into helpers & strictly check bounds

parent bcbf8a52
No related branches found
No related tags found
No related merge requests found
...@@ -439,21 +439,11 @@ func (h *HTTPBin) Delay(w http.ResponseWriter, r *http.Request) { ...@@ -439,21 +439,11 @@ func (h *HTTPBin) Delay(w http.ResponseWriter, r *http.Request) {
return return
} }
delay, err := time.ParseDuration(parts[2]) delay, err := parseBoundedDuration(parts[2], 0, h.options.MaxResponseTime)
if err != nil {
n, err := strconv.ParseFloat(parts[2], 64)
if err != nil { if err != nil {
http.Error(w, "Invalid duration", http.StatusBadRequest) http.Error(w, "Invalid duration", http.StatusBadRequest)
return return
} }
delay = time.Duration(n*1000) * time.Millisecond
}
if delay > h.options.MaxResponseTime {
delay = h.options.MaxResponseTime
} else if delay < 0 {
delay = 0
}
<-time.After(delay) <-time.After(delay)
h.RequestWithBody(w, r) h.RequestWithBody(w, r)
......
...@@ -1236,15 +1236,11 @@ func TestDelay(t *testing.T) { ...@@ -1236,15 +1236,11 @@ func TestDelay(t *testing.T) {
// go-style durations are supported // go-style durations are supported
{"/delay/0ms", 0}, {"/delay/0ms", 0},
{"/delay/500ms", 500 * time.Millisecond}, {"/delay/500ms", 500 * time.Millisecond},
{"/delay/1.5s", maxResponseTime},
// as are floating point seconds // as are floating point seconds
{"/delay/0", 0}, {"/delay/0", 0},
{"/delay/0.5", 500 * time.Millisecond}, {"/delay/0.5", 500 * time.Millisecond},
{"/delay/1", maxResponseTime}, {"/delay/1", maxResponseTime},
{"/delay/1.5", maxResponseTime},
{"/delay/-1", 0},
{"/delay/-3.14", 0},
} }
for _, test := range okTests { for _, test := range okTests {
t.Run("ok"+test.url, func(t *testing.T) { t.Run("ok"+test.url, func(t *testing.T) {
...@@ -1278,6 +1274,12 @@ func TestDelay(t *testing.T) { ...@@ -1278,6 +1274,12 @@ func TestDelay(t *testing.T) {
{"/delay", http.StatusNotFound}, {"/delay", http.StatusNotFound},
{"/delay/foo", http.StatusBadRequest}, {"/delay/foo", http.StatusBadRequest},
{"/delay/1/foo", http.StatusNotFound}, {"/delay/1/foo", http.StatusNotFound},
{"/delay/1.5s", http.StatusBadRequest},
{"/delay/-1ms", http.StatusBadRequest},
{"/delay/1.5", http.StatusBadRequest},
{"/delay/-1", http.StatusBadRequest},
{"/delay/-3.14", http.StatusBadRequest},
} }
for _, test := range badTests { for _, test := range badTests {
......
...@@ -6,7 +6,9 @@ import ( ...@@ -6,7 +6,9 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time"
) )
func getOrigin(r *http.Request) string { func getOrigin(r *http.Request) string {
...@@ -102,3 +104,33 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe ...@@ -102,3 +104,33 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe
} }
return nil return nil
} }
// parseDuration takes a user's input as a string and attempts to convert it
// into a time.Duration
func parseDuration(input string) (time.Duration, error) {
d, err := time.ParseDuration(input)
if err != nil {
n, err := strconv.ParseFloat(input, 64)
if err != nil {
return 0, err
}
d = time.Duration(n*1000) * time.Millisecond
}
return d, nil
}
// parseBoundedDuration parses a time.Duration from user input and ensures that
// it is within a given maximum and minimum time
func parseBoundedDuration(input string, min, max time.Duration) (time.Duration, error) {
d, err := parseDuration(input)
if err != nil {
return 0, err
}
if d > max {
err = fmt.Errorf("duration %s longer than %s", d, max)
} else if d < min {
err = fmt.Errorf("duration %s shorter than %s", d, min)
}
return d, err
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment