From 6ec6c2fb061007bd7f1215c480976649a8518ae8 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Sun, 25 Jun 2017 22:37:39 -0700
Subject: [PATCH] Ensure Host header is included in responses

---
 httpbin/handlers.go      | 14 +++++++-------
 httpbin/handlers_test.go | 10 +++++++++-
 httpbin/helpers.go       | 11 +++++++++++
 3 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index cc27870..3896c28 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -51,7 +51,7 @@ func (h *HTTPBin) UTF8(w http.ResponseWriter, r *http.Request) {
 func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
 	resp := &getResponse{
 		Args:    r.URL.Query(),
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 		Origin:  getOrigin(r),
 		URL:     getURL(r).String(),
 	}
@@ -63,7 +63,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
 func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
 	resp := &bodyResponse{
 		Args:    r.URL.Query(),
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 		Origin:  getOrigin(r),
 		URL:     getURL(r).String(),
 	}
@@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
 // Gzip returns a gzipped response
 func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
 	resp := &gzipResponse{
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 		Origin:  getOrigin(r),
 		Gzipped: true,
 	}
@@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
 // Deflate returns a gzipped response
 func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
 	resp := &deflateResponse{
-		Headers:  r.Header,
+		Headers:  getRequestHeaders(r),
 		Origin:   getOrigin(r),
 		Deflated: true,
 	}
@@ -137,7 +137,7 @@ func (h *HTTPBin) UserAgent(w http.ResponseWriter, r *http.Request) {
 // Headers echoes the incoming request headers
 func (h *HTTPBin) Headers(w http.ResponseWriter, r *http.Request) {
 	body, _ := json.Marshal(&headersResponse{
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 	})
 	writeJSON(w, body, http.StatusOK)
 }
@@ -437,7 +437,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
 
 	resp := &streamResponse{
 		Args:    r.URL.Query(),
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 		Origin:  getOrigin(r),
 		URL:     getURL(r).String(),
 	}
@@ -650,7 +650,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
 	// pulled into a little helper?
 	resp := &getResponse{
 		Args:    r.URL.Query(),
-		Headers: r.Header,
+		Headers: getRequestHeaders(r),
 		Origin:  getOrigin(r),
 		URL:     getURL(r).String(),
 	}
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 60868f7..2b2decf 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -287,6 +287,7 @@ func TestUserAgent(t *testing.T) {
 
 func TestHeaders(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/headers", nil)
+	r.Host = "test-host"
 	r.Header.Set("User-Agent", "test")
 	r.Header.Set("Foo-Header", "foo")
 	r.Header.Add("Bar-Header", "bar1")
@@ -303,13 +304,20 @@ func TestHeaders(t *testing.T) {
 		t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
 	}
 
+	// Host header requires special treatment, because its an attribute of the
+	// http.Request struct itself, not part of its headers map
+	host := resp.Headers[http.CanonicalHeaderKey("Host")]
+	if host == nil || host[0] != "test-host" {
+		t.Fatalf("expected Host header \"test-host\", got %#v", host)
+	}
+
 	for k, expectedValues := range r.Header {
 		values, ok := resp.Headers[http.CanonicalHeaderKey(k)]
 		if !ok {
 			t.Fatalf("expected header %#v in response", k)
 		}
 		if !reflect.DeepEqual(expectedValues, values) {
-			t.Fatalf("header value mismatch: %#v != %#v", values, expectedValues)
+			t.Fatalf("header %s value mismatch: %#v != %#v", k, values, expectedValues)
 		}
 	}
 }
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index 6a5341f..ba65116 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -16,6 +16,17 @@ import (
 	"time"
 )
 
+// requestHeaders takes in incoming request and returns an http.Header map
+// suitable for inclusion in our response data structures.
+//
+// This is necessary to ensure that the incoming Host header is included,
+// because golang only exposes that header on the http.Request struct itself.
+func getRequestHeaders(r *http.Request) http.Header {
+	h := r.Header
+	h.Set("Host", r.Host)
+	return h
+}
+
 func getOrigin(r *http.Request) string {
 	origin := r.Header.Get("X-Forwarded-For")
 	if origin == "" {
-- 
GitLab