diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index aabd3cc..007f45e 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1141,6 +1141,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { reloadDoneCh: make(chan struct{}), dnsWatcherStopCh: make(chan struct{}), apiReloadCh: make(chan *ctrld.Config), + apiForceReloadCh: make(chan struct{}), cfg: &cfg, appCallback: appCallback, } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0bf85f2..b9eb8f5 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -151,6 +151,7 @@ func (p *prog) serveDNS(listenerNum string) error { ufr: ur, }) go p.doSelfUninstall(pr.answer) + answer = pr.answer rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) @@ -168,6 +169,7 @@ func (p *prog) serveDNS(listenerNum string) error { go func() { p.WithLabelValuesInc(statsQueriesCount, labelValues...) p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...) + p.forceFetchingAPI(domain) }() if err := w.WriteMsg(answer); err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") @@ -926,6 +928,41 @@ func (p *prog) performCaptivePortalDetection() { mainLog.Load().Warn().Msg("captive portal login finished, stop leaking query") } +// forceFetchingAPI sends signal to force syncing API config if run in cd mode, +// and the domain == "cdUID.verify.controld.com" +func (p *prog) forceFetchingAPI(domain string) { + if cdUID == "" { + return + } + resolverID, parent, _ := strings.Cut(domain, ".") + if resolverID != cdUID { + return + } + switch { + case cdDev && parent == "verify.controld.dev": + // match ControlD dev + case parent == "verify.controld.com": + // match ControlD + default: + return + } + _ = p.apiForceReloadGroup.DoChan("force_sync_api", func() (interface{}, error) { + p.apiForceReloadCh <- struct{}{} + // Wait here to prevent abusing API if we are flooded. + time.Sleep(timeDurationOrDefault(p.cfg.Service.ForceRefetchWaitTime, 30) * time.Second) + return nil, nil + }) +} + +// timeDurationOrDefault returns time duration value from n if not nil. +// Otherwise, it returns time duration value defaultN. +func timeDurationOrDefault(n *int, defaultN int) time.Duration { + if n != nil && *n > 0 { + return time.Duration(*n) + } + return time.Duration(defaultN) +} + // queryFromSelf reports whether the input IP is from device running ctrld. func queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 34f050d..781edd1 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -21,11 +21,11 @@ import ( "syscall" "time" - "tailscale.com/net/netmon" - "github.com/kardianos/service" "github.com/rs/zerolog" "github.com/spf13/viper" + "golang.org/x/sync/singleflight" + "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" @@ -69,19 +69,21 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - reloadCh chan struct{} // For Windows. - reloadDoneCh chan struct{} - apiReloadCh chan *ctrld.Config - logConn net.Conn - cs *controlServer - csSetDnsDone chan struct{} - csSetDnsOk bool - dnsWatchDogOnce sync.Once - dnsWg sync.WaitGroup - dnsWatcherStopCh chan struct{} + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + apiReloadCh chan *ctrld.Config + apiForceReloadCh chan struct{} + apiForceReloadGroup singleflight.Group + logConn net.Conn + cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool + dnsWatchDogOnce sync.Once + dnsWg sync.WaitGroup + dnsWatcherStopCh chan struct{} cfg *ctrld.Config localUpstreams []string @@ -255,47 +257,48 @@ func (p *prog) apiConfigReload() { return } - secs := 3600 - if p.cfg.Service.RefetchTime != nil && *p.cfg.Service.RefetchTime > 0 { - secs = *p.cfg.Service.RefetchTime - } - - ticker := time.NewTicker(time.Duration(secs) * time.Second) + ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second) defer ticker.Stop() logger := mainLog.Load().With().Str("mode", "api-reload").Logger() logger.Debug().Msg("starting custom config reload timer") lastUpdated := time.Now().Unix() + + doReloadApiConfig := func(forced bool, logger zerolog.Logger) { + resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + selfUninstallCheck(err, p, logger) + if err != nil { + logger.Warn().Err(err).Msg("could not fetch resolver config") + return + } + + if resolverConfig.Ctrld.CustomConfig == "" { + return + } + + if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced { + lastUpdated = time.Now().Unix() + cfg := &ctrld.Config{} + if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil { + logger.Warn().Err(err).Msg("skipping invalid custom config") + if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { + logger.Error().Err(err).Msg("could not mark custom last update failed") + } + return + } + setListenerDefaultValue(cfg) + logger.Debug().Msg("custom config changes detected, reloading...") + p.apiReloadCh <- cfg + } else { + logger.Debug().Msg("custom config does not change") + } + } for { select { + case <-p.apiForceReloadCh: + doReloadApiConfig(true, logger.With().Bool("forced", true).Logger()) case <-ticker.C: - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) - selfUninstallCheck(err, p, logger) - if err != nil { - logger.Warn().Err(err).Msg("could not fetch resolver config") - continue - } - - if resolverConfig.Ctrld.CustomConfig == "" { - continue - } - - if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated { - lastUpdated = time.Now().Unix() - cfg := &ctrld.Config{} - if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil { - logger.Warn().Err(err).Msg("skipping invalid custom config") - if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { - logger.Error().Err(err).Msg("could not mark custom last update failed") - } - break - } - setListenerDefaultValue(cfg) - logger.Debug().Msg("custom config changes detected, reloading...") - p.apiReloadCh <- cfg - } else { - logger.Debug().Msg("custom config does not change") - } + doReloadApiConfig(false, logger) case <-p.stopCh: return } diff --git a/config.go b/config.go index 86ca4b7..d20c695 100644 --- a/config.go +++ b/config.go @@ -217,6 +217,7 @@ type ServiceConfig struct { DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"` DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"` RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"` + ForceRefetchWaitTime *int `mapstructure:"force_refetch_wait_time" toml:"force_refetch_wait_time,omitempty"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` }