diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 1ec2c72..c45e56f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1139,6 +1139,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { stopCh: stopCh, reloadCh: make(chan struct{}), reloadDoneCh: make(chan struct{}), + apiReloadCh: make(chan *ctrld.Config), cfg: &cfg, appCallback: appCallback, } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 0362c02..37647e4 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -28,6 +28,7 @@ import ( "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" + "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" "github.com/Control-D-Inc/ctrld/internal/router" ) @@ -71,6 +72,7 @@ type prog struct { stopCh chan struct{} reloadCh chan struct{} // For Windows. reloadDoneCh chan struct{} + apiReloadCh chan *ctrld.Config logConn net.Conn cs *controlServer csSetDnsDone chan struct{} @@ -128,11 +130,15 @@ func (p *prog) runWait() { p.run(reload, reloadCh) reload = true }() + + var newCfg *ctrld.Config select { case sig := <-reloadSigCh: logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) case <-p.reloadCh: logger.Notice().Msg("reloading...") + case apiCfg := <-p.apiReloadCh: + newCfg = apiCfg case <-p.stopCh: close(reloadCh) return @@ -142,28 +148,31 @@ func (p *prog) runWait() { close(reloadCh) <-done } - newCfg := &ctrld.Config{} - v := viper.NewWithOptions(viper.KeyDelimiter("::")) - ctrld.InitConfig(v, "ctrld") - if configPath != "" { - v.SetConfigFile(configPath) - } - if err := v.ReadInConfig(); err != nil { - logger.Err(err).Msg("could not read new config") - waitOldRunDone() - continue - } - if err := v.Unmarshal(&newCfg); err != nil { - logger.Err(err).Msg("could not unmarshal new config") - waitOldRunDone() - continue - } - if cdUID != "" { - if err := processCDFlags(newCfg); err != nil { - logger.Err(err).Msg("could not fetch ControlD config") + + if newCfg == nil { + newCfg = &ctrld.Config{} + v := viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.InitConfig(v, "ctrld") + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + logger.Err(err).Msg("could not read new config") waitOldRunDone() continue } + if err := v.Unmarshal(&newCfg); err != nil { + logger.Err(err).Msg("could not unmarshal new config") + waitOldRunDone() + continue + } + if cdUID != "" { + if err := processCDFlags(newCfg); err != nil { + logger.Err(err).Msg("could not fetch ControlD config") + waitOldRunDone() + continue + } + } } waitOldRunDone() @@ -230,6 +239,59 @@ func (p *prog) postRun() { } } +// apiConfigReload calls API to check for latest config update then reload ctrld if necessary. +func (p *prog) apiConfigReload() { + if cdUID == "" { + return + } + + secs := 3600 + if p.cfg.Service.RefreshTime != nil && *p.cfg.Service.RefreshTime > 0 { + secs = *p.cfg.Service.RefreshTime + } + + ticker := time.NewTicker(time.Duration(secs) * time.Second) + defer ticker.Stop() + + logger := mainLog.Load().With().Str("mode", "api-reload").Logger() + logger.Debug().Msg("starting custom config reload timer") + lastUpdated := time.Now().Unix() + for { + select { + case <-ticker.C: + resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + selfUninstall(err, p, logger) + if err != nil { + logger.Warn().Err(err).Msg("could not fetch resolver config") + continue + } + + if resolverConfig.Ctrld.CustomConfig == "" { + continue + } + + if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated { + lastUpdated = time.Now().Unix() + cfg := &ctrld.Config{} + if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil { + logger.Warn().Err(err).Msg("skipping invalid custom config") + if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { + logger.Error().Err(err).Msg("could not mark custom last update failed") + } + break + } + setListenerDefaultValue(cfg) + logger.Debug().Msg("custom config changes detected, reloading...") + p.apiReloadCh <- cfg + } else { + logger.Debug().Msg("custom config does not change") + } + case <-p.stopCh: + return + } + } +} + func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) @@ -420,6 +482,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.logConn != nil { _ = p.logConn.Close() } + go p.apiConfigReload() p.postRun() } wg.Wait() diff --git a/config.go b/config.go index 5c531cd..67420e5 100644 --- a/config.go +++ b/config.go @@ -209,6 +209,7 @@ type ServiceConfig struct { MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"` DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"` + RefreshTime *int `mapstructure:"refresh_time" toml:"refresh_time,omitempty"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` } diff --git a/docs/config.md b/docs/config.md index 1ae4978..ab5cf73 100644 --- a/docs/config.md +++ b/docs/config.md @@ -273,6 +273,14 @@ If the time duration is non-positive, default value will be used. - Required: no - Default: 20s +### refresh_time +Time in seconds between each iteration that reloads custom config if changed. + +The value must be a positive number, any invalid value will be ignored and default value will be used. +- Type: number +- Required: no +- Default: 3600 + ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to. diff --git a/internal/controld/config.go b/internal/controld/config.go index d2b564a..01e114b 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -33,7 +33,8 @@ const ( type ResolverConfig struct { DOH string `json:"doh"` Ctrld struct { - CustomConfig string `json:"custom_config"` + CustomConfig string `json:"custom_config"` + CustomLastUpdate int64 `json:"custom_last_update"` } `json:"ctrld"` Exclude []string `json:"exclude"` UID string `json:"uid"` @@ -76,17 +77,28 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, bytes.NewReader(body)) + return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) } // FetchResolverUID fetch resolver uid from provision token. func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) { hostname, _ := os.Hostname() body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname}) - return postUtilityAPI(version, cdDev, bytes.NewReader(body)) + return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) } -func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) { +// UpdateCustomLastFailed calls API to mark custom config is bad. +func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { + uid, clientID := ParseRawUID(rawUID) + req := utilityRequest{UID: uid} + if clientID != "" { + req.ClientID = clientID + } + body, _ := json.Marshal(req) + return postUtilityAPI(version, cdDev, true, bytes.NewReader(body)) +} + +func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev @@ -98,6 +110,9 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig q := req.URL.Query() q.Set("platform", "ctrld") q.Set("version", version) + if lastUpdatedFailed { + q.Set("custom_last_failed", "1") + } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") transport := http.DefaultTransport.(*http.Transport).Clone()