From 355c0a671200919692ee3facb51638a5659bfecd Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Sat, 31 Dec 2016 23:04:04 -0500
Subject: [PATCH] Extract duration parsing into helpers & strictly check bounds

---
 httpbin/handlers.go      | 16 +++-------------
 httpbin/handlers_test.go | 10 ++++++----
 httpbin/helpers.go       | 32 ++++++++++++++++++++++++++++++++
 3 files changed, 41 insertions(+), 17 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index f4cd31b..e8a3c07 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -439,20 +439,10 @@ func (h *HTTPBin) Delay(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	delay, err := time.ParseDuration(parts[2])
+	delay, err := parseBoundedDuration(parts[2], 0, h.options.MaxResponseTime)
 	if err != nil {
-		n, err := strconv.ParseFloat(parts[2], 64)
-		if err != nil {
-			http.Error(w, "Invalid duration", http.StatusBadRequest)
-			return
-		}
-		delay = time.Duration(n*1000) * time.Millisecond
-	}
-
-	if delay > h.options.MaxResponseTime {
-		delay = h.options.MaxResponseTime
-	} else if delay < 0 {
-		delay = 0
+		http.Error(w, "Invalid duration", http.StatusBadRequest)
+		return
 	}
 
 	<-time.After(delay)
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 6071562..07036a3 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -1236,15 +1236,11 @@ func TestDelay(t *testing.T) {
 		// go-style durations are supported
 		{"/delay/0ms", 0},
 		{"/delay/500ms", 500 * time.Millisecond},
-		{"/delay/1.5s", maxResponseTime},
 
 		// as are floating point seconds
 		{"/delay/0", 0},
 		{"/delay/0.5", 500 * time.Millisecond},
 		{"/delay/1", maxResponseTime},
-		{"/delay/1.5", maxResponseTime},
-		{"/delay/-1", 0},
-		{"/delay/-3.14", 0},
 	}
 	for _, test := range okTests {
 		t.Run("ok"+test.url, func(t *testing.T) {
@@ -1278,6 +1274,12 @@ func TestDelay(t *testing.T) {
 		{"/delay", http.StatusNotFound},
 		{"/delay/foo", http.StatusBadRequest},
 		{"/delay/1/foo", http.StatusNotFound},
+
+		{"/delay/1.5s", http.StatusBadRequest},
+		{"/delay/-1ms", http.StatusBadRequest},
+		{"/delay/1.5", http.StatusBadRequest},
+		{"/delay/-1", http.StatusBadRequest},
+		{"/delay/-3.14", http.StatusBadRequest},
 	}
 
 	for _, test := range badTests {
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index e5f9e6a..f5e79a5 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -6,7 +6,9 @@ import (
 	"io/ioutil"
 	"net/http"
 	"net/url"
+	"strconv"
 	"strings"
+	"time"
 )
 
 func getOrigin(r *http.Request) string {
@@ -102,3 +104,33 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe
 	}
 	return nil
 }
+
+// parseDuration takes a user's input as a string and attempts to convert it
+// into a time.Duration
+func parseDuration(input string) (time.Duration, error) {
+	d, err := time.ParseDuration(input)
+	if err != nil {
+		n, err := strconv.ParseFloat(input, 64)
+		if err != nil {
+			return 0, err
+		}
+		d = time.Duration(n*1000) * time.Millisecond
+	}
+	return d, nil
+}
+
+// parseBoundedDuration parses a time.Duration from user input and ensures that
+// it is within a given maximum and minimum time
+func parseBoundedDuration(input string, min, max time.Duration) (time.Duration, error) {
+	d, err := parseDuration(input)
+	if err != nil {
+		return 0, err
+	}
+
+	if d > max {
+		err = fmt.Errorf("duration %s longer than %s", d, max)
+	} else if d < min {
+		err = fmt.Errorf("duration %s shorter than %s", d, min)
+	}
+	return d, err
+}
-- 
GitLab