mirror of
https://github.com/project-zot/zot.git
synced 2024-12-30 22:34:13 -05:00
fix(authn): create sessions only if UI header value is supplied (#1919)
Signed-off-by: Petu Eusebiu <peusebiu@cisco.com>
This commit is contained in:
parent
d1fcab421a
commit
a91c0c5cfe
2 changed files with 280 additions and 198 deletions
|
@ -127,10 +127,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce
|
||||||
userAc.AddGroups(groups)
|
userAc.AddGroups(groups)
|
||||||
userAc.SaveOnRequest(request)
|
userAc.SaveOnRequest(request)
|
||||||
|
|
||||||
// saved logged session
|
// saved logged session only if the request comes from web (has UI session header value)
|
||||||
|
if hasSessionHeader(request) {
|
||||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// we have already populated the request context with userAc
|
// we have already populated the request context with userAc
|
||||||
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
|
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
|
||||||
|
@ -163,9 +165,12 @@ func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, userAc *reqCtx.UserAcce
|
||||||
userAc.AddGroups(groups)
|
userAc.AddGroups(groups)
|
||||||
userAc.SaveOnRequest(request)
|
userAc.SaveOnRequest(request)
|
||||||
|
|
||||||
|
// saved logged session only if the request comes from web (has UI session header value)
|
||||||
|
if hasSessionHeader(request) {
|
||||||
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// we have already populated the request context with userAc
|
// we have already populated the request context with userAc
|
||||||
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
|
if err := ctlr.MetaDB.SetUserGroups(request.Context(), groups); err != nil {
|
||||||
|
|
|
@ -2686,6 +2686,7 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
for _, testcase := range testCases {
|
for _, testcase := range testCases {
|
||||||
t.Run(testcase.testCaseName, func(t *testing.T) {
|
t.Run(testcase.testCaseName, func(t *testing.T) {
|
||||||
|
Convey("make controller", t, func() {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
ctlr.Config.Storage.RootDirectory = dir
|
ctlr.Config.Storage.RootDirectory = dir
|
||||||
|
@ -2697,7 +2698,7 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
defer cm.StopServer()
|
defer cm.StopServer()
|
||||||
test.WaitTillServerReady(baseURL)
|
test.WaitTillServerReady(baseURL)
|
||||||
|
|
||||||
Convey("browser client requests", t, func() {
|
Convey("browser client requests", func() {
|
||||||
Convey("login with no provider supplied", func() {
|
Convey("login with no provider supplied", func() {
|
||||||
client := resty.New()
|
client := resty.New()
|
||||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||||
|
@ -2711,14 +2712,82 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest)
|
So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
//nolint: dupl
|
||||||
|
Convey("make sure sessions are not used without UI header value", func() {
|
||||||
|
sessionsNo, err := getNumberOfSessions(conf.Storage.RootDirectory)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(sessionsNo, ShouldEqual, 0)
|
||||||
|
|
||||||
|
client := resty.New()
|
||||||
|
|
||||||
|
// without header should not create session
|
||||||
|
resp, err := client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(sessionsNo, ShouldEqual, 0)
|
||||||
|
|
||||||
|
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
|
||||||
|
|
||||||
|
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(sessionsNo, ShouldEqual, 1)
|
||||||
|
|
||||||
|
// set cookies
|
||||||
|
client.SetCookies(resp.Cookies())
|
||||||
|
|
||||||
|
// should get same cookie
|
||||||
|
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(sessionsNo, ShouldEqual, 1)
|
||||||
|
|
||||||
|
resp, err = client.R().
|
||||||
|
SetBasicAuth(htpasswdUsername, passphrase).
|
||||||
|
Get(baseURL + constants.FullMgmt)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
client.SetCookies(resp.Cookies())
|
||||||
|
|
||||||
|
// call endpoint with session, without credentials, (added to client after previous request)
|
||||||
|
resp, err = client.R().
|
||||||
|
Get(baseURL + "/v2/_catalog")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp, ShouldNotBeNil)
|
||||||
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
sessionsNo, err = getNumberOfSessions(conf.Storage.RootDirectory)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(sessionsNo, ShouldEqual, 1)
|
||||||
|
})
|
||||||
|
|
||||||
Convey("login with openid and get catalog with session", func() {
|
Convey("login with openid and get catalog with session", func() {
|
||||||
client := resty.New()
|
client := resty.New()
|
||||||
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
client.SetRedirectPolicy(test.CustomRedirectPolicy(20))
|
||||||
|
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
|
||||||
|
|
||||||
Convey("with callback_ui value provided", func() {
|
Convey("with callback_ui value provided", func() {
|
||||||
// first login user
|
// first login user
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
SetQueryParam("provider", "oidc").
|
SetQueryParam("provider", "oidc").
|
||||||
SetQueryParam("callback_ui", baseURL+"/v2/").
|
SetQueryParam("callback_ui", baseURL+"/v2/").
|
||||||
Get(baseURL + constants.LoginPath)
|
Get(baseURL + constants.LoginPath)
|
||||||
|
@ -2729,7 +2798,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// first login user
|
// first login user
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
SetQueryParam("provider", "oidc").
|
SetQueryParam("provider", "oidc").
|
||||||
Get(baseURL + constants.LoginPath)
|
Get(baseURL + constants.LoginPath)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
@ -2740,7 +2808,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// call endpoint with session (added to client after previous request)
|
// call endpoint with session (added to client after previous request)
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2748,14 +2815,12 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// logout with options method for coverage
|
// logout with options method for coverage
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Options(baseURL + constants.LogoutPath)
|
Options(baseURL + constants.LogoutPath)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
|
||||||
// logout user
|
// logout user
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Post(baseURL + constants.LogoutPath)
|
Post(baseURL + constants.LogoutPath)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2763,7 +2828,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// calling endpoint should fail with unauthorized access
|
// calling endpoint should fail with unauthorized access
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2773,6 +2837,7 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
//nolint: dupl
|
//nolint: dupl
|
||||||
Convey("login with basic auth(htpasswd) and get catalog with session", func() {
|
Convey("login with basic auth(htpasswd) and get catalog with session", func() {
|
||||||
client := resty.New()
|
client := resty.New()
|
||||||
|
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
|
||||||
|
|
||||||
// without creds, should get access error
|
// without creds, should get access error
|
||||||
resp, err := client.R().Get(baseURL + "/v2/")
|
resp, err := client.R().Get(baseURL + "/v2/")
|
||||||
|
@ -2806,14 +2871,12 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// call endpoint with session, without credentials, (added to client after previous request)
|
// call endpoint with session, without credentials, (added to client after previous request)
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + constants.FullMgmt)
|
Get(baseURL + constants.FullMgmt)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2821,7 +2884,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// logout user
|
// logout user
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Post(baseURL + constants.LogoutPath)
|
Post(baseURL + constants.LogoutPath)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2829,7 +2891,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// calling endpoint should fail with unauthorized access
|
// calling endpoint should fail with unauthorized access
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2839,6 +2900,7 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
//nolint: dupl
|
//nolint: dupl
|
||||||
Convey("login with ldap and get catalog", func() {
|
Convey("login with ldap and get catalog", func() {
|
||||||
client := resty.New()
|
client := resty.New()
|
||||||
|
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
|
||||||
|
|
||||||
// without creds, should get access error
|
// without creds, should get access error
|
||||||
resp, err := client.R().Get(baseURL + "/v2/")
|
resp, err := client.R().Get(baseURL + "/v2/")
|
||||||
|
@ -2872,14 +2934,12 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// call endpoint with session, without credentials, (added to client after previous request)
|
// call endpoint with session, without credentials, (added to client after previous request)
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + constants.FullMgmt)
|
Get(baseURL + constants.FullMgmt)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2887,7 +2947,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// logout user
|
// logout user
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Post(baseURL + constants.LogoutPath)
|
Post(baseURL + constants.LogoutPath)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2895,7 +2954,6 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
|
|
||||||
// calling endpoint should fail with unauthorized access
|
// calling endpoint should fail with unauthorized access
|
||||||
resp, err = client.R().
|
resp, err = client.R().
|
||||||
SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue).
|
|
||||||
Get(baseURL + "/v2/_catalog")
|
Get(baseURL + "/v2/_catalog")
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
So(resp, ShouldNotBeNil)
|
So(resp, ShouldNotBeNil)
|
||||||
|
@ -2921,6 +2979,7 @@ func TestOpenIDMiddleware(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3273,6 +3332,7 @@ func TestAuthnSessionErrors(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := resty.New()
|
client := resty.New()
|
||||||
|
client.SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue)
|
||||||
|
|
||||||
// first htpasswd saveSessionLoggedUser() error
|
// first htpasswd saveSessionLoggedUser() error
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
|
@ -9764,3 +9824,20 @@ func getEmptyImageConfig() ([]byte, godigest.Digest) {
|
||||||
|
|
||||||
return configBlobContent, configBlobDigestRaw
|
return configBlobContent, configBlobDigestRaw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getNumberOfSessions(rootDir string) (int, error) {
|
||||||
|
rootDirContents, err := os.ReadDir(path.Join(rootDir, "_sessions"))
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionsNo := 0
|
||||||
|
|
||||||
|
for _, file := range rootDirContents {
|
||||||
|
if !file.IsDir() && strings.HasPrefix(file.Name(), "session_") {
|
||||||
|
sessionsNo += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionsNo, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue