mirror of
https://github.com/project-zot/zot.git
synced 2024-12-30 22:34:13 -05:00
c2facc9958
Signed-off-by: evanebb <78433178+evanebb@users.noreply.github.com>
484 lines
11 KiB
Go
484 lines
11 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
zerr "zotregistry.dev/zot/errors"
|
|
"zotregistry.dev/zot/pkg/common"
|
|
"zotregistry.dev/zot/pkg/log"
|
|
)
|
|
|
|
const (
|
|
minimumTokenLifetimeSeconds = 60 // in seconds
|
|
pingTimeout = 5 * time.Second
|
|
// tokenBuffer is used to renew a token before it actually expires
|
|
// to account for the time to process requests on the server.
|
|
tokenBuffer = 5 * time.Second
|
|
)
|
|
|
|
type authType int
|
|
|
|
const (
|
|
noneAuth authType = iota
|
|
basicAuth
|
|
tokenAuth
|
|
)
|
|
|
|
type challengeParams struct {
|
|
realm string
|
|
service string
|
|
scope string
|
|
err string
|
|
}
|
|
|
|
type bearerToken struct {
|
|
Token string `json:"token"` //nolint: tagliatelle
|
|
AccessToken string `json:"access_token"` //nolint: tagliatelle
|
|
ExpiresIn int `json:"expires_in"` //nolint: tagliatelle
|
|
IssuedAt time.Time `json:"issued_at"` //nolint: tagliatelle
|
|
expirationTime time.Time
|
|
}
|
|
|
|
func (token *bearerToken) isExpired() bool {
|
|
// use tokenBuffer to expire it a bit earlier
|
|
return time.Now().After(token.expirationTime.Add(-1 * tokenBuffer))
|
|
}
|
|
|
|
type Config struct {
|
|
URL string
|
|
Username string
|
|
Password string
|
|
CertDir string
|
|
TLSVerify bool
|
|
}
|
|
|
|
type Client struct {
|
|
config *Config
|
|
client *http.Client
|
|
url *url.URL
|
|
authType authType
|
|
cache *TokenCache
|
|
lock *sync.RWMutex
|
|
log log.Logger
|
|
}
|
|
|
|
func New(config Config, log log.Logger) (*Client, error) {
|
|
client := &Client{log: log, lock: new(sync.RWMutex)}
|
|
|
|
client.cache = NewTokenCache()
|
|
|
|
if err := client.SetConfig(config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (httpClient *Client) GetConfig() *Config {
|
|
httpClient.lock.RLock()
|
|
defer httpClient.lock.RUnlock()
|
|
|
|
return httpClient.config
|
|
}
|
|
|
|
func (httpClient *Client) GetHostname() string {
|
|
httpClient.lock.RLock()
|
|
defer httpClient.lock.RUnlock()
|
|
|
|
return httpClient.url.Host
|
|
}
|
|
|
|
func (httpClient *Client) GetBaseURL() string {
|
|
httpClient.lock.RLock()
|
|
defer httpClient.lock.RUnlock()
|
|
|
|
return httpClient.url.String()
|
|
}
|
|
|
|
func (httpClient *Client) SetConfig(config Config) error {
|
|
httpClient.lock.Lock()
|
|
defer httpClient.lock.Unlock()
|
|
|
|
clientURL, err := url.Parse(config.URL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
httpClient.url = clientURL
|
|
|
|
// we want TLS enabled if the upstream registry URL is an HTTPS URL
|
|
tlsEnabled := clientURL.Scheme == "https"
|
|
|
|
clientOpts := common.HTTPClientOptions{
|
|
TLSEnabled: tlsEnabled,
|
|
VerifyTLS: config.TLSVerify,
|
|
Host: clientURL.Host,
|
|
}
|
|
|
|
if config.CertDir != "" {
|
|
// only configure the default cert file names if the CertDir was specified.
|
|
clientOpts.CertOptions = common.HTTPClientCertOptions{
|
|
// filepath is the recommended library to use for joining paths
|
|
// taking into account the underlying OS.
|
|
// ref: https://stackoverflow.com/a/39182128
|
|
ClientCertFile: filepath.Join(config.CertDir, common.ClientCertFilename),
|
|
ClientKeyFile: filepath.Join(config.CertDir, common.ClientKeyFilename),
|
|
RootCaCertFile: filepath.Join(config.CertDir, common.CaCertFilename),
|
|
}
|
|
}
|
|
|
|
client, err := common.CreateHTTPClient(&clientOpts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
httpClient.client = client
|
|
httpClient.config = &config
|
|
|
|
return nil
|
|
}
|
|
|
|
func (httpClient *Client) Ping() bool {
|
|
httpClient.lock.Lock()
|
|
defer httpClient.lock.Unlock()
|
|
|
|
pingURL := *httpClient.url
|
|
|
|
pingURL = *pingURL.JoinPath("/v2/")
|
|
|
|
// for the ping function we want to timeout fast
|
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
|
defer cancel()
|
|
|
|
//nolint: bodyclose
|
|
resp, _, err := httpClient.get(ctx, pingURL.String(), false)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
httpClient.getAuthType(resp)
|
|
|
|
if resp.StatusCode >= http.StatusOK && resp.StatusCode <= http.StatusForbidden {
|
|
return true
|
|
}
|
|
|
|
httpClient.log.Error().Str("url", pingURL.String()).Int("statusCode", resp.StatusCode).
|
|
Str("component", "sync").Msg("failed to ping registry")
|
|
|
|
return false
|
|
}
|
|
|
|
func (httpClient *Client) MakeGetRequest(ctx context.Context, resultPtr interface{}, mediaType string,
|
|
route ...string,
|
|
) ([]byte, string, int, error) {
|
|
httpClient.lock.RLock()
|
|
defer httpClient.lock.RUnlock()
|
|
|
|
var namespace string
|
|
|
|
url := *httpClient.url
|
|
for idx, path := range route {
|
|
url = *url.JoinPath(path)
|
|
|
|
// we know that the second route argument is always the repo name.
|
|
// need it for caching tokens, it's not used in requests made to authz server.
|
|
if idx == 1 {
|
|
namespace = path
|
|
}
|
|
}
|
|
|
|
url.RawQuery = url.Query().Encode()
|
|
//nolint: bodyclose,contextcheck
|
|
resp, body, err := httpClient.makeAndDoRequest(http.MethodGet, mediaType, namespace, url.String())
|
|
if err != nil {
|
|
httpClient.log.Error().Err(err).Str("url", url.String()).Str("component", "sync").
|
|
Str("errorType", common.TypeOf(err)).
|
|
Msg("failed to make request")
|
|
|
|
return nil, "", -1, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
|
|
}
|
|
|
|
// read blob
|
|
if len(body) > 0 {
|
|
err = json.Unmarshal(body, &resultPtr)
|
|
}
|
|
|
|
return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
|
|
}
|
|
|
|
func (httpClient *Client) getAuthType(resp *http.Response) {
|
|
authHeader := resp.Header.Get("www-authenticate")
|
|
|
|
authHeaderLower := strings.ToLower(authHeader)
|
|
|
|
//nolint: gocritic
|
|
if strings.Contains(authHeaderLower, "bearer") {
|
|
httpClient.authType = tokenAuth
|
|
} else if strings.Contains(authHeaderLower, "basic") {
|
|
httpClient.authType = basicAuth
|
|
} else {
|
|
httpClient.authType = noneAuth
|
|
}
|
|
}
|
|
|
|
func (httpClient *Client) setupAuth(req *http.Request, namespace string) error {
|
|
if httpClient.authType == tokenAuth {
|
|
token, err := httpClient.getToken(req.URL.String(), namespace)
|
|
if err != nil {
|
|
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
|
|
Str("errorType", common.TypeOf(err)).
|
|
Msg("failed to get token from authorization realm")
|
|
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token.Token)
|
|
} else if httpClient.authType == basicAuth {
|
|
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (httpClient *Client) get(ctx context.Context, url string, setAuth bool) (*http.Response, []byte, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) //nolint
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if setAuth && httpClient.config.Username != "" && httpClient.config.Password != "" {
|
|
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
|
|
}
|
|
|
|
return httpClient.doRequest(req)
|
|
}
|
|
|
|
func (httpClient *Client) doRequest(req *http.Request) (*http.Response, []byte, error) {
|
|
resp, err := httpClient.client.Do(req)
|
|
if err != nil {
|
|
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
|
|
Str("errorType", common.TypeOf(err)).
|
|
Msg("failed to make request")
|
|
|
|
return nil, nil, err
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
httpClient.log.Error().Err(err).Str("url", req.URL.String()).
|
|
Str("errorType", common.TypeOf(err)).
|
|
Msg("failed to read body")
|
|
|
|
return nil, nil, err
|
|
}
|
|
|
|
return resp, body, nil
|
|
}
|
|
|
|
func (httpClient *Client) makeAndDoRequest(method, mediaType, namespace, urlStr string,
|
|
) (*http.Response, []byte, error) {
|
|
req, err := http.NewRequest(method, urlStr, nil) //nolint
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if err := httpClient.setupAuth(req, namespace); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if mediaType != "" {
|
|
req.Header.Set("Accept", mediaType)
|
|
}
|
|
|
|
resp, body, err := httpClient.doRequest(req)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// let's retry one time if we get an insufficient_scope error
|
|
if ok, challengeParams := needsRetryWithUpdatedScope(err, resp); ok {
|
|
var tokenURL *url.URL
|
|
|
|
var token *bearerToken
|
|
|
|
tokenURL, err = getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
token, err = httpClient.getTokenFromURL(tokenURL.String(), namespace)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token.Token)
|
|
|
|
resp, body, err = httpClient.doRequest(req)
|
|
}
|
|
|
|
return resp, body, err
|
|
}
|
|
|
|
func (httpClient *Client) getTokenFromURL(urlStr, namespace string) (*bearerToken, error) {
|
|
//nolint: bodyclose
|
|
resp, body, err := httpClient.get(context.Background(), urlStr, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, zerr.ErrUnauthorizedAccess
|
|
}
|
|
|
|
token, err := newBearerToken(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// cache it
|
|
httpClient.cache.Set(namespace, token)
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// Gets bearer token from Authorization realm.
|
|
func (httpClient *Client) getToken(urlStr, namespace string) (*bearerToken, error) {
|
|
// first check cache
|
|
token := httpClient.cache.Get(namespace)
|
|
if token != nil && !token.isExpired() {
|
|
return token, nil
|
|
}
|
|
|
|
//nolint: bodyclose
|
|
resp, _, err := httpClient.get(context.Background(), urlStr, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
challengeParams, err := parseAuthHeader(resp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tokenURL, err := getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return httpClient.getTokenFromURL(tokenURL.String(), namespace)
|
|
}
|
|
|
|
func newBearerToken(blob []byte) (*bearerToken, error) {
|
|
token := new(bearerToken)
|
|
if err := json.Unmarshal(blob, &token); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if token.Token == "" {
|
|
token.Token = token.AccessToken
|
|
}
|
|
|
|
if token.ExpiresIn < minimumTokenLifetimeSeconds {
|
|
token.ExpiresIn = minimumTokenLifetimeSeconds
|
|
}
|
|
|
|
if token.IssuedAt.IsZero() {
|
|
token.IssuedAt = time.Now().UTC()
|
|
}
|
|
|
|
token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second)
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func getTokenURLFromChallengeParams(params challengeParams, account string) (*url.URL, error) {
|
|
parsedRealm, err := url.Parse(params.realm)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := parsedRealm.Query()
|
|
query.Set("service", params.service)
|
|
query.Set("scope", params.scope)
|
|
|
|
if account != "" {
|
|
query.Set("account", account)
|
|
}
|
|
|
|
parsedRealm.RawQuery = query.Encode()
|
|
|
|
return parsedRealm, nil
|
|
}
|
|
|
|
func parseAuthHeader(resp *http.Response) (challengeParams, error) {
|
|
authHeader := resp.Header.Get("www-authenticate")
|
|
|
|
authHeaderSlice := strings.Split(authHeader, ",")
|
|
|
|
params := challengeParams{}
|
|
|
|
for _, elem := range authHeaderSlice {
|
|
if strings.Contains(strings.ToLower(elem), "bearer") {
|
|
elem = strings.Split(elem, " ")[1]
|
|
}
|
|
|
|
elem := strings.ReplaceAll(elem, "\"", "")
|
|
|
|
elemSplit := strings.Split(elem, "=")
|
|
if len(elemSplit) != 2 { //nolint:mnd
|
|
return params, zerr.ErrParsingAuthHeader
|
|
}
|
|
|
|
authKey := elemSplit[0]
|
|
|
|
authValue := elemSplit[1]
|
|
|
|
switch authKey {
|
|
case "realm":
|
|
params.realm = authValue
|
|
case "service":
|
|
params.service = authValue
|
|
case "scope":
|
|
params.scope = authValue
|
|
case "error":
|
|
params.err = authValue
|
|
}
|
|
}
|
|
|
|
return params, nil
|
|
}
|
|
|
|
// Checks if the auth headers in the response contain an indication of a failed
|
|
// authorization because of an "insufficient_scope" error.
|
|
func needsRetryWithUpdatedScope(err error, resp *http.Response) (bool, challengeParams) {
|
|
params := challengeParams{}
|
|
if err == nil && resp.StatusCode == http.StatusUnauthorized {
|
|
params, err = parseAuthHeader(resp)
|
|
if err != nil {
|
|
return false, params
|
|
}
|
|
|
|
if params.err == "insufficient_scope" {
|
|
if params.scope != "" {
|
|
return true, params
|
|
}
|
|
}
|
|
}
|
|
|
|
return false, params
|
|
}
|