From 499044e10c54d01bda9918c7fe004a96c7d077d0 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Mon, 26 Jun 2023 17:22:14 -0400
Subject: [PATCH] fix: include Transfer-Encoding when echoing request headers
 (#130)

Fixes #128
---
 httpbin/handlers_test.go | 40 ++++++++++++++++++++++++++++++++++++++++
 httpbin/helpers.go       |  8 ++++++--
 httpbin/helpers_test.go  | 10 +++++++++-
 3 files changed, 55 insertions(+), 3 deletions(-)

diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 447ddd3..e687ef6 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -553,6 +553,7 @@ func testRequestWithBody(t *testing.T, verb, path string) {
 		testRequestWithBodyQueryParams,
 		testRequestWithBodyQueryParamsAndBody,
 		testRequestWithBodyBinaryBody,
+		testRequestWithBodyTransferEncoding,
 	}
 	for _, testFunc := range testFuncs {
 		testFunc := testFunc
@@ -1030,6 +1031,45 @@ func testRequestWithBodyQueryParamsAndBody(t *testing.T, verb, path string) {
 	}
 }
 
+func testRequestWithBodyTransferEncoding(t *testing.T, verb string, path string) {
+	testCases := []struct {
+		given string
+		want  string
+	}{
+		{"", ""},
+		{"identity", ""},
+		{"chunked", "chunked"},
+	}
+	for _, tc := range testCases {
+		tc := tc
+		t.Run("transfer-encoding/"+tc.given, func(t *testing.T) {
+			t.Parallel()
+
+			srv := httptest.NewServer(app)
+			defer srv.Close()
+
+			r, _ := http.NewRequest(verb, srv.URL+path, bytes.NewReader([]byte("{}")))
+			if tc.given != "" {
+				r.TransferEncoding = []string{tc.given}
+			}
+
+			httpResp, err := srv.Client().Do(r)
+			assertNilError(t, err)
+			assertIntEqual(t, httpResp.StatusCode, http.StatusOK)
+
+			var resp *bodyResponse
+			if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
+				t.Fatalf("failed to unmarshal body from JSON: %s", err)
+			}
+
+			got := resp.Headers.Get("Transfer-Encoding")
+			if got != tc.want {
+				t.Errorf("expected Transfer-Encoding %#v, got %#v", tc.want, got)
+			}
+		})
+	}
+}
+
 // TODO: implement and test more complex /status endpoint
 func TestStatus(t *testing.T) {
 	t.Parallel()
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index 606983a..989bb3c 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -25,11 +25,15 @@ const Base64MaxLen = 2000
 // requestHeaders takes in incoming request and returns an http.Header map
 // suitable for inclusion in our response data structures.
 //
-// This is necessary to ensure that the incoming Host header is included,
-// because golang only exposes that header on the http.Request struct itself.
+// This is necessary to ensure that the incoming Host and Transfer-Encoding
+// headers are included, because golang only exposes those values on the
+// http.Request struct itself.
 func getRequestHeaders(r *http.Request) http.Header {
 	h := r.Header
 	h.Set("Host", r.Host)
+	if len(r.TransferEncoding) > 0 {
+		h.Set("Transfer-Encoding", strings.Join(r.TransferEncoding, ","))
+	}
 	return h
 }
 
diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go
index a11c691..1f5cd43 100644
--- a/httpbin/helpers_test.go
+++ b/httpbin/helpers_test.go
@@ -14,8 +14,16 @@ import (
 )
 
 func assertNil(t *testing.T, v interface{}) {
+	t.Helper()
 	if v != nil {
-		t.Errorf("expected nil, got %#v", v)
+		t.Fatalf("expected nil, got %#v", v)
+	}
+}
+
+func assertNilError(t *testing.T, err error) {
+	t.Helper()
+	if err != nil {
+		t.Fatalf("expected nil error, got %s (%T)", err, err)
 	}
 }
 
-- 
GitLab