From 6ec6c2fb061007bd7f1215c480976649a8518ae8 Mon Sep 17 00:00:00 2001 From: Will McCutchen <will@mccutch.org> Date: Sun, 25 Jun 2017 22:37:39 -0700 Subject: [PATCH] Ensure Host header is included in responses --- httpbin/handlers.go | 14 +++++++------- httpbin/handlers_test.go | 10 +++++++++- httpbin/helpers.go | 11 +++++++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/httpbin/handlers.go b/httpbin/handlers.go index cc27870..3896c28 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -51,7 +51,7 @@ func (h *HTTPBin) UTF8(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) { resp := &getResponse{ Args: r.URL.Query(), - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), URL: getURL(r).String(), } @@ -63,7 +63,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { resp := &bodyResponse{ Args: r.URL.Query(), - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), URL: getURL(r).String(), } @@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { // Gzip returns a gzipped response func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { resp := &gzipResponse{ - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), Gzipped: true, } @@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { // Deflate returns a gzipped response func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { resp := &deflateResponse{ - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), Deflated: true, } @@ -137,7 +137,7 @@ func (h *HTTPBin) UserAgent(w http.ResponseWriter, r *http.Request) { // Headers echoes the incoming request headers func (h *HTTPBin) Headers(w http.ResponseWriter, r *http.Request) { body, _ := json.Marshal(&headersResponse{ - Headers: r.Header, + Headers: getRequestHeaders(r), }) writeJSON(w, body, http.StatusOK) } @@ -437,7 +437,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { resp := &streamResponse{ Args: r.URL.Query(), - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), URL: getURL(r).String(), } @@ -650,7 +650,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { // pulled into a little helper? resp := &getResponse{ Args: r.URL.Query(), - Headers: r.Header, + Headers: getRequestHeaders(r), Origin: getOrigin(r), URL: getURL(r).String(), } diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 60868f7..2b2decf 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -287,6 +287,7 @@ func TestUserAgent(t *testing.T) { func TestHeaders(t *testing.T) { r, _ := http.NewRequest("GET", "/headers", nil) + r.Host = "test-host" r.Header.Set("User-Agent", "test") r.Header.Set("Foo-Header", "foo") r.Header.Add("Bar-Header", "bar1") @@ -303,13 +304,20 @@ func TestHeaders(t *testing.T) { t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) } + // Host header requires special treatment, because its an attribute of the + // http.Request struct itself, not part of its headers map + host := resp.Headers[http.CanonicalHeaderKey("Host")] + if host == nil || host[0] != "test-host" { + t.Fatalf("expected Host header \"test-host\", got %#v", host) + } + for k, expectedValues := range r.Header { values, ok := resp.Headers[http.CanonicalHeaderKey(k)] if !ok { t.Fatalf("expected header %#v in response", k) } if !reflect.DeepEqual(expectedValues, values) { - t.Fatalf("header value mismatch: %#v != %#v", values, expectedValues) + t.Fatalf("header %s value mismatch: %#v != %#v", k, values, expectedValues) } } } diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 6a5341f..ba65116 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -16,6 +16,17 @@ import ( "time" ) +// requestHeaders takes in incoming request and returns an http.Header map +// suitable for inclusion in our response data structures. +// +// This is necessary to ensure that the incoming Host header is included, +// because golang only exposes that header on the http.Request struct itself. +func getRequestHeaders(r *http.Request) http.Header { + h := r.Header + h.Set("Host", r.Host) + return h +} + func getOrigin(r *http.Request) string { origin := r.Header.Get("X-Forwarded-For") if origin == "" { -- GitLab