From d917d404bb7140172a3d870bb88e94780ac9393d Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Wed, 30 Sep 2020 20:36:16 -0400
Subject: [PATCH] Make /drip more compatible with httpbin's implementation:

- Set Content-Length

- Send status code before any delay

- Match default parameter values
---
 httpbin/handlers.go      | 44 ++++++++++++++++++++--------------------
 httpbin/handlers_test.go | 33 +++++++++++++++++++++++++-----
 httpbin/httpbin.go       | 30 +++++++++++++++++++++++++--
 3 files changed, 78 insertions(+), 29 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index 8e0d068..fb307e2 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -481,15 +481,16 @@ func (h *HTTPBin) Delay(w http.ResponseWriter, r *http.Request) {
 func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 	q := r.URL.Query()
 
-	duration := time.Duration(0)
-	delay := time.Duration(0)
-	numbytes := int64(10)
-	code := http.StatusOK
+	var (
+		duration = h.DefaultParams.DripDuration
+		delay    = h.DefaultParams.DripDelay
+		numBytes = h.DefaultParams.DripNumBytes
+		code     = http.StatusOK
 
-	var err error
+		err error
+	)
 
-	userDuration := q.Get("duration")
-	if userDuration != "" {
+	if userDuration := q.Get("duration"); userDuration != "" {
 		duration, err = parseBoundedDuration(userDuration, 0, h.MaxDuration)
 		if err != nil {
 			http.Error(w, "Invalid duration", http.StatusBadRequest)
@@ -497,8 +498,7 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	userDelay := q.Get("delay")
-	if userDelay != "" {
+	if userDelay := q.Get("delay"); userDelay != "" {
 		delay, err = parseBoundedDuration(userDelay, 0, h.MaxDuration)
 		if err != nil {
 			http.Error(w, "Invalid delay", http.StatusBadRequest)
@@ -506,17 +506,15 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	userNumBytes := q.Get("numbytes")
-	if userNumBytes != "" {
-		numbytes, err = strconv.ParseInt(userNumBytes, 10, 64)
-		if err != nil || numbytes <= 0 || numbytes > h.MaxBodySize {
+	if userNumBytes := q.Get("numbytes"); userNumBytes != "" {
+		numBytes, err = strconv.ParseInt(userNumBytes, 10, 64)
+		if err != nil || numBytes <= 0 || numBytes > h.MaxBodySize {
 			http.Error(w, "Invalid numbytes", http.StatusBadRequest)
 			return
 		}
 	}
 
-	userCode := q.Get("code")
-	if userCode != "" {
+	if userCode := q.Get("code"); userCode != "" {
 		code, err = strconv.Atoi(userCode)
 		if err != nil || code < 100 || code >= 600 {
 			http.Error(w, "Invalid code", http.StatusBadRequest)
@@ -529,7 +527,13 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	pause := duration / time.Duration(numbytes)
+	pause := duration / time.Duration(numBytes)
+	flusher := w.(http.Flusher)
+
+	w.Header().Set("Content-Type", "application/octet-stream")
+	w.Header().Set("Content-Length", fmt.Sprintf("%d", numBytes))
+	w.WriteHeader(code)
+	flusher.Flush()
 
 	select {
 	case <-r.Context().Done():
@@ -537,13 +541,9 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
 	case <-time.After(delay):
 	}
 
-	w.WriteHeader(code)
-	w.Header().Set("Content-Type", "application/octet-stream")
-
-	f := w.(http.Flusher)
-	for i := int64(0); i < numbytes; i++ {
+	for i := int64(0); i < numBytes; i++ {
 		w.Write([]byte("*"))
-		f.Flush()
+		flusher.Flush()
 
 		select {
 		case <-r.Context().Done():
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 7873586..1cc62f5 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -27,7 +27,14 @@ const maxBodySize int64 = 1024 * 1024
 const maxDuration time.Duration = 1 * time.Second
 const alphanumLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
 
+var testDefaultParams = DefaultParams{
+	DripDelay:    0,
+	DripDuration: 100 * time.Millisecond,
+	DripNumBytes: 10,
+}
+
 var app = New(
+	WithDefaultParams(testDefaultParams),
 	WithMaxBodySize(maxBodySize),
 	WithMaxDuration(maxDuration),
 	WithObserver(StdLogObserver(log.New(ioutil.Discard, "", 0))),
@@ -1507,7 +1514,6 @@ func TestDrip(t *testing.T) {
 	for _, test := range okTests {
 		t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) {
 			url := "/drip?" + test.params.Encode()
-
 			start := time.Now()
 
 			r, _ := http.NewRequest("GET", url, nil)
@@ -1516,8 +1522,9 @@ func TestDrip(t *testing.T) {
 
 			elapsed := time.Since(start)
 
-			assertHeader(t, w, "Content-Type", "application/octet-stream")
 			assertStatusCode(t, w, test.code)
+			assertHeader(t, w, "Content-Type", "application/octet-stream")
+			assertHeader(t, w, "Content-Length", strconv.Itoa(test.numbytes))
 			if len(w.Body.Bytes()) != test.numbytes {
 				t.Fatalf("expected %d bytes, got %d", test.numbytes, len(w.Body.Bytes()))
 			}
@@ -1532,13 +1539,29 @@ func TestDrip(t *testing.T) {
 		srv := httptest.NewServer(handler)
 		defer srv.Close()
 
+		// For this test, we expect the client to time out and cancel the
+		// request after 10ms.  The handler should immediately write a 200 OK
+		// status before the client timeout, preventing a client error, but it
+		// will wait 500ms to write anything to the response body.
+		//
+		// So, we're testing that a) the client got an immediate 200 OK but
+		// that b) the response body was empty.
 		client := http.Client{
 			Timeout: time.Duration(10 * time.Millisecond),
 		}
 		resp, err := client.Get(srv.URL + "/drip?duration=500ms&delay=500ms")
-		if err == nil {
-			body, _ := ioutil.ReadAll(resp.Body)
-			t.Fatalf("expected timeout error, got %d %s", resp.StatusCode, body)
+		if err != nil {
+			t.Fatalf("unexpected error: %s", err)
+		}
+		defer resp.Body.Close()
+
+		body, _ := ioutil.ReadAll(resp.Body)
+		if err != nil {
+			t.Fatalf("error reading response body: %s", err)
+		}
+
+		if len(body) != 0 {
+			t.Fatalf("expected client timeout before body was written, got body %q", string(body))
 		}
 	})
 
diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go
index 7c00474..7d7a8a3 100644
--- a/httpbin/httpbin.go
+++ b/httpbin/httpbin.go
@@ -96,6 +96,24 @@ type HTTPBin struct {
 
 	// Observer called with the result of each handled request
 	Observer Observer
+
+	// Default parameter values
+	DefaultParams DefaultParams
+}
+
+// DefaultParams defines default parameter values
+type DefaultParams struct {
+	DripDuration time.Duration
+	DripDelay    time.Duration
+	DripNumBytes int64
+}
+
+// DefaultDefaultParams defines the DefaultParams that are used by default. In
+// general, these should match the original httpbin.org's defaults.
+var DefaultDefaultParams = DefaultParams{
+	DripDuration: 2 * time.Second,
+	DripDelay:    2 * time.Second,
+	DripNumBytes: 10,
 }
 
 // Handler returns an http.Handler that exposes all HTTPBin endpoints
@@ -197,8 +215,9 @@ func (h *HTTPBin) Handler() http.Handler {
 // New creates a new HTTPBin instance
 func New(opts ...OptionFunc) *HTTPBin {
 	h := &HTTPBin{
-		MaxBodySize: DefaultMaxBodySize,
-		MaxDuration: DefaultMaxDuration,
+		MaxBodySize:   DefaultMaxBodySize,
+		MaxDuration:   DefaultMaxDuration,
+		DefaultParams: DefaultDefaultParams,
 	}
 	for _, opt := range opts {
 		opt(h)
@@ -210,6 +229,13 @@ func New(opts ...OptionFunc) *HTTPBin {
 // instance
 type OptionFunc func(*HTTPBin)
 
+// WithDefaultParams sets the default params handlers will use
+func WithDefaultParams(defaultParams DefaultParams) OptionFunc {
+	return func(h *HTTPBin) {
+		h.DefaultParams = defaultParams
+	}
+}
+
 // WithMaxBodySize sets the maximum amount of memory
 func WithMaxBodySize(m int64) OptionFunc {
 	return func(h *HTTPBin) {
-- 
GitLab