diff --git a/httpbin/handlers.go b/httpbin/handlers.go index b5dbf643a2fc897b9ce0c17183c04b30142a9b2e..3ff4ccfbbbe30dbe33631245856c0775d53dc618 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -51,7 +51,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) { resp := &getResponse{ Args: r.URL.Query(), Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), URL: getURL(r).String(), } body, _ := json.Marshal(resp) @@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { resp := &bodyResponse{ Args: r.URL.Query(), Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), URL: getURL(r).String(), } @@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { resp := &gzipResponse{ Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), Gzipped: true, } body, _ := json.Marshal(resp) @@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { resp := &deflateResponse{ Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), Deflated: true, } body, _ := json.Marshal(resp) @@ -120,7 +120,7 @@ func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { // IP echoes the IP address of the incoming request func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) { body, _ := json.Marshal(&ipResponse{ - Origin: getOrigin(r), + Origin: getClientIP(r), }) writeJSON(w, body, http.StatusOK) } @@ -502,7 +502,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { resp := &streamResponse{ Args: r.URL.Query(), Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), URL: getURL(r).String(), } @@ -716,7 +716,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { resp := &getResponse{ Args: r.URL.Query(), Headers: getRequestHeaders(r), - Origin: getOrigin(r), + Origin: getClientIP(r), URL: getURL(r).String(), } body, _ := json.Marshal(resp) diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 366b8096103f64186c821f420c28a1d8d01d637e..e715752113848c467bb10e33e4779b95ea7673da 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -32,13 +32,23 @@ func getRequestHeaders(r *http.Request) http.Header { return h } -func getOrigin(r *http.Request) string { - forwardedFor := r.Header.Get("X-Forwarded-For") - if forwardedFor == "" { - return r.RemoteAddr +// getClientIP tries to get a reasonable value for the IP address of the +// client making the request. Note that this value will likely be trivial to +// spoof, so do not rely on it for security purposes. +func getClientIP(r *http.Request) string { + // Special case some hosting platforms that provide the value directly. + if clientIP := r.Header.Get("Fly-Client-IP"); clientIP != "" { + return clientIP } - // take the first entry in a comma-separated list of IP addrs - return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0]) + + // Try to pull a reasonable value from the X-Forwarded-For header, if + // present, by taking the first entry in a comma-separated list of IPs. + if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { + return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0]) + } + + // Finally, fall back on the actual remote addr from the request. + return r.RemoteAddr } func getURL(r *http.Request) *url.URL { diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go index 0636e0f0f9e6de652476a8eef1eebc936ed84e56..931ad9db745d7290710aa48b825b6b67a0fb1b07 100644 --- a/httpbin/helpers_test.go +++ b/httpbin/helpers_test.go @@ -3,6 +3,7 @@ package httpbin import ( "fmt" "io" + "net/http" "reflect" "testing" "time" @@ -33,7 +34,7 @@ func assertError(t *testing.T, got, expected error) { } func TestParseDuration(t *testing.T) { - var okTests = []struct { + okTests := []struct { input string expected time.Duration }{ @@ -61,7 +62,7 @@ func TestParseDuration(t *testing.T) { }) } - var badTests = []struct { + badTests := []struct { input string }{ {"foo"}, @@ -154,3 +155,51 @@ func TestSyntheticByteStream(t *testing.T) { } }) } + +func Test_getClientIP(t *testing.T) { + makeHeaders := func(m map[string]string) http.Header { + h := make(http.Header, len(m)) + for k, v := range m { + h.Set(k, v) + } + return h + } + + tests := map[string]struct { + given *http.Request + want string + }{ + "custom platform headers take precedence": { + given: &http.Request{ + Header: makeHeaders(map[string]string{ + "Fly-Client-IP": "9.9.9.9", + "X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3", + }), + RemoteAddr: "0.0.0.0", + }, + want: "9.9.9.9", + }, + "x-forwarded-for is parsed": { + given: &http.Request{ + Header: makeHeaders(map[string]string{ + "X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3", + }), + RemoteAddr: "0.0.0.0", + }, + want: "1.1.1.1", + }, + "remoteaddr is fallback": { + given: &http.Request{ + RemoteAddr: "0.0.0.0", + }, + want: "0.0.0.0", + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + if got := getClientIP(tt.given); got != tt.want { + t.Errorf("getClientIP() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/httpbin/middleware.go b/httpbin/middleware.go index 5186789cdcd25713893b540c11dfde9bc878dcbf..13e937a3edac8ea983e50c6d0caef1cab9f4cad6 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -123,22 +123,26 @@ func observe(o Observer, h http.Handler) http.Handler { t := time.Now() h.ServeHTTP(mw, r) o(Result{ - Status: mw.Status(), - Method: r.Method, - URI: r.URL.RequestURI(), - Size: mw.Size(), - Duration: time.Since(t), + Status: mw.Status(), + Method: r.Method, + URI: r.URL.RequestURI(), + Size: mw.Size(), + Duration: time.Since(t), + UserAgent: r.Header.Get("User-Agent"), + ClientIP: getClientIP(r), }) }) } // Result is the result of handling a request, used for instrumentation type Result struct { - Status int - Method string - URI string - Size int64 - Duration time.Duration + Status int + Method string + URI string + Size int64 + Duration time.Duration + UserAgent string + ClientIP string } // Observer is a function that will be called with the details of a handled @@ -149,7 +153,7 @@ type Observer func(result Result) // format using the given stdlib logger func StdLogObserver(l *log.Logger) Observer { const ( - logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f" + logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f user_agent=%q client_ip=%s" dateFmt = "2006-01-02T15:04:05.9999" ) return func(result Result) { @@ -161,6 +165,8 @@ func StdLogObserver(l *log.Logger) Observer { result.URI, result.Size, result.Duration.Seconds()*1e3, // https://github.com/golang/go/issues/5491#issuecomment-66079585 + result.UserAgent, + result.ClientIP, ) } }