Skip to content
Snippets Groups Projects
Commit 527829ef authored by Will McCutchen's avatar Will McCutchen
Browse files

All /get tests ported and passing

parent 2362728a
Branches
Tags
No related merge requests found
...@@ -77,6 +77,27 @@ func logger(h http.Handler) http.Handler { ...@@ -77,6 +77,27 @@ func logger(h http.Handler) http.Handler {
}) })
} }
func cors(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = "*"
}
respHeader := w.Header()
respHeader.Set("Access-Control-Allow-Origin", origin)
respHeader.Set("Access-Control-Allow-Credentials", "true")
if r.Method == "OPTIONS" {
respHeader.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
respHeader.Set("Access-Control-Max-Age", "3600")
if r.Header.Get("Access-Control-Request-Headers") != "" {
respHeader.Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers"))
}
}
h.ServeHTTP(w, r)
})
}
func get(w http.ResponseWriter, r *http.Request) { func get(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
resp := &Resp{ resp := &Resp{
...@@ -86,11 +107,15 @@ func get(w http.ResponseWriter, r *http.Request) { ...@@ -86,11 +107,15 @@ func get(w http.ResponseWriter, r *http.Request) {
writeResponse(w, r, resp) writeResponse(w, r, resp)
} }
func main() { func app() http.Handler {
h := http.NewServeMux() h := http.NewServeMux()
h.HandleFunc("/get", get) h.HandleFunc("/get", get)
return logger(cors(h))
}
func main() {
a := app()
log.Printf("listening on 9999") log.Printf("listening on 9999")
err := http.ListenAndServe(":9999", logger(h)) err := http.ListenAndServe(":9999", a)
log.Fatal(err) log.Fatal(err)
} }
...@@ -4,15 +4,16 @@ import ( ...@@ -4,15 +4,16 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
) )
func TestGet(t *testing.T) { func TestGet__Basic(t *testing.T) {
r, _ := http.NewRequest("GET", "/get", nil) r, _ := http.NewRequest("GET", "/get", nil)
r.Host = "localhost" r.Host = "localhost"
r.Header.Set("User-Agent", "test") r.Header.Set("User-Agent", "test")
w := httptest.NewRecorder() w := httptest.NewRecorder()
get(w, r) app().ServeHTTP(w, r)
if w.Code != 200 { if w.Code != 200 {
t.Fatalf("expected status code 200, got %d", w.Code) t.Fatalf("expected status code 200, got %d", w.Code)
...@@ -47,3 +48,93 @@ func TestGet(t *testing.T) { ...@@ -47,3 +48,93 @@ func TestGet(t *testing.T) {
} }
} }
} }
func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) {
r, _ := http.NewRequest("GET", "/get", nil)
w := httptest.NewRecorder()
app().ServeHTTP(w, r)
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Fatalf("expected Access-Control-Allow-Origin=*, got %#v", w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestGet__CORSHeadersWithRequestOrigin(t *testing.T) {
r, _ := http.NewRequest("GET", "/get", nil)
r.Header.Set("Origin", "origin")
w := httptest.NewRecorder()
app().ServeHTTP(w, r)
if w.Header().Get("Access-Control-Allow-Origin") != "origin" {
t.Fatalf("expected Access-Control-Allow-Origin=origin, got %#v", w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestGet__CORSHeadersWithOptionsVerb(t *testing.T) {
r, _ := http.NewRequest("OPTIONS", "/get", nil)
w := httptest.NewRecorder()
app().ServeHTTP(w, r)
var headerTests = []struct {
key string
expected string
}{
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Credentials", "true"},
{"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"},
{"Access-Control-Max-Age", "3600"},
{"Access-Control-Allow-Headers", ""},
}
for _, test := range headerTests {
if w.Header().Get(test.key) != test.expected {
t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key))
}
}
}
func TestGet__CORSAllowHeaders(t *testing.T) {
r, _ := http.NewRequest("OPTIONS", "/get", nil)
r.Header.Set("Access-Control-Request-Headers", "X-Test-Header")
w := httptest.NewRecorder()
app().ServeHTTP(w, r)
var headerTests = []struct {
key string
expected string
}{
{"Access-Control-Allow-Headers", "X-Test-Header"},
}
for _, test := range headerTests {
if w.Header().Get(test.key) != test.expected {
t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key))
}
}
}
func TestGet__XForwardedProto(t *testing.T) {
var tests = []struct {
key string
value string
}{
{"X-Forwarded-Proto", "https"},
{"X-Forwarded-Protocol", "https"},
{"X-Forwarded-Ssl", "on"},
}
for _, test := range tests {
r, _ := http.NewRequest("GET", "/get", nil)
r.Header.Set(test.key, test.value)
w := httptest.NewRecorder()
app().ServeHTTP(w, r)
var resp *Resp
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
}
if !strings.HasPrefix(resp.URL, "https://") {
t.Fatalf("%s=%s should result in https URL", test.key, test.value)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment