From c233ad9b1b57a959068817e9a7400d9dbf778e32 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 29 Jul 2024 23:05:49 +0700 Subject: [PATCH] cmd/cli: write new config file on reload --- cmd/cli/cli.go | 8 ++++---- cmd/cli/cli_test.go | 2 +- cmd/cli/control_server.go | 2 +- cmd/cli/dns_proxy.go | 3 ++- cmd/cli/metrics.go | 2 +- cmd/cli/prog.go | 12 ++++++++++-- cmd/cli/prometheus.go | 2 +- 7 files changed, 20 insertions(+), 11 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index cebf054..d7b654d 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1183,7 +1183,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if updated { - if err := writeConfigFile(); err != nil { + if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to write config file") } else { mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) @@ -1277,7 +1277,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } } -func writeConfigFile() error { +func writeConfigFile(cfg *ctrld.Config) error { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } else if configPath != "" { @@ -1330,7 +1330,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { } nop := zerolog.Nop() _, _ = tryUpdateListenerConfig(&cfg, &nop, true) - if err := writeConfigFile(); err != nil { + if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { fp, err := filepath.Abs(defaultConfigFile) @@ -2391,7 +2391,7 @@ func doGenerateNextDNSConfig(uid string) error { mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) generateNextDNSConfig(uid) updateListenerConfig(&cfg) - return writeConfigFile() + return writeConfigFile(&cfg) } func noticeWritingControlDConfig() error { diff --git a/cmd/cli/cli_test.go b/cmd/cli/cli_test.go index fcede32..eae2673 100644 --- a/cmd/cli/cli_test.go +++ b/cmd/cli/cli_test.go @@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) { _, err := os.Stat(configPath) assert.True(t, os.IsNotExist(err)) - assert.NoError(t, writeConfigFile()) + assert.NoError(t, writeConfigFile(&cfg)) _, err = os.Stat(configPath) require.NoError(t, err) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 5fe2cc3..f69c301 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -73,7 +73,7 @@ func (p *prog) registerControlServerHandler() { sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) - if p.cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { for _, client := range clients { client.IncludeQueryCount = true dm := &dto.Metric{} diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9519468..33e5ebb 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -93,6 +93,7 @@ func (p *prog) serveDNS(listenerNum string) error { _ = w.WriteMsg(answer) return } + listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { @@ -847,7 +848,7 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { // - Number of refused queries seen so far equals to selfUninstallMaxQueries. // - The cdUID is deleted. func (p *prog) doSelfUninstall(answer *dns.Msg) { - if !p.canSelfUninstall || answer == nil || answer.Rcode != dns.RcodeRefused { + if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { return } diff --git a/cmd/cli/metrics.go b/cmd/cli/metrics.go index ee64975..565cdcc 100644 --- a/cmd/cli/metrics.go +++ b/cmd/cli/metrics.go @@ -107,7 +107,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { reg := prometheus.NewRegistry() // Register queries count stats if enabled. - if cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { reg.MustRegister(statsQueriesCount) reg.MustRegister(statsClientQueriesCount) } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 00b8c0d..0362c02 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -17,6 +17,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -88,10 +89,11 @@ type prog struct { router router.Router ptrLoopGuard *loopGuard lanLoopGuard *loopGuard + metricsQueryStats atomic.Bool selfUninstallMu sync.Mutex refusedQueryCount int - canSelfUninstall bool + canSelfUninstall atomic.Bool checkingSelfUninstall bool loopMu sync.Mutex @@ -187,6 +189,10 @@ func (p *prog) runWait() { continue } + if err := writeConfigFile(newCfg); err != nil { + logger.Err(err).Msg("could not write new config") + } + // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. mainLog.Load().Debug().Msg("setup upstream with new config") @@ -197,6 +203,7 @@ func (p *prog) runWait() { p.mu.Unlock() logger.Notice().Msg("reloading config successfully") + select { case p.reloadDoneCh <- struct{}{}: default: @@ -249,7 +256,7 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { } // Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config. if len(cfg.Upstream) == 1 && isControlDUpstream { - p.canSelfUninstall = true + p.canSelfUninstall.Store(true) } p.localUpstreams = localUpstreams p.ptrNameservers = ptrNameservers @@ -286,6 +293,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil + p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats) if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { diff --git a/cmd/cli/prometheus.go b/cmd/cli/prometheus.go index fc2fc5d..9082a58 100644 --- a/cmd/cli/prometheus.go +++ b/cmd/cli/prometheus.go @@ -51,7 +51,7 @@ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ // WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled. func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) { - if p.cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { c.WithLabelValues(lvs...).Inc() } }