diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 14b71ad042d72d925684c43618a87b3ef3a09376..2bf8e050f47aa0e995c657853e959594bd49455a 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -216,24 +216,47 @@ func TestGet(t *testing.T) { } func TestHEAD(t *testing.T) { - r, _ := http.NewRequest("HEAD", "/", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + testCases := []struct { + verb string + path string + wantCode int + }{ + {"HEAD", "/", http.StatusOK}, + {"HEAD", "/get", http.StatusOK}, + {"HEAD", "/head", http.StatusOK}, + {"HEAD", "/post", http.StatusMethodNotAllowed}, + {"GET", "/head", http.StatusMethodNotAllowed}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("%s %s", tc.verb, tc.path), func(t *testing.T) { + r, _ := http.NewRequest(tc.verb, tc.path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) - assertStatusCode(t, w, 200) - assertBodyEquals(t, w, "") + assertStatusCode(t, w, tc.wantCode) - contentLengthStr := w.Header().Get("Content-Length") - if contentLengthStr == "" { - t.Fatalf("missing Content-Length header in response") - } - contentLength, err := strconv.Atoi(contentLengthStr) - if err != nil { - t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err) - } - if contentLength <= 0 { - t.Fatalf("Content-Lengh %v should be greater than 0", contentLengthStr) + // we only do further validation when we get an OK response + if tc.wantCode != http.StatusOK { + return + } + + assertStatusCode(t, w, http.StatusOK) + assertBodyEquals(t, w, "") + + contentLengthStr := w.Header().Get("Content-Length") + if contentLengthStr == "" { + t.Fatalf("missing Content-Length header in response") + } + contentLength, err := strconv.Atoi(contentLengthStr) + if err != nil { + t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err) + } + if contentLength <= 0 { + t.Fatalf("Content-Lengh %v should be greater than 0", contentLengthStr) + } + }) } + } func TestCORS(t *testing.T) { diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index a410878c7961e97e845276cae3224a067db3f1d0..ee6c79dcf698e49f6dc6b503ea4c80cbc4a0ca0e 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -124,11 +124,12 @@ func (h *HTTPBin) Handler() http.Handler { mux.HandleFunc("/forms/post", methods(h.FormsPost, "GET")) mux.HandleFunc("/encoding/utf8", methods(h.UTF8, "GET")) + mux.HandleFunc("/delete", methods(h.RequestWithBody, "DELETE")) mux.HandleFunc("/get", methods(h.Get, "GET")) + mux.HandleFunc("/head", methods(h.Get, "HEAD")) + mux.HandleFunc("/patch", methods(h.RequestWithBody, "PATCH")) mux.HandleFunc("/post", methods(h.RequestWithBody, "POST")) mux.HandleFunc("/put", methods(h.RequestWithBody, "PUT")) - mux.HandleFunc("/patch", methods(h.RequestWithBody, "PATCH")) - mux.HandleFunc("/delete", methods(h.RequestWithBody, "DELETE")) mux.HandleFunc("/ip", h.IP) mux.HandleFunc("/user-agent", h.UserAgent)