From 527829eff635a452a852147463d1d0130fe9e482 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Sat, 27 Aug 2016 22:29:54 -0700
Subject: [PATCH] All /get tests ported and passing

---
 main.go      | 29 ++++++++++++++--
 main_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 120 insertions(+), 4 deletions(-)

diff --git a/main.go b/main.go
index 288e276..a6c9770 100644
--- a/main.go
+++ b/main.go
@@ -77,6 +77,27 @@ func logger(h http.Handler) http.Handler {
 	})
 }
 
+func cors(h http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		origin := r.Header.Get("Origin")
+		if origin == "" {
+			origin = "*"
+		}
+		respHeader := w.Header()
+		respHeader.Set("Access-Control-Allow-Origin", origin)
+		respHeader.Set("Access-Control-Allow-Credentials", "true")
+
+		if r.Method == "OPTIONS" {
+			respHeader.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
+			respHeader.Set("Access-Control-Max-Age", "3600")
+			if r.Header.Get("Access-Control-Request-Headers") != "" {
+				respHeader.Set("Access-Control-Allow-Headers", r.Header.Get("Access-Control-Request-Headers"))
+			}
+		}
+		h.ServeHTTP(w, r)
+	})
+}
+
 func get(w http.ResponseWriter, r *http.Request) {
 	r.ParseForm()
 	resp := &Resp{
@@ -86,11 +107,15 @@ func get(w http.ResponseWriter, r *http.Request) {
 	writeResponse(w, r, resp)
 }
 
-func main() {
+func app() http.Handler {
 	h := http.NewServeMux()
 	h.HandleFunc("/get", get)
+	return logger(cors(h))
+}
 
+func main() {
+	a := app()
 	log.Printf("listening on 9999")
-	err := http.ListenAndServe(":9999", logger(h))
+	err := http.ListenAndServe(":9999", a)
 	log.Fatal(err)
 }
diff --git a/main_test.go b/main_test.go
index 2068b3d..f5d0867 100644
--- a/main_test.go
+++ b/main_test.go
@@ -4,15 +4,16 @@ import (
 	"encoding/json"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
 )
 
-func TestGet(t *testing.T) {
+func TestGet__Basic(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/get", nil)
 	r.Host = "localhost"
 	r.Header.Set("User-Agent", "test")
 	w := httptest.NewRecorder()
-	get(w, r)
+	app().ServeHTTP(w, r)
 
 	if w.Code != 200 {
 		t.Fatalf("expected status code 200, got %d", w.Code)
@@ -47,3 +48,93 @@ func TestGet(t *testing.T) {
 		}
 	}
 }
+
+func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) {
+	r, _ := http.NewRequest("GET", "/get", nil)
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	if w.Header().Get("Access-Control-Allow-Origin") != "*" {
+		t.Fatalf("expected Access-Control-Allow-Origin=*, got %#v", w.Header().Get("Access-Control-Allow-Origin"))
+	}
+}
+
+func TestGet__CORSHeadersWithRequestOrigin(t *testing.T) {
+	r, _ := http.NewRequest("GET", "/get", nil)
+	r.Header.Set("Origin", "origin")
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	if w.Header().Get("Access-Control-Allow-Origin") != "origin" {
+		t.Fatalf("expected Access-Control-Allow-Origin=origin, got %#v", w.Header().Get("Access-Control-Allow-Origin"))
+	}
+}
+
+func TestGet__CORSHeadersWithOptionsVerb(t *testing.T) {
+	r, _ := http.NewRequest("OPTIONS", "/get", nil)
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var headerTests = []struct {
+		key      string
+		expected string
+	}{
+		{"Access-Control-Allow-Origin", "*"},
+		{"Access-Control-Allow-Credentials", "true"},
+		{"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"},
+		{"Access-Control-Max-Age", "3600"},
+		{"Access-Control-Allow-Headers", ""},
+	}
+	for _, test := range headerTests {
+		if w.Header().Get(test.key) != test.expected {
+			t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key))
+		}
+	}
+}
+
+func TestGet__CORSAllowHeaders(t *testing.T) {
+	r, _ := http.NewRequest("OPTIONS", "/get", nil)
+	r.Header.Set("Access-Control-Request-Headers", "X-Test-Header")
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var headerTests = []struct {
+		key      string
+		expected string
+	}{
+		{"Access-Control-Allow-Headers", "X-Test-Header"},
+	}
+	for _, test := range headerTests {
+		if w.Header().Get(test.key) != test.expected {
+			t.Fatalf("expected %s = %#v, got %#v", test.key, test.expected, w.Header().Get(test.key))
+		}
+	}
+}
+
+func TestGet__XForwardedProto(t *testing.T) {
+	var tests = []struct {
+		key   string
+		value string
+	}{
+		{"X-Forwarded-Proto", "https"},
+		{"X-Forwarded-Protocol", "https"},
+		{"X-Forwarded-Ssl", "on"},
+	}
+
+	for _, test := range tests {
+		r, _ := http.NewRequest("GET", "/get", nil)
+		r.Header.Set(test.key, test.value)
+		w := httptest.NewRecorder()
+		app().ServeHTTP(w, r)
+
+		var resp *Resp
+		err := json.Unmarshal(w.Body.Bytes(), &resp)
+		if err != nil {
+			t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
+		}
+
+		if !strings.HasPrefix(resp.URL, "https://") {
+			t.Fatalf("%s=%s should result in https URL", test.key, test.value)
+		}
+	}
+}
-- 
GitLab