Skip to content
Snippets Groups Projects
Commit f2c7b90b authored by Will McCutchen's avatar Will McCutchen
Browse files

Improve test coverage

parent 4ece6f11
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
...@@ -42,6 +43,13 @@ func assertBodyContains(t *testing.T, w *httptest.ResponseRecorder, needle strin ...@@ -42,6 +43,13 @@ func assertBodyContains(t *testing.T, w *httptest.ResponseRecorder, needle strin
} }
} }
func TestNewHTTPBin__NilOptions(t *testing.T) {
h := NewHTTPBin(nil)
if h.options.MaxMemory != 0 {
t.Fatalf("expected default MaxMemory == 0, got %#v", h.options.MaxMemory)
}
}
func TestIndex(t *testing.T) { func TestIndex(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -133,6 +141,14 @@ func TestGet__WithParams(t *testing.T) { ...@@ -133,6 +141,14 @@ func TestGet__WithParams(t *testing.T) {
} }
} }
func TestGet__InvalidQuery(t *testing.T) {
r, _ := http.NewRequest("GET", "/get?foo=%ZZ", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestGet__OnlyAllowsGets(t *testing.T) { func TestGet__OnlyAllowsGets(t *testing.T) {
r, _ := http.NewRequest("POST", "/get", nil) r, _ := http.NewRequest("POST", "/get", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
...@@ -382,6 +398,76 @@ func TestPost__FormEncodedBodyNoContentType(t *testing.T) { ...@@ -382,6 +398,76 @@ func TestPost__FormEncodedBodyNoContentType(t *testing.T) {
} }
} }
func TestPost__MultiPartBody(t *testing.T) {
params := map[string][]string{
"foo": {"foo"},
"bar": {"bar1", "bar2"},
}
// Prepare a form that you will submit to that URL.
var body bytes.Buffer
mw := multipart.NewWriter(&body)
for k, vs := range params {
for _, v := range vs {
fw, err := mw.CreateFormField(k)
if err != nil {
t.Fatalf("error creating multipart form field %s: %s", k, err)
}
if _, err := fw.Write([]byte(v)); err != nil {
t.Fatalf("error writing multipart form value %#v for key %s: %s", v, k, err)
}
}
}
mw.Close()
r, _ := http.NewRequest("POST", "/post", bytes.NewReader(body.Bytes()))
r.Header.Set("Content-Type", mw.FormDataContentType())
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK)
assertContentType(t, w, "application/json; encoding=utf-8")
var resp *bodyResponse
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
}
if len(resp.Args) > 0 {
t.Fatalf("expected no query params, got %#v", resp.Args)
}
if len(resp.Form) != len(params) {
t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form))
}
for k, expectedValues := range params {
values, ok := resp.Form[k]
if !ok {
t.Fatalf("expected form field %#v in response", k)
}
if !reflect.DeepEqual(expectedValues, values) {
t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues)
}
}
}
func TestPost__InvalidFormEncodedBody(t *testing.T) {
r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ"))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestPost__InvalidMultiPartBody(t *testing.T) {
r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ"))
r.Header.Set("Content-Type", "multipart/form-data; etc")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestPost__JSON(t *testing.T) { func TestPost__JSON(t *testing.T) {
type testInput struct { type testInput struct {
Foo string Foo string
...@@ -435,6 +521,14 @@ func TestPost__JSON(t *testing.T) { ...@@ -435,6 +521,14 @@ func TestPost__JSON(t *testing.T) {
} }
} }
func TestPost__InvalidJSON(t *testing.T) {
r, _ := http.NewRequest("POST", "/post", bytes.NewReader([]byte("foo")))
r.Header.Set("Content-Type", "application/json; charset=utf-8")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestPost__BodyTooBig(t *testing.T) { func TestPost__BodyTooBig(t *testing.T) {
body := make([]byte, maxMemory+1) body := make([]byte, maxMemory+1)
...@@ -446,6 +540,88 @@ func TestPost__BodyTooBig(t *testing.T) { ...@@ -446,6 +540,88 @@ func TestPost__BodyTooBig(t *testing.T) {
assertContentType(t, w, "application/json; encoding=utf-8") assertContentType(t, w, "application/json; encoding=utf-8")
} }
func TestPost__InvalidQueryParams(t *testing.T) {
r, _ := http.NewRequest("POST", "/post?foo=%ZZ", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestPost__QueryParams(t *testing.T) {
params := url.Values{}
params.Set("foo", "foo")
params.Add("bar", "bar1")
params.Add("bar", "bar2")
r, _ := http.NewRequest("POST", fmt.Sprintf("/post?%s", params.Encode()), nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK)
assertContentType(t, w, "application/json; encoding=utf-8")
var resp *bodyResponse
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
}
if resp.Args.Encode() != params.Encode() {
t.Fatalf("expected args = %#v in response, got %#v", params.Encode(), resp.Args.Encode())
}
if len(resp.Form) > 0 {
t.Fatalf("expected form data, got %#v", resp.Form)
}
}
func TestPost__QueryParamsAndBody(t *testing.T) {
args := url.Values{}
args.Set("query1", "foo")
args.Add("query2", "bar1")
args.Add("query2", "bar2")
form := url.Values{}
form.Set("form1", "foo")
form.Add("form2", "bar1")
form.Add("form2", "bar2")
url := fmt.Sprintf("/post?%s", args.Encode())
body := strings.NewReader(form.Encode())
r, _ := http.NewRequest("POST", url, body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK)
assertContentType(t, w, "application/json; encoding=utf-8")
var resp *bodyResponse
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
}
if resp.Args.Encode() != args.Encode() {
t.Fatalf("expected args = %#v in response, got %#v", args.Encode(), resp.Args.Encode())
}
if len(resp.Form) != len(form) {
t.Fatalf("expected %d form values, got %d", len(form), len(resp.Form))
}
for k, expectedValues := range form {
values, ok := resp.Form[k]
if !ok {
t.Fatalf("expected form field %#v in response", k)
}
if !reflect.DeepEqual(expectedValues, values) {
t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues)
}
}
}
// TODO: implement and test more complex /status endpoint
func TestStatus__Simple(t *testing.T) { func TestStatus__Simple(t *testing.T) {
redirectHeaders := map[string]string{ redirectHeaders := map[string]string{
"Location": "/redirect/1", "Location": "/redirect/1",
...@@ -486,9 +662,28 @@ func TestStatus__Simple(t *testing.T) { ...@@ -486,9 +662,28 @@ func TestStatus__Simple(t *testing.T) {
} }
} }
func TestResponseHeaders(t *testing.T) { func TestStatus__Errors(t *testing.T) {
var tests = []struct {
given interface{}
status int
}{
{"", http.StatusBadRequest},
{"200/foo", http.StatusNotFound},
{3.14, http.StatusBadRequest},
{"foo", http.StatusBadRequest},
}
for _, test := range tests {
url := fmt.Sprintf("/status/%v", test.given)
r, _ := http.NewRequest("GET", url, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, test.status)
}
}
func TestResponseHeaders__OK(t *testing.T) {
headers := map[string][]string{ headers := map[string][]string{
"Content-Type": {"test/test"},
"Foo": {"foo"}, "Foo": {"foo"},
"Bar": {"bar1, bar2"}, "Bar": {"bar1, bar2"},
} }
...@@ -505,6 +700,7 @@ func TestResponseHeaders(t *testing.T) { ...@@ -505,6 +700,7 @@ func TestResponseHeaders(t *testing.T) {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK) assertStatusCode(t, w, http.StatusOK)
assertContentType(t, w, "application/json; encoding=utf-8")
for k, expectedValues := range headers { for k, expectedValues := range headers {
values, ok := w.HeaderMap[k] values, ok := w.HeaderMap[k]
...@@ -533,6 +729,27 @@ func TestResponseHeaders(t *testing.T) { ...@@ -533,6 +729,27 @@ func TestResponseHeaders(t *testing.T) {
} }
} }
func TestResponseHeaders__OverrideContentType(t *testing.T) {
contentType := "text/test"
params := url.Values{}
params.Set("Content-Type", contentType)
r, _ := http.NewRequest("GET", fmt.Sprintf("/response-headers?%s", params.Encode()), nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusOK)
assertContentType(t, w, contentType)
}
func TestResponseHeaders__InvalidQuery(t *testing.T) {
r, _ := http.NewRequest("GET", "/response-headers?foo=%ZZ", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
assertStatusCode(t, w, http.StatusBadRequest)
}
func TestRelativeRedirect__OK(t *testing.T) { func TestRelativeRedirect__OK(t *testing.T) {
var tests = []struct { var tests = []struct {
n int n int
...@@ -560,7 +777,9 @@ func TestRelativeRedirect__Errors(t *testing.T) { ...@@ -560,7 +777,9 @@ func TestRelativeRedirect__Errors(t *testing.T) {
}{ }{
{3.14, http.StatusBadRequest}, {3.14, http.StatusBadRequest},
{-1, http.StatusBadRequest}, {-1, http.StatusBadRequest},
{"", http.StatusBadRequest},
{"foo", http.StatusBadRequest}, {"foo", http.StatusBadRequest},
{"10/bar", http.StatusNotFound},
} }
for _, test := range tests { for _, test := range tests {
......
...@@ -72,7 +72,7 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe ...@@ -72,7 +72,7 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe
return err return err
} }
resp.Form = r.PostForm resp.Form = r.PostForm
case ct == "multipart/form-data": case strings.HasPrefix(ct, "multipart/form-data"):
err := r.ParseMultipartForm(maxMemory) err := r.ParseMultipartForm(maxMemory)
if err != nil { if err != nil {
return err return err
......
...@@ -35,9 +35,6 @@ func cors(h http.Handler) http.Handler { ...@@ -35,9 +35,6 @@ func cors(h http.Handler) http.Handler {
} }
func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc {
if len(methods) == 0 {
return h
}
methodMap := make(map[string]struct{}, len(methods)) methodMap := make(map[string]struct{}, len(methods))
for _, m := range methods { for _, m := range methods {
methodMap[m] = struct{}{} methodMap[m] = struct{}{}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment