diff --git a/.travis.yml b/.travis.yml index 822272c802b84dc4d6072f0067eb3d5429ca460c..e79e6e4f858b18c0c3273a15df422bc10ba638a4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,6 @@ go: - '1.9' - '1.10' - '1.11' - - 'tip' script: - make lint - make test diff --git a/README.md b/README.md index ab2c09a03d3d18ad81777f49eee9f3fdd78b51e9..8fcf1a3c9ba0a723e20efc16273f99b244dab9fa 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,11 @@ variables: ``` $ go-httpbin -help -Usage of ./dist/go-httpbin: +Usage of go-httpbin: + -max-body-size int + Maximum size of request or response, in bytes (default 1048576) -max-duration duration Maximum duration a response may take (default 10s) - -max-memory int - Maximum size of request or response, in bytes (default 1048576) -port int Port to listen on (default 8080) ``` @@ -47,8 +47,8 @@ import ( ) func TestSlowResponse(t *testing.T) { - handler := httpbin.NewHTTPBin().Handler() - srv := httptest.NewServer(handler) + svc := httpbin.New() + srv := httptest.NewServer(svc.Handler()) defer srv.Close() client := http.Client{ diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 5571c224b26da94078e8996693262f013872283f..ec44cc9717668444e202899df49758eef427fe5f 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -27,16 +27,6 @@ func getRequestHeaders(r *http.Request) http.Header { return h } -// Copies all headers from src to dst -// From https://golang.org/src/net/http/httputil/reverseproxy.go#L109 -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} - func getOrigin(r *http.Request) string { origin := r.Header.Get("X-Forwarded-For") if origin == "" { diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index e42ee2bfda6f3c3d30e2fd5b40f263545227c104..d1d28e3e8dde76cb13e95271a421b5ec170d1e15 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -171,10 +171,12 @@ func (h *HTTPBin) Handler() http.Handler { var handler http.Handler handler = mux handler = limitRequestSize(h.MaxBodySize, handler) - handler = metaRequests(handler) + handler = preflight(handler) + handler = autohead(handler) if h.Observer != nil { handler = observe(h.Observer, handler) } + return handler } diff --git a/httpbin/middleware.go b/httpbin/middleware.go index c6041307e039a49e4f04e4d95c5ad43901725915..6576610a1b4d90af590bd42a62df6e6c32ff245c 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -4,11 +4,10 @@ import ( "fmt" "log" "net/http" - "net/http/httptest" "time" ) -func metaRequests(h http.Handler) http.Handler { +func preflight(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin == "" { @@ -18,24 +17,17 @@ func metaRequests(h http.Handler) http.Handler { respHeader.Set("Access-Control-Allow-Origin", origin) respHeader.Set("Access-Control-Allow-Credentials", "true") - switch r.Method { - case "OPTIONS": + if r.Method == "OPTIONS" { w.Header().Set("Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS") w.Header().Set("Access-Control-Max-Age", "3600") if r.Header.Get("Access-Control-Request-Headers") != "" { w.Header().Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers")) } w.WriteHeader(200) - case "HEAD": - rwRec := httptest.NewRecorder() - r.Method = "GET" - h.ServeHTTP(rwRec, r) - - copyHeader(w.Header(), rwRec.Header()) - w.WriteHeader(rwRec.Code) - default: - h.ServeHTTP(w, r) + return } + + h.ServeHTTP(w, r) }) } @@ -43,6 +35,10 @@ func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { methodMap := make(map[string]struct{}, len(methods)) for _, m := range methods { methodMap[m] = struct{}{} + // GET implies support for HEAD + if m == "GET" { + methodMap["HEAD"] = struct{}{} + } } return func(w http.ResponseWriter, r *http.Request) { if _, ok := methodMap[r.Method]; !ok { @@ -62,6 +58,26 @@ func limitRequestSize(maxSize int64, h http.Handler) http.Handler { }) } +// headResponseWriter implements http.ResponseWriter in order to discard the +// body of the response +type headResponseWriter struct { + http.ResponseWriter +} + +func (hw *headResponseWriter) Write(b []byte) (int, error) { + return 0, nil +} + +// autohead automatically discards the body of responses to HEAD requests +func autohead(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w = &headResponseWriter{w} + } + h.ServeHTTP(w, r) + }) +} + // metaResponseWriter implements http.ResponseWriter and http.Flusher in order // to record a response's status code and body size for logging purposes. type metaResponseWriter struct {