From a05101cce941692f3c025f1777c51e7da2d651bb Mon Sep 17 00:00:00 2001
From: Will McCutchen <will@mccutch.org>
Date: Fri, 1 Apr 2022 18:30:18 -0400
Subject: [PATCH] Add user agent and client IP to instrumentation (#78)

---
 httpbin/handlers.go     | 14 +++++------
 httpbin/helpers.go      | 22 ++++++++++++-----
 httpbin/helpers_test.go | 53 +++++++++++++++++++++++++++++++++++++++--
 httpbin/middleware.go   | 28 +++++++++++++---------
 4 files changed, 91 insertions(+), 26 deletions(-)

diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index b5dbf64..3ff4ccf 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -51,7 +51,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
 	resp := &getResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
-		Origin:  getOrigin(r),
+		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
 	}
 	body, _ := json.Marshal(resp)
@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
 	resp := &bodyResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
-		Origin:  getOrigin(r),
+		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
 	}
 
@@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
 func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
 	resp := &gzipResponse{
 		Headers: getRequestHeaders(r),
-		Origin:  getOrigin(r),
+		Origin:  getClientIP(r),
 		Gzipped: true,
 	}
 	body, _ := json.Marshal(resp)
@@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
 func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
 	resp := &deflateResponse{
 		Headers:  getRequestHeaders(r),
-		Origin:   getOrigin(r),
+		Origin:   getClientIP(r),
 		Deflated: true,
 	}
 	body, _ := json.Marshal(resp)
@@ -120,7 +120,7 @@ func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
 // IP echoes the IP address of the incoming request
 func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) {
 	body, _ := json.Marshal(&ipResponse{
-		Origin: getOrigin(r),
+		Origin: getClientIP(r),
 	})
 	writeJSON(w, body, http.StatusOK)
 }
@@ -502,7 +502,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
 	resp := &streamResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
-		Origin:  getOrigin(r),
+		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
 	}
 
@@ -716,7 +716,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
 	resp := &getResponse{
 		Args:    r.URL.Query(),
 		Headers: getRequestHeaders(r),
-		Origin:  getOrigin(r),
+		Origin:  getClientIP(r),
 		URL:     getURL(r).String(),
 	}
 	body, _ := json.Marshal(resp)
diff --git a/httpbin/helpers.go b/httpbin/helpers.go
index 366b809..e715752 100644
--- a/httpbin/helpers.go
+++ b/httpbin/helpers.go
@@ -32,13 +32,23 @@ func getRequestHeaders(r *http.Request) http.Header {
 	return h
 }
 
-func getOrigin(r *http.Request) string {
-	forwardedFor := r.Header.Get("X-Forwarded-For")
-	if forwardedFor == "" {
-		return r.RemoteAddr
+// getClientIP tries to get a reasonable value for the IP address of the
+// client making the request. Note that this value will likely be trivial to
+// spoof, so do not rely on it for security purposes.
+func getClientIP(r *http.Request) string {
+	// Special case some hosting platforms that provide the value directly.
+	if clientIP := r.Header.Get("Fly-Client-IP"); clientIP != "" {
+		return clientIP
 	}
-	// take the first entry in a comma-separated list of IP addrs
-	return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0])
+
+	// Try to pull a reasonable value from the X-Forwarded-For header, if
+	// present, by taking the first entry in a comma-separated list of IPs.
+	if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
+		return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0])
+	}
+
+	// Finally, fall back on the actual remote addr from the request.
+	return r.RemoteAddr
 }
 
 func getURL(r *http.Request) *url.URL {
diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go
index 0636e0f..931ad9d 100644
--- a/httpbin/helpers_test.go
+++ b/httpbin/helpers_test.go
@@ -3,6 +3,7 @@ package httpbin
 import (
 	"fmt"
 	"io"
+	"net/http"
 	"reflect"
 	"testing"
 	"time"
@@ -33,7 +34,7 @@ func assertError(t *testing.T, got, expected error) {
 }
 
 func TestParseDuration(t *testing.T) {
-	var okTests = []struct {
+	okTests := []struct {
 		input    string
 		expected time.Duration
 	}{
@@ -61,7 +62,7 @@ func TestParseDuration(t *testing.T) {
 		})
 	}
 
-	var badTests = []struct {
+	badTests := []struct {
 		input string
 	}{
 		{"foo"},
@@ -154,3 +155,51 @@ func TestSyntheticByteStream(t *testing.T) {
 		}
 	})
 }
+
+func Test_getClientIP(t *testing.T) {
+	makeHeaders := func(m map[string]string) http.Header {
+		h := make(http.Header, len(m))
+		for k, v := range m {
+			h.Set(k, v)
+		}
+		return h
+	}
+
+	tests := map[string]struct {
+		given *http.Request
+		want  string
+	}{
+		"custom platform headers take precedence": {
+			given: &http.Request{
+				Header: makeHeaders(map[string]string{
+					"Fly-Client-IP":   "9.9.9.9",
+					"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
+				}),
+				RemoteAddr: "0.0.0.0",
+			},
+			want: "9.9.9.9",
+		},
+		"x-forwarded-for is parsed": {
+			given: &http.Request{
+				Header: makeHeaders(map[string]string{
+					"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
+				}),
+				RemoteAddr: "0.0.0.0",
+			},
+			want: "1.1.1.1",
+		},
+		"remoteaddr is fallback": {
+			given: &http.Request{
+				RemoteAddr: "0.0.0.0",
+			},
+			want: "0.0.0.0",
+		},
+	}
+	for name, tt := range tests {
+		t.Run(name, func(t *testing.T) {
+			if got := getClientIP(tt.given); got != tt.want {
+				t.Errorf("getClientIP() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
diff --git a/httpbin/middleware.go b/httpbin/middleware.go
index 5186789..13e937a 100644
--- a/httpbin/middleware.go
+++ b/httpbin/middleware.go
@@ -123,22 +123,26 @@ func observe(o Observer, h http.Handler) http.Handler {
 		t := time.Now()
 		h.ServeHTTP(mw, r)
 		o(Result{
-			Status:   mw.Status(),
-			Method:   r.Method,
-			URI:      r.URL.RequestURI(),
-			Size:     mw.Size(),
-			Duration: time.Since(t),
+			Status:    mw.Status(),
+			Method:    r.Method,
+			URI:       r.URL.RequestURI(),
+			Size:      mw.Size(),
+			Duration:  time.Since(t),
+			UserAgent: r.Header.Get("User-Agent"),
+			ClientIP:  getClientIP(r),
 		})
 	})
 }
 
 // Result is the result of handling a request, used for instrumentation
 type Result struct {
-	Status   int
-	Method   string
-	URI      string
-	Size     int64
-	Duration time.Duration
+	Status    int
+	Method    string
+	URI       string
+	Size      int64
+	Duration  time.Duration
+	UserAgent string
+	ClientIP  string
 }
 
 // Observer is a function that will be called with the details of a handled
@@ -149,7 +153,7 @@ type Observer func(result Result)
 // format using the given stdlib logger
 func StdLogObserver(l *log.Logger) Observer {
 	const (
-		logFmt  = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f"
+		logFmt  = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f user_agent=%q client_ip=%s"
 		dateFmt = "2006-01-02T15:04:05.9999"
 	)
 	return func(result Result) {
@@ -161,6 +165,8 @@ func StdLogObserver(l *log.Logger) Observer {
 			result.URI,
 			result.Size,
 			result.Duration.Seconds()*1e3, // https://github.com/golang/go/issues/5491#issuecomment-66079585
+			result.UserAgent,
+			result.ClientIP,
 		)
 	}
 }
-- 
GitLab