diff --git a/main.go b/main.go index b1d224e2d8b83a2765b5d7dab6580243213c0b17..b08145d5e32356cdbf57208f055b0905087fb60a 100644 --- a/main.go +++ b/main.go @@ -83,8 +83,8 @@ func headers(w http.ResponseWriter, r *http.Request) { func app() http.Handler { h := http.NewServeMux() - h.HandleFunc("/", index) - h.HandleFunc("/get", get) + h.HandleFunc("/", methods(index, "GET")) + h.HandleFunc("/get", methods(get, "GET")) h.HandleFunc("/ip", ip) h.HandleFunc("/user-agent", userAgent) h.HandleFunc("/headers", headers) diff --git a/main_test.go b/main_test.go index bc6ccae99faecfb46c144cfb334ca9cf759282dc..1baf1a9a603abd6036761a63753fde0ed3fb430f 100644 --- a/main_test.go +++ b/main_test.go @@ -60,6 +60,16 @@ func TestGet__Basic(t *testing.T) { } } +func TestGet__OnlyAllowsGets(t *testing.T) { + r, _ := http.NewRequest("POST", "/get", nil) + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected HTTP 405, got %d", w.Code) + } +} + func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) { r, _ := http.NewRequest("GET", "/get", nil) w := httptest.NewRecorder() diff --git a/middleware.go b/middleware.go index c25e383386db61bc71861c462ae1392c4e075da2..dc0708ca1a4295b777f19cbec521caaddf2ed785 100644 --- a/middleware.go +++ b/middleware.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "net/http" ) @@ -32,3 +33,20 @@ func cors(h http.Handler) http.Handler { h.ServeHTTP(w, r) }) } + +func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { + if len(methods) == 0 { + return h + } + methodMap := make(map[string]struct{}, len(methods)) + for _, m := range methods { + methodMap[m] = struct{}{} + } + return func(w http.ResponseWriter, r *http.Request) { + if _, ok := methodMap[r.Method]; !ok { + http.Error(w, fmt.Sprintf("method %s not allowed", r.Method), http.StatusMethodNotAllowed) + return + } + h.ServeHTTP(w, r) + } +}