Skip to content
Snippets Groups Projects
Commit ef914ab8 authored by Will McCutchen's avatar Will McCutchen
Browse files

Update /ip endpoint to return single origin IP

parent b30cb3a2
No related branches found
No related tags found
No related merge requests found
......@@ -318,8 +318,44 @@ func TestCORS(t *testing.T) {
}
func TestIP(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
remoteAddr string
headers map[string]string
wantOrigin string
}{
"remote addr used if no x-forwarded-for": {
remoteAddr: "192.168.0.100",
wantOrigin: "192.168.0.100",
},
"remote addr used if x-forwarded-for empty": {
remoteAddr: "192.168.0.100",
headers: map[string]string{"X-Forwarded-For": ""},
wantOrigin: "192.168.0.100",
},
"first entry in x-forwarded-for used if present": {
remoteAddr: "192.168.0.100",
headers: map[string]string{"X-Forwarded-For": "10.1.1.1, 10.2.2.2, 10.3.3.3"},
wantOrigin: "10.1.1.1",
},
"single entry x-forwarded-for ok": {
remoteAddr: "192.168.0.100",
headers: map[string]string{"X-Forwarded-For": "10.1.1.1"},
wantOrigin: "10.1.1.1",
},
}
for name, tc := range testCases {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
r, _ := http.NewRequest("GET", "/ip", nil)
r.RemoteAddr = "192.168.0.100"
r.RemoteAddr = tc.remoteAddr
for k, v := range tc.headers {
r.Header.Set(k, v)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
......@@ -332,8 +368,10 @@ func TestIP(t *testing.T) {
t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err)
}
if resp.Origin != r.RemoteAddr {
t.Fatalf("%#v != %#v", resp.Origin, r.RemoteAddr)
if resp.Origin != tc.wantOrigin {
t.Fatalf("got %q, want %q", resp.Origin, tc.wantOrigin)
}
})
}
}
......
......@@ -33,11 +33,12 @@ func getRequestHeaders(r *http.Request) http.Header {
}
func getOrigin(r *http.Request) string {
origin := r.Header.Get("X-Forwarded-For")
if origin == "" {
origin = r.RemoteAddr
forwardedFor := r.Header.Get("X-Forwarded-For")
if forwardedFor == "" {
return r.RemoteAddr
}
return origin
// take the first entry in a comma-separated list of IP addrs
return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0])
}
func getURL(r *http.Request) *url.URL {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment