diff --git a/caddyhttp/basicauth/basicauth.go b/caddyhttp/basicauth/basicauth.go index db12332a..eb7a3fe2 100644 --- a/caddyhttp/basicauth/basicauth.go +++ b/caddyhttp/basicauth/basicauth.go @@ -34,8 +34,7 @@ type BasicAuth struct { // ServeHTTP implements the httpserver.Handler interface. func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - var hasAuth bool - var isAuthenticated bool + var protected, isAuthenticated bool for _, rule := range a.Rules { for _, res := range rule.Resources { @@ -43,29 +42,33 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error continue } - // Path matches; parse auth header - username, password, ok := r.BasicAuth() - hasAuth = true + // path matches; this endpoint is protected + protected = true - // Check credentials + // parse auth header + username, password, ok := r.BasicAuth() + + // check credentials if !ok || username != rule.Username || !rule.Password(password) { continue } - // Flag set only on successful authentication + // by this point, authentication was successful isAuthenticated = true + + // remove credentials from request to avoid leaking upstream + r.Header.Del("Authorization") } } - if hasAuth { - if !isAuthenticated { - w.Header().Set("WWW-Authenticate", "Basic realm=\"Restricted\"") - return http.StatusUnauthorized, nil - } - // "It's an older code, sir, but it checks out. I was about to clear them." - return a.Next.ServeHTTP(w, r) + if protected && !isAuthenticated { + // browsers show a message that says something like: + // "The website says: " + // which is kinda dumb, but whatever. + w.Header().Set("WWW-Authenticate", "Basic realm=\"Restricted\"") + return http.StatusUnauthorized, nil } // Pass-through when no paths match diff --git a/caddyhttp/basicauth/basicauth_test.go b/caddyhttp/basicauth/basicauth_test.go index d0a66a89..9097b088 100644 --- a/caddyhttp/basicauth/basicauth_test.go +++ b/caddyhttp/basicauth/basicauth_test.go @@ -17,51 +17,57 @@ func TestBasicAuth(t *testing.T) { rw := BasicAuth{ Next: httpserver.HandlerFunc(contentHandler), Rules: []Rule{ - {Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}}, + {Username: "okuser", Password: PlainMatcher("okpass"), Resources: []string{"/testing"}}, }, } tests := []struct { - from string - result int - cred string + from string + result int + user string + password string }{ - {"/testing", http.StatusUnauthorized, "ttest:test"}, - {"/testing", http.StatusOK, "test:ttest"}, - {"/testing", http.StatusUnauthorized, ""}, + {"/testing", http.StatusOK, "okuser", "okpass"}, + {"/testing", http.StatusUnauthorized, "baduser", "okpass"}, + {"/testing", http.StatusUnauthorized, "okuser", "badpass"}, + {"/testing", http.StatusUnauthorized, "OKuser", "okpass"}, + {"/testing", http.StatusUnauthorized, "OKuser", "badPASS"}, + {"/testing", http.StatusUnauthorized, "", "okpass"}, + {"/testing", http.StatusUnauthorized, "okuser", ""}, + {"/testing", http.StatusUnauthorized, "", ""}, } for i, test := range tests { - req, err := http.NewRequest("GET", test.from, nil) if err != nil { - t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) } - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) - req.Header.Set("Authorization", auth) + req.SetBasicAuth(test.user, test.password) rec := httptest.NewRecorder() result, err := rw.ServeHTTP(rec, req) if err != nil { - t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + t.Fatalf("Test %d: Could not ServeHTTP: %v", i, err) } if result != test.result { - t.Errorf("Test %d: Expected Header '%d' but was '%d'", + t.Errorf("Test %d: Expected status code %d but was %d", i, test.result, result) } - if result == http.StatusUnauthorized { + if test.result == http.StatusUnauthorized { headers := rec.Header() if val, ok := headers["Www-Authenticate"]; ok { - if val[0] != "Basic realm=\"Restricted\"" { - t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0]) + if got, want := val[0], "Basic realm=\"Restricted\""; got != want { + t.Errorf("Test %d: Www-Authenticate header should be '%s', got: '%s'", i, want, got) } } else { - t.Errorf("Test %d, should provide a header Www-Authenticate", i) + t.Errorf("Test %d: response should have a 'Www-Authenticate' header", i) + } + } else { + if got, want := req.Header.Get("Authorization"), ""; got != want { + t.Errorf("Test %d: Expected Authorization header to be stripped from request after successful authentication, but is: %s", i, got) } } - } - } func TestMultipleOverlappingRules(t *testing.T) {