From aac468e9b634dca075399b5441cf86e044e14451 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Fri, 11 Nov 2022 00:20:37 -0500
Subject: [PATCH] Rework JSON serialization (#94)

Instead of serializing JSON to a temporary slice of bytes before writing
to the response, instead write directly to the response.

Also, pretty-print JSON responses for better readability (and to match
httpbin).

Note: With this change, we rely much more on the underlying net/http
server's implicit handling of the Content-Length header. One notable
side effect of this is that responses to HEAD requests no longer include
a Content-Length.
---
 httpbin/handlers.go      | 131 +++++++++++++++++----------------------
 httpbin/handlers_test.go |  38 ++++++------
 httpbin/helpers.go       |  16 ++++-
 httpbin/httpbin.go       |  15 +----
 httpbin/middleware.go    |   2 +-
 5 files changed, 95 insertions(+), 107 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index 9a7b973..a401570 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -40,24 +40,26 @@ func (h *HTTPBin) UTF8(w http.ResponseWriter, r *http.Request) {
 
 // Get handles HTTP GET requests
 func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
-	resp := &noBodyResponse{
+	writeJSON(http.StatusOK, w, &noBodyResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
 		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
-	}
-	body, _ := jsonMarshalNoEscape(resp)
-	writeJSON(w, body, http.StatusOK)
+	})
 }
 
 // Anything returns anything that is passed to request.
 func (h *HTTPBin) Anything(w http.ResponseWriter, r *http.Request) {
-	switch r.Method {
-	case "HEAD":
+	// Short-circuit for HEAD requests, which should be handled like regular
+	// GET requests (where the autohead middleware will take care of discarding
+	// the body)
+	if r.Method == http.MethodHead {
 		h.Get(w, r)
-	default:
-		h.RequestWithBody(w, r)
+		return
 	}
+	// All other requests will be handled the same.  For compatibility with
+	// httpbin, the /anything endpoint even allows GET requests to have bodies.
+	h.RequestWithBody(w, r)
 }
 
 // RequestWithBody handles POST, PUT, and PATCH requests
@@ -75,72 +77,72 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	body, _ := jsonMarshalNoEscape(resp)
-	writeJSON(w, body, http.StatusOK)
+	writeJSON(http.StatusOK, w, resp)
 }
 
 // Gzip returns a gzipped response
 func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
-	resp := &gzipResponse{
+	var (
+		buf bytes.Buffer
+		gzw = gzip.NewWriter(&buf)
+	)
+	mustMarshalJSON(gzw, &noBodyResponse{
+		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
 		Origin:  getClientIP(r),
 		Gzipped: true,
-	}
-	body, _ := jsonMarshalNoEscape(resp)
-
-	buf := &bytes.Buffer{}
-	gzw := gzip.NewWriter(buf)
-	gzw.Write(body)
+	})
 	gzw.Close()
 
-	gzBody := buf.Bytes()
-
+	body := buf.Bytes()
 	w.Header().Set("Content-Encoding", "gzip")
-	writeJSON(w, gzBody, http.StatusOK)
+	w.Header().Set("Content-Type", jsonContentType)
+	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+	w.WriteHeader(http.StatusOK)
+	w.Write(body)
 }
 
 // Deflate returns a gzipped response
 func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
-	resp := &deflateResponse{
+	var (
+		buf bytes.Buffer
+		zw  = zlib.NewWriter(&buf)
+	)
+	mustMarshalJSON(zw, &noBodyResponse{
+		Args:     r.URL.Query(),
 		Headers:  getRequestHeaders(r),
 		Origin:   getClientIP(r),
 		Deflated: true,
-	}
-	body, _ := jsonMarshalNoEscape(resp)
-
-	buf := &bytes.Buffer{}
-	w2 := zlib.NewWriter(buf)
-	w2.Write(body)
-	w2.Close()
-
-	compressedBody := buf.Bytes()
+	})
+	zw.Close()
 
+	body := buf.Bytes()
 	w.Header().Set("Content-Encoding", "deflate")
-	writeJSON(w, compressedBody, http.StatusOK)
+	w.Header().Set("Content-Type", jsonContentType)
+	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+	w.WriteHeader(http.StatusOK)
+	w.Write(body)
 }
 
 // IP echoes the IP address of the incoming request
 func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) {
-	body, _ := jsonMarshalNoEscape(&ipResponse{
+	writeJSON(http.StatusOK, w, &ipResponse{
 		Origin: getClientIP(r),
 	})
-	writeJSON(w, body, http.StatusOK)
 }
 
 // UserAgent echoes the incoming User-Agent header
 func (h *HTTPBin) UserAgent(w http.ResponseWriter, r *http.Request) {
-	body, _ := jsonMarshalNoEscape(&userAgentResponse{
+	writeJSON(http.StatusOK, w, &userAgentResponse{
 		UserAgent: r.Header.Get("User-Agent"),
 	})
-	writeJSON(w, body, http.StatusOK)
 }
 
 // Headers echoes the incoming request headers
 func (h *HTTPBin) Headers(w http.ResponseWriter, r *http.Request) {
-	body, _ := jsonMarshalNoEscape(&headersResponse{
+	writeJSON(http.StatusOK, w, &headersResponse{
 		Headers: getRequestHeaders(r),
 	})
-	writeJSON(w, body, http.StatusOK)
 }
 
 type statusCase struct {
@@ -303,11 +305,10 @@ func (h *HTTPBin) ResponseHeaders(w http.ResponseWriter, r *http.Request) {
 			w.Header().Add(k, v)
 		}
 	}
-	body, _ := jsonMarshalNoEscape(args)
 	if contentType := w.Header().Get("Content-Type"); contentType == "" {
 		w.Header().Set("Content-Type", jsonContentType)
 	}
-	w.Write(body)
+	mustMarshalJSON(w, args)
 }
 
 func redirectLocation(r *http.Request, relative bool, n int) string {
@@ -401,8 +402,7 @@ func (h *HTTPBin) Cookies(w http.ResponseWriter, r *http.Request) {
 	for _, c := range r.Cookies() {
 		resp[c.Name] = c.Value
 	}
-	body, _ := jsonMarshalNoEscape(resp)
-	writeJSON(w, body, http.StatusOK)
+	writeJSON(http.StatusOK, w, resp)
 }
 
 // SetCookies sets cookies as specified in query params and redirects to
@@ -456,11 +456,10 @@ func (h *HTTPBin) BasicAuth(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("WWW-Authenticate", `Basic realm="Fake Realm"`)
 	}
 
-	body, _ := jsonMarshalNoEscape(&authResponse{
+	writeJSON(status, w, authResponse{
 		Authorized: authorized,
 		User:       givenUser,
 	})
-	writeJSON(w, body, status)
 }
 
 // HiddenBasicAuth requires HTTP Basic authentication but returns a status of
@@ -482,11 +481,10 @@ func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	body, _ := jsonMarshalNoEscape(&authResponse{
+	writeJSON(http.StatusOK, w, authResponse{
 		Authorized: authorized,
 		User:       givenUser,
 	})
-	writeJSON(w, body, http.StatusOK)
 }
 
 // Stream responds with max(n, 100) lines of JSON-encoded request data.
@@ -518,8 +516,9 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
 	f := w.(http.Flusher)
 	for i := 0; i < n; i++ {
 		resp.ID = i
-		line, _ := jsonMarshalNoEscape(resp)
-		w.Write(line)
+		// Call json.Marshal directly to avoid pretty printing
+		line, _ := json.Marshal(resp)
+		w.Write(append(line, '\n'))
 		f.Flush()
 	}
 }
@@ -708,7 +707,7 @@ func (h *HTTPBin) CacheControl(w http.ResponseWriter, r *http.Request) {
 	h.Get(w, r)
 }
 
-// ETag assumes the resource has the given etag and response to If-None-Match
+// ETag assumes the resource has the given etag and responds to If-None-Match
 // and If-Match headers appropriately.
 func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
 	parts := strings.Split(r.URL.Path, "/")
@@ -720,19 +719,17 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
 	etag := parts[2]
 	w.Header().Set("ETag", fmt.Sprintf(`"%s"`, etag))
 
-	// TODO: This mostly duplicates the work of Get() above, should this be
-	// pulled into a little helper?
-	resp := &noBodyResponse{
+	var buf bytes.Buffer
+	mustMarshalJSON(&buf, noBodyResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
 		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
-	}
-	body, _ := jsonMarshalNoEscape(resp)
+	})
 
 	// Let http.ServeContent deal with If-None-Match and If-Match headers:
 	// https://golang.org/pkg/net/http/#ServeContent
-	http.ServeContent(w, r, "response.json", time.Now(), bytes.NewReader(body))
+	http.ServeContent(w, r, "response.json", time.Now(), bytes.NewReader(buf.Bytes()))
 }
 
 // Bytes returns N random bytes generated with an optional seed
@@ -954,19 +951,17 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	resp, _ := jsonMarshalNoEscape(&authResponse{
+	writeJSON(http.StatusOK, w, authResponse{
 		Authorized: true,
 		User:       user,
 	})
-	writeJSON(w, resp, http.StatusOK)
 }
 
 // UUID - responds with a generated UUID
 func (h *HTTPBin) UUID(w http.ResponseWriter, r *http.Request) {
-	resp, _ := jsonMarshalNoEscape(&uuidResponse{
+	writeJSON(http.StatusOK, w, uuidResponse{
 		UUID: uuidv4(),
 	})
-	writeJSON(w, resp, http.StatusOK)
 }
 
 // Base64 - encodes/decodes input data
@@ -995,7 +990,9 @@ func (h *HTTPBin) Base64(w http.ResponseWriter, r *http.Request) {
 
 // JSON - returns a sample json
 func (h *HTTPBin) JSON(w http.ResponseWriter, r *http.Request) {
-	writeJSON(w, mustStaticAsset("sample.json"), http.StatusOK)
+	w.Header().Set("Content-Type", jsonContentType)
+	w.WriteHeader(http.StatusOK)
+	w.Write(mustStaticAsset("sample.json"))
 }
 
 // Bearer - Prompts the user for authorization using bearer authentication.
@@ -1007,27 +1004,15 @@ func (h *HTTPBin) Bearer(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusUnauthorized)
 		return
 	}
-	body, _ := jsonMarshalNoEscape(&bearerResponse{
+	writeJSON(http.StatusOK, w, bearerResponse{
 		Authenticated: true,
 		Token:         tokenFields[1],
 	})
-	writeJSON(w, body, http.StatusOK)
 }
 
 // Hostname - returns the hostname.
 func (h *HTTPBin) Hostname(w http.ResponseWriter, r *http.Request) {
-	body, _ := jsonMarshalNoEscape(hostnameResponse{
+	writeJSON(http.StatusOK, w, hostnameResponse{
 		Hostname: h.hostname,
 	})
-	writeJSON(w, body, http.StatusOK)
-}
-
-// json.Marshal escapes HTML in strings while httpbin does not, so
-// we need to set up the encoder manually to reproduce that behavior.
-func jsonMarshalNoEscape(value interface{}) ([]byte, error) {
-	buffer := &bytes.Buffer{}
-	encoder := json.NewEncoder(buffer)
-	encoder.SetEscapeHTML(false)
-	err := encoder.Encode(value)
-	return buffer.Bytes(), err
 }
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 366bccd..a5637ae 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -263,16 +263,8 @@ func TestHead(t *testing.T) {
 			assertStatusCode(t, w, http.StatusOK)
 			assertBodyEquals(t, w, "")
 
-			contentLengthStr := w.Header().Get("Content-Length")
-			if contentLengthStr == "" {
-				t.Fatalf("missing Content-Length header in response")
-			}
-			contentLength, err := strconv.Atoi(contentLengthStr)
-			if err != nil {
-				t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err)
-			}
-			if contentLength <= 0 {
-				t.Fatalf("Content-Length %v should be greater than 0", contentLengthStr)
+			if contentLength := w.Header().Get("Content-Length"); contentLength != "" {
+				t.Fatalf("did not expect Content-Length in response to HEAD request")
 			}
 		})
 	}
@@ -496,6 +488,18 @@ func TestAnything(t *testing.T) {
 			testRequestWithBody(t, verb, path)
 		}
 	}
+
+	t.Run("HEAD", func(t *testing.T) {
+		t.Parallel()
+		r, _ := http.NewRequest("HEAD", "/anything", nil)
+		w := httptest.NewRecorder()
+		handler.ServeHTTP(w, r)
+		assertStatusCode(t, w, http.StatusOK)
+		assertBodyEquals(t, w, "")
+		if contentLength := w.Header().Get("Content-Length"); contentLength != "" {
+			t.Fatalf("did not expect Content-Length in response to HEAD request")
+		}
+	})
 }
 
 // getFuncName uses runtime type reflection to get the name of the given
@@ -1591,10 +1595,8 @@ func TestGzip(t *testing.T) {
 	if err != nil {
 		t.Fatalf("error reading gzipped body: %s", err)
 	}
-
-	var resp *gzipResponse
-	err = json.Unmarshal(unzippedBody, &resp)
-	if err != nil {
+	var resp *noBodyResponse
+	if err := json.Unmarshal(unzippedBody, &resp); err != nil {
 		t.Fatalf("error unmarshalling response: %s", err)
 	}
 
@@ -1602,7 +1604,7 @@ func TestGzip(t *testing.T) {
 		t.Fatalf("expected resp.Gzipped == true")
 	}
 
-	if len(unzippedBody) >= zippedContentLength {
+	if len(unzippedBody) <= zippedContentLength {
 		t.Fatalf("expected compressed body")
 	}
 }
@@ -1622,7 +1624,7 @@ func TestDeflate(t *testing.T) {
 		t.Fatalf("missing Content-Length header in response")
 	}
 
-	contentLength, err := strconv.Atoi(contentLengthHeader)
+	compressedContentLength, err := strconv.Atoi(contentLengthHeader)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -1636,7 +1638,7 @@ func TestDeflate(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	var resp *deflateResponse
+	var resp *noBodyResponse
 	err = json.Unmarshal(body, &resp)
 	if err != nil {
 		t.Fatalf("error unmarshalling response: %s", err)
@@ -1646,7 +1648,7 @@ func TestDeflate(t *testing.T) {
 		t.Fatalf("expected resp.Deflated == true")
 	}
 
-	if len(body) >= contentLength {
+	if len(body) <= compressedContentLength {
 		t.Fatalf("expected compressed body")
 	}
 }
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index 48c18c3..705c010 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -82,13 +82,23 @@ func getURL(r *http.Request) *url.URL {
 
 func writeResponse(w http.ResponseWriter, status int, contentType string, body []byte) {
 	w.Header().Set("Content-Type", contentType)
-	w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
 	w.WriteHeader(status)
 	w.Write(body)
 }
 
-func writeJSON(w http.ResponseWriter, body []byte, status int) {
-	writeResponse(w, status, jsonContentType, body)
+func mustMarshalJSON(w io.Writer, val interface{}) {
+	encoder := json.NewEncoder(w)
+	encoder.SetEscapeHTML(false)
+	encoder.SetIndent("", "  ")
+	if err := encoder.Encode(val); err != nil {
+		panic(err.Error())
+	}
+}
+
+func writeJSON(status int, w http.ResponseWriter, val interface{}) {
+	w.Header().Set("Content-Type", jsonContentType)
+	w.WriteHeader(status)
+	mustMarshalJSON(w, val)
 }
 
 func writeHTML(w http.ResponseWriter, body []byte, status int) {
diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go
index 81493a5..f7d9c41 100644
--- a/httpbin/httpbin.go
+++ b/httpbin/httpbin.go
@@ -37,6 +37,9 @@ type noBodyResponse struct {
 	Headers http.Header `json:"headers"`
 	Origin  string      `json:"origin"`
 	URL     string      `json:"url"`
+
+	Deflated bool `json:"deflated,omitempty"`
+	Gzipped  bool `json:"gzipped,omitempty"`
 }
 
 // A generic response for any incoming request that might contain a body (POST,
@@ -60,18 +63,6 @@ type authResponse struct {
 	User       string `json:"user"`
 }
 
-type gzipResponse struct {
-	Headers http.Header `json:"headers"`
-	Origin  string      `json:"origin"`
-	Gzipped bool        `json:"gzipped"`
-}
-
-type deflateResponse struct {
-	Headers  http.Header `json:"headers"`
-	Origin   string      `json:"origin"`
-	Deflated bool        `json:"deflated"`
-}
-
 // An actual stream response body will be made up of one or more of these
 // structs, encoded as JSON and separated by newlines
 type streamResponse struct {
diff --git a/httpbin/middleware.go b/httpbin/middleware.go
index 13e937a..0f9e4e9 100644
--- a/httpbin/middleware.go
+++ b/httpbin/middleware.go
@@ -65,7 +65,7 @@ type headResponseWriter struct {
 }
 
 func (hw *headResponseWriter) Write(b []byte) (int, error) {
-	return 0, nil
+	return len(b), nil
 }
 
 // autohead automatically discards the body of responses to HEAD requests
-- 
GitLab