diff --git a/caddyhttp/requestid/requestid.go b/caddyhttp/requestid/requestid.go index c3f69267f..b03c449f6 100644 --- a/caddyhttp/requestid/requestid.go +++ b/caddyhttp/requestid/requestid.go @@ -16,6 +16,7 @@ package requestid import ( "context" + "log" "net/http" "github.com/google/uuid" @@ -24,12 +25,29 @@ import ( // Handler is a middleware handler type Handler struct { - Next httpserver.Handler + Next httpserver.Handler + HeaderName string // (optional) header from which to read an existing ID } func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - reqid := uuid.New().String() - c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid) + var reqid uuid.UUID + + uuidFromHeader := r.Header.Get(h.HeaderName) + if h.HeaderName != "" && uuidFromHeader != "" { + // use the ID in the header field if it exists + var err error + reqid, err = uuid.Parse(uuidFromHeader) + if err != nil { + log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err) + reqid = uuid.New() + } + } else { + // otherwise, create a new one + reqid = uuid.New() + } + + // set the request ID on the context + c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String()) r = r.WithContext(c) return h.Next.ServeHTTP(w, r) diff --git a/caddyhttp/requestid/requestid_test.go b/caddyhttp/requestid/requestid_test.go index 80968221f..e68c8d2c0 100644 --- a/caddyhttp/requestid/requestid_test.go +++ b/caddyhttp/requestid/requestid_test.go @@ -15,34 +15,53 @@ package requestid import ( - "context" "net/http" + "net/http/httptest" "testing" - "github.com/google/uuid" "github.com/mholt/caddy/caddyhttp/httpserver" ) -func TestRequestID(t *testing.T) { - request, err := http.NewRequest("GET", "http://localhost/", nil) +func TestRequestIDHandler(t *testing.T) { + handler := Handler{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) + if value == "" { + t.Error("Request ID should not be empty") + } + return 0, nil + }), + } + + req, err := http.NewRequest("GET", "http://localhost/", nil) if err != nil { t.Fatal("Could not create HTTP request:", err) } + rec := httptest.NewRecorder() - reqid := uuid.New().String() - - c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid) - - request = request.WithContext(c) - - // See caddyhttp/replacer.go - value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string) - - if value == "" { - t.Fatal("Request ID should not be empty") - } - - if value != reqid { - t.Fatal("Request ID does not match") - } + handler.ServeHTTP(rec, req) +} + +func TestRequestIDFromHeader(t *testing.T) { + headerName := "X-Request-ID" + headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78" + handler := Handler{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) + if value != headerValue { + t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value) + } + return 0, nil + }), + HeaderName: headerName, + } + + req, err := http.NewRequest("GET", "http://localhost/", nil) + if err != nil { + t.Fatal("Could not create HTTP request:", err) + } + req.Header.Set(headerName, headerValue) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) } diff --git a/caddyhttp/requestid/setup.go b/caddyhttp/requestid/setup.go index 4da5a3683..689f99e33 100644 --- a/caddyhttp/requestid/setup.go +++ b/caddyhttp/requestid/setup.go @@ -27,14 +27,19 @@ func init() { } func setup(c *caddy.Controller) error { + var headerName string + for c.Next() { if c.NextArg() { - return c.ArgErr() //no arg expected. + headerName = c.Val() + } + if c.NextArg() { + return c.ArgErr() } } httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { - return Handler{Next: next} + return Handler{Next: next, HeaderName: headerName} }) return nil diff --git a/caddyhttp/requestid/setup_test.go b/caddyhttp/requestid/setup_test.go index aea123694..9c420787b 100644 --- a/caddyhttp/requestid/setup_test.go +++ b/caddyhttp/requestid/setup_test.go @@ -45,7 +45,15 @@ func TestSetup(t *testing.T) { } func TestSetupWithArg(t *testing.T) { - c := caddy.NewTestController("http", `requestid abc`) + c := caddy.NewTestController("http", `requestid X-Request-ID`) + err := setup(c) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } +} + +func TestSetupWithTooManyArgs(t *testing.T) { + c := caddy.NewTestController("http", `requestid foo bar`) err := setup(c) if err == nil { t.Errorf("Expected an error, got: %v", err)