diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 970ba8d031850bd36fd17805dd583809cf4ebad2..149d877c743dff1ef8dcb15f2183b687a512c685 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -42,6 +43,13 @@ func assertBodyContains(t *testing.T, w *httptest.ResponseRecorder, needle strin } } +func TestNewHTTPBin__NilOptions(t *testing.T) { + h := NewHTTPBin(nil) + if h.options.MaxMemory != 0 { + t.Fatalf("expected default MaxMemory == 0, got %#v", h.options.MaxMemory) + } +} + func TestIndex(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -133,6 +141,14 @@ func TestGet__WithParams(t *testing.T) { } } +func TestGet__InvalidQuery(t *testing.T) { + r, _ := http.NewRequest("GET", "/get?foo=%ZZ", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusBadRequest) +} + func TestGet__OnlyAllowsGets(t *testing.T) { r, _ := http.NewRequest("POST", "/get", nil) w := httptest.NewRecorder() @@ -382,6 +398,76 @@ func TestPost__FormEncodedBodyNoContentType(t *testing.T) { } } +func TestPost__MultiPartBody(t *testing.T) { + params := map[string][]string{ + "foo": {"foo"}, + "bar": {"bar1", "bar2"}, + } + + // Prepare a form that you will submit to that URL. + var body bytes.Buffer + mw := multipart.NewWriter(&body) + + for k, vs := range params { + for _, v := range vs { + fw, err := mw.CreateFormField(k) + if err != nil { + t.Fatalf("error creating multipart form field %s: %s", k, err) + } + if _, err := fw.Write([]byte(v)); err != nil { + t.Fatalf("error writing multipart form value %#v for key %s: %s", v, k, err) + } + } + } + mw.Close() + + r, _ := http.NewRequest("POST", "/post", bytes.NewReader(body.Bytes())) + r.Header.Set("Content-Type", mw.FormDataContentType()) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, "application/json; encoding=utf-8") + + var resp *bodyResponse + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) + } + + if len(resp.Args) > 0 { + t.Fatalf("expected no query params, got %#v", resp.Args) + } + if len(resp.Form) != len(params) { + t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form)) + } + for k, expectedValues := range params { + values, ok := resp.Form[k] + if !ok { + t.Fatalf("expected form field %#v in response", k) + } + if !reflect.DeepEqual(expectedValues, values) { + t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues) + } + } +} + +func TestPost__InvalidFormEncodedBody(t *testing.T) { + r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, http.StatusBadRequest) +} + +func TestPost__InvalidMultiPartBody(t *testing.T) { + r, _ := http.NewRequest("POST", "/post", strings.NewReader("%ZZ")) + r.Header.Set("Content-Type", "multipart/form-data; etc") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, http.StatusBadRequest) +} + func TestPost__JSON(t *testing.T) { type testInput struct { Foo string @@ -435,6 +521,14 @@ func TestPost__JSON(t *testing.T) { } } +func TestPost__InvalidJSON(t *testing.T) { + r, _ := http.NewRequest("POST", "/post", bytes.NewReader([]byte("foo"))) + r.Header.Set("Content-Type", "application/json; charset=utf-8") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, http.StatusBadRequest) +} + func TestPost__BodyTooBig(t *testing.T) { body := make([]byte, maxMemory+1) @@ -446,6 +540,88 @@ func TestPost__BodyTooBig(t *testing.T) { assertContentType(t, w, "application/json; encoding=utf-8") } +func TestPost__InvalidQueryParams(t *testing.T) { + r, _ := http.NewRequest("POST", "/post?foo=%ZZ", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, http.StatusBadRequest) +} + +func TestPost__QueryParams(t *testing.T) { + params := url.Values{} + params.Set("foo", "foo") + params.Add("bar", "bar1") + params.Add("bar", "bar2") + + r, _ := http.NewRequest("POST", fmt.Sprintf("/post?%s", params.Encode()), nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, "application/json; encoding=utf-8") + + var resp *bodyResponse + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) + } + + if resp.Args.Encode() != params.Encode() { + t.Fatalf("expected args = %#v in response, got %#v", params.Encode(), resp.Args.Encode()) + } + + if len(resp.Form) > 0 { + t.Fatalf("expected form data, got %#v", resp.Form) + } +} + +func TestPost__QueryParamsAndBody(t *testing.T) { + args := url.Values{} + args.Set("query1", "foo") + args.Add("query2", "bar1") + args.Add("query2", "bar2") + + form := url.Values{} + form.Set("form1", "foo") + form.Add("form2", "bar1") + form.Add("form2", "bar2") + + url := fmt.Sprintf("/post?%s", args.Encode()) + body := strings.NewReader(form.Encode()) + + r, _ := http.NewRequest("POST", url, body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, "application/json; encoding=utf-8") + + var resp *bodyResponse + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err) + } + + if resp.Args.Encode() != args.Encode() { + t.Fatalf("expected args = %#v in response, got %#v", args.Encode(), resp.Args.Encode()) + } + + if len(resp.Form) != len(form) { + t.Fatalf("expected %d form values, got %d", len(form), len(resp.Form)) + } + for k, expectedValues := range form { + values, ok := resp.Form[k] + if !ok { + t.Fatalf("expected form field %#v in response", k) + } + if !reflect.DeepEqual(expectedValues, values) { + t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues) + } + } +} + +// TODO: implement and test more complex /status endpoint func TestStatus__Simple(t *testing.T) { redirectHeaders := map[string]string{ "Location": "/redirect/1", @@ -486,11 +662,30 @@ func TestStatus__Simple(t *testing.T) { } } -func TestResponseHeaders(t *testing.T) { +func TestStatus__Errors(t *testing.T) { + var tests = []struct { + given interface{} + status int + }{ + {"", http.StatusBadRequest}, + {"200/foo", http.StatusNotFound}, + {3.14, http.StatusBadRequest}, + {"foo", http.StatusBadRequest}, + } + + for _, test := range tests { + url := fmt.Sprintf("/status/%v", test.given) + r, _ := http.NewRequest("GET", url, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, test.status) + } +} + +func TestResponseHeaders__OK(t *testing.T) { headers := map[string][]string{ - "Content-Type": {"test/test"}, - "Foo": {"foo"}, - "Bar": {"bar1, bar2"}, + "Foo": {"foo"}, + "Bar": {"bar1, bar2"}, } params := url.Values{} @@ -505,6 +700,7 @@ func TestResponseHeaders(t *testing.T) { handler.ServeHTTP(w, r) assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, "application/json; encoding=utf-8") for k, expectedValues := range headers { values, ok := w.HeaderMap[k] @@ -533,6 +729,27 @@ func TestResponseHeaders(t *testing.T) { } } +func TestResponseHeaders__OverrideContentType(t *testing.T) { + contentType := "text/test" + + params := url.Values{} + params.Set("Content-Type", contentType) + + r, _ := http.NewRequest("GET", fmt.Sprintf("/response-headers?%s", params.Encode()), nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, contentType) +} + +func TestResponseHeaders__InvalidQuery(t *testing.T) { + r, _ := http.NewRequest("GET", "/response-headers?foo=%ZZ", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + assertStatusCode(t, w, http.StatusBadRequest) +} + func TestRelativeRedirect__OK(t *testing.T) { var tests = []struct { n int @@ -560,7 +777,9 @@ func TestRelativeRedirect__Errors(t *testing.T) { }{ {3.14, http.StatusBadRequest}, {-1, http.StatusBadRequest}, + {"", http.StatusBadRequest}, {"foo", http.StatusBadRequest}, + {"10/bar", http.StatusNotFound}, } for _, test := range tests { diff --git a/httpbin/helpers.go b/httpbin/helpers.go index debda7501b5497cf342526ad43f5ac2d65acb371..9ded2e4edd252c8602ef606cef1909438320f017 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -72,7 +72,7 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe return err } resp.Form = r.PostForm - case ct == "multipart/form-data": + case strings.HasPrefix(ct, "multipart/form-data"): err := r.ParseMultipartForm(maxMemory) if err != nil { return err diff --git a/httpbin/middleware.go b/httpbin/middleware.go index 87693028abf55c8856df07c445260fc06b23a51a..ec040869ee9f15fb18107e88707c28e15ca2c2d3 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -35,9 +35,6 @@ func cors(h http.Handler) http.Handler { } func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { - if len(methods) == 0 { - return h - } methodMap := make(map[string]struct{}, len(methods)) for _, m := range methods { methodMap[m] = struct{}{}