Skip to content
Snippets Groups Projects
Unverified Commit 793a5046 authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Merge pull request #5 from na--/head-and-options-methods-2

Add support for HEAD and fix the OPTIONS HTTP method
parents 5f1d7077 0379c252
No related branches found
No related tags found
No related merge requests found
......@@ -192,6 +192,27 @@ func TestGet(t *testing.T) {
}
}
func TestHEAD(t *testing.T) {
r, _ := http.NewRequest("HEAD", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, 200)
assertBodyEquals(t, w, "")
contentLengthStr := w.HeaderMap.Get("Content-Length")
if contentLengthStr == "" {
t.Fatalf("missing Content-Length header in response")
}
contentLength, err := strconv.Atoi(contentLengthStr)
if err != nil {
t.Fatalf("error converting Content-Lengh %v to integer: %s", contentLengthStr, err)
}
if contentLength <= 0 {
t.Fatalf("Content-Lengh %v should be greater than 0", contentLengthStr)
}
}
func TestCORS(t *testing.T) {
t.Run("CORS/no_request_origin", func(t *testing.T) {
r, _ := http.NewRequest("GET", "/get", nil)
......@@ -213,13 +234,15 @@ func TestCORS(t *testing.T) {
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, 200)
var headerTests = []struct {
key string
expected string
}{
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Credentials", "true"},
{"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"},
{"Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS"},
{"Access-Control-Max-Age", "3600"},
{"Access-Control-Allow-Headers", ""},
}
......@@ -234,6 +257,8 @@ func TestCORS(t *testing.T) {
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, 200)
var headerTests = []struct {
key string
expected string
......
......@@ -27,6 +27,16 @@ func getRequestHeaders(r *http.Request) http.Header {
return h
}
// Copies all headers from src to dst
// From https://golang.org/src/net/http/httputil/reverseproxy.go#L109
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func getOrigin(r *http.Request) string {
origin := r.Header.Get("X-Forwarded-For")
if origin == "" {
......
......@@ -174,8 +174,8 @@ func (h *HTTPBin) Handler() http.Handler {
var handler http.Handler
handler = mux
handler = limitRequestSize(h.options.MaxMemory, handler)
handler = metaRequests(handler)
handler = logger(handler)
handler = cors(handler)
return handler
}
......
......@@ -4,10 +4,11 @@ import (
"fmt"
"log"
"net/http"
"net/http/httptest"
"time"
)
func cors(h http.Handler) http.Handler {
func metaRequests(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
......@@ -17,14 +18,24 @@ func cors(h http.Handler) http.Handler {
respHeader.Set("Access-Control-Allow-Origin", origin)
respHeader.Set("Access-Control-Allow-Credentials", "true")
if r.Method == "OPTIONS" {
respHeader.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
respHeader.Set("Access-Control-Max-Age", "3600")
switch r.Method {
case "OPTIONS":
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Max-Age", "3600")
if r.Header.Get("Access-Control-Request-Headers") != "" {
respHeader.Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers"))
}
w.Header().Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers"))
}
w.WriteHeader(200)
case "HEAD":
rwRec := httptest.NewRecorder()
r.Method = "GET"
h.ServeHTTP(rwRec, r)
copyHeader(w.Header(), rwRec.Header())
w.WriteHeader(rwRec.Code)
default:
h.ServeHTTP(w, r)
}
})
}
......@@ -92,10 +103,11 @@ func (mw *metaResponseWriter) Size() int {
func logger(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqMethod, reqURI := r.Method, r.URL.RequestURI()
mw := &metaResponseWriter{w: w}
t := time.Now()
h.ServeHTTP(mw, r)
duration := time.Now().Sub(t)
log.Printf("status=%d method=%s uri=%q size=%d duration=%s", mw.Status(), r.Method, r.URL.RequestURI(), mw.Size(), duration)
log.Printf("status=%d method=%s uri=%q size=%d duration=%s", mw.Status(), reqMethod, reqURI, mw.Size(), duration)
})
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment