From f77f01c3d7e25dc2efb7a0065316fadf11bc8329 Mon Sep 17 00:00:00 2001 From: Will McCutchen <will@mccutch.org> Date: Sun, 28 Aug 2016 11:20:56 -0700 Subject: [PATCH] Add methods() middleware for restricting HTTP methods --- main.go | 4 ++-- main_test.go | 10 ++++++++++ middleware.go | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index b1d224e..b08145d 100644 --- a/main.go +++ b/main.go @@ -83,8 +83,8 @@ func headers(w http.ResponseWriter, r *http.Request) { func app() http.Handler { h := http.NewServeMux() - h.HandleFunc("/", index) - h.HandleFunc("/get", get) + h.HandleFunc("/", methods(index, "GET")) + h.HandleFunc("/get", methods(get, "GET")) h.HandleFunc("/ip", ip) h.HandleFunc("/user-agent", userAgent) h.HandleFunc("/headers", headers) diff --git a/main_test.go b/main_test.go index bc6ccae..1baf1a9 100644 --- a/main_test.go +++ b/main_test.go @@ -60,6 +60,16 @@ func TestGet__Basic(t *testing.T) { } } +func TestGet__OnlyAllowsGets(t *testing.T) { + r, _ := http.NewRequest("POST", "/get", nil) + w := httptest.NewRecorder() + app().ServeHTTP(w, r) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected HTTP 405, got %d", w.Code) + } +} + func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) { r, _ := http.NewRequest("GET", "/get", nil) w := httptest.NewRecorder() diff --git a/middleware.go b/middleware.go index c25e383..dc0708c 100644 --- a/middleware.go +++ b/middleware.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "net/http" ) @@ -32,3 +33,20 @@ func cors(h http.Handler) http.Handler { h.ServeHTTP(w, r) }) } + +func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { + if len(methods) == 0 { + return h + } + methodMap := make(map[string]struct{}, len(methods)) + for _, m := range methods { + methodMap[m] = struct{}{} + } + return func(w http.ResponseWriter, r *http.Request) { + if _, ok := methodMap[r.Method]; !ok { + http.Error(w, fmt.Sprintf("method %s not allowed", r.Method), http.StatusMethodNotAllowed) + return + } + h.ServeHTTP(w, r) + } +} -- GitLab