From f2c7b90bdf48de19394b8cd79336d6c4b0cf8c05 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Wed, 12 Oct 2016 10:53:57 -0700
Subject: [PATCH] Improve test coverage

---
 httpbin/handlers_test.go | 227 ++++++++++++++++++++++++++++++++++++++-
 httpbin/helpers.go       |   2 +-
 httpbin/middleware.go    |   3 -
 3 files changed, 224 insertions(+), 8 deletions(-)

diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index 970ba8d..149d877 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"mime/multipart"
 	"net/http"
 	"net/http/httptest"
 	"net/url"
@@ -42,6 +43,13 @@ func assertBodyContains(t *testing.T, w *httptest.ResponseRecorder, needle strin
 	}
 }
 
+func TestNewHTTPBin__NilOptions(t *testing.T) {
+	h := NewHTTPBin(nil)
+	if h.options.MaxMemory != 0 {
+		t.Fatalf("expected default MaxMemory == 0, got %#v", h.options.MaxMemory)
+	}
+}
+
 func TestIndex(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/", nil)
 	w := httptest.NewRecorder()
@@ -133,6 +141,14 @@ func TestGet__WithParams(t *testing.T) {
 	}
 }
 
+func TestGet__InvalidQuery(t *testing.T) {
+	r, _ := http.NewRequest("GET", "/get?foo=%ZZ", nil)
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
 func TestGet__OnlyAllowsGets(t *testing.T) {
 	r, _ := http.NewRequest("POST", "/get", nil)
 	w := httptest.NewRecorder()
@@ -382,6 +398,76 @@ func TestPost__FormEncodedBodyNoContentType(t *testing.T) {
 	}
 }
 
+func TestPost__MultiPartBody(t *testing.T) {
+	params := map[string][]string{
+		"foo": {"foo"},
+		"bar": {"bar1", "bar2"},
+	}
+
+	// Prepare a form that you will submit to that URL.
+	var body bytes.Buffer
+	mw := multipart.NewWriter(&body)
+
+	for k, vs := range params {
+		for _, v := range vs {
+			fw, err := mw.CreateFormField(k)
+			if err != nil {
+				t.Fatalf("error creating multipart form field %s: %s", k, err)
+			}
+			if _, err := fw.Write([]byte(v)); err != nil {
+				t.Fatalf("error writing multipart form value %#v for key %s: %s", v, k, err)
+			}
+		}
+	}
+	mw.Close()
+
+	r, _ := http.NewRequest("POST", "/post", bytes.NewReader(body.Bytes()))
+	r.Header.Set("Content-Type", mw.FormDataContentType())
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+
+	assertStatusCode(t, w, http.StatusOK)
+	assertContentType(t, w, "application/json; encoding=utf-8")
+
+	var resp *bodyResponse
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
+	}
+
+	if len(resp.Args) > 0 {
+		t.Fatalf("expected no query params, got %#v", resp.Args)
+	}
+	if len(resp.Form) != len(params) {
+		t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form))
+	}
+	for k, expectedValues := range params {
+		values, ok := resp.Form[k]
+		if !ok {
+			t.Fatalf("expected form field %#v in response", k)
+		}
+		if !reflect.DeepEqual(expectedValues, values) {
+			t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues)
+		}
+	}
+}
+
+func TestPost__InvalidFormEncodedBody(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ"))
+	r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
+func TestPost__InvalidMultiPartBody(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ"))
+	r.Header.Set("Content-Type", "multipart/form-data; etc")
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
 func TestPost__JSON(t *testing.T) {
 	type testInput struct {
 		Foo  string
@@ -435,6 +521,14 @@ func TestPost__JSON(t *testing.T) {
 	}
 }
 
+func TestPost__InvalidJSON(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/post", bytes.NewReader([]byte("foo")))
+	r.Header.Set("Content-Type", "application/json; charset=utf-8")
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
 func TestPost__BodyTooBig(t *testing.T) {
 	body := make([]byte, maxMemory+1)
 
@@ -446,6 +540,88 @@ func TestPost__BodyTooBig(t *testing.T) {
 	assertContentType(t, w, "application/json; encoding=utf-8")
 }
 
+func TestPost__InvalidQueryParams(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/post?foo=%ZZ", nil)
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
+func TestPost__QueryParams(t *testing.T) {
+	params := url.Values{}
+	params.Set("foo", "foo")
+	params.Add("bar", "bar1")
+	params.Add("bar", "bar2")
+
+	r, _ := http.NewRequest("POST", fmt.Sprintf("/post?%s", params.Encode()), nil)
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+
+	assertStatusCode(t, w, http.StatusOK)
+	assertContentType(t, w, "application/json; encoding=utf-8")
+
+	var resp *bodyResponse
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
+	}
+
+	if resp.Args.Encode() != params.Encode() {
+		t.Fatalf("expected args = %#v in response, got %#v", params.Encode(), resp.Args.Encode())
+	}
+
+	if len(resp.Form) > 0 {
+		t.Fatalf("expected form data, got %#v", resp.Form)
+	}
+}
+
+func TestPost__QueryParamsAndBody(t *testing.T) {
+	args := url.Values{}
+	args.Set("query1", "foo")
+	args.Add("query2", "bar1")
+	args.Add("query2", "bar2")
+
+	form := url.Values{}
+	form.Set("form1", "foo")
+	form.Add("form2", "bar1")
+	form.Add("form2", "bar2")
+
+	url := fmt.Sprintf("/post?%s", args.Encode())
+	body := strings.NewReader(form.Encode())
+
+	r, _ := http.NewRequest("POST", url, body)
+	r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+
+	assertStatusCode(t, w, http.StatusOK)
+	assertContentType(t, w, "application/json; encoding=utf-8")
+
+	var resp *bodyResponse
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
+	}
+
+	if resp.Args.Encode() != args.Encode() {
+		t.Fatalf("expected args = %#v in response, got %#v", args.Encode(), resp.Args.Encode())
+	}
+
+	if len(resp.Form) != len(form) {
+		t.Fatalf("expected %d form values, got %d", len(form), len(resp.Form))
+	}
+	for k, expectedValues := range form {
+		values, ok := resp.Form[k]
+		if !ok {
+			t.Fatalf("expected form field %#v in response", k)
+		}
+		if !reflect.DeepEqual(expectedValues, values) {
+			t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues)
+		}
+	}
+}
+
+// TODO: implement and test more complex /status endpoint
 func TestStatus__Simple(t *testing.T) {
 	redirectHeaders := map[string]string{
 		"Location": "/redirect/1",
@@ -486,11 +662,30 @@ func TestStatus__Simple(t *testing.T) {
 	}
 }
 
-func TestResponseHeaders(t *testing.T) {
+func TestStatus__Errors(t *testing.T) {
+	var tests = []struct {
+		given  interface{}
+		status int
+	}{
+		{"", http.StatusBadRequest},
+		{"200/foo", http.StatusNotFound},
+		{3.14, http.StatusBadRequest},
+		{"foo", http.StatusBadRequest},
+	}
+
+	for _, test := range tests {
+		url := fmt.Sprintf("/status/%v", test.given)
+		r, _ := http.NewRequest("GET", url, nil)
+		w := httptest.NewRecorder()
+		handler.ServeHTTP(w, r)
+		assertStatusCode(t, w, test.status)
+	}
+}
+
+func TestResponseHeaders__OK(t *testing.T) {
 	headers := map[string][]string{
-		"Content-Type": {"test/test"},
-		"Foo":          {"foo"},
-		"Bar":          {"bar1, bar2"},
+		"Foo": {"foo"},
+		"Bar": {"bar1, bar2"},
 	}
 
 	params := url.Values{}
@@ -505,6 +700,7 @@ func TestResponseHeaders(t *testing.T) {
 	handler.ServeHTTP(w, r)
 
 	assertStatusCode(t, w, http.StatusOK)
+	assertContentType(t, w, "application/json; encoding=utf-8")
 
 	for k, expectedValues := range headers {
 		values, ok := w.HeaderMap[k]
@@ -533,6 +729,27 @@ func TestResponseHeaders(t *testing.T) {
 	}
 }
 
+func TestResponseHeaders__OverrideContentType(t *testing.T) {
+	contentType := "text/test"
+
+	params := url.Values{}
+	params.Set("Content-Type", contentType)
+
+	r, _ := http.NewRequest("GET", fmt.Sprintf("/response-headers?%s", params.Encode()), nil)
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+
+	assertStatusCode(t, w, http.StatusOK)
+	assertContentType(t, w, contentType)
+}
+
+func TestResponseHeaders__InvalidQuery(t *testing.T) {
+	r, _ := http.NewRequest("GET", "/response-headers?foo=%ZZ", nil)
+	w := httptest.NewRecorder()
+	handler.ServeHTTP(w, r)
+	assertStatusCode(t, w, http.StatusBadRequest)
+}
+
 func TestRelativeRedirect__OK(t *testing.T) {
 	var tests = []struct {
 		n        int
@@ -560,7 +777,9 @@ func TestRelativeRedirect__Errors(t *testing.T) {
 	}{
 		{3.14, http.StatusBadRequest},
 		{-1, http.StatusBadRequest},
+		{"", http.StatusBadRequest},
 		{"foo", http.StatusBadRequest},
+		{"10/bar", http.StatusNotFound},
 	}
 
 	for _, test := range tests {
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index debda75..9ded2e4 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -72,7 +72,7 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe
 			return err
 		}
 		resp.Form = r.PostForm
-	case ct == "multipart/form-data":
+	case strings.HasPrefix(ct, "multipart/form-data"):
 		err := r.ParseMultipartForm(maxMemory)
 		if err != nil {
 			return err
diff --git a/httpbin/middleware.go b/httpbin/middleware.go
index 8769302..ec04086 100644
--- a/httpbin/middleware.go
+++ b/httpbin/middleware.go
@@ -35,9 +35,6 @@ func cors(h http.Handler) http.Handler {
 }
 
 func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc {
-	if len(methods) == 0 {
-		return h
-	}
 	methodMap := make(map[string]struct{}, len(methods))
 	for _, m := range methods {
 		methodMap[m] = struct{}{}
-- 
GitLab