0
Fork 0
mirror of https://github.com/project-zot/zot.git synced 2024-12-30 22:34:13 -05:00

fix(authn): create sessions only if UI header value is supplied (#1919)

Signed-off-by: Petu Eusebiu <peusebiu@cisco.com>
This commit is contained in:
peusebiu 2023-10-12 16:37:55 +03:00 committed by GitHub
parent d1fcab421a
commit a91c0c5cfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 280 additions and 198 deletions

View file

@ -127,10 +127,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce
userAc.AddGroups(groups) userAc.AddGroups(groups)
userAc.SaveOnRequest(request) userAc.SaveOnRequest(request)
// saved logged session // saved logged session only if the request comes from web (has UI session header value)
if hasSessionHeader(request) {
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil { if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, err return false, err
} }
}
// we have already populated the request context with userAc // we have already populated the request context with userAc
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil { if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
@ -163,9 +165,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce
userAc.AddGroups(groups) userAc.AddGroups(groups)
userAc.SaveOnRequest(request) userAc.SaveOnRequest(request)
// saved logged session only if the request comes from web (has UI session header value)
if hasSessionHeader(request) {
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil { if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
return false, err return false, err
} }
}
// we have already populated the request context with userAc // we have already populated the request context with userAc
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil { if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {

View file

@ -2686,6 +2686,7 @@ func TestOpenIDMiddleware(t *testing.T) {
for _, testcase := range testCases { for _, testcase := range testCases {
t.Run(testcase.testCaseName, func(t *testing.T) { t.Run(testcase.testCaseName, func(t *testing.T) {
Convey("make controller", t, func() {
dir := t.TempDir() dir := t.TempDir()
ctlr.Config.Storage.RootDirectory = dir ctlr.Config.Storage.RootDirectory = dir
@ -2697,7 +2698,7 @@ func TestOpenIDMiddleware(t *testing.T) {
defer cm.StopServer() defer cm.StopServer()
test.WaitTillServerReady(baseURL) test.WaitTillServerReady(baseURL)
Convey("browser client requests", t, func() { Convey("browser client requests", func() {
Convey("login with no provider supplied", func() { Convey("login with no provider supplied", func() {
client := resty.New() client := resty.New()
client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
@ -2711,14 +2712,82 @@ func TestOpenIDMiddleware(t *testing.T) {
So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest)
}) })
//nolint: dupl
Convey("make sure sessions are not used without UI header value", func() {
sessionsNo, err := getNumberOfSessions(conf.Storage.RootDirectory)
So(err, ShouldBeNil)
So(sessionsNo, ShouldEqual, 0)
client := resty.New()
// without header should not create session
resp, err := client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
So(err, ShouldBeNil)
So(sessionsNo, ShouldEqual, 0)
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
So(err, ShouldBeNil)
So(sessionsNo, ShouldEqual, 1)
// set cookies
client.SetCookies(resp.Cookies())
// should get same cookie
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
So(err, ShouldBeNil)
So(sessionsNo, ShouldEqual, 1)
resp, err = client.R().
SetBasicAuth(htpasswdUsername, passphrase).
Get(baseURL + constants.FullMgmt)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
client.SetCookies(resp.Cookies())
// call endpoint with session, without credentials, (added to client after previous request)
resp, err = client.R().
Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
So(err, ShouldBeNil)
So(sessionsNo, ShouldEqual, 1)
})
Convey("login with openid and get catalog with session", func() { Convey("login with openid and get catalog with session", func() {
client := resty.New() client := resty.New()
client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
Convey("with callback_ui value provided", func() { Convey("with callback_ui value provided", func() {
// first login user // first login user
resp, err := client.R(). resp, err := client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("provider", "oidc"). SetQueryParam("provider", "oidc").
SetQueryParam("callback_ui", baseURL+"/v2/"). SetQueryParam("callback_ui", baseURL+"/v2/").
Get(baseURL + constants.LoginPath) Get(baseURL + constants.LoginPath)
@ -2729,7 +2798,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// first login user // first login user
resp, err := client.R(). resp, err := client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
SetQueryParam("provider", "oidc"). SetQueryParam("provider", "oidc").
Get(baseURL + constants.LoginPath) Get(baseURL + constants.LoginPath)
So(err, ShouldBeNil) So(err, ShouldBeNil)
@ -2740,7 +2808,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// call endpoint with session (added to client after previous request) // call endpoint with session (added to client after previous request)
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2748,14 +2815,12 @@ func TestOpenIDMiddleware(t *testing.T) {
// logout with options method for coverage // logout with options method for coverage
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Options(baseURL + constants.LogoutPath) Options(baseURL + constants.LogoutPath)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
// logout user // logout user
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.LogoutPath) Post(baseURL + constants.LogoutPath)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2763,7 +2828,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// calling endpoint should fail with unauthorized access // calling endpoint should fail with unauthorized access
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2773,6 +2837,7 @@ func TestOpenIDMiddleware(t *testing.T) {
//nolint: dupl //nolint: dupl
Convey("login with basic auth(htpasswd) and get catalog with session", func() { Convey("login with basic auth(htpasswd) and get catalog with session", func() {
client := resty.New() client := resty.New()
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
// without creds, should get access error // without creds, should get access error
resp, err := client.R().Get(baseURL + "/v2/") resp, err := client.R().Get(baseURL + "/v2/")
@ -2806,14 +2871,12 @@ func TestOpenIDMiddleware(t *testing.T) {
// call endpoint with session, without credentials, (added to client after previous request) // call endpoint with session, without credentials, (added to client after previous request)
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK) So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + constants.FullMgmt) Get(baseURL + constants.FullMgmt)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2821,7 +2884,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// logout user // logout user
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.LogoutPath) Post(baseURL + constants.LogoutPath)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2829,7 +2891,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// calling endpoint should fail with unauthorized access // calling endpoint should fail with unauthorized access
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2839,6 +2900,7 @@ func TestOpenIDMiddleware(t *testing.T) {
//nolint: dupl //nolint: dupl
Convey("login with ldap and get catalog", func() { Convey("login with ldap and get catalog", func() {
client := resty.New() client := resty.New()
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
// without creds, should get access error // without creds, should get access error
resp, err := client.R().Get(baseURL + "/v2/") resp, err := client.R().Get(baseURL + "/v2/")
@ -2872,14 +2934,12 @@ func TestOpenIDMiddleware(t *testing.T) {
// call endpoint with session, without credentials, (added to client after previous request) // call endpoint with session, without credentials, (added to client after previous request)
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK) So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + constants.FullMgmt) Get(baseURL + constants.FullMgmt)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2887,7 +2947,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// logout user // logout user
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Post(baseURL + constants.LogoutPath) Post(baseURL + constants.LogoutPath)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2895,7 +2954,6 @@ func TestOpenIDMiddleware(t *testing.T) {
// calling endpoint should fail with unauthorized access // calling endpoint should fail with unauthorized access
resp, err = client.R(). resp, err = client.R().
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
Get(baseURL + "/v2/_catalog") Get(baseURL + "/v2/_catalog")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
@ -2921,6 +2979,7 @@ func TestOpenIDMiddleware(t *testing.T) {
}) })
}) })
}) })
})
} }
} }
@ -3273,6 +3332,7 @@ func TestAuthnSessionErrors(t *testing.T) {
}() }()
client := resty.New() client := resty.New()
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
// first htpasswd saveSessionLoggedUser() error // first htpasswd saveSessionLoggedUser() error
resp, err := client.R(). resp, err := client.R().
@ -9764,3 +9824,20 @@ func getEmptyImageConfig() ([]byte, godigest.Digest) {
return configBlobContent, configBlobDigestRaw return configBlobContent, configBlobDigestRaw
} }
func getNumberOfSessions(rootDir string) (int, error) {
rootDirContents, err := os.ReadDir(path.Join(rootDir, "_sessions"))
if err != nil {
return -1, err
}
sessionsNo := 0
for _, file := range rootDirContents {
if !file.IsDir() && strings.HasPrefix(file.Name(), "session_") {
sessionsNo += 1
}
}
return sessionsNo, nil
}