From c5cb2f4802fac63859778c72989fe88dce35fe35 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Tue, 15 Jan 2019 17:45:21 -0800
Subject: [PATCH] Misc updates (#12)

* Stop testing on tip, it takes too long

* Refactor automatic handling of HEAD requests

* Update README
---
 .travis.yml           |  1 -
 README.md             | 10 +++++-----
 httpbin/helpers.go    | 10 ----------
 httpbin/httpbin.go    |  4 +++-
 httpbin/middleware.go | 42 +++++++++++++++++++++++++++++-------------
 5 files changed, 37 insertions(+), 30 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 822272c..e79e6e4 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -4,7 +4,6 @@ go:
   - '1.9'
   - '1.10'
   - '1.11'
-  - 'tip'
 script:
   - make lint
   - make test
diff --git a/README.md b/README.md
index ab2c09a..8fcf1a3 100644
--- a/README.md
+++ b/README.md
@@ -15,11 +15,11 @@ variables:
 
 ```
 $ go-httpbin -help
-Usage of ./dist/go-httpbin:
+Usage of go-httpbin:
+  -max-body-size int
+        Maximum size of request or response, in bytes (default 1048576)
   -max-duration duration
         Maximum duration a response may take (default 10s)
-  -max-memory int
-        Maximum size of request or response, in bytes (default 1048576)
   -port int
         Port to listen on (default 8080)
 ```
@@ -47,8 +47,8 @@ import (
 )
 
 func TestSlowResponse(t *testing.T) {
-    handler := httpbin.NewHTTPBin().Handler()
-    srv := httptest.NewServer(handler)
+    svc := httpbin.New()
+    srv := httptest.NewServer(svc.Handler())
     defer srv.Close()
 
     client := http.Client{
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index 5571c22..ec44cc9 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -27,16 +27,6 @@ func getRequestHeaders(r *http.Request) http.Header {
 	return h
 }
 
-// Copies all headers from src to dst
-// From https://golang.org/src/net/http/httputil/reverseproxy.go#L109
-func copyHeader(dst, src http.Header) {
-	for k, vv := range src {
-		for _, v := range vv {
-			dst.Add(k, v)
-		}
-	}
-}
-
 func getOrigin(r *http.Request) string {
 	origin := r.Header.Get("X-Forwarded-For")
 	if origin == "" {
diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go
index e42ee2b..d1d28e3 100644
--- a/httpbin/httpbin.go
+++ b/httpbin/httpbin.go
@@ -171,10 +171,12 @@ func (h *HTTPBin) Handler() http.Handler {
 	var handler http.Handler
 	handler = mux
 	handler = limitRequestSize(h.MaxBodySize, handler)
-	handler = metaRequests(handler)
+	handler = preflight(handler)
+	handler = autohead(handler)
 	if h.Observer != nil {
 		handler = observe(h.Observer, handler)
 	}
+
 	return handler
 }
 
diff --git a/httpbin/middleware.go b/httpbin/middleware.go
index c604130..6576610 100644
--- a/httpbin/middleware.go
+++ b/httpbin/middleware.go
@@ -4,11 +4,10 @@ import (
 	"fmt"
 	"log"
 	"net/http"
-	"net/http/httptest"
 	"time"
 )
 
-func metaRequests(h http.Handler) http.Handler {
+func preflight(h http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		origin := r.Header.Get("Origin")
 		if origin == "" {
@@ -18,24 +17,17 @@ func metaRequests(h http.Handler) http.Handler {
 		respHeader.Set("Access-Control-Allow-Origin", origin)
 		respHeader.Set("Access-Control-Allow-Credentials", "true")
 
-		switch r.Method {
-		case "OPTIONS":
+		if r.Method == "OPTIONS" {
 			w.Header().Set("Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS")
 			w.Header().Set("Access-Control-Max-Age", "3600")
 			if r.Header.Get("Access-Control-Request-Headers") != "" {
 				w.Header().Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers"))
 			}
 			w.WriteHeader(200)
-		case "HEAD":
-			rwRec := httptest.NewRecorder()
-			r.Method = "GET"
-			h.ServeHTTP(rwRec, r)
-
-			copyHeader(w.Header(), rwRec.Header())
-			w.WriteHeader(rwRec.Code)
-		default:
-			h.ServeHTTP(w, r)
+			return
 		}
+
+		h.ServeHTTP(w, r)
 	})
 }
 
@@ -43,6 +35,10 @@ func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc {
 	methodMap := make(map[string]struct{}, len(methods))
 	for _, m := range methods {
 		methodMap[m] = struct{}{}
+		// GET implies support for HEAD
+		if m == "GET" {
+			methodMap["HEAD"] = struct{}{}
+		}
 	}
 	return func(w http.ResponseWriter, r *http.Request) {
 		if _, ok := methodMap[r.Method]; !ok {
@@ -62,6 +58,26 @@ func limitRequestSize(maxSize int64, h http.Handler) http.Handler {
 	})
 }
 
+// headResponseWriter implements http.ResponseWriter in order to discard the
+// body of the response
+type headResponseWriter struct {
+	http.ResponseWriter
+}
+
+func (hw *headResponseWriter) Write(b []byte) (int, error) {
+	return 0, nil
+}
+
+// autohead automatically discards the body of responses to HEAD requests
+func autohead(h http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if r.Method == "HEAD" {
+			w = &headResponseWriter{w}
+		}
+		h.ServeHTTP(w, r)
+	})
+}
+
 // metaResponseWriter implements http.ResponseWriter and http.Flusher in order
 // to record a response's status code and body size for logging purposes.
 type metaResponseWriter struct {
-- 
GitLab