Skip to content
Snippets Groups Projects
Commit 9e4e1b6a authored by Will McCutchen's avatar Will McCutchen
Browse files

Use middleware to limit request body size

parent f6cde8d9
Branches
Tags
No related merge requests found
...@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { ...@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
URL: getURL(r).String(), URL: getURL(r).String(),
} }
err := parseBody(w, r, resp, h.options.MaxMemory) err := parseBody(w, r, resp)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("error parsing request body: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("error parsing request body: %s", err), http.StatusBadRequest)
return return
......
...@@ -70,14 +70,14 @@ func writeHTML(w http.ResponseWriter, body []byte, status int) { ...@@ -70,14 +70,14 @@ func writeHTML(w http.ResponseWriter, body []byte, status int) {
// parseBody handles parsing a request body into our standard API response, // parseBody handles parsing a request body into our standard API response,
// taking care to only consume the request body once based on the Content-Type // taking care to only consume the request body once based on the Content-Type
// of the request. The given Resp will be updated. // of the request. The given bodyResponse will be modified.
func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMemory int64) error { //
// Note: this function expects callers to limit the the maximum size of the
// request body. See, e.g., the limitRequestSize middleware.
func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse) error {
if r.Body == nil { if r.Body == nil {
return nil return nil
} }
// Restrict size of request body
r.Body = http.MaxBytesReader(w, r.Body, maxMemory)
defer r.Body.Close() defer r.Body.Close()
ct := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
...@@ -89,7 +89,10 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe ...@@ -89,7 +89,10 @@ func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse, maxMe
} }
resp.Form = r.PostForm resp.Form = r.PostForm
case strings.HasPrefix(ct, "multipart/form-data"): case strings.HasPrefix(ct, "multipart/form-data"):
err := r.ParseMultipartForm(maxMemory) // The memory limit here only restricts how many parts will be kept in
// memory before overflowing to disk:
// http://localhost:8080/pkg/net/http/#Request.ParseMultipartForm
err := r.ParseMultipartForm(1024 * 1024)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -157,7 +157,13 @@ func (h *HTTPBin) Handler() http.Handler { ...@@ -157,7 +157,13 @@ func (h *HTTPBin) Handler() http.Handler {
mux.HandleFunc("/stream-bytes", http.NotFound) mux.HandleFunc("/stream-bytes", http.NotFound)
mux.HandleFunc("/links", http.NotFound) mux.HandleFunc("/links", http.NotFound)
return logger(cors(mux)) // Apply global middleware
var handler http.Handler
handler = mux
handler = limitRequestSize(h.options.MaxMemory, handler)
handler = logger(handler)
handler = cors(handler)
return handler
} }
// NewHTTPBin creates a new HTTPBin // NewHTTPBin creates a new HTTPBin
......
...@@ -47,3 +47,12 @@ func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc { ...@@ -47,3 +47,12 @@ func methods(h http.HandlerFunc, methods ...string) http.HandlerFunc {
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
} }
} }
func limitRequestSize(maxSize int64, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
}
h.ServeHTTP(w, r)
})
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment