Skip to content
Snippets Groups Projects
Unverified Commit aac468e9 authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Rework JSON serialization (#94)

Instead of serializing JSON to a temporary slice of bytes before writing
to the response, instead write directly to the response.

Also, pretty-print JSON responses for better readability (and to match
httpbin).

Note: With this change, we rely much more on the underlying net/http
server's implicit handling of the Content-Length header. One notable
side effect of this is that responses to HEAD requests no longer include
a Content-Length.
parent bd122ea6
No related branches found
No related tags found
No related merge requests found
...@@ -40,24 +40,26 @@ func (h *HTTPBin) UTF8(w http.ResponseWriter, r *http.Request) { ...@@ -40,24 +40,26 @@ func (h *HTTPBin) UTF8(w http.ResponseWriter, r *http.Request) {
// Get handles HTTP GET requests // Get handles HTTP GET requests
func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
resp := &noBodyResponse{ writeJSON(http.StatusOK, w, &noBodyResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getClientIP(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} })
body, _ := jsonMarshalNoEscape(resp)
writeJSON(w, body, http.StatusOK)
} }
// Anything returns anything that is passed to request. // Anything returns anything that is passed to request.
func (h *HTTPBin) Anything(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Anything(w http.ResponseWriter, r *http.Request) {
switch r.Method { // Short-circuit for HEAD requests, which should be handled like regular
case "HEAD": // GET requests (where the autohead middleware will take care of discarding
// the body)
if r.Method == http.MethodHead {
h.Get(w, r) h.Get(w, r)
default: return
h.RequestWithBody(w, r)
} }
// All other requests will be handled the same. For compatibility with
// httpbin, the /anything endpoint even allows GET requests to have bodies.
h.RequestWithBody(w, r)
} }
// RequestWithBody handles POST, PUT, and PATCH requests // RequestWithBody handles POST, PUT, and PATCH requests
...@@ -75,72 +77,72 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { ...@@ -75,72 +77,72 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
return return
} }
body, _ := jsonMarshalNoEscape(resp) writeJSON(http.StatusOK, w, resp)
writeJSON(w, body, http.StatusOK)
} }
// Gzip returns a gzipped response // Gzip returns a gzipped response
func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
resp := &gzipResponse{ var (
buf bytes.Buffer
gzw = gzip.NewWriter(&buf)
)
mustMarshalJSON(gzw, &noBodyResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getClientIP(r), Origin: getClientIP(r),
Gzipped: true, Gzipped: true,
} })
body, _ := jsonMarshalNoEscape(resp)
buf := &bytes.Buffer{}
gzw := gzip.NewWriter(buf)
gzw.Write(body)
gzw.Close() gzw.Close()
gzBody := buf.Bytes() body := buf.Bytes()
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
writeJSON(w, gzBody, http.StatusOK) w.Header().Set("Content-Type", jsonContentType)
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
} }
// Deflate returns a gzipped response // Deflate returns a gzipped response
func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
resp := &deflateResponse{ var (
buf bytes.Buffer
zw = zlib.NewWriter(&buf)
)
mustMarshalJSON(zw, &noBodyResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getClientIP(r), Origin: getClientIP(r),
Deflated: true, Deflated: true,
} })
body, _ := jsonMarshalNoEscape(resp) zw.Close()
buf := &bytes.Buffer{}
w2 := zlib.NewWriter(buf)
w2.Write(body)
w2.Close()
compressedBody := buf.Bytes()
body := buf.Bytes()
w.Header().Set("Content-Encoding", "deflate") w.Header().Set("Content-Encoding", "deflate")
writeJSON(w, compressedBody, http.StatusOK) w.Header().Set("Content-Type", jsonContentType)
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
} }
// IP echoes the IP address of the incoming request // IP echoes the IP address of the incoming request
func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) {
body, _ := jsonMarshalNoEscape(&ipResponse{ writeJSON(http.StatusOK, w, &ipResponse{
Origin: getClientIP(r), Origin: getClientIP(r),
}) })
writeJSON(w, body, http.StatusOK)
} }
// UserAgent echoes the incoming User-Agent header // UserAgent echoes the incoming User-Agent header
func (h *HTTPBin) UserAgent(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) UserAgent(w http.ResponseWriter, r *http.Request) {
body, _ := jsonMarshalNoEscape(&userAgentResponse{ writeJSON(http.StatusOK, w, &userAgentResponse{
UserAgent: r.Header.Get("User-Agent"), UserAgent: r.Header.Get("User-Agent"),
}) })
writeJSON(w, body, http.StatusOK)
} }
// Headers echoes the incoming request headers // Headers echoes the incoming request headers
func (h *HTTPBin) Headers(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Headers(w http.ResponseWriter, r *http.Request) {
body, _ := jsonMarshalNoEscape(&headersResponse{ writeJSON(http.StatusOK, w, &headersResponse{
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
}) })
writeJSON(w, body, http.StatusOK)
} }
type statusCase struct { type statusCase struct {
...@@ -303,11 +305,10 @@ func (h *HTTPBin) ResponseHeaders(w http.ResponseWriter, r *http.Request) { ...@@ -303,11 +305,10 @@ func (h *HTTPBin) ResponseHeaders(w http.ResponseWriter, r *http.Request) {
w.Header().Add(k, v) w.Header().Add(k, v)
} }
} }
body, _ := jsonMarshalNoEscape(args)
if contentType := w.Header().Get("Content-Type"); contentType == "" { if contentType := w.Header().Get("Content-Type"); contentType == "" {
w.Header().Set("Content-Type", jsonContentType) w.Header().Set("Content-Type", jsonContentType)
} }
w.Write(body) mustMarshalJSON(w, args)
} }
func redirectLocation(r *http.Request, relative bool, n int) string { func redirectLocation(r *http.Request, relative bool, n int) string {
...@@ -401,8 +402,7 @@ func (h *HTTPBin) Cookies(w http.ResponseWriter, r *http.Request) { ...@@ -401,8 +402,7 @@ func (h *HTTPBin) Cookies(w http.ResponseWriter, r *http.Request) {
for _, c := range r.Cookies() { for _, c := range r.Cookies() {
resp[c.Name] = c.Value resp[c.Name] = c.Value
} }
body, _ := jsonMarshalNoEscape(resp) writeJSON(http.StatusOK, w, resp)
writeJSON(w, body, http.StatusOK)
} }
// SetCookies sets cookies as specified in query params and redirects to // SetCookies sets cookies as specified in query params and redirects to
...@@ -456,11 +456,10 @@ func (h *HTTPBin) BasicAuth(w http.ResponseWriter, r *http.Request) { ...@@ -456,11 +456,10 @@ func (h *HTTPBin) BasicAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="Fake Realm"`) w.Header().Set("WWW-Authenticate", `Basic realm="Fake Realm"`)
} }
body, _ := jsonMarshalNoEscape(&authResponse{ writeJSON(status, w, authResponse{
Authorized: authorized, Authorized: authorized,
User: givenUser, User: givenUser,
}) })
writeJSON(w, body, status)
} }
// HiddenBasicAuth requires HTTP Basic authentication but returns a status of // HiddenBasicAuth requires HTTP Basic authentication but returns a status of
...@@ -482,11 +481,10 @@ func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) { ...@@ -482,11 +481,10 @@ func (h *HTTPBin) HiddenBasicAuth(w http.ResponseWriter, r *http.Request) {
return return
} }
body, _ := jsonMarshalNoEscape(&authResponse{ writeJSON(http.StatusOK, w, authResponse{
Authorized: authorized, Authorized: authorized,
User: givenUser, User: givenUser,
}) })
writeJSON(w, body, http.StatusOK)
} }
// Stream responds with max(n, 100) lines of JSON-encoded request data. // Stream responds with max(n, 100) lines of JSON-encoded request data.
...@@ -518,8 +516,9 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { ...@@ -518,8 +516,9 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
f := w.(http.Flusher) f := w.(http.Flusher)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
resp.ID = i resp.ID = i
line, _ := jsonMarshalNoEscape(resp) // Call json.Marshal directly to avoid pretty printing
w.Write(line) line, _ := json.Marshal(resp)
w.Write(append(line, '\n'))
f.Flush() f.Flush()
} }
} }
...@@ -708,7 +707,7 @@ func (h *HTTPBin) CacheControl(w http.ResponseWriter, r *http.Request) { ...@@ -708,7 +707,7 @@ func (h *HTTPBin) CacheControl(w http.ResponseWriter, r *http.Request) {
h.Get(w, r) h.Get(w, r)
} }
// ETag assumes the resource has the given etag and response to If-None-Match // ETag assumes the resource has the given etag and responds to If-None-Match
// and If-Match headers appropriately. // and If-Match headers appropriately.
func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(r.URL.Path, "/") parts := strings.Split(r.URL.Path, "/")
...@@ -720,19 +719,17 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { ...@@ -720,19 +719,17 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
etag := parts[2] etag := parts[2]
w.Header().Set("ETag", fmt.Sprintf(`"%s"`, etag)) w.Header().Set("ETag", fmt.Sprintf(`"%s"`, etag))
// TODO: This mostly duplicates the work of Get() above, should this be var buf bytes.Buffer
// pulled into a little helper? mustMarshalJSON(&buf, noBodyResponse{
resp := &noBodyResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getClientIP(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} })
body, _ := jsonMarshalNoEscape(resp)
// Let http.ServeContent deal with If-None-Match and If-Match headers: // Let http.ServeContent deal with If-None-Match and If-Match headers:
// https://golang.org/pkg/net/http/#ServeContent // https://golang.org/pkg/net/http/#ServeContent
http.ServeContent(w, r, "response.json", time.Now(), bytes.NewReader(body)) http.ServeContent(w, r, "response.json", time.Now(), bytes.NewReader(buf.Bytes()))
} }
// Bytes returns N random bytes generated with an optional seed // Bytes returns N random bytes generated with an optional seed
...@@ -954,19 +951,17 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) { ...@@ -954,19 +951,17 @@ func (h *HTTPBin) DigestAuth(w http.ResponseWriter, r *http.Request) {
return return
} }
resp, _ := jsonMarshalNoEscape(&authResponse{ writeJSON(http.StatusOK, w, authResponse{
Authorized: true, Authorized: true,
User: user, User: user,
}) })
writeJSON(w, resp, http.StatusOK)
} }
// UUID - responds with a generated UUID // UUID - responds with a generated UUID
func (h *HTTPBin) UUID(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) UUID(w http.ResponseWriter, r *http.Request) {
resp, _ := jsonMarshalNoEscape(&uuidResponse{ writeJSON(http.StatusOK, w, uuidResponse{
UUID: uuidv4(), UUID: uuidv4(),
}) })
writeJSON(w, resp, http.StatusOK)
} }
// Base64 - encodes/decodes input data // Base64 - encodes/decodes input data
...@@ -995,7 +990,9 @@ func (h *HTTPBin) Base64(w http.ResponseWriter, r *http.Request) { ...@@ -995,7 +990,9 @@ func (h *HTTPBin) Base64(w http.ResponseWriter, r *http.Request) {
// JSON - returns a sample json // JSON - returns a sample json
func (h *HTTPBin) JSON(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) JSON(w http.ResponseWriter, r *http.Request) {
writeJSON(w, mustStaticAsset("sample.json"), http.StatusOK) w.Header().Set("Content-Type", jsonContentType)
w.WriteHeader(http.StatusOK)
w.Write(mustStaticAsset("sample.json"))
} }
// Bearer - Prompts the user for authorization using bearer authentication. // Bearer - Prompts the user for authorization using bearer authentication.
...@@ -1007,27 +1004,15 @@ func (h *HTTPBin) Bearer(w http.ResponseWriter, r *http.Request) { ...@@ -1007,27 +1004,15 @@ func (h *HTTPBin) Bearer(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
body, _ := jsonMarshalNoEscape(&bearerResponse{ writeJSON(http.StatusOK, w, bearerResponse{
Authenticated: true, Authenticated: true,
Token: tokenFields[1], Token: tokenFields[1],
}) })
writeJSON(w, body, http.StatusOK)
} }
// Hostname - returns the hostname. // Hostname - returns the hostname.
func (h *HTTPBin) Hostname(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Hostname(w http.ResponseWriter, r *http.Request) {
body, _ := jsonMarshalNoEscape(hostnameResponse{ writeJSON(http.StatusOK, w, hostnameResponse{
Hostname: h.hostname, Hostname: h.hostname,
}) })
writeJSON(w, body, http.StatusOK)
}
// json.Marshal escapes HTML in strings while httpbin does not, so
// we need to set up the encoder manually to reproduce that behavior.
func jsonMarshalNoEscape(value interface{}) ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
encoder.SetEscapeHTML(false)
err := encoder.Encode(value)
return buffer.Bytes(), err
} }
...@@ -263,16 +263,8 @@ func TestHead(t *testing.T) { ...@@ -263,16 +263,8 @@ func TestHead(t *testing.T) {
assertStatusCode(t, w, http.StatusOK) assertStatusCode(t, w, http.StatusOK)
assertBodyEquals(t, w, "") assertBodyEquals(t, w, "")
contentLengthStr := w.Header().Get("Content-Length") if contentLength := w.Header().Get("Content-Length"); contentLength != "" {
if contentLengthStr == "" { t.Fatalf("did not expect Content-Length in response to HEAD request")
t.Fatalf("missing Content-Length header in response")
}
contentLength, err := strconv.Atoi(contentLengthStr)
if err != nil {
t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err)
}
if contentLength <= 0 {
t.Fatalf("Content-Length %v should be greater than 0", contentLengthStr)
} }
}) })
} }
...@@ -496,6 +488,18 @@ func TestAnything(t *testing.T) { ...@@ -496,6 +488,18 @@ func TestAnything(t *testing.T) {
testRequestWithBody(t, verb, path) testRequestWithBody(t, verb, path)
} }
} }
t.Run("HEAD", func(t *testing.T) {
t.Parallel()
r, _ := http.NewRequest("HEAD", "/anything", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK)
assertBodyEquals(t, w, "")
if contentLength := w.Header().Get("Content-Length"); contentLength != "" {
t.Fatalf("did not expect Content-Length in response to HEAD request")
}
})
} }
// getFuncName uses runtime type reflection to get the name of the given // getFuncName uses runtime type reflection to get the name of the given
...@@ -1591,10 +1595,8 @@ func TestGzip(t *testing.T) { ...@@ -1591,10 +1595,8 @@ func TestGzip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error reading gzipped body: %s", err) t.Fatalf("error reading gzipped body: %s", err)
} }
var resp *noBodyResponse
var resp *gzipResponse if err := json.Unmarshal(unzippedBody, &resp); err != nil {
err = json.Unmarshal(unzippedBody, &resp)
if err != nil {
t.Fatalf("error unmarshalling response: %s", err) t.Fatalf("error unmarshalling response: %s", err)
} }
...@@ -1602,7 +1604,7 @@ func TestGzip(t *testing.T) { ...@@ -1602,7 +1604,7 @@ func TestGzip(t *testing.T) {
t.Fatalf("expected resp.Gzipped == true") t.Fatalf("expected resp.Gzipped == true")
} }
if len(unzippedBody) >= zippedContentLength { if len(unzippedBody) <= zippedContentLength {
t.Fatalf("expected compressed body") t.Fatalf("expected compressed body")
} }
} }
...@@ -1622,7 +1624,7 @@ func TestDeflate(t *testing.T) { ...@@ -1622,7 +1624,7 @@ func TestDeflate(t *testing.T) {
t.Fatalf("missing Content-Length header in response") t.Fatalf("missing Content-Length header in response")
} }
contentLength, err := strconv.Atoi(contentLengthHeader) compressedContentLength, err := strconv.Atoi(contentLengthHeader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -1636,7 +1638,7 @@ func TestDeflate(t *testing.T) { ...@@ -1636,7 +1638,7 @@ func TestDeflate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var resp *deflateResponse var resp *noBodyResponse
err = json.Unmarshal(body, &resp) err = json.Unmarshal(body, &resp)
if err != nil { if err != nil {
t.Fatalf("error unmarshalling response: %s", err) t.Fatalf("error unmarshalling response: %s", err)
...@@ -1646,7 +1648,7 @@ func TestDeflate(t *testing.T) { ...@@ -1646,7 +1648,7 @@ func TestDeflate(t *testing.T) {
t.Fatalf("expected resp.Deflated == true") t.Fatalf("expected resp.Deflated == true")
} }
if len(body) >= contentLength { if len(body) <= compressedContentLength {
t.Fatalf("expected compressed body") t.Fatalf("expected compressed body")
} }
} }
......
...@@ -82,13 +82,23 @@ func getURL(r *http.Request) *url.URL { ...@@ -82,13 +82,23 @@ func getURL(r *http.Request) *url.URL {
func writeResponse(w http.ResponseWriter, status int, contentType string, body []byte) { func writeResponse(w http.ResponseWriter, status int, contentType string, body []byte) {
w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(status) w.WriteHeader(status)
w.Write(body) w.Write(body)
} }
func writeJSON(w http.ResponseWriter, body []byte, status int) { func mustMarshalJSON(w io.Writer, val interface{}) {
writeResponse(w, status, jsonContentType, body) encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")
if err := encoder.Encode(val); err != nil {
panic(err.Error())
}
}
func writeJSON(status int, w http.ResponseWriter, val interface{}) {
w.Header().Set("Content-Type", jsonContentType)
w.WriteHeader(status)
mustMarshalJSON(w, val)
} }
func writeHTML(w http.ResponseWriter, body []byte, status int) { func writeHTML(w http.ResponseWriter, body []byte, status int) {
......
...@@ -37,6 +37,9 @@ type noBodyResponse struct { ...@@ -37,6 +37,9 @@ type noBodyResponse struct {
Headers http.Header `json:"headers"` Headers http.Header `json:"headers"`
Origin string `json:"origin"` Origin string `json:"origin"`
URL string `json:"url"` URL string `json:"url"`
Deflated bool `json:"deflated,omitempty"`
Gzipped bool `json:"gzipped,omitempty"`
} }
// A generic response for any incoming request that might contain a body (POST, // A generic response for any incoming request that might contain a body (POST,
...@@ -60,18 +63,6 @@ type authResponse struct { ...@@ -60,18 +63,6 @@ type authResponse struct {
User string `json:"user"` User string `json:"user"`
} }
type gzipResponse struct {
Headers http.Header `json:"headers"`
Origin string `json:"origin"`
Gzipped bool `json:"gzipped"`
}
type deflateResponse struct {
Headers http.Header `json:"headers"`
Origin string `json:"origin"`
Deflated bool `json:"deflated"`
}
// An actual stream response body will be made up of one or more of these // An actual stream response body will be made up of one or more of these
// structs, encoded as JSON and separated by newlines // structs, encoded as JSON and separated by newlines
type streamResponse struct { type streamResponse struct {
......
...@@ -65,7 +65,7 @@ type headResponseWriter struct { ...@@ -65,7 +65,7 @@ type headResponseWriter struct {
} }
func (hw *headResponseWriter) Write(b []byte) (int, error) { func (hw *headResponseWriter) Write(b []byte) (int, error) {
return 0, nil return len(b), nil
} }
// autohead automatically discards the body of responses to HEAD requests // autohead automatically discards the body of responses to HEAD requests
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment