From 1d207379cb9eca98940ed8f24bf58ff1cb145ad7 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 00:59:47 -0500 Subject: [PATCH] wait for healthy upstream before accepting queries on network change --- cmd/cli/dns_proxy.go | 141 ++++++++++++++++++++++++++++++++++++++----- cmd/cli/prog.go | 3 + 2 files changed, 128 insertions(+), 16 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index f7bbe6e..cd5fb60 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1250,7 +1250,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M // and re-initializing the OS resolver with the nameservers // applying listener back to the interface func (p *prog) reinitializeOSResolver(networkChange bool) { - // Cancel any existing operations + // Cancel any existing operations. p.resetCtxMu.Lock() defer p.resetCtxMu.Unlock() @@ -1261,9 +1261,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { }() mainLog.Load().Debug().Msg("attempting to reset DNS") + // Remove the listener immediately. p.resetDNS() mainLog.Load().Debug().Msg("DNS reset completed") + // Initialize OS resolver regardless of upstream recovery. mainLog.Load().Debug().Msg("initializing OS resolver") ns := ctrld.InitializeOsResolver(true) if len(ns) == 0 { @@ -1272,18 +1274,38 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) } - // start leaking queries immediately if networkChange { - // set all upstreams to failed and provide to performLeakingQuery - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - // Iterate over both key and upstream to ensure that we have a fallback key - for key, upstream := range p.cfg.Upstream { - mainLog.Load().Debug().Msgf("network change upstream checking: %v, key: %q", upstream, key) - mapKey := upstreamPrefix + key - failedUpstreams[mapKey] = upstream + // If we're already waiting on a recovery from a previous network change, + // cancel that wait to avoid stale recovery. + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change") + p.recoveryCancel() + p.recoveryCancel = nil } - go p.performLeakingQuery(failedUpstreams, "all") + // Create a new context (with a timeout) for this recovery wait. + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + p.recoveryCancel = cancel + p.recoveryCancelMu.Unlock() + // Launch a goroutine that monitors the non-OS upstreams. + go func() { + recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") + return + } + mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; reattaching DNS", recoveredUpstream) + p.setDNS() + p.logInterfacesState() + + // Clear the recovery cancel func as recovery has been achieved. + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + }() + + // Optionally flush DNS caches (if needed). if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } @@ -1291,13 +1313,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { // delay putting back the ctrld listener to allow for captive portal to trigger time.Sleep(5 * time.Second) } + } else { + // For non-network-change cases, immediately re-enable the listener. + p.setDNS() + p.logInterfacesState() } - - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") - p.logInterfacesState() - } // FlushDNSCache flushes the DNS cache on macOS. @@ -1488,3 +1508,92 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { } return true } + +// checkUpstreamOnce sends a test query to the specified upstream. +// Returns nil if the upstream responds successfully. +func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { + mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + return err + } + + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + + timeout := 1000 * time.Millisecond + if uc.Timeout > 0 { + timeout = time.Millisecond * time.Duration(uc.Timeout) + } + mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + uc.ReBootstrap() + mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + + start := time.Now() + _, err = resolver.Resolve(ctx, msg) + duration := time.Since(start) + + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + } else { + mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + } + return err +} + +// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream +// and returns when the first one recovers. +func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) { + recoveredCh := make(chan string, 1) + var wg sync.WaitGroup + + // Loop over your upstream configuration; skip the OS resolver. + for k, uc := range p.cfg.Upstream { + if uc.Type == ctrld.ResolverTypeOS { + continue + } + + upstreamName := upstreamPrefix + k + mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName) + wg.Add(1) + go func(name string, uc *ctrld.UpstreamConfig) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name) + return + default: + if err := p.checkUpstreamOnce(name, uc); err == nil { + mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name) + select { + case recoveredCh <- name: + default: + } + return + } else { + mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name) + } + time.Sleep(checkUpstreamBackoffSleep) + } + } + }(upstreamName, uc) + } + + var recovered string + select { + case recovered = <-recoveredCh: + mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered) + case <-ctx.Done(): + return "", ctx.Err() + } + + wg.Wait() + return recovered, nil +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d7a9a95..3dc9e1b 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -124,6 +124,9 @@ type prog struct { resetCtxMu sync.Mutex + recoveryCancelMu sync.Mutex + recoveryCancel context.CancelFunc + started chan struct{} onStartedDone chan struct{} onStarted []func()