From f77f01c3d7e25dc2efb7a0065316fadf11bc8329 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Sun, 28 Aug 2016 11:20:56 -0700
Subject: [PATCH] Add methods() middleware for restricting HTTP methods

---
 main.go       |  4 ++--
 main_test.go  | 10 ++++++++++
 middleware.go | 18 ++++++++++++++++++
 3 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/main.go b/main.go
index b1d224e..b08145d 100644
--- a/main.go
+++ b/main.go
@@ -83,8 +83,8 @@ func headers(w http.ResponseWriter, r *http.Request) {
 
 func app() http.Handler {
 	h := http.NewServeMux()
-	h.HandleFunc("/", index)
-	h.HandleFunc("/get", get)
+	h.HandleFunc("/", methods(index, "GET"))
+	h.HandleFunc("/get", methods(get, "GET"))
 	h.HandleFunc("/ip", ip)
 	h.HandleFunc("/user-agent", userAgent)
 	h.HandleFunc("/headers", headers)
diff --git a/main_test.go b/main_test.go
index bc6ccae..1baf1a9 100644
--- a/main_test.go
+++ b/main_test.go
@@ -60,6 +60,16 @@ func TestGet__Basic(t *testing.T) {
 	}
 }
 
+func TestGet__OnlyAllowsGets(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/get", nil)
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	if w.Code != http.StatusMethodNotAllowed {
+		t.Fatalf("expected HTTP 405, got %d", w.Code)
+	}
+}
+
 func TestGet__CORSHeadersWithoutRequestOrigin(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/get", nil)
 	w := httptest.NewRecorder()
diff --git a/middleware.go b/middleware.go
index c25e383..dc0708c 100644
--- a/middleware.go
+++ b/middleware.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"fmt"
 	"log"
 	"net/http"
 )
@@ -32,3 +33,20 @@ func cors(h http.Handler) http.Handler {
 		h.ServeHTTP(w, r)
 	})
 }
+
+func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc {
+	if len(methods) == 0 {
+		return h
+	}
+	methodMap := make(map[string]struct{}, len(methods))
+	for _, m := range methods {
+		methodMap[m] = struct{}{}
+	}
+	return func(w http.ResponseWriter, r *http.Request) {
+		if _, ok := methodMap[r.Method]; !ok {
+			http.Error(w, fmt.Sprintf("method %s not allowed", r.Method), http.StatusMethodNotAllowed)
+			return
+		}
+		h.ServeHTTP(w, r)
+	}
+}
-- 
GitLab