diff --git a/pkg/cli/client/client.go b/pkg/cli/client/client.go index b27e9448..4ef1f3a5 100644 --- a/pkg/cli/client/client.go +++ b/pkg/cli/client/client.go @@ -90,7 +90,13 @@ func doHTTPRequest(req *http.Request, verifyTLS bool, debug bool, httpClientLock.Lock() if httpClientsMap[host] == nil { - httpClient, err = common.CreateHTTPClient(verifyTLS, host, "") + httpClient, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + // we want TLS enabled when verifyTLS is true. + TLSEnabled: verifyTLS, + VerifyTLS: verifyTLS, + Host: host, + CertOptions: common.HTTPClientCertOptions{}, + }) if err != nil { return nil, err } diff --git a/pkg/common/common.go b/pkg/common/common.go index d034e468..3b9abe8b 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -18,9 +18,9 @@ const ( httpTimeout = 5 * time.Minute certsPath = "/etc/containers/certs.d" homeCertsDir = ".config/containers/certs.d" - clientCertFilename = "client.cert" - clientKeyFilename = "client.key" - caCertFilename = "ca.crt" + ClientCertFilename = "client.cert" + ClientKeyFilename = "client.key" + CaCertFilename = "ca.crt" CosignSignature = "cosign" CosignSigKey = "dev.cosignproject.cosign/signature" diff --git a/pkg/common/http_client.go b/pkg/common/http_client.go index a4c62712..8d8f76a5 100644 --- a/pkg/common/http_client.go +++ b/pkg/common/http_client.go @@ -5,14 +5,13 @@ import ( "crypto/x509" "net/http" "os" - "path" "path/filepath" ) func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) { - clientCert := filepath.Join(certsPath, clientCertFilename) - clientKey := filepath.Join(certsPath, clientKeyFilename) - caCertFile := filepath.Join(certsPath, caCertFilename) + clientCert := filepath.Join(certsPath, ClientCertFilename) + clientKey := filepath.Join(certsPath, ClientKeyFilename) + caCertFile := filepath.Join(certsPath, CaCertFilename) cert, err := tls.LoadX509KeyPair(clientCert, clientKey) if err != nil { @@ -59,9 +58,43 @@ func loadPerHostCerts(caCertPool *x509.CertPool, host string) *tls.Config { return nil } -func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client, error) { +// Holds certificate options for an HTTP client. +type HTTPClientCertOptions struct { + ClientCertFile string // Holds the path to the client certificate file. Mandatory if ClientKeyFile is present. + ClientKeyFile string // Holds the path to the client key file. Mandatory if ClientCertFile is present. + RootCaCertFile string // Optional. Holds the path to the custom Root CA cert file. +} + +// Holds client options for creating an HTTP client. +type HTTPClientOptions struct { + // Results in a client with TLS config if true. + TLSEnabled bool + + // Results in a client without certificate config and TLS verification disabled if true. + // Note: if TLSEnabled is false and VerifyTLS is true, the client will not have the verification + // of insecure certificates set to false. For this, both TLSEnabled and VerifyTLS need to be + // true. + VerifyTLS bool + + // The target host for the imminent connection. Used for loading host specific certificates if any. + Host string + + // Certificate options for the client. + CertOptions HTTPClientCertOptions +} + +func CreateHTTPClient(clientOptions *HTTPClientOptions) (*http.Client, error) { htr := http.DefaultTransport.(*http.Transport).Clone() //nolint: forcetypeassert - if !verifyTLS { + + // If TLS is not enabled, return the client without any further TLS config. + if !clientOptions.TLSEnabled { + return &http.Client{ + Timeout: httpTimeout, + Transport: htr, + }, nil + } + + if !clientOptions.VerifyTLS { htr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint: gosec return &http.Client{ @@ -70,36 +103,37 @@ func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client }, nil } - // Add a copy of the system cert pool + // Add a copy of the system cert pool. caCertPool, _ := x509.SystemCertPool() - tlsConfig := loadPerHostCerts(caCertPool, host) - if tlsConfig == nil { - tlsConfig = &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12} - } - - htr.TLSClientConfig = tlsConfig - - if certDir != "" { - clientCert := path.Join(certDir, "client.cert") - clientKey := path.Join(certDir, "client.key") - caCertPath := path.Join(certDir, "ca.crt") - - caCert, err := os.ReadFile(caCertPath) + // Add a custom CA cert if present in the options. + if clientOptions.CertOptions.RootCaCertFile != "" { + caCert, err := os.ReadFile(clientOptions.CertOptions.RootCaCertFile) if err != nil { return nil, err } caCertPool.AppendCertsFromPEM(caCert) + } - cert, err := tls.LoadX509KeyPair(clientCert, clientKey) + // Load certificates specific to the host if any. + tlsConfig := loadPerHostCerts(caCertPool, clientOptions.Host) + if tlsConfig == nil { + tlsConfig = &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12} + } + + // Try to load certificate key pair if either are present in the options. + if clientOptions.CertOptions.ClientCertFile != "" || clientOptions.CertOptions.ClientKeyFile != "" { + cert, err := tls.LoadX509KeyPair(clientOptions.CertOptions.ClientCertFile, clientOptions.CertOptions.ClientKeyFile) if err != nil { return nil, err } - htr.TLSClientConfig.Certificates = append(htr.TLSClientConfig.Certificates, cert) + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) } + htr.TLSClientConfig = tlsConfig + return &http.Client{ Transport: htr, Timeout: httpTimeout, diff --git a/pkg/common/http_client_test.go b/pkg/common/http_client_test.go index 058e7505..c74473dd 100644 --- a/pkg/common/http_client_test.go +++ b/pkg/common/http_client_test.go @@ -2,6 +2,7 @@ package common_test import ( "crypto/x509" + "net/http" "os" "path" "testing" @@ -35,7 +36,16 @@ func TestHTTPClient(t *testing.T) { err = os.Chmod(path.Join(tempDir, "ca.crt"), 0o000) So(err, ShouldBeNil) - _, err = common.CreateHTTPClient(true, "localhost", tempDir) + _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + RootCaCertFile: path.Join(tempDir, common.CaCertFilename), + }, + }) So(err, ShouldNotBeNil) }) @@ -46,7 +56,124 @@ func TestHTTPClient(t *testing.T) { err = os.Chmod(path.Join(tempDir, "client.key"), 0o000) So(err, ShouldBeNil) - _, err = common.CreateHTTPClient(true, "localhost", tempDir) + _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + RootCaCertFile: path.Join(tempDir, common.CaCertFilename), + }, + }) So(err, ShouldNotBeNil) }) + + Convey("test CreateHTTPClient() no TLS", t, func() { + _, err := common.CreateHTTPClient(&common.HTTPClientOptions{}) + So(err, ShouldBeNil) + }) + + Convey("test CreateHTTPClient() with only client cert configured", t, func() { + tempDir := t.TempDir() + err := test.CopyTestKeysAndCerts(tempDir) + So(err, ShouldBeNil) + + _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + }, + }) + So(err, ShouldNotBeNil) + }) + + Convey("test CreateHTTPClient() with only client key configured", t, func() { + tempDir := t.TempDir() + err := test.CopyTestKeysAndCerts(tempDir) + So(err, ShouldBeNil) + + _, err = common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + }, + }) + So(err, ShouldNotBeNil) + }) + + Convey("test CreateHTTPClient() with full certificate config", t, func() { + tempDir := t.TempDir() + err := test.CopyTestKeysAndCerts(tempDir) + So(err, ShouldBeNil) + + client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + RootCaCertFile: path.Join(tempDir, common.CaCertFilename), + }, + }) + So(err, ShouldBeNil) + + htr, ok := client.Transport.(*http.Transport) + So(ok, ShouldBeTrue) + So(htr.TLSClientConfig.RootCAs, ShouldNotBeNil) + So(htr.TLSClientConfig.Certificates, ShouldNotBeEmpty) + }) + + Convey("test CreateHTTPClient() with no TLS verify", t, func() { + tempDir := t.TempDir() + err := test.CopyTestKeysAndCerts(tempDir) + So(err, ShouldBeNil) + + client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: true, + VerifyTLS: false, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + RootCaCertFile: path.Join(tempDir, common.CaCertFilename), + }, + }) + So(err, ShouldBeNil) + + htr, ok := client.Transport.(*http.Transport) + So(ok, ShouldBeTrue) + So(htr.TLSClientConfig.Certificates, ShouldBeEmpty) + So(htr.TLSClientConfig.RootCAs, ShouldBeNil) + So(htr.TLSClientConfig.InsecureSkipVerify, ShouldBeTrue) + }) + + Convey("test CreateHTTPClient() with no TLS, but TLS verify enabled", t, func() { + tempDir := t.TempDir() + err := test.CopyTestKeysAndCerts(tempDir) + So(err, ShouldBeNil) + + client, err := common.CreateHTTPClient(&common.HTTPClientOptions{ + TLSEnabled: false, + VerifyTLS: true, + Host: "localhost", + CertOptions: common.HTTPClientCertOptions{ + ClientCertFile: path.Join(tempDir, common.ClientCertFilename), + ClientKeyFile: path.Join(tempDir, common.ClientKeyFilename), + RootCaCertFile: path.Join(tempDir, common.CaCertFilename), + }, + }) + So(err, ShouldBeNil) + + htr, ok := client.Transport.(*http.Transport) + So(ok, ShouldBeTrue) + So(htr.TLSClientConfig.Certificates, ShouldBeEmpty) + So(htr.TLSClientConfig.RootCAs, ShouldBeNil) + So(htr.TLSClientConfig.InsecureSkipVerify, ShouldBeFalse) + }) } diff --git a/pkg/extensions/sync/httpclient/client.go b/pkg/extensions/sync/httpclient/client.go index 4b968d61..a2125dfe 100644 --- a/pkg/extensions/sync/httpclient/client.go +++ b/pkg/extensions/sync/httpclient/client.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "path/filepath" "strings" "sync" "time" @@ -114,7 +115,26 @@ func (httpClient *Client) SetConfig(config Config) error { httpClient.url = clientURL - client, err := common.CreateHTTPClient(config.TLSVerify, clientURL.Host, config.CertDir) + clientOpts := common.HTTPClientOptions{ + // we want TLS enabled when verifyTLS is true. + TLSEnabled: config.TLSVerify, + VerifyTLS: config.TLSVerify, + Host: clientURL.Host, + } + + if config.CertDir != "" { + // only configure the default cert file names if the CertDir was specified. + clientOpts.CertOptions = common.HTTPClientCertOptions{ + // filepath is the recommended library to use for joining paths + // taking into account the underlying OS. + // ref: https://stackoverflow.com/a/39182128 + ClientCertFile: filepath.Join(config.CertDir, common.ClientCertFilename), + ClientKeyFile: filepath.Join(config.CertDir, common.ClientKeyFilename), + RootCaCertFile: filepath.Join(config.CertDir, common.CaCertFilename), + } + } + + client, err := common.CreateHTTPClient(&clientOpts) if err != nil { return err }