From fed5c09b7145366c6971efa417325d7dd18c77a9 Mon Sep 17 00:00:00 2001 From: Roxana Nemulescu Date: Tue, 6 Jul 2021 19:50:46 +0300 Subject: [PATCH] TLS certs in CLI client resolve #194 --- Makefile | 8 +- pkg/cli/client.go | 110 ++++++++++- pkg/cli/client_test.go | 405 ++++++++++++++++++++++++++++++++++++++ pkg/cli/cve_cmd_test.go | 50 +---- pkg/cli/image_cmd_test.go | 22 ++- 5 files changed, 539 insertions(+), 56 deletions(-) create mode 100644 pkg/cli/client_test.go diff --git a/Makefile b/Makefile index a94c8fbd..20687286 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ TMPDIR := $(shell mktemp -d) STACKER := $(shell which stacker) .PHONY: all -all: doc binary binary-minimal debug test check +all: doc binary binary-minimal debug test test-clean check .PHONY: binary-minimal binary-minimal: doc @@ -25,8 +25,14 @@ debug: doc .PHONY: test test: $(shell mkdir -p test/data; cd test/data; ../scripts/gen_certs.sh; cd ${TOP_LEVEL}; sudo skopeo --insecure-policy copy -q docker://public.ecr.aws/t0x7q1g8/centos:7 oci:${TOP_LEVEL}/test/data/zot-test:0.0.1;sudo skopeo --insecure-policy copy -q docker://public.ecr.aws/t0x7q1g8/centos:8 oci:${TOP_LEVEL}/test/data/zot-cve-test:0.0.1) + $(shell sudo mkdir -p /etc/containers/certs.d/127.0.0.1:8089/; sudo cp test/data/client.* /etc/containers/certs.d/127.0.0.1:8089/; sudo cp test/data/ca.* /etc/containers/certs.d/127.0.0.1:8089/;) + $(shell sudo chmod a=rwx /etc/containers/certs.d/127.0.0.1:8089/*.key) go test -tags extended -v -race -cover -coverpkg ./... -coverprofile=coverage.txt -covermode=atomic ./... +.PHONY: test-clean +test-clean: + $(shell sudo rm -rf /etc/containers/certs.d/127.0.0.1:8089/) + .PHONY: covhtml covhtml: go tool cover -html=coverage.txt -o coverage.html diff --git a/pkg/cli/client.go b/pkg/cli/client.go index ba4dcf77..da3af454 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -6,11 +6,14 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "io/ioutil" "net/http" "net/url" + "os" + "path/filepath" "strings" "sync" "time" @@ -18,16 +21,39 @@ import ( zotErrors "github.com/anuvu/zot/errors" ) -var httpClient *http.Client //nolint: gochecknoglobals +var httpClientsMap = make(map[string]*http.Client) //nolint: gochecknoglobals +var httpClientLock sync.Mutex //nolint: gochecknoglobals -const httpTimeout = 5 * time.Minute +const ( + httpTimeout = 5 * time.Minute + certsPath = "/etc/containers/certs.d" + homeCertsDir = ".config/containers/certs.d" + clientCertFilename = "client.cert" + clientKeyFilename = "client.key" + caCertFilename = "ca.crt" +) -func createHTTPClient(verifyTLS bool) *http.Client { +func createHTTPClient(verifyTLS bool, host string) *http.Client { var tr = http.DefaultTransport.(*http.Transport).Clone() if !verifyTLS { tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint: gosec + + return &http.Client{ + Timeout: httpTimeout, + Transport: tr, + } } + // Add a copy of the system cert pool + caCertPool, _ := x509.SystemCertPool() + + tlsConfig := loadPerHostCerts(caCertPool, host) + if tlsConfig == nil { + tlsConfig = &tls.Config{RootCAs: caCertPool} + } + + tr = &http.Transport{TLSClientConfig: tlsConfig} + return &http.Client{ Timeout: httpTimeout, Transport: tr, @@ -70,10 +96,22 @@ func makeGraphQLRequest(url, query, username, } func doHTTPRequest(req *http.Request, verifyTLS bool, resultsPtr interface{}) (http.Header, error) { - if httpClient == nil { - httpClient = createHTTPClient(verifyTLS) + var httpClient *http.Client + + host := req.Host + + httpClientLock.Lock() + + if httpClientsMap[host] == nil { + httpClient = createHTTPClient(verifyTLS, host) + + httpClientsMap[host] = httpClient + } else { + httpClient = httpClientsMap[host] } + httpClientLock.Unlock() + resp, err := httpClient.Do(req) if err != nil { return nil, err @@ -98,6 +136,68 @@ func doHTTPRequest(req *http.Request, verifyTLS bool, resultsPtr interface{}) (h return resp.Header, nil } +func loadPerHostCerts(caCertPool *x509.CertPool, host string) *tls.Config { + // Check if the /home/user/.config/containers/certs.d/$IP:$PORT dir exists + home := os.Getenv("HOME") + clientCertsDir := filepath.Join(home, homeCertsDir, host) + + if dirExists(clientCertsDir) { + tlsConfig, err := getTLSConfig(clientCertsDir, caCertPool) + + if err == nil { + return tlsConfig + } + } + + // Check if the /etc/containers/certs.d/$IP:$PORT dir exists + clientCertsDir = filepath.Join(certsPath, host) + if dirExists(clientCertsDir) { + tlsConfig, err := getTLSConfig(clientCertsDir, caCertPool) + + if err == nil { + return tlsConfig + } + } + + return nil +} + +func getTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) { + clientCert := filepath.Join(certsPath, clientCertFilename) + clientKey := filepath.Join(certsPath, clientKeyFilename) + caCertFile := filepath.Join(certsPath, caCertFilename) + + cert, err := tls.LoadX509KeyPair(clientCert, clientKey) + if err != nil { + return nil, err + } + + caCert, err := ioutil.ReadFile(caCertFile) + if err != nil { + return nil, err + } + + caCertPool.AppendCertsFromPEM(caCert) + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + }, nil +} + +func dirExists(d string) bool { + fi, err := os.Stat(d) + if err != nil && os.IsNotExist(err) { + return false + } + + if !fi.IsDir() { + return false + } + + return true +} + func isURL(str string) bool { u, err := url.Parse(str) return err == nil && u.Scheme != "" && u.Host != "" diff --git a/pkg/cli/client_test.go b/pkg/cli/client_test.go new file mode 100644 index 00000000..5929ecbc --- /dev/null +++ b/pkg/cli/client_test.go @@ -0,0 +1,405 @@ +// +build extended + +package cli //nolint:testpackage + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "io/ioutil" + "os" + "path" + "path/filepath" + + "gopkg.in/resty.v1" + + "testing" + "time" + + "github.com/anuvu/zot/pkg/api" + . "github.com/smartystreets/goconvey/convey" +) + +const ( + BaseURL1 = "http://127.0.0.1:8088" + BaseSecureURL1 = "https://127.0.0.1:8088" + HOST1 = "127.0.0.1:8088" + SecurePort1 = "8088" + BaseURL2 = "http://127.0.0.1:8089" + BaseSecureURL2 = "https://127.0.0.1:8089" + SecurePort2 = "8089" + BaseURL3 = "http://127.0.0.1:8090" + BaseSecureURL3 = "https://127.0.0.1:8090" + SecurePort3 = "8090" + username = "test" + passphrase = "test" + ServerCert = "../../test/data/server.cert" + ServerKey = "../../test/data/server.key" + CACert = "../../test/data/ca.crt" + sourceCertsDir = "../../test/data" + certsDir1 = "/.config/containers/certs.d/127.0.0.1:8088/" +) + +func makeHtpasswdFile() string { + f, err := ioutil.TempFile("", "htpasswd-") + if err != nil { + panic(err) + } + + // bcrypt(username="test", passwd="test") + content := []byte("test:$2y$05$hlbSXDp6hzDLu6VwACS39ORvVRpr3OMR4RlJ31jtlaOEGnPjKZI1m\n") + if err := ioutil.WriteFile(f.Name(), content, 0600); err != nil { + panic(err) + } + + return f.Name() +} + +func TestTLSWithAuth(t *testing.T) { + Convey("Make a new controller", t, func() { + caCert, err := ioutil.ReadFile(CACert) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool}) + defer func() { resty.SetTLSClientConfig(nil) }() + config := api.NewConfig() + config.HTTP.Port = SecurePort1 + htpasswdPath := makeHtpasswdFile() + defer os.Remove(htpasswdPath) + + config.HTTP.Auth = &api.AuthConfig{ + HTPasswd: api.AuthHTPasswd{ + Path: htpasswdPath, + }, + } + + config.HTTP.TLS = &api.TLSConfig{ + Cert: ServerCert, + Key: ServerKey, + CACert: CACert, + } + + c := api.NewController(config) + dir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + c.Config.Storage.RootDirectory = dir + go func() { + // this blocks + if err := c.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(BaseSecureURL1) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + defer func() { + ctx := context.Background() + _ = c.Server.Shutdown(ctx) + }() + + Convey("Test with htpassw auth", func() { + configPath := makeConfigFile(`{"configs":[{"_name":"imagetest","showspinner":false}]}`) + defer os.Remove(configPath) + + home := os.Getenv("HOME") + destCertsDir := filepath.Join(home, certsDir1) + if err = copyFiles(sourceCertsDir, destCertsDir); err != nil { + panic(err) + } + defer os.RemoveAll(destCertsDir) + + args := []string{"imagetest", "--name", "dummyImageName", "--url", HOST1} + imageCmd := NewImageCommand(new(searchService)) + imageBuff := bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err := imageCmd.Execute() + So(err, ShouldNotBeNil) + So(imageBuff.String(), ShouldContainSubstring, "invalid URL format") + + args = []string{"imagetest"} + configPath = makeConfigFile( + fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s/v2/_catalog","showspinner":false}]}`, + BaseSecureURL1)) + defer os.Remove(configPath) + imageCmd = NewImageCommand(new(searchService)) + imageBuff = bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err = imageCmd.Execute() + So(err, ShouldNotBeNil) + So(imageBuff.String(), ShouldContainSubstring, "check credentials") + + user := fmt.Sprintf("%s:%s", username, passphrase) + args = []string{"imagetest", "-u", user} + configPath = makeConfigFile( + fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s/v2/_catalog","showspinner":false}]}`, + BaseSecureURL1)) + defer os.Remove(configPath) + imageCmd = NewImageCommand(new(searchService)) + imageBuff = bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err = imageCmd.Execute() + So(err, ShouldBeNil) + }) + }) +} + +func TestTLSWithoutAuth(t *testing.T) { + Convey("Home certs - Make a new controller", t, func() { + caCert, err := ioutil.ReadFile(CACert) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool}) + defer func() { resty.SetTLSClientConfig(nil) }() + config := api.NewConfig() + config.HTTP.Port = SecurePort1 + config.HTTP.TLS = &api.TLSConfig{ + Cert: ServerCert, + Key: ServerKey, + CACert: CACert, + } + + c := api.NewController(config) + dir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + c.Config.Storage.RootDirectory = dir + go func() { + // this blocks + if err := c.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(BaseURL1) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + defer func() { + ctx := context.Background() + _ = c.Server.Shutdown(ctx) + }() + + Convey("Certs in user's home", func() { + configPath := makeConfigFile( + fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s/v2/_catalog","showspinner":false}]}`, + BaseSecureURL1)) + defer os.Remove(configPath) + + home := os.Getenv("HOME") + destCertsDir := filepath.Join(home, certsDir1) + if err = copyFiles(sourceCertsDir, destCertsDir); err != nil { + panic(err) + } + defer os.RemoveAll(destCertsDir) + + args := []string{"imagetest"} + imageCmd := NewImageCommand(new(searchService)) + imageBuff := bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err := imageCmd.Execute() + So(err, ShouldBeNil) + }) + }) + + Convey("Privileged certs - Make a new controller", t, func() { + caCert, err := ioutil.ReadFile(CACert) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool}) + defer func() { resty.SetTLSClientConfig(nil) }() + config := api.NewConfig() + config.HTTP.Port = SecurePort2 + config.HTTP.TLS = &api.TLSConfig{ + Cert: ServerCert, + Key: ServerKey, + CACert: CACert, + } + + c := api.NewController(config) + dir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + c.Config.Storage.RootDirectory = dir + go func() { + // this blocks + if err := c.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(BaseURL2) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + defer func() { + ctx := context.Background() + _ = c.Server.Shutdown(ctx) + }() + + Convey("Certs in privileged path", func() { + configPath := makeConfigFile( + fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s/v2/_catalog","showspinner":false}]}`, + BaseSecureURL2)) + defer os.Remove(configPath) + + args := []string{"imagetest"} + imageCmd := NewImageCommand(new(searchService)) + imageBuff := bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err := imageCmd.Execute() + So(err, ShouldBeNil) + }) + }) +} + +func TestTLSBadCerts(t *testing.T) { + Convey("Make a new controller", t, func() { + caCert, err := ioutil.ReadFile(CACert) + So(err, ShouldBeNil) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + resty.SetTLSClientConfig(&tls.Config{RootCAs: caCertPool}) + defer func() { resty.SetTLSClientConfig(nil) }() + config := api.NewConfig() + config.HTTP.Port = SecurePort3 + config.HTTP.TLS = &api.TLSConfig{ + Cert: ServerCert, + Key: ServerKey, + CACert: CACert, + } + + c := api.NewController(config) + dir, err := ioutil.TempDir("", "oci-repo-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + c.Config.Storage.RootDirectory = dir + go func() { + // this blocks + if err := c.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(BaseURL3) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + defer func() { + ctx := context.Background() + _ = c.Server.Shutdown(ctx) + }() + + Convey("Test with system certs", func() { + configPath := makeConfigFile( + fmt.Sprintf(`{"configs":[{"_name":"imagetest","url":"%s/v2/_catalog","showspinner":false}]}`, + BaseSecureURL3)) + defer os.Remove(configPath) + + args := []string{"imagetest"} + imageCmd := NewImageCommand(new(searchService)) + imageBuff := bytes.NewBufferString("") + imageCmd.SetOut(imageBuff) + imageCmd.SetErr(ioutil.Discard) + imageCmd.SetArgs(args) + err := imageCmd.Execute() + So(err, ShouldNotBeNil) + So(imageBuff.String(), ShouldContainSubstring, "certificate signed by unknown authority") + }) + }) +} + +func copyFiles(sourceDir string, destDir string) error { + sourceMeta, err := os.Stat(sourceDir) + if err != nil { + return err + } + + if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil { + return err + } + + files, err := ioutil.ReadDir(sourceDir) + if err != nil { + return err + } + + for _, file := range files { + sourceFilePath := path.Join(sourceDir, file.Name()) + destFilePath := path.Join(destDir, file.Name()) + + if file.IsDir() { + if err = copyFiles(sourceFilePath, destFilePath); err != nil { + return err + } + } else { + sourceFile, err := os.Open(sourceFilePath) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(destFilePath) + if err != nil { + return err + } + defer destFile.Close() + + if _, err = io.Copy(destFile, sourceFile); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/cli/cve_cmd_test.go b/pkg/cli/cve_cmd_test.go index 609a1785..b0258574 100644 --- a/pkg/cli/cve_cmd_test.go +++ b/pkg/cli/cve_cmd_test.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "fmt" - "io" "io/ioutil" "os" "path" @@ -285,8 +284,8 @@ func TestSearchCVECmd(t *testing.T) { } func TestServerCVEResponse(t *testing.T) { - port := "8080" - url := "http://127.0.0.1:8080" + port := getFreePort() + url := getBaseURL(port) config := api.NewConfig() config.HTTP.Port = port c := api.NewController(config) @@ -481,48 +480,3 @@ func TestServerCVEResponse(t *testing.T) { }) }) } - -func copyFiles(sourceDir string, destDir string) error { - sourceMeta, err := os.Stat(sourceDir) - if err != nil { - return err - } - - if err := os.MkdirAll(destDir, sourceMeta.Mode()); err != nil { - return err - } - - files, err := ioutil.ReadDir(sourceDir) - if err != nil { - return err - } - - for _, file := range files { - sourceFilePath := path.Join(sourceDir, file.Name()) - destFilePath := path.Join(destDir, file.Name()) - - if file.IsDir() { - if err = copyFiles(sourceFilePath, destFilePath); err != nil { - return err - } - } else { - sourceFile, err := os.Open(sourceFilePath) - if err != nil { - return err - } - defer sourceFile.Close() - - destFile, err := os.Create(destFilePath) - if err != nil { - return err - } - defer destFile.Close() - - if _, err = io.Copy(destFile, sourceFile); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/cli/image_cmd_test.go b/pkg/cli/image_cmd_test.go index 0cf1ebb4..9397715e 100644 --- a/pkg/cli/image_cmd_test.go +++ b/pkg/cli/image_cmd_test.go @@ -25,9 +25,27 @@ import ( "github.com/anuvu/zot/pkg/extensions" godigest "github.com/opencontainers/go-digest" ispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/phayes/freeport" . "github.com/smartystreets/goconvey/convey" ) +const ( + BaseURL = "http://127.0.0.1:%s" +) + +func getBaseURL(port string) string { + return fmt.Sprintf(BaseURL, port) +} + +func getFreePort() string { + port, err := freeport.GetFreePort() + if err != nil { + panic(err) + } + + return fmt.Sprint(port) +} + func TestSearchImageCmd(t *testing.T) { Convey("Test image help", t, func() { args := []string{"--help"} @@ -282,8 +300,8 @@ func TestOutputFormat(t *testing.T) { func TestServerResponse(t *testing.T) { Convey("Test from real server", t, func() { - port := "8080" - url := "http://127.0.0.1:8080" + port := getFreePort() + url := getBaseURL(port) config := api.NewConfig() config.HTTP.Port = port config.Extensions = &extensions.ExtensionConfig{