0
Fork 0
mirror of https://github.com/project-zot/zot.git synced 2025-01-06 22:40:28 -05:00
zot/pkg/extensions/sync/httpclient/client.go
Vishwas R be5ad66797
refactor(http): refactor http client to accept more customisable options (#2414)
refactor(http): refactor http client to take options struct

This commit updates the arguments for the `CreateHTTPClient`
function to consume a struct which can be extended as required.
It replaces the certPath argument with a struct of 3 paths for
client ertificate, client key, and ca cert. It also adds
a TLSEnabled option for when an HTTP Client is required
without any further TLS config.

Existing consumers of this function have been updated so that
they can work as they do today. This change is a no-op for
existing features.

This allows for certificate paths to be customised and
allows other modules to re-use the same HTTP client and get
the benefits of mTLS support and per-host certificates.

Signed-off-by: Vishwas Rajashekar <vrajashe@cisco.com>
2024-05-06 13:43:41 -07:00

482 lines
11 KiB
Go

package client
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync"
"time"
zerr "zotregistry.dev/zot/errors"
"zotregistry.dev/zot/pkg/common"
"zotregistry.dev/zot/pkg/log"
)
const (
minimumTokenLifetimeSeconds = 60 // in seconds
pingTimeout = 5 * time.Second
// tokenBuffer is used to renew a token before it actually expires
// to account for the time to process requests on the server.
tokenBuffer = 5 * time.Second
)
type authType int
const (
noneAuth authType = iota
basicAuth
tokenAuth
)
type challengeParams struct {
realm string
service string
scope string
err string
}
type bearerToken struct {
Token string `json:"token"` //nolint: tagliatelle
AccessToken string `json:"access_token"` //nolint: tagliatelle
ExpiresIn int `json:"expires_in"` //nolint: tagliatelle
IssuedAt time.Time `json:"issued_at"` //nolint: tagliatelle
expirationTime time.Time
}
func (token *bearerToken) isExpired() bool {
// use tokenBuffer to expire it a bit earlier
return time.Now().After(token.expirationTime.Add(-1 * tokenBuffer))
}
type Config struct {
URL string
Username string
Password string
CertDir string
TLSVerify bool
}
type Client struct {
config *Config
client *http.Client
url *url.URL
authType authType
cache *TokenCache
lock *sync.RWMutex
log log.Logger
}
func New(config Config, log log.Logger) (*Client, error) {
client := &Client{log: log, lock: new(sync.RWMutex)}
client.cache = NewTokenCache()
if err := client.SetConfig(config); err != nil {
return nil, err
}
return client, nil
}
func (httpClient *Client) GetConfig() *Config {
httpClient.lock.RLock()
defer httpClient.lock.RUnlock()
return httpClient.config
}
func (httpClient *Client) GetHostname() string {
httpClient.lock.RLock()
defer httpClient.lock.RUnlock()
return httpClient.url.Host
}
func (httpClient *Client) GetBaseURL() string {
httpClient.lock.RLock()
defer httpClient.lock.RUnlock()
return httpClient.url.String()
}
func (httpClient *Client) SetConfig(config Config) error {
httpClient.lock.Lock()
defer httpClient.lock.Unlock()
clientURL, err := url.Parse(config.URL)
if err != nil {
return err
}
httpClient.url = clientURL
clientOpts := common.HTTPClientOptions{
// we want TLS enabled when verifyTLS is true.
TLSEnabled: config.TLSVerify,
VerifyTLS: config.TLSVerify,
Host: clientURL.Host,
}
if config.CertDir != "" {
// only configure the default cert file names if the CertDir was specified.
clientOpts.CertOptions = common.HTTPClientCertOptions{
// filepath is the recommended library to use for joining paths
// taking into account the underlying OS.
// ref: https://stackoverflow.com/a/39182128
ClientCertFile: filepath.Join(config.CertDir, common.ClientCertFilename),
ClientKeyFile: filepath.Join(config.CertDir, common.ClientKeyFilename),
RootCaCertFile: filepath.Join(config.CertDir, common.CaCertFilename),
}
}
client, err := common.CreateHTTPClient(&clientOpts)
if err != nil {
return err
}
httpClient.client = client
httpClient.config = &config
return nil
}
func (httpClient *Client) Ping() bool {
httpClient.lock.Lock()
defer httpClient.lock.Unlock()
pingURL := *httpClient.url
pingURL = *pingURL.JoinPath("/v2/")
// for the ping function we want to timeout fast
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
//nolint: bodyclose
resp, _, err := httpClient.get(ctx, pingURL.String(), false)
if err != nil {
return false
}
httpClient.getAuthType(resp)
if resp.StatusCode >= http.StatusOK && resp.StatusCode <= http.StatusForbidden {
return true
}
httpClient.log.Error().Str("url", pingURL.String()).Int("statusCode", resp.StatusCode).
Str("component", "sync").Msg("failed to ping registry")
return false
}
func (httpClient *Client) MakeGetRequest(ctx context.Context, resultPtr interface{}, mediaType string,
route ...string,
) ([]byte, string, int, error) {
httpClient.lock.RLock()
defer httpClient.lock.RUnlock()
var namespace string
url := *httpClient.url
for idx, path := range route {
url = *url.JoinPath(path)
// we know that the second route argument is always the repo name.
// need it for caching tokens, it's not used in requests made to authz server.
if idx == 1 {
namespace = path
}
}
url.RawQuery = url.Query().Encode()
//nolint: bodyclose,contextcheck
resp, body, err := httpClient.makeAndDoRequest(http.MethodGet, mediaType, namespace, url.String())
if err != nil {
httpClient.log.Error().Err(err).Str("url", url.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to make request")
return nil, "", -1, err
}
if resp.StatusCode != http.StatusOK {
return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
}
// read blob
if len(body) > 0 {
err = json.Unmarshal(body, &resultPtr)
}
return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
}
func (httpClient *Client) getAuthType(resp *http.Response) {
authHeader := resp.Header.Get("www-authenticate")
authHeaderLower := strings.ToLower(authHeader)
//nolint: gocritic
if strings.Contains(authHeaderLower, "bearer") {
httpClient.authType = tokenAuth
} else if strings.Contains(authHeaderLower, "basic") {
httpClient.authType = basicAuth
} else {
httpClient.authType = noneAuth
}
}
func (httpClient *Client) setupAuth(req *http.Request, namespace string) error {
if httpClient.authType == tokenAuth {
token, err := httpClient.getToken(req.URL.String(), namespace)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to get token from authorization realm")
return err
}
req.Header.Set("Authorization", "Bearer "+token.Token)
} else if httpClient.authType == basicAuth {
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
}
return nil
}
func (httpClient *Client) get(ctx context.Context, url string, setAuth bool) (*http.Response, []byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) //nolint
if err != nil {
return nil, nil, err
}
if setAuth && httpClient.config.Username != "" && httpClient.config.Password != "" {
req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
}
return httpClient.doRequest(req)
}
func (httpClient *Client) doRequest(req *http.Request) (*http.Response, []byte, error) {
resp, err := httpClient.client.Do(req)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
Str("errorType", common.TypeOf(err)).
Msg("failed to make request")
return nil, nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
httpClient.log.Error().Err(err).Str("url", req.URL.String()).
Str("errorType", common.TypeOf(err)).
Msg("failed to read body")
return nil, nil, err
}
return resp, body, nil
}
func (httpClient *Client) makeAndDoRequest(method, mediaType, namespace, urlStr string,
) (*http.Response, []byte, error) {
req, err := http.NewRequest(method, urlStr, nil) //nolint
if err != nil {
return nil, nil, err
}
if err := httpClient.setupAuth(req, namespace); err != nil {
return nil, nil, err
}
if mediaType != "" {
req.Header.Set("Accept", mediaType)
}
resp, body, err := httpClient.doRequest(req)
if err != nil {
return nil, nil, err
}
// let's retry one time if we get an insufficient_scope error
if ok, challengeParams := needsRetryWithUpdatedScope(err, resp); ok {
var tokenURL *url.URL
var token *bearerToken
tokenURL, err = getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
if err != nil {
return nil, nil, err
}
token, err = httpClient.getTokenFromURL(tokenURL.String(), namespace)
if err != nil {
return nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+token.Token)
resp, body, err = httpClient.doRequest(req)
}
return resp, body, err
}
func (httpClient *Client) getTokenFromURL(urlStr, namespace string) (*bearerToken, error) {
//nolint: bodyclose
resp, body, err := httpClient.get(context.Background(), urlStr, true)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, zerr.ErrUnauthorizedAccess
}
token, err := newBearerToken(body)
if err != nil {
return nil, err
}
// cache it
httpClient.cache.Set(namespace, token)
return token, nil
}
// Gets bearer token from Authorization realm.
func (httpClient *Client) getToken(urlStr, namespace string) (*bearerToken, error) {
// first check cache
token := httpClient.cache.Get(namespace)
if token != nil && !token.isExpired() {
return token, nil
}
//nolint: bodyclose
resp, _, err := httpClient.get(context.Background(), urlStr, false)
if err != nil {
return nil, err
}
challengeParams, err := parseAuthHeader(resp)
if err != nil {
return nil, err
}
tokenURL, err := getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
if err != nil {
return nil, err
}
return httpClient.getTokenFromURL(tokenURL.String(), namespace)
}
func newBearerToken(blob []byte) (*bearerToken, error) {
token := new(bearerToken)
if err := json.Unmarshal(blob, &token); err != nil {
return nil, err
}
if token.Token == "" {
token.Token = token.AccessToken
}
if token.ExpiresIn < minimumTokenLifetimeSeconds {
token.ExpiresIn = minimumTokenLifetimeSeconds
}
if token.IssuedAt.IsZero() {
token.IssuedAt = time.Now().UTC()
}
token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second)
return token, nil
}
func getTokenURLFromChallengeParams(params challengeParams, account string) (*url.URL, error) {
parsedRealm, err := url.Parse(params.realm)
if err != nil {
return nil, err
}
query := parsedRealm.Query()
query.Set("service", params.service)
query.Set("scope", params.scope)
if account != "" {
query.Set("account", account)
}
parsedRealm.RawQuery = query.Encode()
return parsedRealm, nil
}
func parseAuthHeader(resp *http.Response) (challengeParams, error) {
authHeader := resp.Header.Get("www-authenticate")
authHeaderSlice := strings.Split(authHeader, ",")
params := challengeParams{}
for _, elem := range authHeaderSlice {
if strings.Contains(strings.ToLower(elem), "bearer") {
elem = strings.Split(elem, " ")[1]
}
elem := strings.ReplaceAll(elem, "\"", "")
elemSplit := strings.Split(elem, "=")
if len(elemSplit) != 2 { //nolint: gomnd
return params, zerr.ErrParsingAuthHeader
}
authKey := elemSplit[0]
authValue := elemSplit[1]
switch authKey {
case "realm":
params.realm = authValue
case "service":
params.service = authValue
case "scope":
params.scope = authValue
case "error":
params.err = authValue
}
}
return params, nil
}
// Checks if the auth headers in the response contain an indication of a failed
// authorization because of an "insufficient_scope" error.
func needsRetryWithUpdatedScope(err error, resp *http.Response) (bool, challengeParams) {
params := challengeParams{}
if err == nil && resp.StatusCode == http.StatusUnauthorized {
params, err = parseAuthHeader(resp)
if err != nil {
return false, params
}
if params.err == "insufficient_scope" {
if params.scope != "" {
return true, params
}
}
}
return false, params
}