Skip to content
Snippets Groups Projects
Unverified Commit a05101cc authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Add user agent and client IP to instrumentation (#78)

parent a9161da8
No related branches found
No related tags found
No related merge requests found
...@@ -51,7 +51,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) { ...@@ -51,7 +51,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
resp := &getResponse{ resp := &getResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} }
body, _ := json.Marshal(resp) body, _ := json.Marshal(resp)
...@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { ...@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
resp := &bodyResponse{ resp := &bodyResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} }
...@@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) { ...@@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
resp := &gzipResponse{ resp := &gzipResponse{
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
Gzipped: true, Gzipped: true,
} }
body, _ := json.Marshal(resp) body, _ := json.Marshal(resp)
...@@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) { ...@@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
resp := &deflateResponse{ resp := &deflateResponse{
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
Deflated: true, Deflated: true,
} }
body, _ := json.Marshal(resp) body, _ := json.Marshal(resp)
...@@ -120,7 +120,7 @@ func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) { ...@@ -120,7 +120,7 @@ func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
// IP echoes the IP address of the incoming request // IP echoes the IP address of the incoming request
func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) { func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) {
body, _ := json.Marshal(&ipResponse{ body, _ := json.Marshal(&ipResponse{
Origin: getOrigin(r), Origin: getClientIP(r),
}) })
writeJSON(w, body, http.StatusOK) writeJSON(w, body, http.StatusOK)
} }
...@@ -502,7 +502,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) { ...@@ -502,7 +502,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
resp := &streamResponse{ resp := &streamResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} }
...@@ -716,7 +716,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) { ...@@ -716,7 +716,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
resp := &getResponse{ resp := &getResponse{
Args: r.URL.Query(), Args: r.URL.Query(),
Headers: getRequestHeaders(r), Headers: getRequestHeaders(r),
Origin: getOrigin(r), Origin: getClientIP(r),
URL: getURL(r).String(), URL: getURL(r).String(),
} }
body, _ := json.Marshal(resp) body, _ := json.Marshal(resp)
......
...@@ -32,13 +32,23 @@ func getRequestHeaders(r *http.Request) http.Header { ...@@ -32,13 +32,23 @@ func getRequestHeaders(r *http.Request) http.Header {
return h return h
} }
func getOrigin(r *http.Request) string { // getClientIP tries to get a reasonable value for the IP address of the
forwardedFor := r.Header.Get("X-Forwarded-For") // client making the request. Note that this value will likely be trivial to
if forwardedFor == "" { // spoof, so do not rely on it for security purposes.
return r.RemoteAddr func getClientIP(r *http.Request) string {
// Special case some hosting platforms that provide the value directly.
if clientIP := r.Header.Get("Fly-Client-IP"); clientIP != "" {
return clientIP
} }
// take the first entry in a comma-separated list of IP addrs
return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0]) // Try to pull a reasonable value from the X-Forwarded-For header, if
// present, by taking the first entry in a comma-separated list of IPs.
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0])
}
// Finally, fall back on the actual remote addr from the request.
return r.RemoteAddr
} }
func getURL(r *http.Request) *url.URL { func getURL(r *http.Request) *url.URL {
......
...@@ -3,6 +3,7 @@ package httpbin ...@@ -3,6 +3,7 @@ package httpbin
import ( import (
"fmt" "fmt"
"io" "io"
"net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
...@@ -33,7 +34,7 @@ func assertError(t *testing.T, got, expected error) { ...@@ -33,7 +34,7 @@ func assertError(t *testing.T, got, expected error) {
} }
func TestParseDuration(t *testing.T) { func TestParseDuration(t *testing.T) {
var okTests = []struct { okTests := []struct {
input string input string
expected time.Duration expected time.Duration
}{ }{
...@@ -61,7 +62,7 @@ func TestParseDuration(t *testing.T) { ...@@ -61,7 +62,7 @@ func TestParseDuration(t *testing.T) {
}) })
} }
var badTests = []struct { badTests := []struct {
input string input string
}{ }{
{"foo"}, {"foo"},
...@@ -154,3 +155,51 @@ func TestSyntheticByteStream(t *testing.T) { ...@@ -154,3 +155,51 @@ func TestSyntheticByteStream(t *testing.T) {
} }
}) })
} }
func Test_getClientIP(t *testing.T) {
makeHeaders := func(m map[string]string) http.Header {
h := make(http.Header, len(m))
for k, v := range m {
h.Set(k, v)
}
return h
}
tests := map[string]struct {
given *http.Request
want string
}{
"custom platform headers take precedence": {
given: &http.Request{
Header: makeHeaders(map[string]string{
"Fly-Client-IP": "9.9.9.9",
"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
}),
RemoteAddr: "0.0.0.0",
},
want: "9.9.9.9",
},
"x-forwarded-for is parsed": {
given: &http.Request{
Header: makeHeaders(map[string]string{
"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
}),
RemoteAddr: "0.0.0.0",
},
want: "1.1.1.1",
},
"remoteaddr is fallback": {
given: &http.Request{
RemoteAddr: "0.0.0.0",
},
want: "0.0.0.0",
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := getClientIP(tt.given); got != tt.want {
t.Errorf("getClientIP() = %v, want %v", got, tt.want)
}
})
}
}
...@@ -123,22 +123,26 @@ func observe(o Observer, h http.Handler) http.Handler { ...@@ -123,22 +123,26 @@ func observe(o Observer, h http.Handler) http.Handler {
t := time.Now() t := time.Now()
h.ServeHTTP(mw, r) h.ServeHTTP(mw, r)
o(Result{ o(Result{
Status: mw.Status(), Status: mw.Status(),
Method: r.Method, Method: r.Method,
URI: r.URL.RequestURI(), URI: r.URL.RequestURI(),
Size: mw.Size(), Size: mw.Size(),
Duration: time.Since(t), Duration: time.Since(t),
UserAgent: r.Header.Get("User-Agent"),
ClientIP: getClientIP(r),
}) })
}) })
} }
// Result is the result of handling a request, used for instrumentation // Result is the result of handling a request, used for instrumentation
type Result struct { type Result struct {
Status int Status int
Method string Method string
URI string URI string
Size int64 Size int64
Duration time.Duration Duration time.Duration
UserAgent string
ClientIP string
} }
// Observer is a function that will be called with the details of a handled // Observer is a function that will be called with the details of a handled
...@@ -149,7 +153,7 @@ type Observer func(result Result) ...@@ -149,7 +153,7 @@ type Observer func(result Result)
// format using the given stdlib logger // format using the given stdlib logger
func StdLogObserver(l *log.Logger) Observer { func StdLogObserver(l *log.Logger) Observer {
const ( const (
logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f" logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f user_agent=%q client_ip=%s"
dateFmt = "2006-01-02T15:04:05.9999" dateFmt = "2006-01-02T15:04:05.9999"
) )
return func(result Result) { return func(result Result) {
...@@ -161,6 +165,8 @@ func StdLogObserver(l *log.Logger) Observer { ...@@ -161,6 +165,8 @@ func StdLogObserver(l *log.Logger) Observer {
result.URI, result.URI,
result.Size, result.Size,
result.Duration.Seconds()*1e3, // https://github.com/golang/go/issues/5491#issuecomment-66079585 result.Duration.Seconds()*1e3, // https://github.com/golang/go/issues/5491#issuecomment-66079585
result.UserAgent,
result.ClientIP,
) )
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment