From 98042d8dbd7a482f6d9c0763f6442fb254f985fb Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 15:25:19 -0500 Subject: [PATCH] remove leaking logic in favor of recovery logic. --- cmd/cli/dns_proxy.go | 370 ++++++++++++++---------------------- cmd/cli/prog.go | 35 ++-- cmd/cli/resolvconf.go | 2 +- cmd/cli/upstream_monitor.go | 13 ++ 4 files changed, 174 insertions(+), 246 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0447eef..31e8aa8 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -432,23 +432,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - upstreamMapKey := strings.Join(upstreams, "_") - - leaked := false - if len(upstreamConfigs) > 0 { - p.leakingQueryMu.Lock() - if p.leakingQueryRunning[upstreamMapKey] || p.leakingQueryRunning["all"] { - upstreamConfigs = nil - leaked = true - if p.leakingQueryRunning["all"] { - ctrld.Log(ctx, mainLog.Load().Debug(), "all upstreams marked down for network change, leaking query to OS resolver") - } else { - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) - } - } - p.leakingQueryMu.Unlock() - } - if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} @@ -472,11 +455,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // 4. Try remote upstream. isLanOrPtrQuery := false if req.ufr.matched { - if leaked { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } else { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } + ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { case isSrvLookup(req.msg): @@ -557,13 +536,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { isNetworkErr := errNetworkError(err) if isNetworkErr { p.um.increaseFailureCount(upstreams[n]) - if p.um.isDown(upstreams[n]) { - p.um.mu.RLock() - if !p.um.checking[upstreams[n]] { - go p.checkUpstream(upstreams[n], upstreamConfig) - } - p.um.mu.RUnlock() - } } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error @@ -594,16 +566,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, logger, "DNS loop detected") continue } - if p.um.isDown(upstreams[n]) { - // never skip the OS resolver, since we usually query this resolver when we - // have no other upstreams to query - if upstreams[n] != upstreamOS { - logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS) - ctrld.Log(ctx, logger, "Upstream is down") - continue - } - } answer := resolve(n, upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { @@ -651,20 +613,29 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) - if p.leakOnUpstreamFailure() { - p.leakingQueryMu.Lock() - // get the map key as concact of upstreams - if !p.leakingQueryRunning[upstreamMapKey] { - p.leakingQueryRunning[upstreamMapKey] = true - // get a map of the failed upstreams - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for n, upstream := range upstreamConfigs { - failedUpstreams[upstreams[n]] = upstream + + // if we have no healthy upstreams, trigger recovery flow + if p.recoverOnUpstreamFailure() { + if p.um.countHealthy(upstreams) == 0 { + p.recoveryCancelMu.Lock() + if p.recoveryCancel == nil { + var reason RecoveryReason + if upstreams[0] == upstreamOS { + reason = RecoveryReasonOSFailure + } else { + reason = RecoveryReasonRegularFailure + } + mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + go p.handleRecovery(reason) + } else { + mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") } - go p.performLeakingQuery(failedUpstreams, upstreamMapKey) + p.recoveryCancelMu.Unlock() + } else { + mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } - p.leakingQueryMu.Unlock() } + answer := new(dns.Msg) answer.SetRcode(req.msg, dns.RcodeServerFailure) res.answer = answer @@ -994,86 +965,6 @@ func (p *prog) selfUninstallCoolOfPeriod() { p.selfUninstallMu.Unlock() } -// performLeakingQuery performs necessary works to leak queries to OS resolver. -// once we store the leakingQuery flag, we are leaking queries to OS resolver -// we then start testing all the upstreams forever, waiting for success, but in parallel -func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) { - - mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams) - - // Signal dns watchers to stop, so changes made below won't be reverted. - p.leakingQueryMu.Lock() - p.leakingQueryRunning[upstreamMapKey] = true - p.leakingQueryMu.Unlock() - defer func() { - p.leakingQueryMu.Lock() - p.leakingQueryRunning[upstreamMapKey] = false - p.leakingQueryMu.Unlock() - mainLog.Load().Warn().Msg("stop leaking query") - }() - - // we only want to reset DNS when our resolver is broken - // this allows us to find the new OS resolver nameservers - // we skip the all upstream lock key to prevent duplicate calls - if p.um.isDown(upstreamOS) && upstreamMapKey != "all" { - - mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") - p.reinitializeOSResolver(false) - - } - - // Test all failed upstreams in parallel - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // if a network change, delay upstream checks by 1s - // this is to ensure we actually leak queries to OS resolver - // We have observed some captive portals leak queries to public upstreams - // This can cause the captive portal on MacOS to not trigger a popup - if upstreamMapKey != "all" { - mainLog.Load().Debug().Msg("network change leaking queries, delaying upstream checks by 1s") - time.Sleep(1 * time.Second) - } - - upstreamCh := make(chan string, len(failedUpstreams)) - for name, uc := range failedUpstreams { - go func(name string, uc *ctrld.UpstreamConfig) { - for { - select { - case <-ctx.Done(): - return - default: - // make sure this upstream is not already being checked - p.um.mu.RLock() - if p.um.checking[name] { - p.um.mu.RUnlock() - continue - } - p.um.mu.RUnlock() - mainLog.Load().Debug(). - Str("upstream", name). - Msg("checking upstream") - - p.checkUpstream(name, uc) - mainLog.Load().Debug(). - Str("upstream", name). - Msg("upstream recovered") - upstreamCh <- name - return - } - } - }(name, uc) - } - - // Wait for any upstream to recover - name := <-upstreamCh - - mainLog.Load().Info(). - Str("upstream", name). - Msg("stopping leak as upstream recovered") - -} - // forceFetchingAPI sends signal to force syncing API config if run in cd mode, // and the domain == "cdUID.verify.controld.com" func (p *prog) forceFetchingAPI(domain string) { @@ -1245,85 +1136,6 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M return answer } -// reinitializeOSResolver reinitializes the OS resolver -// by removing ctrld listenr from the interface, collecting the network nameservers -// and re-initializing the OS resolver with the nameservers -// applying listener back to the interface -func (p *prog) reinitializeOSResolver(networkChange bool) { - // Cancel any existing operations. - p.resetCtxMu.Lock() - defer p.resetCtxMu.Unlock() - - p.leakingQueryReset.Store(true) - - mainLog.Load().Debug().Msg("attempting to reset DNS") - // Remove the listener immediately. - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") - - if networkChange { - // If we're already waiting on a recovery from a previous network change, - // cancel that wait to avoid stale recovery. - p.recoveryCancelMu.Lock() - if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change") - p.recoveryCancel() - p.recoveryCancel = nil - } - - ctx, cancel := context.WithCancel(context.Background()) - p.recoveryCancel = cancel - p.recoveryCancelMu.Unlock() - - // Launch a goroutine that monitors the non-OS upstreams. - go func() { - recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") - return - } - - mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; initializing OS resolver and attaching DNS listener", recoveredUpstream) - - // Initialize OS resolver regardless of upstream recovery. - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver(true) - if len(ns) == 0 { - mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values") - } else { - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } - - p.setDNS() - p.logInterfacesState() - - // allow watchers to reset changes - p.leakingQueryReset.Store(false) - - // Clear the recovery cancel func as recovery has been achieved. - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() - }() - - // Optionally flush DNS caches (if needed). - if err := FlushDNSCache(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") - } - if runtime.GOOS == "darwin" { - // delay putting back the ctrld listener to allow for captive portal to trigger - time.Sleep(5 * time.Second) - } - } else { - // For non-network-change cases, immediately re-enable the listener. - p.setDNS() - p.logInterfacesState() - - // allow watchers to reset changes - p.leakingQueryReset.Store(false) - } -} - // FlushDNSCache flushes the DNS cache on macOS. func FlushDNSCache() error { // if not macOS, return @@ -1457,7 +1269,10 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { ctrld.SetDefaultLocalIPv6(ip) } mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) - p.reinitializeOSResolver(true) + + if p.recoverOnUpstreamFailure() { + p.handleRecovery(RecoveryReasonNetworkChange) + } }) mon.Start() @@ -1551,53 +1366,154 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro return err } -// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream -// and returns when the first one recovers. -func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) { +// handleRecovery performs a unified recovery by removing DNS settings, +// canceling existing recovery checks for network changes, but coalescing duplicate +// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), +// and then re-applying the DNS settings. +func (p *prog) handleRecovery(reason RecoveryReason) { + mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + + // For network changes, cancel any existing recovery check because the network state has changed. + if reason == RecoveryReasonNetworkChange { + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.recoveryCancel() + p.recoveryCancel = nil + } + p.recoveryCancelMu.Unlock() + } else { + // For upstream failures, if a recovery is already in progress, do nothing new. + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + p.recoveryCancelMu.Unlock() + return + } + p.recoveryCancelMu.Unlock() + } + + // Create a new recovery context without a fixed timeout. + p.recoveryCancelMu.Lock() + recoveryCtx, cancel := context.WithCancel(context.Background()) + p.recoveryCancel = cancel + p.recoveryCancelMu.Unlock() + + // Immediately remove our DNS settings from the interface. + // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface + p.recoveryRunning.Store(true) + p.resetDNS() + + // For an OS failure, reinitialize OS resolver nameservers immediately. + if reason == RecoveryReasonOSFailure { + mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Build upstream map based on the recovery reason. + upstreams := p.buildRecoveryUpstreams(reason) + + // Wait indefinitely until one of the upstreams recovers. + recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) + if err != nil { + mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + return + } + mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + + // For network changes we also reinitialize the OS resolver. + if reason == RecoveryReasonNetworkChange { + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Apply our DNS settings back and log the interface state. + p.setDNS() + p.logInterfacesState() + + // allow watchdogs to put the listener back on the interface if its changed for any reason + p.recoveryRunning.Store(false) + + // Clear the recovery cancellation for a clean slate. + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() +} + +// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers. +// It returns the name of the recovered upstream or an error if the check times out. +func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) { recoveredCh := make(chan string, 1) var wg sync.WaitGroup - // Loop over your upstream configuration; skip the OS resolver. - for k, uc := range p.cfg.Upstream { - if uc.Type == ctrld.ResolverTypeOS { - continue - } + mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) - upstreamName := upstreamPrefix + k - mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName) + for name, uc := range upstreams { wg.Add(1) go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() + mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) for { select { case <-ctx.Done(): - mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name) + mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) return default: + // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { - mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name) + mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) select { case recoveredCh <- name: + mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) default: + mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") } return - } else { - mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name) } + mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) } } - }(upstreamName, uc) + }(name, uc) } var recovered string select { case recovered = <-recoveredCh: - mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered) case <-ctx.Done(): return "", ctx.Err() } - wg.Wait() return recovered, nil } + +// buildRecoveryUpstreams constructs the map of upstream configurations to test. +// For OS failures we supply the manual OS resolver upstream configuration. +// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS). +func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig { + upstreams := make(map[string]*ctrld.UpstreamConfig) + switch reason { + case RecoveryReasonOSFailure: + upstreams[upstreamOS] = osUpstreamConfig + case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure: + // Use all configured upstreams except any OS type. + for k, uc := range p.cfg.Upstream { + if uc.Type != ctrld.ResolverTypeOS { + upstreams[upstreamPrefix+k] = uc + } + } + } + return upstreams +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 3dc9e1b..8a86bcf 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -48,6 +48,17 @@ const ( ctrldServiceName = "ctrld" ) +// RecoveryReason provides context for why we are waiting for recovery. +// recovery involves removing the listener IP from the interface and +// waiting for the upstreams to work before returning +type RecoveryReason int + +const ( + RecoveryReasonNetworkChange RecoveryReason = iota + RecoveryReasonRegularFailure + RecoveryReasonOSFailure +) + // ControlSocketName returns name for control unix socket. func ControlSocketName() string { if isMobile() { @@ -118,14 +129,9 @@ type prog struct { loopMu sync.Mutex loop map[string]bool - leakingQueryMu sync.Mutex - leakingQueryRunning map[string]bool - leakingQueryReset atomic.Bool - - resetCtxMu sync.Mutex - recoveryCancelMu sync.Mutex recoveryCancel context.CancelFunc + recoveryRunning atomic.Bool started chan struct{} onStartedDone chan struct{} @@ -429,7 +435,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) - p.leakingQueryRunning = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil @@ -779,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQueryReset.Load() { + if p.recoveryRunning.Load() { return } if dnsChanged(iface, ns) { @@ -980,16 +985,10 @@ func findWorkingInterface(currentIface string) string { return currentIface } -// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams. -func (p *prog) leakOnUpstreamFailure() bool { - if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { - return *ptr - } - // Default is false on routers, since this leaking is only useful for devices that move between networks. - if router.Name() != "" { - return false - } - return true +// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure. +func (p *prog) recoverOnUpstreamFailure() bool { + // Default is false on routers, since this recovery flow is only useful for devices that move between networks. + return router.Name() == "" } func randomLocalIP() string { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 9d37d68..0f3f731 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -67,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQueryReset.Load() { + if p.recoveryRunning.Load() { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index df52a14..e42b3c1 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -145,3 +145,16 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { time.Sleep(checkUpstreamBackoffSleep) } } + +// countHealthy returns the number of upstreams in the provided map that are considered healthy. +func (um *upstreamMonitor) countHealthy(upstreams []string) int { + var count int + um.mu.RLock() + defer um.mu.RUnlock() + for _, upstream := range upstreams { + if !um.isDown(upstream) { + count++ + } + } + return count +}