wait for healthy upstream before accepting queries on network change

This commit is contained in:
Alex
2025-02-07 00:59:47 -05:00
committed by Cuong Manh Le
parent fb49cb71e3
commit 1d207379cb
2 changed files with 128 additions and 16 deletions

View File

@@ -1250,7 +1250,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
// and re-initializing the OS resolver with the nameservers
// applying listener back to the interface
func (p *prog) reinitializeOSResolver(networkChange bool) {
// Cancel any existing operations
// Cancel any existing operations.
p.resetCtxMu.Lock()
defer p.resetCtxMu.Unlock()
@@ -1261,9 +1261,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
}()
mainLog.Load().Debug().Msg("attempting to reset DNS")
// Remove the listener immediately.
p.resetDNS()
mainLog.Load().Debug().Msg("DNS reset completed")
// Initialize OS resolver regardless of upstream recovery.
mainLog.Load().Debug().Msg("initializing OS resolver")
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
@@ -1272,18 +1274,38 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns)
}
// start leaking queries immediately
if networkChange {
// set all upstreams to failed and provide to performLeakingQuery
failedUpstreams := make(map[string]*ctrld.UpstreamConfig)
// Iterate over both key and upstream to ensure that we have a fallback key
for key, upstream := range p.cfg.Upstream {
mainLog.Load().Debug().Msgf("network change upstream checking: %v, key: %q", upstream, key)
mapKey := upstreamPrefix + key
failedUpstreams[mapKey] = upstream
// If we're already waiting on a recovery from a previous network change,
// cancel that wait to avoid stale recovery.
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change")
p.recoveryCancel()
p.recoveryCancel = nil
}
go p.performLeakingQuery(failedUpstreams, "all")
// Create a new context (with a timeout) for this recovery wait.
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
p.recoveryCancel = cancel
p.recoveryCancelMu.Unlock()
// Launch a goroutine that monitors the non-OS upstreams.
go func() {
recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener")
return
}
mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; reattaching DNS", recoveredUpstream)
p.setDNS()
p.logInterfacesState()
// Clear the recovery cancel func as recovery has been achieved.
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
}()
// Optionally flush DNS caches (if needed).
if err := FlushDNSCache(); err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
}
@@ -1291,13 +1313,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
// delay putting back the ctrld listener to allow for captive portal to trigger
time.Sleep(5 * time.Second)
}
} else {
// For non-network-change cases, immediately re-enable the listener.
p.setDNS()
p.logInterfacesState()
}
mainLog.Load().Debug().Msg("setting DNS configuration")
p.setDNS()
mainLog.Load().Debug().Msg("DNS configuration set successfully")
p.logInterfacesState()
}
// FlushDNSCache flushes the DNS cache on macOS.
@@ -1488,3 +1508,92 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool {
}
return true
}
// checkUpstreamOnce sends a test query to the specified upstream.
// Returns nil if the upstream responds successfully.
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
return err
}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
timeout := 1000 * time.Millisecond
if uc.Timeout > 0 {
timeout = time.Millisecond * time.Duration(uc.Timeout)
}
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
uc.ReBootstrap()
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
start := time.Now()
_, err = resolver.Resolve(ctx, msg)
duration := time.Since(start)
if err != nil {
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
} else {
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
}
return err
}
// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream
// and returns when the first one recovers.
func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) {
recoveredCh := make(chan string, 1)
var wg sync.WaitGroup
// Loop over your upstream configuration; skip the OS resolver.
for k, uc := range p.cfg.Upstream {
if uc.Type == ctrld.ResolverTypeOS {
continue
}
upstreamName := upstreamPrefix + k
mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName)
wg.Add(1)
go func(name string, uc *ctrld.UpstreamConfig) {
defer wg.Done()
for {
select {
case <-ctx.Done():
mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name)
return
default:
if err := p.checkUpstreamOnce(name, uc); err == nil {
mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name)
select {
case recoveredCh <- name:
default:
}
return
} else {
mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name)
}
time.Sleep(checkUpstreamBackoffSleep)
}
}
}(upstreamName, uc)
}
var recovered string
select {
case recovered = <-recoveredCh:
mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered)
case <-ctx.Done():
return "", ctx.Err()
}
wg.Wait()
return recovered, nil
}

View File

@@ -124,6 +124,9 @@ type prog struct {
resetCtxMu sync.Mutex
recoveryCancelMu sync.Mutex
recoveryCancel context.CancelFunc
started chan struct{}
onStartedDone chan struct{}
onStarted []func()