diff --git a/pkg/api/controller.go b/pkg/api/controller.go index a1fce970..fdf09c22 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -31,14 +31,16 @@ const ( ) type Controller struct { - Config *config.Config - Router *mux.Router - StoreController storage.StoreController - Log log.Logger - Audit *log.Logger - Server *http.Server - Metrics monitoring.MetricServer - wgShutDown *goSync.WaitGroup // use it to gracefully shutdown goroutines + Config *config.Config + Router *mux.Router + StoreController storage.StoreController + Log log.Logger + Audit *log.Logger + Server *http.Server + Metrics monitoring.MetricServer + wgShutDown *goSync.WaitGroup // use it to gracefully shutdown goroutines + reloadCtx context.Context // use it to gracefully reload goroutines with new configuration + cancelOnReloadFunc context.CancelFunc // use it to stop goroutines } func NewController(config *config.Config) *Controller { @@ -48,6 +50,9 @@ func NewController(config *config.Config) *Controller { controller.Config = config controller.Log = logger controller.wgShutDown = new(goSync.WaitGroup) + /* context used to cancel go routines so that + we can change their config on the fly (restart routines with different config) */ + controller.reloadCtx, controller.cancelOnReloadFunc = context.WithCancel(context.Background()) if config.Log.Audit != "" { audit := log.NewAuditLogger(config.Log.Level, config.Log.Audit) @@ -321,12 +326,35 @@ func (c *Controller) InitImageStore() error { // Enable extensions if extension config is provided if c.Config.Extensions != nil && c.Config.Extensions.Sync != nil && *c.Config.Extensions.Sync.Enable { - ext.EnableSyncExtension(c.Config, c.wgShutDown, c.StoreController, c.Log) + ext.EnableSyncExtension(c.reloadCtx, c.Config, c.wgShutDown, c.StoreController, c.Log) } return nil } +func (c *Controller) LoadNewConfig(config *config.Config) { + // cancel go routines context so we can reload configuration + c.cancelOnReloadFunc() + + // reload access control config + c.Config.AccessControl = config.AccessControl + c.Config.HTTP.RawAccessControl = config.HTTP.RawAccessControl + + // create new context for the next config reload + c.reloadCtx, c.cancelOnReloadFunc = context.WithCancel(context.Background()) + + // Enable extensions if extension config is provided + if config.Extensions != nil && config.Extensions.Sync != nil { + // reload sync config + c.Config.Extensions.Sync = config.Extensions.Sync + ext.EnableSyncExtension(c.reloadCtx, c.Config, c.wgShutDown, c.StoreController, c.Log) + } else if c.Config.Extensions != nil { + c.Config.Extensions.Sync = nil + } + + c.Log.Info().Interface("reloaded params", c.Config.Sanitize()).Msg("new configuration settings") +} + func (c *Controller) Shutdown() { // wait gracefully c.wgShutDown.Wait() diff --git a/pkg/cli/config_reloader.go b/pkg/cli/config_reloader.go new file mode 100644 index 00000000..50ca4222 --- /dev/null +++ b/pkg/cli/config_reloader.go @@ -0,0 +1,72 @@ +package cli + +import ( + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" + "zotregistry.io/zot/pkg/api" + "zotregistry.io/zot/pkg/api/config" +) + +type HotReloader struct { + watcher *fsnotify.Watcher + filePath string + ctlr *api.Controller +} + +func NewHotReloader(ctlr *api.Controller, filePath string) (*HotReloader, error) { + // creates a new file watcher + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + hotReloader := &HotReloader{ + watcher: watcher, + filePath: filePath, + ctlr: ctlr, + } + + return hotReloader, nil +} + +func (hr *HotReloader) Start() { + done := make(chan bool) + // run watcher + go func() { + defer hr.watcher.Close() + + go func() { + for { + select { + // watch for events + case event := <-hr.watcher.Events: + if event.Op == fsnotify.Write { + log.Info().Msg("config file changed, trying to reload config") + + newConfig := config.New() + + err := LoadConfiguration(newConfig, hr.filePath) + if err != nil { + log.Error().Err(err).Msg("couldn't reload config, retry writing it.") + + continue + } + + hr.ctlr.LoadNewConfig(newConfig) + } + // watch for errors + case err := <-hr.watcher.Errors: + log.Error().Err(err).Msgf("fsnotfy error while watching config %s", hr.filePath) + panic(err) + } + } + }() + + if err := hr.watcher.Add(hr.filePath); err != nil { + log.Error().Err(err).Msgf("error adding config file %s to FsNotify watcher", hr.filePath) + panic(err) + } + + <-done + }() +} diff --git a/pkg/cli/config_reloader_test.go b/pkg/cli/config_reloader_test.go new file mode 100644 index 00000000..6e8ac987 --- /dev/null +++ b/pkg/cli/config_reloader_test.go @@ -0,0 +1,384 @@ +package cli_test + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "golang.org/x/crypto/bcrypt" + "zotregistry.io/zot/pkg/cli" + "zotregistry.io/zot/pkg/test" +) + +func TestConfigReloader(t *testing.T) { + oldArgs := os.Args + + defer func() { os.Args = oldArgs }() + + Convey("reload access control config", t, func(c C) { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + logFile, err := ioutil.TempFile("", "zot-log*.txt") + So(err, ShouldBeNil) + + username := "alice" + password := "alice" + + hash, err := bcrypt.GenerateFromPassword([]byte(password), 10) + if err != nil { + panic(err) + } + + usernameAndHash := fmt.Sprintf("%s:%s", username, string(hash)) + + htpasswdPath := test.MakeHtpasswdFileFromString(usernameAndHash) + defer os.Remove(htpasswdPath) + + defer os.Remove(logFile.Name()) // clean up + + content := fmt.Sprintf(`{ + "distSpecVersion": "0.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot" + }, + "http": { + "address": "127.0.0.1", + "port": "%s", + "realm": "zot", + "auth": { + "htpasswd": { + "path": "%s" + }, + "failDelay": 1 + }, + "accessControl": { + "**": { + "policies": [ + { + "users": ["charlie"], + "actions": ["read"] + } + ], + "defaultPolicy": ["read", "create"] + }, + "adminPolicy": { + "users": ["admin"], + "actions": ["read", "create", "update", "delete"] + } + } + }, + "log": { + "level": "debug", + "output": "%s" + } + }`, port, htpasswdPath, logFile.Name()) + + cfgfile, err := ioutil.TempFile("", "zot-test*.json") + So(err, ShouldBeNil) + + defer os.Remove(cfgfile.Name()) // clean up + + _, err = cfgfile.Write([]byte(content)) + So(err, ShouldBeNil) + + // err = cfgfile.Close() + // So(err, ShouldBeNil) + + os.Args = []string{"cli_test", "serve", cfgfile.Name()} + go func() { + err = cli.NewServerRootCmd().Execute() + So(err, ShouldBeNil) + }() + + test.WaitTillServerReady(baseURL) + + content = fmt.Sprintf(`{ + "distSpecVersion": "0.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot" + }, + "http": { + "address": "127.0.0.1", + "port": "%s", + "realm": "zot", + "auth": { + "htpasswd": { + "path": "%s" + }, + "failDelay": 1 + }, + "accessControl": { + "**": { + "policies": [ + { + "users": ["alice"], + "actions": ["read", "create", "update", "delete"] + } + ], + "defaultPolicy": ["read"] + }, + "adminPolicy": { + "users": ["admin"], + "actions": ["read", "create", "update", "delete"] + } + } + }, + "log": { + "level": "debug", + "output": "%s" + } + }`, port, htpasswdPath, logFile.Name()) + + err = cfgfile.Truncate(0) + So(err, ShouldBeNil) + + _, err = cfgfile.Seek(0, io.SeekStart) + So(err, ShouldBeNil) + + _, err = cfgfile.WriteString(content) + So(err, ShouldBeNil) + + err = cfgfile.Close() + So(err, ShouldBeNil) + + // wait for config reload + time.Sleep(2 * time.Second) + + data, err := os.ReadFile(logFile.Name()) + So(err, ShouldBeNil) + So(string(data), ShouldContainSubstring, "reloaded params") + So(string(data), ShouldContainSubstring, "new configuration settings") + So(string(data), ShouldContainSubstring, "\"Users\":[\"alice\"]") + So(string(data), ShouldContainSubstring, "\"Actions\":[\"read\",\"create\",\"update\",\"delete\"]") + }) + + Convey("reload sync config", t, func(c C) { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + logFile, err := ioutil.TempFile("", "zot-log*.txt") + So(err, ShouldBeNil) + + defer os.Remove(logFile.Name()) // clean up + + content := fmt.Sprintf(`{ + "distSpecVersion": "0.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot" + }, + "http": { + "address": "127.0.0.1", + "port": "%s" + }, + "log": { + "level": "debug", + "output": "%s" + }, + "extensions": { + "sync": { + "registries": [{ + "urls": ["http://localhost:8080"], + "tlsVerify": false, + "onDemand": true, + "maxRetries": 3, + "retryDelay": "15m", + "certDir": "", + "content":[ + { + "prefix": "zot-test", + "tags": { + "regex": ".*", + "semver": true + } + } + ] + }] + } + } + }`, port, logFile.Name()) + + cfgfile, err := ioutil.TempFile("", "zot-test*.json") + So(err, ShouldBeNil) + + defer os.Remove(cfgfile.Name()) // clean up + + _, err = cfgfile.Write([]byte(content)) + So(err, ShouldBeNil) + + // err = cfgfile.Close() + // So(err, ShouldBeNil) + + os.Args = []string{"cli_test", "serve", cfgfile.Name()} + go func() { + err = cli.NewServerRootCmd().Execute() + So(err, ShouldBeNil) + }() + + test.WaitTillServerReady(baseURL) + + content = fmt.Sprintf(`{ + "distSpecVersion": "0.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot" + }, + "http": { + "address": "127.0.0.1", + "port": "%s" + }, + "log": { + "level": "debug", + "output": "%s" + }, + "extensions": { + "sync": { + "registries": [{ + "urls": ["http://localhost:9999"], + "tlsVerify": true, + "onDemand": false, + "maxRetries": 10, + "retryDelay": "5m", + "certDir": "certs", + "content":[ + { + "prefix": "zot-cve-test", + "tags": { + "regex": "tag", + "semver": false + } + } + ] + }] + } + } + }`, port, logFile.Name()) + + err = cfgfile.Truncate(0) + So(err, ShouldBeNil) + + _, err = cfgfile.Seek(0, io.SeekStart) + So(err, ShouldBeNil) + + _, err = cfgfile.WriteString(content) + So(err, ShouldBeNil) + + err = cfgfile.Close() + So(err, ShouldBeNil) + + // wait for config reload + time.Sleep(2 * time.Second) + + data, err := os.ReadFile(logFile.Name()) + So(err, ShouldBeNil) + So(string(data), ShouldContainSubstring, "reloaded params") + So(string(data), ShouldContainSubstring, "new configuration settings") + So(string(data), ShouldContainSubstring, "\"URLs\":[\"http://localhost:9999\"]") + So(string(data), ShouldContainSubstring, "\"TLSVerify\":true") + So(string(data), ShouldContainSubstring, "\"OnDemand\":false") + So(string(data), ShouldContainSubstring, "\"MaxRetries\":10") + So(string(data), ShouldContainSubstring, "\"RetryDelay\":300000000000") + So(string(data), ShouldContainSubstring, "\"CertDir\":\"certs\"") + So(string(data), ShouldContainSubstring, "\"Prefix\":\"zot-cve-test\"") + So(string(data), ShouldContainSubstring, "\"Regex\":\"tag\"") + So(string(data), ShouldContainSubstring, "\"Semver\":false") + }) + + Convey("reload bad config", t, func(c C) { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + logFile, err := ioutil.TempFile("", "zot-log*.txt") + So(err, ShouldBeNil) + + defer os.Remove(logFile.Name()) // clean up + + content := fmt.Sprintf(`{ + "distSpecVersion": "0.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot" + }, + "http": { + "address": "127.0.0.1", + "port": "%s" + }, + "log": { + "level": "debug", + "output": "%s" + }, + "extensions": { + "sync": { + "registries": [{ + "urls": ["http://localhost:8080"], + "tlsVerify": false, + "onDemand": true, + "maxRetries": 3, + "retryDelay": "15m", + "certDir": "", + "content":[ + { + "prefix": "zot-test", + "tags": { + "regex": ".*", + "semver": true + } + } + ] + }] + } + } + }`, port, logFile.Name()) + + cfgfile, err := ioutil.TempFile("", "zot-test*.json") + So(err, ShouldBeNil) + + defer os.Remove(cfgfile.Name()) // clean up + + _, err = cfgfile.Write([]byte(content)) + So(err, ShouldBeNil) + + // err = cfgfile.Close() + // So(err, ShouldBeNil) + + os.Args = []string{"cli_test", "serve", cfgfile.Name()} + go func() { + err = cli.NewServerRootCmd().Execute() + So(err, ShouldBeNil) + }() + + test.WaitTillServerReady(baseURL) + + content = "[]" + + err = cfgfile.Truncate(0) + So(err, ShouldBeNil) + + _, err = cfgfile.Seek(0, io.SeekStart) + So(err, ShouldBeNil) + + _, err = cfgfile.WriteString(content) + So(err, ShouldBeNil) + + err = cfgfile.Close() + So(err, ShouldBeNil) + + // wait for config reload + time.Sleep(2 * time.Second) + + data, err := os.ReadFile(logFile.Name()) + So(err, ShouldBeNil) + So(string(data), ShouldNotContainSubstring, "reloaded params") + So(string(data), ShouldNotContainSubstring, "new configuration settings") + So(string(data), ShouldContainSubstring, "\"URLs\":[\"http://localhost:8080\"]") + So(string(data), ShouldContainSubstring, "\"TLSVerify\":false") + So(string(data), ShouldContainSubstring, "\"OnDemand\":true") + So(string(data), ShouldContainSubstring, "\"MaxRetries\":3") + So(string(data), ShouldContainSubstring, "\"CertDir\":\"\"") + So(string(data), ShouldContainSubstring, "\"Prefix\":\"zot-test\"") + So(string(data), ShouldContainSubstring, "\"Regex\":\".*\"") + So(string(data), ShouldContainSubstring, "\"Semver\":true") + }) +} diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 103f4dd3..6022dcc1 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -7,7 +7,6 @@ import ( "time" glob "github.com/bmatcuk/doublestar/v4" - "github.com/fsnotify/fsnotify" "github.com/mitchellh/mapstructure" distspec "github.com/opencontainers/distribution-spec/specs-go" "github.com/rs/zerolog/log" @@ -38,46 +37,19 @@ func newServeCmd(conf *config.Config) *cobra.Command { Long: "`serve` stores and distributes OCI images", Run: func(cmd *cobra.Command, args []string) { if len(args) > 0 { - LoadConfiguration(conf, args[0]) + if err := LoadConfiguration(conf, args[0]); err != nil { + panic(err) + } } ctlr := api.NewController(conf) - // creates a new file watcher - watcher, err := fsnotify.NewWatcher() + hotReloader, err := NewHotReloader(ctlr, args[0]) if err != nil { panic(err) } - defer watcher.Close() - done := make(chan bool) - // run watcher - go func() { - go func() { - for { - select { - // watch for events - case event := <-watcher.Events: - if event.Op == fsnotify.Write { - log.Info().Msg("config file changed, trying to reload accessControl config") - newConfig := config.New() - LoadConfiguration(newConfig, args[0]) - ctlr.Config.AccessControl = newConfig.AccessControl - } - // watch for errors - case err := <-watcher.Errors: - log.Error().Err(err).Msgf("FsNotify error while watching config %s", args[0]) - panic(err) - } - } - }() - - if err := watcher.Add(args[0]); err != nil { - log.Error().Err(err).Msgf("error adding config file %s to FsNotify watcher", args[0]) - panic(err) - } - <-done - }() + hotReloader.Start() if err := ctlr.Run(); err != nil { panic(err) @@ -97,7 +69,9 @@ func newScrubCmd(conf *config.Config) *cobra.Command { Long: "`scrub` checks manifest/blob integrity", Run: func(cmd *cobra.Command, args []string) { if len(args) > 0 { - LoadConfiguration(conf, args[0]) + if err := LoadConfiguration(conf, args[0]); err != nil { + panic(err) + } } else { if err := cmd.Usage(); err != nil { panic(err) @@ -152,7 +126,10 @@ func newVerifyCmd(conf *config.Config) *cobra.Command { Long: "`verify` validates a zot config file", Run: func(cmd *cobra.Command, args []string) { if len(args) > 0 { - LoadConfiguration(conf, args[0]) + if err := LoadConfiguration(conf, args[0]); err != nil { + panic(err) + } + log.Info().Msgf("Config file %s is valid", args[0]) } }, @@ -220,12 +197,13 @@ func NewCliRootCmd() *cobra.Command { return rootCmd } -func validateConfiguration(config *config.Config) { +func validateConfiguration(config *config.Config) error { // enforce GC params if config.Storage.GCDelay < 0 { log.Error().Err(errors.ErrBadConfig). Msgf("invalid garbage-collect delay %v specified", config.Storage.GCDelay) - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } if !config.Storage.GC && config.Storage.GCDelay != 0 { @@ -238,7 +216,8 @@ func validateConfiguration(config *config.Config) { if config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil) { log.Error().Err(errors.ErrBadConfig). Msg("access control config requires httpasswd or ldap authentication to be enabled") - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } } @@ -246,13 +225,15 @@ func validateConfiguration(config *config.Config) { // enforce s3 driver in case of using storage driver if config.Storage.StorageDriver["name"] != storage.S3StorageDriverName { log.Error().Err(errors.ErrBadConfig).Msgf("unsupported storage driver: %s", config.Storage.StorageDriver["name"]) - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } // enforce filesystem storage in case sync feature is enabled if config.Extensions != nil && config.Extensions.Sync != nil { log.Error().Err(errors.ErrBadConfig).Msg("sync supports only filesystem storage") - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } } @@ -263,7 +244,8 @@ func validateConfiguration(config *config.Config) { if regCfg.MaxRetries != nil && regCfg.RetryDelay == nil { log.Error().Err(errors.ErrBadConfig).Msgf("extensions.sync.registries[%d].retryDelay"+ " is required when using extensions.sync.registries[%d].maxRetries", id, id) - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } if regCfg.Content != nil { @@ -271,7 +253,8 @@ func validateConfiguration(config *config.Config) { ok := glob.ValidatePattern(content.Prefix) if !ok { log.Error().Err(glob.ErrBadPattern).Str("pattern", content.Prefix).Msg("sync pattern could not be compiled") - panic(errors.ErrBadConfig) + + return glob.ErrBadPattern } } } @@ -288,7 +271,8 @@ func validateConfiguration(config *config.Config) { if storageConfig.StorageDriver["name"] != storage.S3StorageDriverName { log.Error().Err(errors.ErrBadConfig).Str("subpath", route).Msgf("unsupported storage driver: %s", storageConfig.StorageDriver["name"]) - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } } } @@ -301,10 +285,13 @@ func validateConfiguration(config *config.Config) { ok := glob.ValidatePattern(pattern) if !ok { log.Error().Err(glob.ErrBadPattern).Str("pattern", pattern).Msg("authorization pattern could not be compiled") - panic(errors.ErrBadConfig) + + return glob.ErrBadPattern } } } + + return nil } func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) { @@ -382,7 +369,7 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) { } } -func LoadConfiguration(config *config.Config, configPath string) { +func LoadConfiguration(config *config.Config, configPath string) error { // Default is dot (.) but because we allow glob patterns in authz // we need another key delimiter. viperInstance := viper.NewWithOptions(viper.KeyDelimiter("::")) @@ -391,29 +378,37 @@ func LoadConfiguration(config *config.Config, configPath string) { if err := viperInstance.ReadInConfig(); err != nil { log.Error().Err(err).Msg("error while reading configuration") - panic(err) + + return err } metaData := &mapstructure.Metadata{} if err := viperInstance.Unmarshal(&config, metadataConfig(metaData)); err != nil { log.Error().Err(err).Msg("error while unmarshalling new config") - panic(err) + + return err } if len(metaData.Keys) == 0 || len(metaData.Unused) > 0 { log.Error().Err(errors.ErrBadConfig).Msg("bad configuration, retry writing it") - panic(errors.ErrBadConfig) + + return errors.ErrBadConfig } err := config.LoadAccessControlConfig(viperInstance) if err != nil { log.Error().Err(err).Msg("unable to unmarshal config's accessControl") - panic(err) + + return err } // defaults applyDefaultValues(config, viperInstance) // various config checks - validateConfiguration(config) + if err := validateConfiguration(config); err != nil { + return err + } + + return nil } diff --git a/pkg/cli/root_test.go b/pkg/cli/root_test.go index 866aca8b..0ddcc8f3 100644 --- a/pkg/cli/root_test.go +++ b/pkg/cli/root_test.go @@ -234,7 +234,7 @@ func TestVerify(t *testing.T) { content := []byte(`{"storage":{"rootDirectory":"/tmp/zot"}, "http":{"address":"127.0.0.1","port":"8080","realm":"zot", "auth":{"htpasswd":{"path":"test/data/htpasswd"},"failDelay":1}, - "accessControl":{"\|":{"policies":[],"defaultPolicy":[]}}}}`) + "accessControl":{"[":{"policies":[],"defaultPolicy":[]}}}}`) _, err = tmpfile.Write(content) So(err, ShouldBeNil) err = tmpfile.Close() @@ -299,16 +299,19 @@ func TestVerify(t *testing.T) { func TestLoadConfig(t *testing.T) { Convey("Test viper load config", t, func(c C) { config := config.New() - So(func() { cli.LoadConfiguration(config, "../../examples/config-policy.json") }, ShouldNotPanic) + err := cli.LoadConfiguration(config, "../../examples/config-policy.json") + So(err, ShouldBeNil) }) } func TestGC(t *testing.T) { Convey("Test GC config", t, func(c C) { config := config.New() - So(func() { cli.LoadConfiguration(config, "../../examples/config-multiple.json") }, ShouldNotPanic) + err := cli.LoadConfiguration(config, "../../examples/config-multiple.json") + So(err, ShouldBeNil) So(config.Storage.GCDelay, ShouldEqual, storage.DefaultGCDelay) - So(func() { cli.LoadConfiguration(config, "../../examples/config-gc.json") }, ShouldNotPanic) + err = cli.LoadConfiguration(config, "../../examples/config-gc.json") + So(err, ShouldBeNil) So(config.Storage.GCDelay, ShouldNotEqual, storage.DefaultGCDelay) }) @@ -330,7 +333,8 @@ func TestGC(t *testing.T) { err = ioutil.WriteFile(file.Name(), contents, 0o600) So(err, ShouldBeNil) - So(func() { cli.LoadConfiguration(config, file.Name()) }, ShouldNotPanic) + err = cli.LoadConfiguration(config, file.Name()) + So(err, ShouldBeNil) }) Convey("Negative GC delay", func() { @@ -347,7 +351,8 @@ func TestGC(t *testing.T) { err = ioutil.WriteFile(file.Name(), contents, 0o600) So(err, ShouldBeNil) - So(func() { cli.LoadConfiguration(config, file.Name()) }, ShouldPanic) + err = cli.LoadConfiguration(config, file.Name()) + So(err, ShouldNotBeNil) }) }) } @@ -547,7 +552,8 @@ func TestApplyDefaultValues(t *testing.T) { err = os.Chmod(file.Name(), 0o777) So(err, ShouldBeNil) - cli.LoadConfiguration(oldConfig, file.Name()) + err = cli.LoadConfiguration(oldConfig, file.Name()) + So(err, ShouldBeNil) configContent, err = ioutil.ReadFile(file.Name()) So(err, ShouldBeNil) @@ -563,7 +569,8 @@ func TestApplyDefaultValues(t *testing.T) { err = os.Chmod(file.Name(), 0o444) So(err, ShouldBeNil) - cli.LoadConfiguration(oldConfig, file.Name()) + err = cli.LoadConfiguration(oldConfig, file.Name()) + So(err, ShouldBeNil) configContent, err = ioutil.ReadFile(file.Name()) So(err, ShouldBeNil) diff --git a/pkg/extensions/extensions.go b/pkg/extensions/extensions.go index 9ee5a5cb..ea511e37 100644 --- a/pkg/extensions/extensions.go +++ b/pkg/extensions/extensions.go @@ -4,6 +4,7 @@ package extensions import ( + "context" goSync "sync" "time" @@ -69,14 +70,14 @@ func EnableExtensions(config *config.Config, log log.Logger, rootDir string) { } // EnableSyncExtension enables sync extension. -func EnableSyncExtension(config *config.Config, wg *goSync.WaitGroup, +func EnableSyncExtension(ctx context.Context, config *config.Config, wg *goSync.WaitGroup, storeController storage.StoreController, log log.Logger) { if config.Extensions.Sync != nil && *config.Extensions.Sync.Enable { - if err := sync.Run(*config.Extensions.Sync, storeController, wg, log); err != nil { + if err := sync.Run(ctx, *config.Extensions.Sync, storeController, wg, log); err != nil { log.Error().Err(err).Msg("Error encountered while setting up syncing") } } else { - log.Info().Msg("Sync registries config not provided, skipping sync") + log.Info().Msg("Sync registries config not provided or disabled, skipping sync") } } diff --git a/pkg/extensions/minimal.go b/pkg/extensions/minimal.go index ea94d159..02e1b9f9 100644 --- a/pkg/extensions/minimal.go +++ b/pkg/extensions/minimal.go @@ -4,6 +4,7 @@ package extensions import ( + "context" goSync "sync" "time" @@ -25,7 +26,7 @@ func EnableExtensions(config *config.Config, log log.Logger, rootDir string) { } // EnableSyncExtension ... -func EnableSyncExtension(config *config.Config, wg *goSync.WaitGroup, +func EnableSyncExtension(ctx context.Context, config *config.Config, wg *goSync.WaitGroup, storeController storage.StoreController, log log.Logger) { log.Warn().Msg("skipping enabling sync extension because given zot binary doesn't support any extensions," + "please build zot full binary for this feature") diff --git a/pkg/extensions/sync/sync.go b/pkg/extensions/sync/sync.go index 30c6bc51..87bf6126 100644 --- a/pkg/extensions/sync/sync.go +++ b/pkg/extensions/sync/sync.go @@ -184,8 +184,8 @@ func filterImagesBySemver(upstreamReferences *[]types.ImageReference, content Co } // imagesToCopyFromRepos lists all images given a registry name and its repos. -func imagesToCopyFromUpstream(registryName string, repos []string, upstreamCtx *types.SystemContext, - content Content, log log.Logger) ([]types.ImageReference, error) { +func imagesToCopyFromUpstream(ctx context.Context, registryName string, repos []string, + upstreamCtx *types.SystemContext, content Content, log log.Logger) ([]types.ImageReference, error) { var upstreamReferences []types.ImageReference for _, repoName := range repos { @@ -196,7 +196,7 @@ func imagesToCopyFromUpstream(registryName string, repos []string, upstreamCtx * return nil, err } - tags, err := getImageTags(context.Background(), upstreamCtx, repoRef) + tags, err := getImageTags(ctx, upstreamCtx, repoRef) if err != nil { log.Error().Err(err).Msgf("couldn't fetch tags for %s", repoRef) @@ -279,8 +279,9 @@ func getUpstreamContext(regCfg *RegistryConfig, credentials Credentials) *types. return upstreamCtx } -func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController storage.StoreController, - localCtx *types.SystemContext, policyCtx *signature.PolicyContext, credentials Credentials, log log.Logger) error { +func syncRegistry(ctx context.Context, regCfg RegistryConfig, upstreamURL string, + storeController storage.StoreController, localCtx *types.SystemContext, + policyCtx *signature.PolicyContext, credentials Credentials, log log.Logger) error { log.Info().Msgf("syncing registry: %s", upstreamURL) var err error @@ -306,7 +307,7 @@ func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController sto return err } - if err = retry.RetryIfNecessary(context.Background(), func() error { + if err = retry.RetryIfNecessary(ctx, func() error { catalog, err = getUpstreamCatalog(httpClient, upstreamURL, log) return err @@ -330,8 +331,8 @@ func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController sto r := repos id := contentID - if err = retry.RetryIfNecessary(context.Background(), func() error { - refs, err := imagesToCopyFromUpstream(upstreamAddr, r, upstreamCtx, regCfg.Content[id], log) + if err = retry.RetryIfNecessary(ctx, func() error { + refs, err := imagesToCopyFromUpstream(ctx, upstreamAddr, r, upstreamCtx, regCfg.Content[id], log) images = append(images, refs...) return err @@ -356,7 +357,7 @@ func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController sto imageStore := storeController.GetImageStore(repo) - canBeSkipped, err := canSkipImage(repo, tag, upstreamImageRef, imageStore, upstreamCtx, log) + canBeSkipped, err := canSkipImage(ctx, repo, tag, upstreamImageRef, imageStore, upstreamCtx, log) if err != nil { log.Error().Err(err).Msgf("couldn't check if the upstream image %s can be skipped", upstreamImageRef.DockerReference()) @@ -378,8 +379,8 @@ func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController sto log.Info().Msgf("copying image %s to %s", upstreamImageRef.DockerReference(), localCachePath) - if err = retry.RetryIfNecessary(context.Background(), func() error { - _, err = copy.Image(context.Background(), policyCtx, localImageRef, upstreamImageRef, &options) + if err = retry.RetryIfNecessary(ctx, func() error { + _, err = copy.Image(ctx, policyCtx, localImageRef, upstreamImageRef, &options) return err }, retryOptions); err != nil { @@ -397,7 +398,7 @@ func syncRegistry(regCfg RegistryConfig, upstreamURL string, storeController sto return err } - if err = retry.RetryIfNecessary(context.Background(), func() error { + if err = retry.RetryIfNecessary(ctx, func() error { err = syncSignatures(httpClient, storeController, upstreamURL, repo, tag, log) return err @@ -435,7 +436,8 @@ func getLocalContexts(log log.Logger) (*types.SystemContext, *signature.PolicyCo return localCtx, policyContext, nil } -func Run(cfg Config, storeController storage.StoreController, wtgrp *goSync.WaitGroup, logger log.Logger) error { +func Run(ctx context.Context, cfg Config, storeController storage.StoreController, + wtgrp *goSync.WaitGroup, logger log.Logger) error { var credentialsFile CredentialsFile var err error @@ -476,19 +478,18 @@ func Run(cfg Config, storeController storage.StoreController, wtgrp *goSync.Wait tlogger := log.Logger{Logger: logger.With().Caller().Timestamp().Logger()} // schedule each registry sync - go func(regCfg RegistryConfig, logger log.Logger) { - // run on intervals - for ; true; <-ticker.C { + go func(ctx context.Context, regCfg RegistryConfig, logger log.Logger) { + for { // increment reference since will be busy, so shutdown has to wait wtgrp.Add(1) for _, upstreamURL := range regCfg.URLs { upstreamAddr := StripRegistryTransport(upstreamURL) // first try syncing main registry - if err := syncRegistry(regCfg, upstreamURL, storeController, localCtx, policyCtx, + if err := syncRegistry(ctx, regCfg, upstreamURL, storeController, localCtx, policyCtx, credentialsFile[upstreamAddr], logger); err != nil { logger.Error().Err(err).Str("registry", upstreamURL). - Msg("sync exited with error, falling back to auxiliary registries") + Msg("sync exited with error, falling back to auxiliary registries if any") } else { // if success fall back to main registry break @@ -496,8 +497,18 @@ func Run(cfg Config, storeController storage.StoreController, wtgrp *goSync.Wait } // mark as done after a single sync run wtgrp.Done() + + select { + case <-ctx.Done(): + ticker.Stop() + + return + case <-ticker.C: + // run on intervals + continue + } } - }(regCfg, tlogger) + }(ctx, regCfg, tlogger) } logger.Info().Msg("finished setting up sync") diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index 19b10490..0aa3a084 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -139,9 +139,15 @@ func TestSyncInternal(t *testing.T) { CertDir: "", } - cfg := Config{Registries: []RegistryConfig{syncRegistryConfig}, CredentialsFile: "/invalid/path/to/file"} + defaultValue := true + cfg := Config{ + Registries: []RegistryConfig{syncRegistryConfig}, + Enable: &defaultValue, + CredentialsFile: "/invalid/path/to/file", + } + ctx := context.Background() - So(Run(cfg, storage.StoreController{}, new(goSync.WaitGroup), log.NewLogger("debug", "")), ShouldNotBeNil) + So(Run(ctx, cfg, storage.StoreController{}, new(goSync.WaitGroup), log.NewLogger("debug", "")), ShouldNotBeNil) _, err = getFileCredentials("/invalid/path/to/file") So(err, ShouldNotBeNil) @@ -248,10 +254,11 @@ func TestSyncInternal(t *testing.T) { repos := []string{"repo1"} upstreamCtx := &types.SystemContext{} - _, err := imagesToCopyFromUpstream("localhost:4566", repos, upstreamCtx, Content{}, log.NewLogger("debug", "")) + _, err := imagesToCopyFromUpstream(context.Background(), "localhost:4566", repos, upstreamCtx, + Content{}, log.NewLogger("debug", "")) So(err, ShouldNotBeNil) - _, err = imagesToCopyFromUpstream("docker://localhost:4566", repos, upstreamCtx, + _, err = imagesToCopyFromUpstream(context.Background(), "docker://localhost:4566", repos, upstreamCtx, Content{}, log.NewLogger("debug", "")) So(err, ShouldNotBeNil) }) @@ -302,7 +309,8 @@ func TestSyncInternal(t *testing.T) { So(err, ShouldBeNil) So(taggedRef, ShouldNotBeNil) - canBeSkipped, err := canSkipImage(testImage, testImageTag, upstreamRef, imageStore, &types.SystemContext{}, log) + canBeSkipped, err := canSkipImage(context.Background(), testImage, testImageTag, upstreamRef, + imageStore, &types.SystemContext{}, log) So(err, ShouldNotBeNil) So(canBeSkipped, ShouldBeFalse) @@ -311,7 +319,8 @@ func TestSyncInternal(t *testing.T) { panic(err) } - canBeSkipped, err = canSkipImage(testImage, testImageTag, upstreamRef, imageStore, &types.SystemContext{}, log) + canBeSkipped, err = canSkipImage(context.Background(), testImage, testImageTag, upstreamRef, + imageStore, &types.SystemContext{}, log) So(err, ShouldNotBeNil) So(canBeSkipped, ShouldBeFalse) }) diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 9260172f..e78a87ac 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -35,6 +35,7 @@ import ( "gopkg.in/resty.v1" "zotregistry.io/zot/pkg/api" "zotregistry.io/zot/pkg/api/config" + "zotregistry.io/zot/pkg/cli" extconf "zotregistry.io/zot/pkg/extensions/config" "zotregistry.io/zot/pkg/extensions/sync" "zotregistry.io/zot/pkg/storage" @@ -642,6 +643,126 @@ func TestOnDemandPermsDenied(t *testing.T) { }) } +func TestConfigReloader(t *testing.T) { + Convey("Verify periodically sync config reloader works", t, func() { + duration, _ := time.ParseDuration("3s") + + sctlr, srcBaseURL, srcDir, _, _ := startUpstreamServer(t, false, false) + defer os.RemoveAll(srcDir) + + defer func() { + sctlr.Shutdown() + }() + + var tlsVerify bool + + syncRegistryConfig := sync.RegistryConfig{ + Content: []sync.Content{ + { + Prefix: testImage, + }, + }, + URLs: []string{srcBaseURL}, + PollInterval: duration, + TLSVerify: &tlsVerify, + CertDir: "", + OnDemand: true, + } + + defaultVal := true + syncConfig := &sync.Config{ + Enable: &defaultVal, + Registries: []sync.RegistryConfig{syncRegistryConfig}, + } + + destPort := test.GetFreePort() + destConfig := config.New() + destBaseURL := test.GetBaseURL(destPort) + + destConfig.HTTP.Port = destPort + + destDir, err := ioutil.TempDir("", "oci-dest-repo-test") + if err != nil { + panic(err) + } + + defer os.RemoveAll(destDir) + + destConfig.Storage.RootDirectory = destDir + + destConfig.Extensions = &extconf.ExtensionConfig{} + destConfig.Extensions.Search = nil + destConfig.Extensions.Sync = syncConfig + + logFile, err := ioutil.TempFile("", "zot-log*.txt") + So(err, ShouldBeNil) + + defer os.Remove(logFile.Name()) // clean up + + destConfig.Log.Output = logFile.Name() + + dctlr := api.NewController(destConfig) + + defer func() { + dctlr.Shutdown() + }() + + go func() { + // this blocks + if err := dctlr.Run(); err != nil { + return + } + }() + + // wait till ready + for { + _, err := resty.R().Get(destBaseURL) + if err == nil { + break + } + + time.Sleep(100 * time.Millisecond) + } + + content := fmt.Sprintf(`{"distSpecVersion": "0.1.0-dev", "storage": {"rootDirectory": "%s"}, + "http": {"address": "127.0.0.1", "port": "%s", "ReadOnly": false}, + "log": {"level": "debug", "output": "%s"}}`, destDir, destPort, logFile.Name()) + + cfgfile, err := ioutil.TempFile("", "zot-test*.json") + So(err, ShouldBeNil) + + defer os.Remove(cfgfile.Name()) // clean up + + _, err = cfgfile.Write([]byte(content)) + So(err, ShouldBeNil) + + hotReloader, err := cli.NewHotReloader(dctlr, cfgfile.Name()) + So(err, ShouldBeNil) + + hotReloader.Start() + + // let it sync + time.Sleep(3 * time.Second) + + // modify config + _, err = cfgfile.WriteString(" ") + So(err, ShouldBeNil) + + err = cfgfile.Close() + So(err, ShouldBeNil) + + time.Sleep(2 * time.Second) + + data, err := os.ReadFile(logFile.Name()) + t.Logf("downstream log: %s", string(data)) + So(err, ShouldBeNil) + So(string(data), ShouldContainSubstring, "reloaded params") + So(string(data), ShouldContainSubstring, "new configuration settings") + So(string(data), ShouldContainSubstring, "\"Sync\":null") + So(string(data), ShouldNotContainSubstring, "sync:") + }) +} + func TestBadTLS(t *testing.T) { Convey("Verify sync TLS feature", t, func() { updateDuration, _ := time.ParseDuration("30m") @@ -2501,7 +2622,7 @@ func TestOnDemandMultipleRetries(t *testing.T) { done := make(chan bool) go func() { /* watch .sync local cache, make sure just one .sync/subdir is populated with image - the lock from ondemand should prevent spawning multiple go routines for the same image*/ + the channel from ondemand should prevent spawning multiple go routines for the same image*/ for { time.Sleep(250 * time.Millisecond) select { diff --git a/pkg/extensions/sync/utils.go b/pkg/extensions/sync/utils.go index d4ffc803..d050ad35 100644 --- a/pkg/extensions/sync/utils.go +++ b/pkg/extensions/sync/utils.go @@ -574,7 +574,7 @@ func getLocalImageRef(imageStore storage.ImageStore, repo, tag string) (types.Im } // canSkipImage returns whether or not the image can be skipped from syncing. -func canSkipImage(repo, tag string, upstreamRef types.ImageReference, +func canSkipImage(ctx context.Context, repo, tag string, upstreamRef types.ImageReference, imageStore storage.ImageStore, upstreamCtx *types.SystemContext, log log.Logger) (bool, error) { // filter already pulled images _, localImageDigest, _, err := imageStore.GetImageManifest(repo, tag) @@ -588,7 +588,7 @@ func canSkipImage(repo, tag string, upstreamRef types.ImageReference, return false, err } - upstreamImageDigest, err := docker.GetDigest(context.Background(), upstreamCtx, upstreamRef) + upstreamImageDigest, err := docker.GetDigest(ctx, upstreamCtx, upstreamRef) if err != nil { log.Error().Err(err).Msgf("couldn't get upstream image %s manifest", upstreamRef.DockerReference())