From 7eb054a56dcefdd1fdf634264eb572c369dd0bc1 Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Sun, 28 Aug 2016 22:56:50 -0700
Subject: [PATCH] Add /post endpoint

---
 helpers.go   |  43 +++++++++++++++
 main.go      |  27 +++++++++-
 main_test.go | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 213 insertions(+), 3 deletions(-)

diff --git a/helpers.go b/helpers.go
index 8e3779b..e7c4a15 100644
--- a/helpers.go
+++ b/helpers.go
@@ -2,8 +2,10 @@ package main
 
 import (
 	"encoding/json"
+	"io/ioutil"
 	"net/http"
 	"net/url"
+	"strings"
 )
 
 func getOrigin(r *http.Request) string {
@@ -56,3 +58,44 @@ func writeJSON(w http.ResponseWriter, body []byte) {
 	w.Header().Set("Content-Type", "application/json; encoding=utf-8")
 	w.Write(body)
 }
+
+// parseBody handles parsing a request body into our standard API response,
+// taking care to only consume the request body once based on the Content-Type
+// of the request. The given Resp will be updated.
+func parseBody(w http.ResponseWriter, r *http.Request, resp *Resp) error {
+	if r.Body == nil {
+		return nil
+	}
+
+	// Restrict size of request body
+	r.Body = http.MaxBytesReader(w, r.Body, maxMemory)
+
+	ct := r.Header.Get("Content-Type")
+	switch {
+	case ct == "application/x-www-form-urlencoded":
+		err := r.ParseForm()
+		if err != nil {
+			return err
+		}
+		resp.Form = r.PostForm
+	case ct == "multipart/form-data":
+		err := r.ParseMultipartForm(maxMemory)
+		if err != nil {
+			return err
+		}
+		resp.Form = r.PostForm
+	case strings.HasPrefix(ct, "application/json"):
+		dec := json.NewDecoder(r.Body)
+		err := dec.Decode(&resp.JSON)
+		if err != nil {
+			return err
+		}
+	default:
+		data, err := ioutil.ReadAll(r.Body)
+		if err != nil {
+			return err
+		}
+		resp.Data = data
+	}
+	return nil
+}
diff --git a/main.go b/main.go
index 7649071..a99de7c 100644
--- a/main.go
+++ b/main.go
@@ -16,10 +16,10 @@ type Resp struct {
 	Origin  string      `json:"origin"`
 	URL     string      `json:"url"`
 
-	Data  string              `json:"data,omitempty"`
+	Data  []byte              `json:"data,omitempty"`
 	Files map[string][]string `json:"files,omitempty"`
 	Form  map[string][]string `json:"form,omitempty"`
-	JSON  map[string][]string `json:"json,omitempty"`
+	JSON  interface{}         `json:"json,omitempty"`
 }
 
 // IPResp is the response for the /ip endpoint
@@ -37,6 +37,9 @@ type UserAgentResp struct {
 	UserAgent string `json:"user-agent"`
 }
 
+// Max size of a request body we'll handle
+const maxMemory = 1024*1024*5 + 1
+
 // Index must be wrapped by the withTemplates middleware before it can be used
 func index(w http.ResponseWriter, r *http.Request, t *template.Template) {
 	t = t.Lookup("index.html")
@@ -70,6 +73,25 @@ func get(w http.ResponseWriter, r *http.Request) {
 	writeResponse(w, r, resp)
 }
 
+func post(w http.ResponseWriter, r *http.Request) {
+	args, err := url.ParseQuery(r.URL.RawQuery)
+	if err != nil {
+		http.Error(w, fmt.Sprintf("error parsing query params: %s", err), http.StatusBadRequest)
+		return
+	}
+
+	resp := &Resp{
+		Args:    args,
+		Headers: r.Header,
+	}
+
+	err = parseBody(w, r, resp)
+	if err != nil {
+		http.Error(w, fmt.Sprintf("error parsing request body: %s", err), http.StatusBadRequest)
+	}
+	writeResponse(w, r, resp)
+}
+
 func ip(w http.ResponseWriter, r *http.Request) {
 	body, _ := json.Marshal(&IPResp{
 		Origin: getOrigin(r),
@@ -97,6 +119,7 @@ func app() http.Handler {
 	h.HandleFunc("/", methods(templateWrapper(index), "GET"))
 	h.HandleFunc("/forms/post", methods(templateWrapper(formsPost), "GET"))
 	h.HandleFunc("/get", methods(get, "GET"))
+	h.HandleFunc("/post", methods(post, "POST"))
 	h.HandleFunc("/ip", ip)
 	h.HandleFunc("/user-agent", userAgent)
 	h.HandleFunc("/headers", headers)
diff --git a/main_test.go b/main_test.go
index 327bedb..1c8563d 100644
--- a/main_test.go
+++ b/main_test.go
@@ -1,9 +1,11 @@
 package main
 
 import (
+	"bytes"
 	"encoding/json"
 	"net/http"
 	"net/http/httptest"
+	"net/url"
 	"reflect"
 	"strings"
 	"testing"
@@ -225,7 +227,149 @@ func TestHeaders(t *testing.T) {
 			t.Fatalf("expected header %#v in response", k)
 		}
 		if !reflect.DeepEqual(expectedValues, values) {
-			t.Fatalf("%#v != %#v", values, expectedValues)
+			t.Fatalf("header value mismatch: %#v != %#v", values, expectedValues)
 		}
 	}
 }
+
+func TestPost__EmptyBody(t *testing.T) {
+	r, _ := http.NewRequest("POST", "/post", nil)
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var resp *Resp
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
+	}
+
+	if len(resp.Args) > 0 {
+		t.Fatalf("expected no query params, got %#v", resp.Args)
+	}
+	if len(resp.Form) > 0 {
+		t.Fatalf("expected no form data, got %#v", resp.Form)
+	}
+}
+
+func TestPost__FormEncodedBody(t *testing.T) {
+	params := url.Values{}
+	params.Set("foo", "foo")
+	params.Add("bar", "bar1")
+	params.Add("bar", "bar2")
+
+	r, _ := http.NewRequest("POST", "/post", strings.NewReader(params.Encode()))
+	r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var resp *Resp
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %#v from JSON: %s", w.Body.String(), err)
+	}
+
+	if len(resp.Args) > 0 {
+		t.Fatalf("expected no query params, got %#v", resp.Args)
+	}
+	if len(resp.Form) != len(params) {
+		t.Fatalf("expected %d form values, got %d", len(params), len(resp.Form))
+	}
+	for k, expectedValues := range params {
+		values, ok := resp.Form[k]
+		if !ok {
+			t.Fatalf("expected form field %#v in response", k)
+		}
+		if !reflect.DeepEqual(expectedValues, values) {
+			t.Fatalf("form value mismatch: %#v != %#v", values, expectedValues)
+		}
+	}
+}
+
+func TestPost__FormEncodedBodyNoContentType(t *testing.T) {
+	params := url.Values{}
+	params.Set("foo", "foo")
+	params.Add("bar", "bar1")
+	params.Add("bar", "bar2")
+
+	r, _ := http.NewRequest("POST", "/post", strings.NewReader(params.Encode()))
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var resp *Resp
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
+	}
+
+	if len(resp.Args) > 0 {
+		t.Fatalf("expected no query params, got %#v", resp.Args)
+	}
+	if len(resp.Form) != 0 {
+		t.Fatalf("expected no form values, got %d", len(resp.Form))
+	}
+	if string(resp.Data) != params.Encode() {
+		t.Fatalf("response data mismatch, %#v != %#v", string(resp.Data), params.Encode())
+	}
+}
+
+func TestPost__JSON(t *testing.T) {
+	type testInput struct {
+		Foo  string
+		Bar  int
+		Baz  []float64
+		Quux map[int]string
+	}
+	input := &testInput{
+		Foo:  "foo",
+		Bar:  123,
+		Baz:  []float64{1.0, 1.1, 1.2},
+		Quux: map[int]string{1: "one", 2: "two", 3: "three"},
+	}
+	inputBody, _ := json.Marshal(input)
+
+	r, _ := http.NewRequest("POST", "/post", bytes.NewReader(inputBody))
+	r.Header.Set("Content-Type", "application/json; charset=utf-8")
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	var resp *Resp
+	err := json.Unmarshal(w.Body.Bytes(), &resp)
+	if err != nil {
+		t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
+	}
+
+	if len(resp.Args) > 0 {
+		t.Fatalf("expected no query params, got %#v", resp.Args)
+	}
+	if len(resp.Form) != 0 {
+		t.Fatalf("expected no form values, got %d", len(resp.Form))
+	}
+	if resp.Data != nil {
+		t.Fatalf("expected no data, got %#v", resp.Data)
+	}
+
+	// Need to re-marshall just the JSON field from the response in order to
+	// re-unmarshall it into our expected type
+	outputBodyBytes, _ := json.Marshal(resp.JSON)
+	output := &testInput{}
+	err = json.Unmarshal(outputBodyBytes, output)
+	if err != nil {
+		t.Fatalf("failed to round-trip JSON: coult not re-unmarshal JSON: %s", err)
+	}
+
+	if !reflect.DeepEqual(input, output) {
+		t.Fatalf("failed to round-trip JSON: %#v != %#v", output, input)
+	}
+}
+
+func TestPost__BodyTooBig(t *testing.T) {
+	body := make([]byte, maxMemory+1)
+
+	r, _ := http.NewRequest("POST", "/post", bytes.NewReader(body))
+	w := httptest.NewRecorder()
+	app().ServeHTTP(w, r)
+
+	if w.Code != http.StatusBadRequest {
+		t.Fatalf("expected code %d, got %d", http.StatusBadRequest, w.Code)
+	}
+}
-- 
GitLab