Skip to content
Snippets Groups Projects
Unverified Commit 0de8ec96 authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Refactor test suite to make real requests to a real server (#131)

In the course of validating #125, we discovered that using the stdlib's
[`httptest.ResponseRecorder`][0] mechanism to drive the vast majority of
our unit tests led to some slight brokenness due to subtle differences
in the way those "simulated" requests are handled vs "real" requests to
a live HTTP server, as [explained in this comment][1].

That prompted me to do a big ass refactor of the entire test suite,
swapping httptest.ResponseRecorder for interacting with a live server
instance via [`httptest.Server`][2].

This should make the test suite more accurate and reliable in the long
run by ensuring that the vast majority of tests are making actual HTTP
requests and reading responses from the wire.

Note that updating these tests also uncovered a few minor bugs in
existing handler code, fixed in a separate commit for visibility.

P.S. I'm awfully sorry to anyone who tries to merge or rebase local test
changes after this refactor lands, that is goign to be a nightmare. If
you run into issues resolving conflicts, feel free to ping me and I can
try to help!

[0]: https://pkg.go.dev/net/http/httptest#ResponseRecorder
[1]: https://github.com/mccutchen/go-httpbin/pull/125#issuecomment-1596176645
[2]: https://pkg.go.dev/net/http/httptest#Server
parent 499044e1
No related branches found
No related tags found
No related merge requests found
module github.com/mccutchen/go-httpbin/v2
go 1.16
go 1.18
......@@ -17,6 +17,8 @@ import (
"github.com/mccutchen/go-httpbin/v2/httpbin/digest"
)
var nilValues = url.Values{}
func notImplementedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not implemented", http.StatusNotImplemented)
}
......@@ -71,6 +73,8 @@ func (h *HTTPBin) Anything(w http.ResponseWriter, r *http.Request) {
func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
resp := &bodyResponse{
Args: r.URL.Query(),
Files: nilValues,
Form: nilValues,
Headers: getRequestHeaders(r),
Method: r.Method,
Origin: getClientIP(r),
......@@ -628,25 +632,38 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
return
}
pause := duration / time.Duration(numBytes)
flusher := w.(http.Flusher)
pause := duration
if numBytes > 1 {
// compensate for lack of pause after final write (i.e. if we're
// writing 10 bytes, we will only pause 9 times)
pause = duration / time.Duration(numBytes-1)
}
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", numBytes))
w.WriteHeader(code)
flusher := w.(http.Flusher)
flusher.Flush()
// wait for initial delay before writing response body
select {
case <-r.Context().Done():
return
case <-time.After(delay):
}
// write response body byte-by-byte, pausing between each write
b := []byte{'*'}
for i := int64(0); i < numBytes; i++ {
w.Write(b)
flusher.Flush()
// don't pause after last byte
if i == numBytes-1 {
break
}
select {
case <-r.Context().Done():
return
......@@ -829,9 +846,10 @@ func handleBytes(w http.ResponseWriter, r *http.Request, streaming bool) {
}
}()
} else {
// if not streaming, we will write the whole response at once
chunkSize = numBytes
w.Header().Set("Content-Length", strconv.Itoa(numBytes))
write = func(chunk []byte) {
w.Header().Set("Content-Length", strconv.Itoa(len(chunk)))
w.Write(chunk)
}
}
......@@ -878,6 +896,7 @@ func (h *HTTPBin) Links(w http.ResponseWriter, r *http.Request) {
offset, err := strconv.Atoi(parts[3])
if err != nil {
http.Error(w, "Invalid offset", http.StatusBadRequest)
return
}
doLinksPage(w, r, n, offset)
return
......@@ -937,6 +956,7 @@ func doImage(w http.ResponseWriter, kind string) {
img, err := staticAsset("image." + kind)
if err != nil {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
contentType := "image/" + kind
if kind == "svg" {
......
This diff is collapsed.
......@@ -143,10 +143,6 @@ func parseFiles(fileHeaders map[string][]*multipart.FileHeader) (map[string][]st
// Note: this function expects callers to limit the the maximum size of the
// request body. See, e.g., the limitRequestSize middleware.
func parseBody(w http.ResponseWriter, r *http.Request, resp *bodyResponse) error {
if r.Body == nil {
return nil
}
// Always set resp.Data to the incoming request body, in case we don't know
// how to handle the content type
body, err := io.ReadAll(r.Body)
......
......@@ -8,42 +8,11 @@ import (
"mime/multipart"
"net/http"
"net/url"
"reflect"
"testing"
"time"
)
func assertNil(t *testing.T, v interface{}) {
t.Helper()
if v != nil {
t.Fatalf("expected nil, got %#v", v)
}
}
func assertNilError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("expected nil error, got %s (%T)", err, err)
}
}
func assertIntEqual(t *testing.T, a, b int) {
if a != b {
t.Errorf("expected %v == %v", a, b)
}
}
func assertBytesEqual(t *testing.T, a, b []byte) {
if !reflect.DeepEqual(a, b) {
t.Errorf("expected %v == %v", a, b)
}
}
func assertError(t *testing.T, got, expected error) {
if got != expected {
t.Errorf("expected error %v, got %v", expected, got)
}
}
"github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
)
func mustParse(s string) *url.URL {
u, e := url.Parse(s)
......@@ -54,7 +23,7 @@ func mustParse(s string) *url.URL {
}
func TestGetURL(t *testing.T) {
baseUrl, _ := url.Parse("http://example.com/something?foo=bar")
baseURL := mustParse("http://example.com/something?foo=bar")
tests := []struct {
name string
input *http.Request
......@@ -63,7 +32,7 @@ func TestGetURL(t *testing.T) {
{
"basic test",
&http.Request{
URL: baseUrl,
URL: baseURL,
Header: http.Header{},
},
mustParse("http://example.com/something?foo=bar"),
......@@ -71,7 +40,7 @@ func TestGetURL(t *testing.T) {
{
"if TLS is not nil, scheme is https",
&http.Request{
URL: baseUrl,
URL: baseURL,
TLS: &tls.ConnectionState{},
Header: http.Header{},
},
......@@ -80,7 +49,7 @@ func TestGetURL(t *testing.T) {
{
"if X-Forwarded-Proto is present, scheme is that value",
&http.Request{
URL: baseUrl,
URL: baseURL,
Header: http.Header{"X-Forwarded-Proto": {"https"}},
},
mustParse("https://example.com/something?foo=bar"),
......@@ -88,7 +57,7 @@ func TestGetURL(t *testing.T) {
{
"if X-Forwarded-Proto is present, scheme is that value (2)",
&http.Request{
URL: baseUrl,
URL: baseURL,
Header: http.Header{"X-Forwarded-Proto": {"bananas"}},
},
mustParse("bananas://example.com/something?foo=bar"),
......@@ -96,7 +65,7 @@ func TestGetURL(t *testing.T) {
{
"if X-Forwarded-Ssl is 'on', scheme is https",
&http.Request{
URL: baseUrl,
URL: baseURL,
Header: http.Header{"X-Forwarded-Ssl": {"on"}},
},
mustParse("https://example.com/something?foo=bar"),
......@@ -114,9 +83,7 @@ func TestGetURL(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
res := getURL(test.input)
if res.String() != test.expected.String() {
t.Fatalf("expected %s, got %s", test.expected, res)
}
assert.Equal(t, res.String(), test.expected.String(), "URL mismatch")
})
}
}
......@@ -143,12 +110,8 @@ func TestParseDuration(t *testing.T) {
t.Run(fmt.Sprintf("ok/%s", test.input), func(t *testing.T) {
t.Parallel()
result, err := parseDuration(test.input)
if err != nil {
t.Fatalf("unexpected error parsing duration %v: %s", test.input, err)
}
if result != test.expected {
t.Fatalf("expected %s, got %s", test.expected, result)
}
assert.NilError(t, err)
assert.Equal(t, result, test.expected, "incorrect duration")
})
}
......@@ -186,23 +149,23 @@ func TestSyntheticByteStream(t *testing.T) {
// read first half
p := make([]byte, 5)
count, err := s.Read(p)
assertNil(t, err)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{0, 1, 2, 3, 4})
assert.NilError(t, err)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{0, 1, 2, 3, 4}, "incorrect bytes read")
// read second half
p = make([]byte, 5)
count, err = s.Read(p)
assertError(t, err, io.EOF)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{5, 6, 7, 8, 9})
assert.Error(t, err, io.EOF)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{5, 6, 7, 8, 9}, "incorrect bytes read")
// can't read any more
p = make([]byte, 5)
count, err = s.Read(p)
assertError(t, err, io.EOF)
assertIntEqual(t, count, 0)
assertBytesEqual(t, p, []byte{0, 0, 0, 0, 0})
assert.Error(t, err, io.EOF)
assert.Equal(t, count, 0, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{0, 0, 0, 0, 0}, "incorrect bytes read")
})
t.Run("read into too-large buffer", func(t *testing.T) {
......@@ -210,9 +173,9 @@ func TestSyntheticByteStream(t *testing.T) {
s := newSyntheticByteStream(5, factory)
p := make([]byte, 10)
count, err := s.Read(p)
assertError(t, err, io.EOF)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{0, 1, 2, 3, 4, 0, 0, 0, 0, 0})
assert.Error(t, err, io.EOF)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{0, 1, 2, 3, 4, 0, 0, 0, 0, 0}, "incorrect bytes read")
})
t.Run("seek", func(t *testing.T) {
......@@ -222,37 +185,31 @@ func TestSyntheticByteStream(t *testing.T) {
p := make([]byte, 5)
s.Seek(10, io.SeekStart)
count, err := s.Read(p)
assertNil(t, err)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{10, 11, 12, 13, 14})
assert.NilError(t, err)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{10, 11, 12, 13, 14}, "incorrect bytes read")
s.Seek(10, io.SeekCurrent)
count, err = s.Read(p)
assertNil(t, err)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{25, 26, 27, 28, 29})
assert.NilError(t, err)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{25, 26, 27, 28, 29}, "incorrect bytes read")
s.Seek(10, io.SeekEnd)
count, err = s.Read(p)
assertNil(t, err)
assertIntEqual(t, count, 5)
assertBytesEqual(t, p, []byte{90, 91, 92, 93, 94})
assert.NilError(t, err)
assert.Equal(t, count, 5, "incorrect number of bytes read")
assert.DeepEqual(t, p, []byte{90, 91, 92, 93, 94}, "incorrect bytes read")
// invalid whence
_, err = s.Seek(10, 666)
if err.Error() != "Seek: invalid whence" {
t.Errorf("Expected \"Seek: invalid whence\", got %#v", err.Error())
}
assert.Equal(t, err.Error(), "Seek: invalid whence", "incorrect error for invalid whence")
// invalid offset
_, err = s.Seek(-10, io.SeekStart)
if err.Error() != "Seek: invalid offset" {
t.Errorf("Expected \"Seek: invalid offset\", got %#v", err.Error())
}
assert.Equal(t, err.Error(), "Seek: invalid offset", "incorrect error for invalid offset")
})
}
func Test_getClientIP(t *testing.T) {
func TestGetClientIP(t *testing.T) {
t.Parallel()
makeHeaders := func(m map[string]string) http.Header {
......@@ -297,9 +254,7 @@ func Test_getClientIP(t *testing.T) {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
if got := getClientIP(tc.given); got != tc.want {
t.Errorf("getClientIP() = %v, want %v", got, tc.want)
}
assert.Equal(t, getClientIP(tc.given), tc.want, "incorrect client ip")
})
}
}
......
......@@ -45,8 +45,8 @@ type bodyResponse struct {
URL string `json:"url"`
Data string `json:"data"`
Files map[string][]string `json:"files"`
Form map[string][]string `json:"form"`
Files url.Values `json:"files"`
Form url.Values `json:"form"`
JSON interface{} `json:"json"`
}
......
package assert
import (
"fmt"
"net/http"
"reflect"
"strings"
"testing"
"github.com/mccutchen/go-httpbin/v2/internal/testing/must"
)
// Equal asserts that two values are equal.
func Equal[T comparable](t *testing.T, want, got T, msg string, arg ...any) {
t.Helper()
if want != got {
if msg == "" {
msg = "expected values to match"
}
msg = fmt.Sprintf(msg, arg...)
t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got)
}
}
// DeepEqual asserts that two values are deeply equal.
func DeepEqual[T any](t *testing.T, want, got T, msg string, arg ...any) {
t.Helper()
if !reflect.DeepEqual(want, got) {
if msg == "" {
msg = "expected values to match"
}
msg = fmt.Sprintf(msg, arg...)
t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got)
}
}
// NilError asserts that an error is nil.
func NilError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("expected nil error, got %s (%T)", err, err)
}
}
// Error asserts that an error is not nil.
func Error(t *testing.T, got, expected error) {
t.Helper()
if got != expected {
t.Fatalf("expected error %v, got %v", expected, got)
}
}
// StatusCode asserts that a response has a specific status code.
func StatusCode(t *testing.T, resp *http.Response, code int) {
t.Helper()
if resp.StatusCode != code {
t.Fatalf("expected status code %d, got %d", code, resp.StatusCode)
}
}
// Header asserts that a header key has a specific value in a response.
func Header(t *testing.T, resp *http.Response, key, want string) {
t.Helper()
got := resp.Header.Get(key)
if want != got {
t.Fatalf("expected header %s=%#v, got %#v", key, want, got)
}
}
// ContentType asserts that a response has a specific Content-Type header
// value.
func ContentType(t *testing.T, resp *http.Response, contentType string) {
t.Helper()
Header(t, resp, "Content-Type", contentType)
}
// BodyContains asserts that a response body contains a specific substring.
func BodyContains(t *testing.T, resp *http.Response, needle string) {
t.Helper()
body := must.ReadAll(t, resp.Body)
if !strings.Contains(body, needle) {
t.Fatalf("expected string %q in body %q", needle, body)
}
}
// BodyEquals asserts that a response body is equal to a specific string.
func BodyEquals(t *testing.T, resp *http.Response, want string) {
t.Helper()
got := must.ReadAll(t, resp.Body)
Equal(t, got, want, "incorrect response body")
}
package must
import (
"encoding/json"
"io"
"net/http"
"testing"
"time"
)
// DoReq makes an HTTP request and fails the test if there is an error.
func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response {
t.Helper()
start := time.Now()
resp, err := client.Do(req)
if err != nil {
t.Fatalf("error making HTTP request: %s %s: %s", req.Method, req.URL, err)
}
t.Logf("HTTP request: %s %s => %s (%s)", req.Method, req.URL, resp.Status, time.Since(start))
return resp
}
// ReadAll reads all bytes from an io.Reader and fails the test if there is an
// error.
func ReadAll(t *testing.T, r io.Reader) string {
t.Helper()
body, err := io.ReadAll(r)
if err != nil {
t.Fatalf("error reading: %s", err)
}
if rc, ok := r.(io.ReadCloser); ok {
rc.Close()
}
return string(body)
}
// Unmarshal unmarshals JSON from an io.Reader into a value and fails the test
// if there is an error.
func Unmarshal[T any](t *testing.T, r io.Reader) T {
t.Helper()
var v T
if err := json.NewDecoder(r).Decode(&v); err != nil {
t.Fatal(err)
}
return v
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment