diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index c8313aed2f9eccf0cc26c7517088d9309de199f2..ab70f5c6fe4d1549780886e1c887cafedf36cb09 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -192,6 +192,27 @@ func TestGet(t *testing.T) { } } +func TestHEAD(t *testing.T) { + r, _ := http.NewRequest("HEAD", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, 200) + assertBodyEquals(t, w, "") + + contentLengthStr := w.HeaderMap.Get("Content-Length") + if contentLengthStr == "" { + t.Fatalf("missing Content-Length header in response") + } + contentLength, err := strconv.Atoi(contentLengthStr) + if err != nil { + t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err) + } + if contentLength <= 0 { + t.Fatalf("Content-Lengh %v should be greater than 0", contentLengthStr) + } +} + func TestCORS(t *testing.T) { t.Run("CORS/no_request_origin", func(t *testing.T) { r, _ := http.NewRequest("GET", "/get", nil) @@ -213,13 +234,15 @@ func TestCORS(t *testing.T) { w := httptest.NewRecorder() handler.ServeHTTP(w, r) + assertStatusCode(t, w, 200) + var headerTests = []struct { key string expected string }{ {"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Credentials", "true"}, - {"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"}, + {"Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS"}, {"Access-Control-Max-Age", "3600"}, {"Access-Control-Allow-Headers", ""}, } @@ -234,6 +257,8 @@ func TestCORS(t *testing.T) { w := httptest.NewRecorder() handler.ServeHTTP(w, r) + assertStatusCode(t, w, 200) + var headerTests = []struct { key string expected string diff --git a/httpbin/helpers.go b/httpbin/helpers.go index ec44cc9717668444e202899df49758eef427fe5f..5571c224b26da94078e8996693262f013872283f 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -27,6 +27,16 @@ func getRequestHeaders(r *http.Request) http.Header { return h } +// Copies all headers from src to dst +// From https://golang.org/src/net/http/httputil/reverseproxy.go#L109 +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + func getOrigin(r *http.Request) string { origin := r.Header.Get("X-Forwarded-For") if origin == "" { diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index c33f1e57067563fde761df7ffb2e303a1b4336ae..ad12c29b4f4a40000ee30fadf8078462c356809a 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -174,8 +174,8 @@ func (h *HTTPBin) Handler() http.Handler { var handler http.Handler handler = mux handler = limitRequestSize(h.options.MaxMemory, handler) + handler = metaRequests(handler) handler = logger(handler) - handler = cors(handler) return handler } diff --git a/httpbin/middleware.go b/httpbin/middleware.go index dacfce593e9d231972dafd6ae1a95d7c0ebce773..60033fa57d675bc7f2eb125fbe15819f70cc43b7 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -4,10 +4,11 @@ import ( "fmt" "log" "net/http" + "net/http/httptest" "time" ) -func cors(h http.Handler) http.Handler { +func metaRequests(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin == "" { @@ -17,14 +18,24 @@ func cors(h http.Handler) http.Handler { respHeader.Set("Access-Control-Allow-Origin", origin) respHeader.Set("Access-Control-Allow-Credentials", "true") - if r.Method == "OPTIONS" { - respHeader.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") - respHeader.Set("Access-Control-Max-Age", "3600") + switch r.Method { + case "OPTIONS": + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS") + w.Header().Set("Access-Control-Max-Age", "3600") if r.Header.Get("Access-Control-Request-Headers") != "" { - respHeader.Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers")) + w.Header().Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers")) } + w.WriteHeader(200) + case "HEAD": + rwRec := httptest.NewRecorder() + r.Method = "GET" + h.ServeHTTP(rwRec, r) + + copyHeader(w.Header(), rwRec.Header()) + w.WriteHeader(rwRec.Code) + default: + h.ServeHTTP(w, r) } - h.ServeHTTP(w, r) }) } @@ -92,10 +103,11 @@ func (mw *metaResponseWriter) Size() int { func logger(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqMethod, reqURI := r.Method, r.URL.RequestURI() mw := &metaResponseWriter{w: w} t := time.Now() h.ServeHTTP(mw, r) duration := time.Now().Sub(t) - log.Printf("status=%d method=%s uri=%q size=%d duration=%s", mw.Status(), r.Method, r.URL.RequestURI(), mw.Size(), duration) + log.Printf("status=%d method=%s uri=%q size=%d duration=%s", mw.Status(), reqMethod, reqURI, mw.Size(), duration) }) }