diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 6d3ab68839d50fe1e4d1bdc2e35c0d8bbe6edb3d..36919ceb70093ed7acbb7db7b98c0cdc4ce2967a 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -735,130 +735,72 @@ func TestResponseHeaders__OverrideContentType(t *testing.T) { assertContentType(t, w, contentType) } -func TestRedirect__OK(t *testing.T) { +func TestRedirects__OK(t *testing.T) { var tests = []struct { - relative bool - n int - location string + requestURL string + expectedLocation string }{ - {true, 1, "/get"}, - {true, 2, "/relative-redirect/1"}, - {true, 100, "/relative-redirect/99"}, + {"/redirect/1", "/get"}, + {"/redirect/2", "/relative-redirect/1"}, + {"/redirect/100", "/relative-redirect/99"}, - {false, 1, "http://host/get"}, - {false, 2, "http://host/absolute-redirect/1"}, - {false, 100, "http://host/absolute-redirect/99"}, - } - - for _, test := range tests { - u := fmt.Sprintf("/redirect/%d", test.n) - if !test.relative { - u = fmt.Sprintf("%s?absolute=true", u) - } - r, _ := http.NewRequest("GET", u, nil) - r.Host = "host" - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", test.location) - } -} - -func TestRedirect__Errors(t *testing.T) { - var tests = []struct { - relative bool - given interface{} - status int - }{ - {true, 3.14, http.StatusBadRequest}, - {true, -1, http.StatusBadRequest}, - {true, "", http.StatusBadRequest}, - {true, "foo", http.StatusBadRequest}, - {true, "10/bar", http.StatusNotFound}, - - {false, 3.14, http.StatusBadRequest}, - {false, -1, http.StatusBadRequest}, - {false, "", http.StatusBadRequest}, - {false, "foo", http.StatusBadRequest}, - {false, "10/bar", http.StatusNotFound}, - } + {"/redirect/1?absolute=true", "http://host/get"}, + {"/redirect/2?absolute=TRUE", "http://host/absolute-redirect/1"}, + {"/redirect/100?absolute=True", "http://host/absolute-redirect/99"}, - for _, test := range tests { - u := fmt.Sprintf("/redirect/%v", test.given) - if !test.relative { - u = fmt.Sprintf("%s?absolute=true", u) - } - r, _ := http.NewRequest("GET", u, nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + {"/redirect/100?absolute=t", "/relative-redirect/99"}, + {"/redirect/100?absolute=1", "/relative-redirect/99"}, + {"/redirect/100?absolute=yes", "/relative-redirect/99"}, - assertStatusCode(t, w, test.status) - } -} + {"/relative-redirect/1", "/get"}, + {"/relative-redirect/2", "/relative-redirect/1"}, + {"/relative-redirect/100", "/relative-redirect/99"}, -func TestAbsoluteAndRelativeRedirects__OK(t *testing.T) { - var tests = []struct { - relative bool - n int - location string - }{ - {true, 1, "/get"}, - {true, 2, "/relative-redirect/1"}, - {true, 100, "/relative-redirect/99"}, - - {false, 1, "http://host/get"}, - {false, 2, "http://host/absolute-redirect/1"}, - {false, 100, "http://host/absolute-redirect/99"}, + {"/absolute-redirect/1", "http://host/get"}, + {"/absolute-redirect/2", "http://host/absolute-redirect/1"}, + {"/absolute-redirect/100", "http://host/absolute-redirect/99"}, } for _, test := range tests { - var urlTemplate string - if test.relative { - urlTemplate = "/relative-redirect/%d" - } else { - urlTemplate = "/absolute-redirect/%d" - } - r, _ := http.NewRequest("GET", fmt.Sprintf(urlTemplate, test.n), nil) + r, _ := http.NewRequest("GET", test.requestURL, nil) r.Host = "host" w := httptest.NewRecorder() handler.ServeHTTP(w, r) assertStatusCode(t, w, http.StatusFound) - assertHeader(t, w, "Location", test.location) + assertHeader(t, w, "Location", test.expectedLocation) } } -func TestAbsoluteAndRelativeRedirects__Errors(t *testing.T) { +func TestRedirects__Errors(t *testing.T) { var tests = []struct { - relative bool - given interface{} - status int + requestURL string + expectedStatus int }{ - {true, 3.14, http.StatusBadRequest}, - {true, -1, http.StatusBadRequest}, - {true, "", http.StatusBadRequest}, - {true, "foo", http.StatusBadRequest}, - {true, "10/bar", http.StatusNotFound}, + {"/redirect", http.StatusNotFound}, + {"/redirect/", http.StatusBadRequest}, + {"/redirect/3.14", http.StatusBadRequest}, + {"/redirect/foo", http.StatusBadRequest}, + {"/redirect/10/foo", http.StatusNotFound}, - {false, 3.14, http.StatusBadRequest}, - {false, -1, http.StatusBadRequest}, - {false, "", http.StatusBadRequest}, - {false, "foo", http.StatusBadRequest}, - {false, "10/bar", http.StatusNotFound}, + {"/relative-redirect", http.StatusNotFound}, + {"/relative-redirect/", http.StatusBadRequest}, + {"/relative-redirect/3.14", http.StatusBadRequest}, + {"/relative-redirect/foo", http.StatusBadRequest}, + {"/relative-redirect/10/foo", http.StatusNotFound}, + + {"/absolute-redirect", http.StatusNotFound}, + {"/absolute-redirect/", http.StatusBadRequest}, + {"/absolute-redirect/3.14", http.StatusBadRequest}, + {"/absolute-redirect/foo", http.StatusBadRequest}, + {"/absolute-redirect/10/foo", http.StatusNotFound}, } for _, test := range tests { - var urlTemplate string - if test.relative { - urlTemplate = "/relative-redirect/%v" - } else { - urlTemplate = "/absolute-redirect/%v" - } - r, _ := http.NewRequest("GET", fmt.Sprintf(urlTemplate, test.given), nil) + r, _ := http.NewRequest("GET", test.requestURL, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, r) - assertStatusCode(t, w, test.status) + assertStatusCode(t, w, test.expectedStatus) } } diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index 35fca2f46dc2e1ca129a816b9123b51dfc44a43e..cac976f1e19b73b40aa80e62daffd9a3d0246cd5 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -64,13 +64,21 @@ func (h *HTTPBin) Handler() http.Handler { mux.HandleFunc("/ip", h.IP) mux.HandleFunc("/user-agent", h.UserAgent) mux.HandleFunc("/headers", h.Headers) - mux.HandleFunc("/status/", h.Status) mux.HandleFunc("/response-headers", h.ResponseHeaders) + mux.HandleFunc("/status/", h.Status) mux.HandleFunc("/redirect/", h.Redirect) mux.HandleFunc("/relative-redirect/", h.RelativeRedirect) mux.HandleFunc("/absolute-redirect/", h.AbsoluteRedirect) + // Make sure our ServeMux doesn't "helpfully" redirect these invalid + // endpoints by adding a trailing slash. See the ServeMux docs for more + // info: https://golang.org/pkg/net/http/#ServeMux + mux.HandleFunc("/status", http.NotFound) + mux.HandleFunc("/redirect", http.NotFound) + mux.HandleFunc("/relative-redirect", http.NotFound) + mux.HandleFunc("/absolute-redirect", http.NotFound) + return logger(cors(mux)) }