0
Fork 0
mirror of https://github.com/project-zot/zot.git synced 2024-12-30 22:34:13 -05:00

refactor(http): refactor http client to accept more customisable options (#2414)

refactor(http): refactor http client to take options struct

This commit updates the arguments for the `CreateHTTPClient`
function to consume a struct which can be extended as required.
It replaces the certPath argument with a struct of 3 paths for
client ertificate, client key, and ca cert. It also adds
a TLSEnabled option for when an HTTP Client is required
without any further TLS config.

Existing consumers of this function have been updated so that
they can work as they do today. This change is a no-op for
existing features.

This allows for certificate paths to be customised and
allows other modules to re-use the same HTTP client and get
the benefits of mTLS support and per-host certificates.

Signed-off-by: Vishwas Rajashekar <vrajashe@cisco.com>
This commit is contained in:
Vishwas R 2024-05-07 02:13:41 +05:30 committed by GitHub
parent 4671e412fc
commit be5ad66797
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 216 additions and 29 deletions

View file

@ -90,7 +90,13 @@ func doHTTPRequest(req *http.Request, verifyTLS bool, debug bool,
httpClientLock.Lock() httpClientLock.Lock()
if httpClientsMap[host] == nil { if httpClientsMap[host] == nil {
httpClient, err = common.CreateHTTPClient(verifyTLS, host, "") httpClient, err = common.CreateHTTPClient(&common.HTTPClientOptions{
// we want TLS enabled when verifyTLS is true.
TLSEnabled: verifyTLS,
VerifyTLS: verifyTLS,
Host: host,
CertOptions: common.HTTPClientCertOptions{},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -18,9 +18,9 @@ const (
httpTimeout = 5 * time.Minute httpTimeout = 5 * time.Minute
certsPath = "/etc/containers/certs.d" certsPath = "/etc/containers/certs.d"
homeCertsDir = ".config/containers/certs.d" homeCertsDir = ".config/containers/certs.d"
clientCertFilename = "client.cert" ClientCertFilename = "client.cert"
clientKeyFilename = "client.key" ClientKeyFilename = "client.key"
caCertFilename = "ca.crt" CaCertFilename = "ca.crt"
CosignSignature = "cosign" CosignSignature = "cosign"
CosignSigKey = "dev.cosignproject.cosign/signature" CosignSigKey = "dev.cosignproject.cosign/signature"

View file

@ -5,14 +5,13 @@ import (
"crypto/x509" "crypto/x509"
"net/http" "net/http"
"os" "os"
"path"
"path/filepath" "path/filepath"
) )
func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) { func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) {
clientCert := filepath.Join(certsPath, clientCertFilename) clientCert := filepath.Join(certsPath, ClientCertFilename)
clientKey := filepath.Join(certsPath, clientKeyFilename) clientKey := filepath.Join(certsPath, ClientKeyFilename)
caCertFile := filepath.Join(certsPath, caCertFilename) caCertFile := filepath.Join(certsPath, CaCertFilename)
cert, err := tls.LoadX509KeyPair(clientCert, clientKey) cert, err := tls.LoadX509KeyPair(clientCert, clientKey)
if err != nil { if err != nil {
@ -59,9 +58,43 @@ func loadPerHostCerts(caCertPool *x509.CertPool, host string) *tls.Config {
return nil return nil
} }
func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client, error) { // Holds certificate options for an HTTP client.
type HTTPClientCertOptions struct {
ClientCertFile string // Holds the path to the client certificate file. Mandatory if ClientKeyFile is present.
ClientKeyFile string // Holds the path to the client key file. Mandatory if ClientCertFile is present.
RootCaCertFile string // Optional. Holds the path to the custom Root CA cert file.
}
// Holds client options for creating an HTTP client.
type HTTPClientOptions struct {
// Results in a client with TLS config if true.
TLSEnabled bool
// Results in a client without certificate config and TLS verification disabled if true.
// Note: if TLSEnabled is false and VerifyTLS is true, the client will not have the verification
// of insecure certificates set to false. For this, both TLSEnabled and VerifyTLS need to be
// true.
VerifyTLS bool
// The target host for the imminent connection. Used for loading host specific certificates if any.
Host string
// Certificate options for the client.
CertOptions HTTPClientCertOptions
}
func CreateHTTPClient(clientOptions *HTTPClientOptions) (*http.Client, error) {
htr := http.DefaultTransport.(*http.Transport).Clone() //nolint: forcetypeassert htr := http.DefaultTransport.(*http.Transport).Clone() //nolint: forcetypeassert
if !verifyTLS {
// If TLS is not enabled, return the client without any further TLS config.
if !clientOptions.TLSEnabled {
return &http.Client{
Timeout: httpTimeout,
Transport: htr,
}, nil
}
if !clientOptions.VerifyTLS {
htr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint: gosec htr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint: gosec
return &http.Client{ return &http.Client{
@ -70,36 +103,37 @@ func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client
}, nil }, nil
} }
// Add a copy of the system cert pool // Add a copy of the system cert pool.
caCertPool, _ := x509.SystemCertPool() caCertPool, _ := x509.SystemCertPool()
tlsConfig := loadPerHostCerts(caCertPool, host) // Add a custom CA cert if present in the options.
if tlsConfig == nil { if clientOptions.CertOptions.RootCaCertFile != "" {
tlsConfig = &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12} caCert, err := os.ReadFile(clientOptions.CertOptions.RootCaCertFile)
}
htr.TLSClientConfig = tlsConfig
if certDir != "" {
clientCert := path.Join(certDir, "client.cert")
clientKey := path.Join(certDir, "client.key")
caCertPath := path.Join(certDir, "ca.crt")
caCert, err := os.ReadFile(caCertPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
caCertPool.AppendCertsFromPEM(caCert) caCertPool.AppendCertsFromPEM(caCert)
}
cert, err := tls.LoadX509KeyPair(clientCert, clientKey) // Load certificates specific to the host if any.
tlsConfig := loadPerHostCerts(caCertPool, clientOptions.Host)
if tlsConfig == nil {
tlsConfig = &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}
}
// Try to load certificate key pair if either are present in the options.
if clientOptions.CertOptions.ClientCertFile != "" || clientOptions.CertOptions.ClientKeyFile != "" {
cert, err := tls.LoadX509KeyPair(clientOptions.CertOptions.ClientCertFile, clientOptions.CertOptions.ClientKeyFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
htr.TLSClientConfig.Certificates = append(htr.TLSClientConfig.Certificates, cert) tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
} }
htr.TLSClientConfig = tlsConfig
return &http.Client{ return &http.Client{
Transport: htr, Transport: htr,
Timeout: httpTimeout, Timeout: httpTimeout,

View file

@ -2,6 +2,7 @@ package common_test
import ( import (
"crypto/x509" "crypto/x509"
"net/http"
"os" "os"
"path" "path"
"testing" "testing"
@ -35,7 +36,16 @@ func TestHTTPClient(t *testing.T) {
err = os.Chmod(path.Join(tempDir, "ca.crt"), 0o000) err = os.Chmod(path.Join(tempDir, "ca.crt"), 0o000)
So(err, ShouldBeNil) So(err, ShouldBeNil)
_, err = common.CreateHTTPClient(true, "localhost", tempDir) _, err = common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
RootCaCertFile: path.Join(tempDir, common.CaCertFilename),
},
})
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
}) })
@ -46,7 +56,124 @@ func TestHTTPClient(t *testing.T) {
err = os.Chmod(path.Join(tempDir, "client.key"), 0o000) err = os.Chmod(path.Join(tempDir, "client.key"), 0o000)
So(err, ShouldBeNil) So(err, ShouldBeNil)
_, err = common.CreateHTTPClient(true, "localhost", tempDir) _, err = common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
RootCaCertFile: path.Join(tempDir, common.CaCertFilename),
},
})
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
}) })
Convey("test CreateHTTPClient() no TLS", t, func() {
_, err := common.CreateHTTPClient(&common.HTTPClientOptions{})
So(err, ShouldBeNil)
})
Convey("test CreateHTTPClient() with only client cert configured", t, func() {
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
_, err = common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
},
})
So(err, ShouldNotBeNil)
})
Convey("test CreateHTTPClient() with only client key configured", t, func() {
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
_, err = common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
},
})
So(err, ShouldNotBeNil)
})
Convey("test CreateHTTPClient() with full certificate config", t, func() {
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
client, err := common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
RootCaCertFile: path.Join(tempDir, common.CaCertFilename),
},
})
So(err, ShouldBeNil)
htr, ok := client.Transport.(*http.Transport)
So(ok, ShouldBeTrue)
So(htr.TLSClientConfig.RootCAs, ShouldNotBeNil)
So(htr.TLSClientConfig.Certificates, ShouldNotBeEmpty)
})
Convey("test CreateHTTPClient() with no TLS verify", t, func() {
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
client, err := common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: true,
VerifyTLS: false,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
RootCaCertFile: path.Join(tempDir, common.CaCertFilename),
},
})
So(err, ShouldBeNil)
htr, ok := client.Transport.(*http.Transport)
So(ok, ShouldBeTrue)
So(htr.TLSClientConfig.Certificates, ShouldBeEmpty)
So(htr.TLSClientConfig.RootCAs, ShouldBeNil)
So(htr.TLSClientConfig.InsecureSkipVerify, ShouldBeTrue)
})
Convey("test CreateHTTPClient() with no TLS, but TLS verify enabled", t, func() {
tempDir := t.TempDir()
err := test.CopyTestKeysAndCerts(tempDir)
So(err, ShouldBeNil)
client, err := common.CreateHTTPClient(&common.HTTPClientOptions{
TLSEnabled: false,
VerifyTLS: true,
Host: "localhost",
CertOptions: common.HTTPClientCertOptions{
ClientCertFile: path.Join(tempDir, common.ClientCertFilename),
ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename),
RootCaCertFile: path.Join(tempDir, common.CaCertFilename),
},
})
So(err, ShouldBeNil)
htr, ok := client.Transport.(*http.Transport)
So(ok, ShouldBeTrue)
So(htr.TLSClientConfig.Certificates, ShouldBeEmpty)
So(htr.TLSClientConfig.RootCAs, ShouldBeNil)
So(htr.TLSClientConfig.InsecureSkipVerify, ShouldBeFalse)
})
} }

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -114,7 +115,26 @@ func (httpClient *Client) SetConfig(config Config) error {
httpClient.url = clientURL httpClient.url = clientURL
client, err := common.CreateHTTPClient(config.TLSVerify, clientURL.Host, config.CertDir) clientOpts := common.HTTPClientOptions{
// we want TLS enabled when verifyTLS is true.
TLSEnabled: config.TLSVerify,
VerifyTLS: config.TLSVerify,
Host: clientURL.Host,
}
if config.CertDir != "" {
// only configure the default cert file names if the CertDir was specified.
clientOpts.CertOptions = common.HTTPClientCertOptions{
// filepath is the recommended library to use for joining paths
// taking into account the underlying OS.
// ref: https://stackoverflow.com/a/39182128
ClientCertFile: filepath.Join(config.CertDir, common.ClientCertFilename),
ClientKeyFile: filepath.Join(config.CertDir, common.ClientKeyFilename),
RootCaCertFile: filepath.Join(config.CertDir, common.CaCertFilename),
}
}
client, err := common.CreateHTTPClient(&clientOpts)
if err != nil { if err != nil {
return err return err
} }