Skip to content
Snippets Groups Projects
Unverified Commit a05101cc authored by Will McCutchen's avatar Will McCutchen Committed by GitHub
Browse files

Add user agent and client IP to instrumentation (#78)

parent a9161da8
No related branches found
No related tags found
No related merge requests found
......@@ -51,7 +51,7 @@ func (h *HTTPBin) Get(w http.ResponseWriter, r *http.Request) {
resp := &getResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
URL: getURL(r).String(),
}
body, _ := json.Marshal(resp)
......@@ -63,7 +63,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
resp := &bodyResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
URL: getURL(r).String(),
}
......@@ -81,7 +81,7 @@ func (h *HTTPBin) RequestWithBody(w http.ResponseWriter, r *http.Request) {
func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
resp := &gzipResponse{
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
Gzipped: true,
}
body, _ := json.Marshal(resp)
......@@ -101,7 +101,7 @@ func (h *HTTPBin) Gzip(w http.ResponseWriter, r *http.Request) {
func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
resp := &deflateResponse{
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
Deflated: true,
}
body, _ := json.Marshal(resp)
......@@ -120,7 +120,7 @@ func (h *HTTPBin) Deflate(w http.ResponseWriter, r *http.Request) {
// IP echoes the IP address of the incoming request
func (h *HTTPBin) IP(w http.ResponseWriter, r *http.Request) {
body, _ := json.Marshal(&ipResponse{
Origin: getOrigin(r),
Origin: getClientIP(r),
})
writeJSON(w, body, http.StatusOK)
}
......@@ -502,7 +502,7 @@ func (h *HTTPBin) Stream(w http.ResponseWriter, r *http.Request) {
resp := &streamResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
URL: getURL(r).String(),
}
......@@ -716,7 +716,7 @@ func (h *HTTPBin) ETag(w http.ResponseWriter, r *http.Request) {
resp := &getResponse{
Args: r.URL.Query(),
Headers: getRequestHeaders(r),
Origin: getOrigin(r),
Origin: getClientIP(r),
URL: getURL(r).String(),
}
body, _ := json.Marshal(resp)
......
......@@ -32,15 +32,25 @@ func getRequestHeaders(r *http.Request) http.Header {
return h
}
func getOrigin(r *http.Request) string {
forwardedFor := r.Header.Get("X-Forwarded-For")
if forwardedFor == "" {
return r.RemoteAddr
}
// take the first entry in a comma-separated list of IP addrs
// getClientIP tries to get a reasonable value for the IP address of the
// client making the request. Note that this value will likely be trivial to
// spoof, so do not rely on it for security purposes.
func getClientIP(r *http.Request) string {
// Special case some hosting platforms that provide the value directly.
if clientIP := r.Header.Get("Fly-Client-IP"); clientIP != "" {
return clientIP
}
// Try to pull a reasonable value from the X-Forwarded-For header, if
// present, by taking the first entry in a comma-separated list of IPs.
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0])
}
// Finally, fall back on the actual remote addr from the request.
return r.RemoteAddr
}
func getURL(r *http.Request) *url.URL {
scheme := r.Header.Get("X-Forwarded-Proto")
if scheme == "" {
......
......@@ -3,6 +3,7 @@ package httpbin
import (
"fmt"
"io"
"net/http"
"reflect"
"testing"
"time"
......@@ -33,7 +34,7 @@ func assertError(t *testing.T, got, expected error) {
}
func TestParseDuration(t *testing.T) {
var okTests = []struct {
okTests := []struct {
input string
expected time.Duration
}{
......@@ -61,7 +62,7 @@ func TestParseDuration(t *testing.T) {
})
}
var badTests = []struct {
badTests := []struct {
input string
}{
{"foo"},
......@@ -154,3 +155,51 @@ func TestSyntheticByteStream(t *testing.T) {
}
})
}
func Test_getClientIP(t *testing.T) {
makeHeaders := func(m map[string]string) http.Header {
h := make(http.Header, len(m))
for k, v := range m {
h.Set(k, v)
}
return h
}
tests := map[string]struct {
given *http.Request
want string
}{
"custom platform headers take precedence": {
given: &http.Request{
Header: makeHeaders(map[string]string{
"Fly-Client-IP": "9.9.9.9",
"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
}),
RemoteAddr: "0.0.0.0",
},
want: "9.9.9.9",
},
"x-forwarded-for is parsed": {
given: &http.Request{
Header: makeHeaders(map[string]string{
"X-Forwarded-For": "1.1.1.1,2.2.2.2,3.3.3.3",
}),
RemoteAddr: "0.0.0.0",
},
want: "1.1.1.1",
},
"remoteaddr is fallback": {
given: &http.Request{
RemoteAddr: "0.0.0.0",
},
want: "0.0.0.0",
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := getClientIP(tt.given); got != tt.want {
t.Errorf("getClientIP() = %v, want %v", got, tt.want)
}
})
}
}
......@@ -128,6 +128,8 @@ func observe(o Observer, h http.Handler) http.Handler {
URI: r.URL.RequestURI(),
Size: mw.Size(),
Duration: time.Since(t),
UserAgent: r.Header.Get("User-Agent"),
ClientIP: getClientIP(r),
})
})
}
......@@ -139,6 +141,8 @@ type Result struct {
URI string
Size int64
Duration time.Duration
UserAgent string
ClientIP string
}
// Observer is a function that will be called with the details of a handled
......@@ -149,7 +153,7 @@ type Observer func(result Result)
// format using the given stdlib logger
func StdLogObserver(l *log.Logger) Observer {
const (
logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f"
logFmt = "time=%q status=%d method=%q uri=%q size_bytes=%d duration_ms=%0.02f user_agent=%q client_ip=%s"
dateFmt = "2006-01-02T15:04:05.9999"
)
return func(result Result) {
......@@ -161,6 +165,8 @@ func StdLogObserver(l *log.Logger) Observer {
result.URI,
result.Size,
result.Duration.Seconds()*1e3, // https://github.com/golang/go/issues/5491#issuecomment-66079585
result.UserAgent,
result.ClientIP,
)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment