0
Fork 0
mirror of https://github.com/project-zot/zot.git synced 2024-12-30 22:34:13 -05:00

fix(sync): added bearer client for sync (#2222)

fixed ping function taking too much time

closes: #2213 #2212

Signed-off-by: Petu Eusebiu <peusebiu@cisco.com>
This commit is contained in:
peusebiu 2024-02-14 19:18:10 +02:00 committed by GitHub
parent d0eb043be5
commit 8e68255946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 943 additions and 135 deletions

View file

@ -936,7 +936,7 @@ Configure each registry sync:
] ]
}, },
{ {
"urls": ["https://docker.io/library"], "urls": ["https://index.docker.io"],
"onDemand": true, # doesn't have content, don't periodically pull, pull just on demand. "onDemand": true, # doesn't have content, don't periodically pull, pull just on demand.
"tlsVerify": true, "tlsVerify": true,
"maxRetries": 3, "maxRetries": 3,

View file

@ -482,7 +482,7 @@ func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc {
if err != nil { if err != nil {
ctlr.Log.Error().Err(err).Msg("failed to parse Authorization header") ctlr.Log.Error().Err(err).Msg("failed to parse Authorization header")
response.Header().Set("Content-Type", "application/json") response.Header().Set("Content-Type", "application/json")
zcommon.WriteJSON(response, http.StatusInternalServerError, apiErr.NewError(apiErr.UNSUPPORTED)) zcommon.WriteJSON(response, http.StatusUnauthorized, apiErr.NewError(apiErr.UNSUPPORTED))
return return
} }

View file

@ -3114,7 +3114,7 @@ func TestBearerAuth(t *testing.T) {
Get(baseURL + "/v2/") Get(baseURL + "/v2/")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp, ShouldNotBeNil) So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
resp, err = resty.R().SetHeader("Authorization", resp, err = resty.R().SetHeader("Authorization",
fmt.Sprintf("Bearer %s", goodToken.AccessToken)).Options(baseURL + "/v2/") fmt.Sprintf("Bearer %s", goodToken.AccessToken)).Options(baseURL + "/v2/")

View file

@ -240,7 +240,7 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques
response.Header().Set(constants.DistAPIVersion, "registry/2.0") response.Header().Set(constants.DistAPIVersion, "registry/2.0")
// NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to // NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to
// work correctly // work correctly
if rh.c.Config.HTTP.Auth != nil { if rh.c.Config.IsBasicAuthnEnabled() || rh.c.Config.IsBearerAuthEnabled() {
// don't send auth headers if request is coming from UI // don't send auth headers if request is coming from UI
if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue {
if rh.c.Config.HTTP.Auth.Bearer != nil { if rh.c.Config.HTTP.Auth.Bearer != nil {

View file

@ -1,18 +1,12 @@
package common package common
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"errors"
"io"
"net/http" "net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"zotregistry.dev/zot/pkg/log"
) )
func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) { func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) {
@ -107,57 +101,7 @@ func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client
} }
return &http.Client{ return &http.Client{
Timeout: httpTimeout,
Transport: htr, Transport: htr,
Timeout: httpTimeout,
}, nil }, nil
} }
func MakeHTTPGetRequest(ctx context.Context, httpClient *http.Client,
username string, password string, resultPtr interface{},
blobURL string, mediaType string, log log.Logger,
) ([]byte, string, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, blobURL, nil) //nolint
if err != nil {
return nil, "", 0, err
}
if mediaType != "" {
req.Header.Set("Accept", mediaType)
}
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
resp, err := httpClient.Do(req)
if err != nil {
log.Error().Str("errorType", TypeOf(err)).
Err(err).Str("blobURL", blobURL).Msg("couldn't get blob")
return nil, "", -1, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Str("errorType", TypeOf(err)).
Err(err).Str("blobURL", blobURL).Msg("couldn't get blob")
return nil, "", resp.StatusCode, err
}
if resp.StatusCode != http.StatusOK {
return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
}
// read blob
if len(body) > 0 {
err = json.Unmarshal(body, &resultPtr)
if err != nil {
return body, "", resp.StatusCode, err
}
}
return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
}

View file

@ -1,19 +1,14 @@
package common_test package common_test
import ( import (
"context"
"crypto/x509" "crypto/x509"
"os" "os"
"path" "path"
"testing" "testing"
ispec "github.com/opencontainers/image-spec/specs-go/v1"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
"zotregistry.dev/zot/pkg/api"
"zotregistry.dev/zot/pkg/api/config"
"zotregistry.dev/zot/pkg/common" "zotregistry.dev/zot/pkg/common"
"zotregistry.dev/zot/pkg/log"
test "zotregistry.dev/zot/pkg/test/common" test "zotregistry.dev/zot/pkg/test/common"
) )
@ -54,30 +49,4 @@ func TestHTTPClient(t *testing.T) {
_, err = common.CreateHTTPClient(true, "localhost", tempDir) _, err = common.CreateHTTPClient(true, "localhost", tempDir)
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
}) })
Convey("test MakeHTTPGetRequest() no permissions on key", t, func() {
port := test.GetFreePort()
baseURL := test.GetBaseURL(port)
conf := config.New()
conf.HTTP.Port = port
ctlr := api.NewController(conf)
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
ctlr.Config.Storage.RootDirectory = tempDir
cm := test.NewControllerManager(ctlr)
cm.StartServer()
defer cm.StopServer()
test.WaitTillServerReady(baseURL)
var resultPtr interface{}
httpClient, err := common.CreateHTTPClient(true, "localhost", tempDir)
So(err, ShouldBeNil)
_, _, _, err = common.MakeHTTPGetRequest(context.Background(), httpClient, "", "",
resultPtr, baseURL+"/v2/", ispec.MediaTypeImageManifest, log.NewLogger("", ""))
So(err, ShouldBeNil)
})
} }

View file

@ -50,6 +50,8 @@ func EnableSyncExtension(config *config.Config, metaDB mTypes.MetaDB,
service, err := sync.New(registryConfig, credsPath, tmpDir, storeController, metaDB, log) service, err := sync.New(registryConfig, credsPath, tmpDir, storeController, metaDB, log)
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to initialize sync extension")
return nil, err return nil, err
} }

View file

@ -0,0 +1,58 @@
package client
import (
"sync"
)
// Key:Value store for bearer tokens, key is namespace, value is token.
// We are storing only pull scoped tokens, the http client is for pulling only.
type TokenCache struct {
entries sync.Map
}
func NewTokenCache() *TokenCache {
return &TokenCache{
entries: sync.Map{},
}
}
func (c *TokenCache) Set(namespace string, token *bearerToken) {
if c == nil || token == nil {
return
}
defer c.prune()
c.entries.Store(namespace, token)
}
func (c *TokenCache) Get(namespace string) *bearerToken {
if c == nil {
return nil
}
val, ok := c.entries.Load(namespace)
if !ok {
return nil
}
bearerToken, ok := val.(*bearerToken)
if !ok {
return nil
}
return bearerToken
}
func (c *TokenCache) prune() {
c.entries.Range(func(key, val any) bool {
bearerToken, ok := val.(*bearerToken)
if ok {
if bearerToken.isExpired() {
c.entries.Delete(key)
}
}
return true
})
}

View file

@ -2,15 +2,56 @@ package client
import ( import (
"context" "context"
"encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync" "sync"
"time"
zerr "zotregistry.dev/zot/errors"
"zotregistry.dev/zot/pkg/common" "zotregistry.dev/zot/pkg/common"
"zotregistry.dev/zot/pkg/log" "zotregistry.dev/zot/pkg/log"
) )
const (
minimumTokenLifetimeSeconds = 60 // in seconds
pingTimeout = 5 * time.Second
// tokenBuffer is used to renew a token before it actually expires
// to account for the time to process requests on the server.
tokenBuffer = 5 * time.Second
)
type authType int
const (
noneAuth authType = iota
basicAuth
tokenAuth
)
type challengeParams struct {
realm string
service string
scope string
err string
}
type bearerToken struct {
Token string `json:"token"` //nolint: tagliatelle
AccessToken string `json:"access_token"` //nolint: tagliatelle
ExpiresIn int `json:"expires_in"` //nolint: tagliatelle
IssuedAt time.Time `json:"issued_at"` //nolint: tagliatelle
expirationTime time.Time
}
func (token *bearerToken) isExpired() bool {
// use tokenBuffer to expire it a bit earlier
return time.Now().After(token.expirationTime.Add(-1 * tokenBuffer))
}
type Config struct { type Config struct {
URL string URL string
Username string Username string
@ -23,12 +64,17 @@ type Client struct {
config *Config config *Config
client *http.Client client *http.Client
url *url.URL url *url.URL
authType authType
cache *TokenCache
lock *sync.RWMutex lock *sync.RWMutex
log log.Logger log log.Logger
} }
func New(config Config, log log.Logger) (*Client, error) { func New(config Config, log log.Logger) (*Client, error) {
client := &Client{log: log, lock: new(sync.RWMutex)} client := &Client{log: log, lock: new(sync.RWMutex)}
client.cache = NewTokenCache()
if err := client.SetConfig(config); err != nil { if err := client.SetConfig(config); err != nil {
return nil, err return nil, err
} }
@ -50,6 +96,13 @@ func (httpClient *Client) GetHostname() string {
return httpClient.url.Host return httpClient.url.Host
} }
func (httpClient *Client) GetBaseURL() string {
httpClient.lock.RLock()
defer httpClient.lock.RUnlock()
return httpClient.url.String()
}
func (httpClient *Client) SetConfig(config Config) error { func (httpClient *Client) SetConfig(config Config) error {
httpClient.lock.Lock() httpClient.lock.Lock()
defer httpClient.lock.Unlock() defer httpClient.lock.Unlock()
@ -73,41 +126,30 @@ func (httpClient *Client) SetConfig(config Config) error {
} }
func (httpClient *Client) Ping() bool { func (httpClient *Client) Ping() bool {
httpClient.lock.RLock() httpClient.lock.Lock()
defer httpClient.lock.RUnlock() defer httpClient.lock.Unlock()
pingURL := *httpClient.url pingURL := *httpClient.url
pingURL = *pingURL.JoinPath("/v2/") pingURL = *pingURL.JoinPath("/v2/")
req, err := http.NewRequest(http.MethodGet, pingURL.String(), nil) //nolint // for the ping function we want to timeout fast
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
//nolint: bodyclose
resp, _, err := httpClient.get(ctx, pingURL.String(), false)
if err != nil { if err != nil {
return false return false
} }
resp, err := httpClient.client.Do(req) httpClient.getAuthType(resp)
if err != nil {
httpClient.log.Error().Err(err).Str("url", pingURL.String()).Str("component", "sync").
Msg("failed to ping registry")
return false if resp.StatusCode >= http.StatusOK && resp.StatusCode <= http.StatusForbidden {
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized {
return true return true
} }
body, err := io.ReadAll(resp.Body) httpClient.log.Error().Str("url", pingURL.String()).Int("statusCode", resp.StatusCode).
if err != nil {
httpClient.log.Error().Err(err).Str("url", pingURL.String()).
Msg("failed to read body while pinging registry")
return false
}
httpClient.log.Error().Str("url", pingURL.String()).Str("body", string(body)).Int("statusCode", resp.StatusCode).
Str("component", "sync").Msg("failed to ping registry") Str("component", "sync").Msg("failed to ping registry")
return false return false
@ -119,17 +161,302 @@ func (httpClient *Client) MakeGetRequest(ctx context.Context, resultPtr interfac
httpClient.lock.RLock() httpClient.lock.RLock()
defer httpClient.lock.RUnlock() defer httpClient.lock.RUnlock()
url := *httpClient.url var namespace string
for _, r := range route { url := *httpClient.url
url = *url.JoinPath(r) for idx, path := range route {
url = *url.JoinPath(path)
// we know that the second route argument is always the repo name.
// need it for caching tokens, it's not used in requests made to authz server.
if idx == 1 {
namespace = path
}
} }
url.RawQuery = url.Query().Encode() url.RawQuery = url.Query().Encode()
//nolint: bodyclose,contextcheck
resp, body, err := httpClient.makeAndDoRequest(http.MethodGet, mediaType, namespace, url.String())
if err != nil {
httpClient.log.Error().Err(err).Str("url", url.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to make request")
body, mediaType, statusCode, err := common.MakeHTTPGetRequest(ctx, httpClient.client, httpClient.config.Username, return nil, "", -1, err
httpClient.config.Password, resultPtr, }
url.String(), mediaType, httpClient.log)
if resp.StatusCode != http.StatusOK {
return body, mediaType, statusCode, err return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
}
// read blob
if len(body) > 0 {
err = json.Unmarshal(body, &resultPtr)
}
return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
}
func (httpClient *Client) getAuthType(resp *http.Response) {
authHeader := resp.Header.Get("www-authenticate")
authHeaderLower := strings.ToLower(authHeader)
//nolint: gocritic
if strings.Contains(authHeaderLower, "bearer") {
httpClient.authType = tokenAuth
} else if strings.Contains(authHeaderLower, "basic") {
httpClient.authType = basicAuth
} else {
httpClient.authType = noneAuth
}
}
func (httpClient *Client) setupAuth(req *http.Request, namespace string) error {
if httpClient.authType == tokenAuth {
token, err := httpClient.getToken(req.URL.String(), namespace)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to get token from authorization realm")
return err
}
req.Header.Set("Authorization", "Bearer "+token.Token)
} else if httpClient.authType == basicAuth {
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
}
return nil
}
func (httpClient *Client) get(ctx context.Context, url string, setAuth bool) (*http.Response, []byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) //nolint
if err != nil {
return nil, nil, err
}
if setAuth && httpClient.config.Username != "" && httpClient.config.Password != "" {
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
}
return httpClient.doRequest(req)
}
func (httpClient *Client) doRequest(req *http.Request) (*http.Response, []byte, error) {
resp, err := httpClient.client.Do(req)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to make request")
return nil, nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).
Str("errorType", common.TypeOf(err)).
Msg("failed to read body")
return nil, nil, err
}
return resp, body, nil
}
func (httpClient *Client) makeAndDoRequest(method, mediaType, namespace, urlStr string,
) (*http.Response, []byte, error) {
req, err := http.NewRequest(method, urlStr, nil) //nolint
if err != nil {
return nil, nil, err
}
if err := httpClient.setupAuth(req, namespace); err != nil {
return nil, nil, err
}
if mediaType != "" {
req.Header.Set("Accept", mediaType)
}
resp, body, err := httpClient.doRequest(req)
if err != nil {
return nil, nil, err
}
// let's retry one time if we get an insufficient_scope error
if ok, challengeParams := needsRetryWithUpdatedScope(err, resp); ok {
var tokenURL *url.URL
var token *bearerToken
tokenURL, err = getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
if err != nil {
return nil, nil, err
}
token, err = httpClient.getTokenFromURL(tokenURL.String(), namespace)
if err != nil {
return nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+token.Token)
resp, body, err = httpClient.doRequest(req)
}
return resp, body, err
}
func (httpClient *Client) getTokenFromURL(urlStr, namespace string) (*bearerToken, error) {
//nolint: bodyclose
resp, body, err := httpClient.get(context.Background(), urlStr, true)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, zerr.ErrUnauthorizedAccess
}
token, err := newBearerToken(body)
if err != nil {
return nil, err
}
// cache it
httpClient.cache.Set(namespace, token)
return token, nil
}
// Gets bearer token from Authorization realm.
func (httpClient *Client) getToken(urlStr, namespace string) (*bearerToken, error) {
// first check cache
token := httpClient.cache.Get(namespace)
if token != nil && !token.isExpired() {
return token, nil
}
//nolint: bodyclose
resp, _, err := httpClient.get(context.Background(), urlStr, false)
if err != nil {
return nil, err
}
challengeParams, err := parseAuthHeader(resp)
if err != nil {
return nil, err
}
tokenURL, err := getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
if err != nil {
return nil, err
}
return httpClient.getTokenFromURL(tokenURL.String(), namespace)
}
func newBearerToken(blob []byte) (*bearerToken, error) {
token := new(bearerToken)
if err := json.Unmarshal(blob, &token); err != nil {
return nil, err
}
if token.Token == "" {
token.Token = token.AccessToken
}
if token.ExpiresIn < minimumTokenLifetimeSeconds {
token.ExpiresIn = minimumTokenLifetimeSeconds
}
if token.IssuedAt.IsZero() {
token.IssuedAt = time.Now().UTC()
}
token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second)
return token, nil
}
func getTokenURLFromChallengeParams(params challengeParams, account string) (*url.URL, error) {
parsedRealm, err := url.Parse(params.realm)
if err != nil {
return nil, err
}
query := parsedRealm.Query()
query.Set("service", params.service)
query.Set("scope", params.scope)
if account != "" {
query.Set("account", account)
}
parsedRealm.RawQuery = query.Encode()
return parsedRealm, nil
}
func parseAuthHeader(resp *http.Response) (challengeParams, error) {
authHeader := resp.Header.Get("www-authenticate")
authHeaderSlice := strings.Split(authHeader, ",")
params := challengeParams{}
for _, elem := range authHeaderSlice {
if strings.Contains(strings.ToLower(elem), "bearer") {
elem = strings.Split(elem, " ")[1]
}
elem := strings.ReplaceAll(elem, "\"", "")
elemSplit := strings.Split(elem, "=")
if len(elemSplit) != 2 { //nolint: gomnd
return params, zerr.ErrParsingAuthHeader
}
authKey := elemSplit[0]
authValue := elemSplit[1]
switch authKey {
case "realm":
params.realm = authValue
case "service":
params.service = authValue
case "scope":
params.scope = authValue
case "error":
params.err = authValue
}
}
return params, nil
}
// Checks if the auth headers in the response contain an indication of a failed
// authorization because of an "insufficient_scope" error.
func needsRetryWithUpdatedScope(err error, resp *http.Response) (bool, challengeParams) {
params := challengeParams{}
if err == nil && resp.StatusCode == http.StatusUnauthorized {
params, err = parseAuthHeader(resp)
if err != nil {
return false, params
}
if params.err == "insufficient_scope" {
if params.scope != "" {
return true, params
}
}
}
return false, params
} }

View file

@ -0,0 +1,167 @@
package client
import (
"net/http"
"net/http/httptest"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"zotregistry.dev/zot/pkg/log"
)
func TestTokenCache(t *testing.T) {
Convey("Get/Set tokens", t, func() {
tokenCache := NewTokenCache()
token := &bearerToken{
Token: "tokenA",
ExpiresIn: 3,
IssuedAt: time.Now(),
}
token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second).Add(tokenBuffer)
tokenCache.Set("repo", token)
cachedToken := tokenCache.Get("repo")
So(cachedToken.Token, ShouldEqual, token.Token)
// add token which expires soon
token2 := &bearerToken{
Token: "tokenB",
ExpiresIn: 1,
IssuedAt: time.Now(),
}
token2.expirationTime = token2.IssuedAt.Add(time.Duration(token2.ExpiresIn) * time.Second).Add(tokenBuffer)
tokenCache.Set("repo2", token2)
cachedToken = tokenCache.Get("repo2")
So(cachedToken.Token, ShouldEqual, token2.Token)
time.Sleep(1 * time.Second)
// token3 should be expired when adding a new one
token3 := &bearerToken{
Token: "tokenC",
ExpiresIn: 3,
IssuedAt: time.Now(),
}
token3.expirationTime = token3.IssuedAt.Add(time.Duration(token3.ExpiresIn) * time.Second).Add(tokenBuffer)
tokenCache.Set("repo3", token3)
cachedToken = tokenCache.Get("repo3")
So(cachedToken.Token, ShouldEqual, token3.Token)
// token2 should be expired
token = tokenCache.Get("repo2")
So(token, ShouldBeNil)
time.Sleep(2 * time.Second)
// the rest of them should also be expired
tokenCache.Set("repo4", &bearerToken{
Token: "tokenD",
})
// token1 should be expired
token = tokenCache.Get("repo1")
So(token, ShouldBeNil)
})
Convey("Error paths", t, func() {
tokenCache := NewTokenCache()
token := tokenCache.Get("repo")
So(token, ShouldBeNil)
tokenCache = nil
token = tokenCache.Get("repo")
So(token, ShouldBeNil)
tokenCache = NewTokenCache()
tokenCache.Set("repo", nil)
token = tokenCache.Get("repo")
So(token, ShouldBeNil)
})
}
func TestNeedsRetryOnInsuficientScope(t *testing.T) {
resp := http.Response{
Status: "401 Unauthorized",
StatusCode: http.StatusUnauthorized,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: map[string][]string{
"Content-Length": {"145"},
"Content-Type": {"application/json"},
"Date": {"Fri, 26 Aug 2022 08:03:13 GMT"},
"X-Content-Type-Options": {"nosniff"},
},
Request: nil,
}
Convey("Test client retries on insufficient scope", t, func() {
resp.Header["Www-Authenticate"] = []string{
`Bearer realm="https://registry.suse.com/auth",service="SUSE Linux Docker Registry"` +
`,scope="registry:catalog:*",error="insufficient_scope"`,
}
expectedScope := "registry:catalog:*"
expectedRealm := "https://registry.suse.com/auth"
expectedService := "SUSE Linux Docker Registry"
needsRetry, params := needsRetryWithUpdatedScope(nil, &resp)
So(needsRetry, ShouldBeTrue)
So(params.scope, ShouldEqual, expectedScope)
So(params.realm, ShouldEqual, expectedRealm)
So(params.service, ShouldEqual, expectedService)
})
Convey("Test client fails on insufficient scope", t, func() {
resp.Header["Www-Authenticate"] = []string{
`Bearer realm="https://registry.suse.com/auth=error"`,
}
needsRetry, _ := needsRetryWithUpdatedScope(nil, &resp)
So(needsRetry, ShouldBeFalse)
})
}
func TestClient(t *testing.T) {
Convey("Test client", t, func() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
client, err := New(Config{
URL: server.URL,
TLSVerify: false,
}, log.NewLogger("", ""))
So(err, ShouldBeNil)
Convey("Test Ping() fails", func() {
ok := client.Ping()
So(ok, ShouldBeFalse)
})
Convey("Test makeAndDoRequest() fails", func() {
client.authType = tokenAuth
//nolint: bodyclose
_, _, err := client.makeAndDoRequest(http.MethodGet, "application/json", "catalog", server.URL)
So(err, ShouldNotBeNil)
})
Convey("Test setupAuth() fails", func() {
request, err := http.NewRequest(http.MethodGet, server.URL, nil) //nolint: noctx
So(err, ShouldBeNil)
client.authType = tokenAuth
err = client.setupAuth(request, "catalog")
So(err, ShouldNotBeNil)
})
})
}

View file

@ -109,14 +109,20 @@ func (onDemand *BaseOnDemand) syncImage(ctx context.Context, repo, reference str
var err error var err error
for serviceID, service := range onDemand.services { for serviceID, service := range onDemand.services {
err = service.SetNextAvailableURL() err = service.SetNextAvailableURL()
if err != nil {
isPingErr := errors.Is(err, zerr.ErrSyncPingRegistry)
if err != nil && !isPingErr {
syncResult <- err syncResult <- err
return return
} }
// no need to try to sync inline if there is a ping error, we want to retry in background
if !isPingErr {
err = service.SyncImage(ctx, repo, reference) err = service.SyncImage(ctx, repo, reference)
if err != nil { }
if err != nil || isPingErr {
if errors.Is(err, zerr.ErrManifestNotFound) || if errors.Is(err, zerr.ErrManifestNotFound) ||
errors.Is(err, zerr.ErrSyncImageFilteredOut) || errors.Is(err, zerr.ErrSyncImageFilteredOut) ||
errors.Is(err, zerr.ErrSyncImageNotSigned) { errors.Is(err, zerr.ErrSyncImageNotSigned) {

View file

@ -6,6 +6,7 @@ package sync
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/containers/image/v5/docker" "github.com/containers/image/v5/docker"
dockerReference "github.com/containers/image/v5/docker/reference" dockerReference "github.com/containers/image/v5/docker/reference"
@ -58,6 +59,26 @@ func (registry *RemoteRegistry) GetRepositories(ctx context.Context) ([]string,
return catalog.Repositories, nil return catalog.Repositories, nil
} }
func (registry *RemoteRegistry) GetDockerRemoteRepo(repo string) string {
dockerNamespace := "library"
dockerRegistry := "docker.io"
remoteHost := registry.client.GetHostname()
repoRef, err := parseRepositoryReference(fmt.Sprintf("%s/%s", remoteHost, repo))
if err != nil {
return repo
}
if !strings.Contains(repo, dockerNamespace) &&
strings.Contains(repoRef.String(), dockerNamespace) &&
strings.Contains(repoRef.String(), dockerRegistry) {
return fmt.Sprintf("%s/%s", dockerNamespace, repo)
}
return repo
}
func (registry *RemoteRegistry) GetImageReference(repo, reference string) (types.ImageReference, error) { func (registry *RemoteRegistry) GetImageReference(repo, reference string) (types.ImageReference, error) {
remoteHost := registry.client.GetHostname() remoteHost := registry.client.GetHostname()

View file

@ -93,9 +93,12 @@ func New(
service.retryOptions = retryOptions service.retryOptions = retryOptions
service.storeController = storeController service.storeController = storeController
err = service.SetNextAvailableClient() // try to set next client.
if err != nil { if err := service.SetNextAvailableClient(); err != nil {
return nil, err // if it's a ping issue, it will be retried
if !errors.Is(err, zerr.ErrSyncPingRegistry) {
return service, err
}
} }
service.references = references.NewReferences( service.references = references.NewReferences(
@ -118,7 +121,14 @@ func (service *BaseService) SetNextAvailableClient() error {
return nil return nil
} }
found := false
for _, url := range service.config.URLs { for _, url := range service.config.URLs {
// skip current client
if service.client != nil && service.client.GetBaseURL() == url {
continue
}
remoteAddress := StripRegistryTransport(url) remoteAddress := StripRegistryTransport(url)
credentials := service.credentials[remoteAddress] credentials := service.credentials[remoteAddress]
@ -149,12 +159,14 @@ func (service *BaseService) SetNextAvailableClient() error {
return err return err
} }
if !service.client.Ping() { if service.client.Ping() {
continue found = true
break
} }
} }
if service.client == nil { if service.client == nil || !found {
return zerr.ErrSyncPingRegistry return zerr.ErrSyncPingRegistry
} }
@ -241,6 +253,8 @@ func (service *BaseService) SyncReference(ctx context.Context, repo string,
} }
} }
remoteRepo = service.remote.GetDockerRemoteRepo(remoteRepo)
service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("subject", subjectDigestStr). service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("subject", subjectDigestStr).
Str("reference type", referenceType).Msg("syncing reference for image") Str("reference type", referenceType).Msg("syncing reference for image")
@ -263,6 +277,8 @@ func (service *BaseService) SyncImage(ctx context.Context, repo, reference strin
} }
} }
remoteRepo = service.remote.GetDockerRemoteRepo(remoteRepo)
service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("reference", reference). service.log.Info().Str("remote", remoteURL).Str("repository", repo).Str("reference", reference).
Msg("syncing image") Msg("syncing image")

View file

@ -63,6 +63,9 @@ type Remote interface {
GetRepoTags(repo string) ([]string, error) GetRepoTags(repo string) ([]string, error)
// Get manifest content, mediaType, digest given an ImageReference // Get manifest content, mediaType, digest given an ImageReference
GetManifestContent(imageReference types.ImageReference) ([]byte, string, digest.Digest, error) GetManifestContent(imageReference types.ImageReference) ([]byte, string, digest.Digest, error)
// In the case of public dockerhub images 'library' namespace is added to the repo names of images
// eg: alpine -> library/alpine
GetDockerRemoteRepo(repo string) string
} }
// Local registry. // Local registry.
@ -111,6 +114,11 @@ func (gen *TaskGenerator) Next() (scheduler.Task, error) {
return nil, nil return nil, nil
} }
// a task with this repo is already running
if gen.lastRepo == repo {
return nil, nil
}
gen.lastRepo = repo gen.lastRepo = repo
return newSyncRepoTask(gen.lastRepo, gen.Service), nil return newSyncRepoTask(gen.lastRepo, gen.Service), nil

View file

@ -11,6 +11,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"os" "os"
"os/exec" "os/exec"
"path" "path"
@ -47,6 +48,7 @@ import (
"zotregistry.dev/zot/pkg/log" "zotregistry.dev/zot/pkg/log"
mTypes "zotregistry.dev/zot/pkg/meta/types" mTypes "zotregistry.dev/zot/pkg/meta/types"
storageConstants "zotregistry.dev/zot/pkg/storage/constants" storageConstants "zotregistry.dev/zot/pkg/storage/constants"
authutils "zotregistry.dev/zot/pkg/test/auth"
test "zotregistry.dev/zot/pkg/test/common" test "zotregistry.dev/zot/pkg/test/common"
. "zotregistry.dev/zot/pkg/test/image-utils" . "zotregistry.dev/zot/pkg/test/image-utils"
"zotregistry.dev/zot/pkg/test/mocks" "zotregistry.dev/zot/pkg/test/mocks"
@ -2364,6 +2366,284 @@ func TestTLS(t *testing.T) {
}) })
} }
func TestBearerAuth(t *testing.T) {
Convey("Verify periodically sync bearer auth", t, func() {
updateDuration, _ := time.ParseDuration("1h")
// a repo for which clients do not have access, sync shouldn't be able to sync it
unauthorizedNamespace := testCveImage
authTestServer := authutils.MakeAuthTestServer(ServerKey, unauthorizedNamespace)
defer authTestServer.Close()
sctlr, srcBaseURL, _, _, srcClient := makeUpstreamServer(t, false, false)
aurl, err := url.Parse(authTestServer.URL)
So(err, ShouldBeNil)
sctlr.Config.HTTP.Auth = &config.AuthConfig{
Bearer: &config.BearerConfig{
Cert: ServerCert,
Realm: authTestServer.URL + "/auth/token",
Service: aurl.Host,
},
}
scm := test.NewControllerManager(sctlr)
scm.StartAndWait(sctlr.Config.HTTP.Port)
defer scm.StopServer()
registryName := sync.StripRegistryTransport(srcBaseURL)
credentialsFile := makeCredentialsFile(fmt.Sprintf(`{"%s":{"username": "%s", "password": "%s"}}`,
registryName, username, password))
var tlsVerify bool
syncRegistryConfig := syncconf.RegistryConfig{
Content: []syncconf.Content{
{
Prefix: "**", // sync everything
},
},
URLs: []string{srcBaseURL},
PollInterval: updateDuration,
TLSVerify: &tlsVerify,
CertDir: "",
}
defaultVal := true
syncConfig := &syncconf.Config{
Enable: &defaultVal,
CredentialsFile: credentialsFile,
Registries: []syncconf.RegistryConfig{syncRegistryConfig},
}
dctlr, destBaseURL, _, destClient := makeDownstreamServer(t, false, syncConfig)
dcm := test.NewControllerManager(dctlr)
dcm.StartAndWait(dctlr.Config.HTTP.Port)
defer dcm.StopServer()
var srcTagsList TagsList
var destTagsList TagsList
resp, err := srcClient.R().Get(srcBaseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
authorizationHeader := authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate"))
resp, err = resty.R().
SetQueryParam("service", authorizationHeader.Service).
Get(authorizationHeader.Realm)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
var goodToken authutils.AccessTokenResponse
err = json.Unmarshal(resp.Body(), &goodToken)
So(err, ShouldBeNil)
resp, err = srcClient.R().
SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)).
Get(srcBaseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = srcClient.R().Get(srcBaseURL + "/v2/" + testImage + "/tags/list")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
authorizationHeader = authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate"))
resp, err = resty.R().
SetQueryParam("service", authorizationHeader.Service).
SetQueryParam("scope", authorizationHeader.Scope).
Get(authorizationHeader.Realm)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
goodToken = authutils.AccessTokenResponse{}
err = json.Unmarshal(resp.Body(), &goodToken)
So(err, ShouldBeNil)
resp, err = srcClient.R().SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)).
Get(srcBaseURL + "/v2/" + testImage + "/tags/list")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
err = json.Unmarshal(resp.Body(), &srcTagsList)
if err != nil {
panic(err)
}
for {
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/tags/list")
if err != nil {
panic(err)
}
err = json.Unmarshal(resp.Body(), &destTagsList)
if err != nil {
panic(err)
}
if len(destTagsList.Tags) > 0 {
break
}
time.Sleep(500 * time.Millisecond)
}
So(destTagsList, ShouldResemble, srcTagsList)
waitSyncFinish(dctlr.Config.Log.Output)
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
// unauthorized namespace
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testCveImage + "/manifests/" + testImageTag)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusNotFound)
})
Convey("Verify ondemand sync bearer auth", t, func() {
// a repo for which clients do not have access, sync shouldn't be able to sync it
unauthorizedNamespace := testCveImage
authTestServer := authutils.MakeAuthTestServer(ServerKey, unauthorizedNamespace)
defer authTestServer.Close()
sctlr, srcBaseURL, _, _, srcClient := makeUpstreamServer(t, false, false)
aurl, err := url.Parse(authTestServer.URL)
So(err, ShouldBeNil)
sctlr.Config.HTTP.Auth = &config.AuthConfig{
Bearer: &config.BearerConfig{
Cert: ServerCert,
Realm: authTestServer.URL + "/auth/token",
Service: aurl.Host,
},
}
scm := test.NewControllerManager(sctlr)
scm.StartAndWait(sctlr.Config.HTTP.Port)
defer scm.StopServer()
registryName := sync.StripRegistryTransport(srcBaseURL)
credentialsFile := makeCredentialsFile(fmt.Sprintf(`{"%s":{"username": "%s", "password": "%s"}}`,
registryName, username, password))
var tlsVerify bool
syncRegistryConfig := syncconf.RegistryConfig{
Content: []syncconf.Content{
{
Prefix: "**", // sync everything
},
},
URLs: []string{srcBaseURL},
TLSVerify: &tlsVerify,
OnDemand: true,
CertDir: "",
}
defaultVal := true
syncConfig := &syncconf.Config{
Enable: &defaultVal,
CredentialsFile: credentialsFile,
Registries: []syncconf.RegistryConfig{syncRegistryConfig},
}
dctlr, destBaseURL, _, destClient := makeDownstreamServer(t, false, syncConfig)
dcm := test.NewControllerManager(dctlr)
dcm.StartAndWait(dctlr.Config.HTTP.Port)
defer dcm.StopServer()
var srcTagsList TagsList
var destTagsList TagsList
resp, err := srcClient.R().Get(srcBaseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
authorizationHeader := authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate"))
resp, err = resty.R().
SetQueryParam("service", authorizationHeader.Service).
Get(authorizationHeader.Realm)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
var goodToken authutils.AccessTokenResponse
err = json.Unmarshal(resp.Body(), &goodToken)
So(err, ShouldBeNil)
resp, err = srcClient.R().
SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)).
Get(srcBaseURL + "/v2/")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = srcClient.R().Get(srcBaseURL + "/v2/" + testImage + "/tags/list")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized)
authorizationHeader = authutils.ParseBearerAuthHeader(resp.Header().Get("WWW-Authenticate"))
resp, err = resty.R().
SetQueryParam("service", authorizationHeader.Service).
SetQueryParam("scope", authorizationHeader.Scope).
Get(authorizationHeader.Realm)
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
goodToken = authutils.AccessTokenResponse{}
err = json.Unmarshal(resp.Body(), &goodToken)
So(err, ShouldBeNil)
resp, err = srcClient.R().SetHeader("Authorization", fmt.Sprintf("Bearer %s", goodToken.AccessToken)).
Get(srcBaseURL + "/v2/" + testImage + "/tags/list")
So(err, ShouldBeNil)
So(resp, ShouldNotBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
err = json.Unmarshal(resp.Body(), &srcTagsList)
if err != nil {
panic(err)
}
// sync on demand
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusOK)
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testImage + "/tags/list")
if err != nil {
panic(err)
}
err = json.Unmarshal(resp.Body(), &destTagsList)
if err != nil {
panic(err)
}
So(destTagsList, ShouldResemble, srcTagsList)
// unauthorized namespace
resp, err = destClient.R().Get(destBaseURL + "/v2/" + testCveImage + "/manifests/" + testImageTag)
So(err, ShouldBeNil)
So(resp.StatusCode(), ShouldEqual, http.StatusNotFound)
})
}
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
Convey("Verify sync basic auth", t, func() { Convey("Verify sync basic auth", t, func() {
updateDuration, _ := time.ParseDuration("1h") updateDuration, _ := time.ParseDuration("1h")

View file

@ -20,10 +20,20 @@ type SyncRemote struct {
// Get a list of tags given a repo // Get a list of tags given a repo
GetRepoTagsFn func(repo string) ([]string, error) GetRepoTagsFn func(repo string) ([]string, error)
GetDockerRemoteRepoFn func(repo string) string
// Get manifest content, mediaType, digest given an ImageReference // Get manifest content, mediaType, digest given an ImageReference
GetManifestContentFn func(imageReference types.ImageReference) ([]byte, string, digest.Digest, error) GetManifestContentFn func(imageReference types.ImageReference) ([]byte, string, digest.Digest, error)
} }
func (remote SyncRemote) GetDockerRemoteRepo(repo string) string {
if remote.GetDockerRemoteRepoFn != nil {
return remote.GetDockerRemoteRepoFn(repo)
}
return ""
}
func (remote SyncRemote) GetImageReference(repo string, tag string) (types.ImageReference, error) { func (remote SyncRemote) GetImageReference(repo string, tag string) (types.ImageReference, error) {
if remote.GetImageReferenceFn != nil { if remote.GetImageReferenceFn != nil {
return remote.GetImageReferenceFn(repo, tag) return remote.GetImageReferenceFn(repo, tag)