From c4135f952f007359532b09d0b086fad60f806c77 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Fri, 14 Oct 2016 17:16:49 -0700
Subject: [PATCH] Simpler, unified redirect tests

---
 httpbin/handlers_test.go | 140 ++++++++++++---------------------------
 httpbin/httpbin.go       |  10 ++-
 2 files changed, 50 insertions(+), 100 deletions(-)

diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 6d3ab68..36919ce 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -735,130 +735,72 @@ func TestResponseHeaders__OverrideContentType(t *testing.T) {
 	assertContentType(t, w, contentType)
 }
 
-func TestRedirect__OK(t *testing.T) {
+func TestRedirects__OK(t *testing.T) {
 	var tests = []struct {
-		relative bool
-		n        int
-		location string
+		requestURL       string
+		expectedLocation string
 	}{
-		{true, 1, "/get"},
-		{true, 2, "/relative-redirect/1"},
-		{true, 100, "/relative-redirect/99"},
+		{"/redirect/1", "/get"},
+		{"/redirect/2", "/relative-redirect/1"},
+		{"/redirect/100", "/relative-redirect/99"},
 
-		{false, 1, "http://host/get"},
-		{false, 2, "http://host/absolute-redirect/1"},
-		{false, 100, "http://host/absolute-redirect/99"},
-	}
-
-	for _, test := range tests {
-		u := fmt.Sprintf("/redirect/%d", test.n)
-		if !test.relative {
-			u = fmt.Sprintf("%s?absolute=true", u)
-		}
-		r, _ := http.NewRequest("GET", u, nil)
-		r.Host = "host"
-		w := httptest.NewRecorder()
-		handler.ServeHTTP(w, r)
-
-		assertStatusCode(t, w, http.StatusFound)
-		assertHeader(t, w, "Location", test.location)
-	}
-}
-
-func TestRedirect__Errors(t *testing.T) {
-	var tests = []struct {
-		relative bool
-		given    interface{}
-		status   int
-	}{
-		{true, 3.14, http.StatusBadRequest},
-		{true, -1, http.StatusBadRequest},
-		{true, "", http.StatusBadRequest},
-		{true, "foo", http.StatusBadRequest},
-		{true, "10/bar", http.StatusNotFound},
-
-		{false, 3.14, http.StatusBadRequest},
-		{false, -1, http.StatusBadRequest},
-		{false, "", http.StatusBadRequest},
-		{false, "foo", http.StatusBadRequest},
-		{false, "10/bar", http.StatusNotFound},
-	}
+		{"/redirect/1?absolute=true", "http://host/get"},
+		{"/redirect/2?absolute=TRUE", "http://host/absolute-redirect/1"},
+		{"/redirect/100?absolute=True", "http://host/absolute-redirect/99"},
 
-	for _, test := range tests {
-		u := fmt.Sprintf("/redirect/%v", test.given)
-		if !test.relative {
-			u = fmt.Sprintf("%s?absolute=true", u)
-		}
-		r, _ := http.NewRequest("GET", u, nil)
-		w := httptest.NewRecorder()
-		handler.ServeHTTP(w, r)
+		{"/redirect/100?absolute=t", "/relative-redirect/99"},
+		{"/redirect/100?absolute=1", "/relative-redirect/99"},
+		{"/redirect/100?absolute=yes", "/relative-redirect/99"},
 
-		assertStatusCode(t, w, test.status)
-	}
-}
+		{"/relative-redirect/1", "/get"},
+		{"/relative-redirect/2", "/relative-redirect/1"},
+		{"/relative-redirect/100", "/relative-redirect/99"},
 
-func TestAbsoluteAndRelativeRedirects__OK(t *testing.T) {
-	var tests = []struct {
-		relative bool
-		n        int
-		location string
-	}{
-		{true, 1, "/get"},
-		{true, 2, "/relative-redirect/1"},
-		{true, 100, "/relative-redirect/99"},
-
-		{false, 1, "http://host/get"},
-		{false, 2, "http://host/absolute-redirect/1"},
-		{false, 100, "http://host/absolute-redirect/99"},
+		{"/absolute-redirect/1", "http://host/get"},
+		{"/absolute-redirect/2", "http://host/absolute-redirect/1"},
+		{"/absolute-redirect/100", "http://host/absolute-redirect/99"},
 	}
 
 	for _, test := range tests {
-		var urlTemplate string
-		if test.relative {
-			urlTemplate = "/relative-redirect/%d"
-		} else {
-			urlTemplate = "/absolute-redirect/%d"
-		}
-		r, _ := http.NewRequest("GET", fmt.Sprintf(urlTemplate, test.n), nil)
+		r, _ := http.NewRequest("GET", test.requestURL, nil)
 		r.Host = "host"
 		w := httptest.NewRecorder()
 		handler.ServeHTTP(w, r)
 
 		assertStatusCode(t, w, http.StatusFound)
-		assertHeader(t, w, "Location", test.location)
+		assertHeader(t, w, "Location", test.expectedLocation)
 	}
 }
 
-func TestAbsoluteAndRelativeRedirects__Errors(t *testing.T) {
+func TestRedirects__Errors(t *testing.T) {
 	var tests = []struct {
-		relative bool
-		given    interface{}
-		status   int
+		requestURL     string
+		expectedStatus int
 	}{
-		{true, 3.14, http.StatusBadRequest},
-		{true, -1, http.StatusBadRequest},
-		{true, "", http.StatusBadRequest},
-		{true, "foo", http.StatusBadRequest},
-		{true, "10/bar", http.StatusNotFound},
+		{"/redirect", http.StatusNotFound},
+		{"/redirect/", http.StatusBadRequest},
+		{"/redirect/3.14", http.StatusBadRequest},
+		{"/redirect/foo", http.StatusBadRequest},
+		{"/redirect/10/foo", http.StatusNotFound},
 
-		{false, 3.14, http.StatusBadRequest},
-		{false, -1, http.StatusBadRequest},
-		{false, "", http.StatusBadRequest},
-		{false, "foo", http.StatusBadRequest},
-		{false, "10/bar", http.StatusNotFound},
+		{"/relative-redirect", http.StatusNotFound},
+		{"/relative-redirect/", http.StatusBadRequest},
+		{"/relative-redirect/3.14", http.StatusBadRequest},
+		{"/relative-redirect/foo", http.StatusBadRequest},
+		{"/relative-redirect/10/foo", http.StatusNotFound},
+
+		{"/absolute-redirect", http.StatusNotFound},
+		{"/absolute-redirect/", http.StatusBadRequest},
+		{"/absolute-redirect/3.14", http.StatusBadRequest},
+		{"/absolute-redirect/foo", http.StatusBadRequest},
+		{"/absolute-redirect/10/foo", http.StatusNotFound},
 	}
 
 	for _, test := range tests {
-		var urlTemplate string
-		if test.relative {
-			urlTemplate = "/relative-redirect/%v"
-		} else {
-			urlTemplate = "/absolute-redirect/%v"
-		}
-		r, _ := http.NewRequest("GET", fmt.Sprintf(urlTemplate, test.given), nil)
+		r, _ := http.NewRequest("GET", test.requestURL, nil)
 		w := httptest.NewRecorder()
 		handler.ServeHTTP(w, r)
 
-		assertStatusCode(t, w, test.status)
+		assertStatusCode(t, w, test.expectedStatus)
 	}
 }
diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go
index 35fca2f..cac976f 100644
--- a/httpbin/httpbin.go
+++ b/httpbin/httpbin.go
@@ -64,13 +64,21 @@ func (h *HTTPBin) Handler() http.Handler {
 	mux.HandleFunc("/ip", h.IP)
 	mux.HandleFunc("/user-agent", h.UserAgent)
 	mux.HandleFunc("/headers", h.Headers)
-	mux.HandleFunc("/status/", h.Status)
 	mux.HandleFunc("/response-headers", h.ResponseHeaders)
 
+	mux.HandleFunc("/status/", h.Status)
 	mux.HandleFunc("/redirect/", h.Redirect)
 	mux.HandleFunc("/relative-redirect/", h.RelativeRedirect)
 	mux.HandleFunc("/absolute-redirect/", h.AbsoluteRedirect)
 
+	// Make sure our ServeMux doesn't "helpfully" redirect these invalid
+	// endpoints by adding a trailing slash. See the ServeMux docs for more
+	// info: https://golang.org/pkg/net/http/#ServeMux
+	mux.HandleFunc("/status", http.NotFound)
+	mux.HandleFunc("/redirect", http.NotFound)
+	mux.HandleFunc("/relative-redirect", http.NotFound)
+	mux.HandleFunc("/absolute-redirect", http.NotFound)
+
 	return logger(cors(mux))
 }
 
-- 
GitLab