diff --git a/main.go b/main.go index 288e276df4fbd40fd5a2e78446e186711622921c..a6c97702d14c165c74c8f13f861d57fcefb1ddb4 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,27 @@ func logger(h http.Handler) http.Handler { }) } +func cors(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + origin = "*" + } + respHeader := w.Header() + respHeader.Set("Access-Control-Allow-Origin", origin) + respHeader.Set("Access-Control-Allow-Credentials", "true") + + if r.Method == "OPTIONS" { + respHeader.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") + respHeader.Set("Access-Control-Max-Age", "3600") + if r.Header.Get("Access-Control-Request-Headers") != "" { + respHeader.Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers")) + } + } + h.ServeHTTP(w, r) + }) +} + func get(w http.ResponseWriter, r *http.Request) { r.ParseForm() resp := &Resp{ @@ -86,11 +107,15 @@ func get(w http.ResponseWriter, r *http.Request) { writeResponse(w, r, resp) } -func main() { +func app() http.Handler { h := http.NewServeMux() h.HandleFunc("/get", get) + return logger(cors(h)) +} +func main() { + a := app() log.Printf("listening on 9999") - err := http.ListenAndServe(":9999", logger(h)) + err := http.ListenAndServe(":9999", a) log.Fatal(err) } diff --git a/main_test.go b/main_test.go index 2068b3d8a4c78ab395e7a1b59d999604e084fba3..f5d0867e4c682df3795f87b9a67a4cf479664eaa 100644 --- a/main_test.go +++ b/main_test.go @@ -4,15 +4,16 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" ) -func TestGet(t *testing.T) { +func TestGet__Basic(t *testing.T) { r, _ := http.NewRequest("GET", "/get", nil) r.Host = "localhost" r.Header.Set("User-Agent", "test") w := httptest.NewRecorder() - get(w, r) + app().ServeHTTP(w, r) if w.Code != 200 { t.Fatalf("expected status code 200, got %d", w.Code) @@ -47,3 +48,93 @@ func TestGet(t *testing.T) { } } } + +func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) { + r, _ := http.NewRequest("GET", "/get", nil) + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Fatalf("expected Access-Control-Allow-Origin=*, got %#v", w.Header().Get("Access-Control-Allow-Origin")) + } +} + +func TestGet__CORSHeadersWithRequestOrigin(t *testing.T) { + r, _ := http.NewRequest("GET", "/get", nil) + r.Header.Set("Origin", "origin") + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + if w.Header().Get("Access-Control-Allow-Origin") != "origin" { + t.Fatalf("expected Access-Control-Allow-Origin=origin, got %#v", w.Header().Get("Access-Control-Allow-Origin")) + } +} + +func TestGet__CORSHeadersWithOptionsVerb(t *testing.T) { + r, _ := http.NewRequest("OPTIONS", "/get", nil) + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + var headerTests = []struct { + key string + expected string + }{ + {"Access-Control-Allow-Origin", "*"}, + {"Access-Control-Allow-Credentials", "true"}, + {"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"}, + {"Access-Control-Max-Age", "3600"}, + {"Access-Control-Allow-Headers", ""}, + } + for _, test := range headerTests { + if w.Header().Get(test.key) != test.expected { + t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key)) + } + } +} + +func TestGet__CORSAllowHeaders(t *testing.T) { + r, _ := http.NewRequest("OPTIONS", "/get", nil) + r.Header.Set("Access-Control-Request-Headers", "X-Test-Header") + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + var headerTests = []struct { + key string + expected string + }{ + {"Access-Control-Allow-Headers", "X-Test-Header"}, + } + for _, test := range headerTests { + if w.Header().Get(test.key) != test.expected { + t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key)) + } + } +} + +func TestGet__XForwardedProto(t *testing.T) { + var tests = []struct { + key string + value string + }{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-Protocol", "https"}, + {"X-Forwarded-Ssl", "on"}, + } + + for _, test := range tests { + r, _ := http.NewRequest("GET", "/get", nil) + r.Header.Set(test.key, test.value) + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + var resp *Resp + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + } + + if !strings.HasPrefix(resp.URL, "https://") { + t.Fatalf("%s=%s should result in https URL", test.key, test.value) + } + } +}