mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
wait for healthy upstream before accepting queries on network change
This commit is contained in:
@@ -1250,7 +1250,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
|
||||
// and re-initializing the OS resolver with the nameservers
|
||||
// applying listener back to the interface
|
||||
func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
// Cancel any existing operations
|
||||
// Cancel any existing operations.
|
||||
p.resetCtxMu.Lock()
|
||||
defer p.resetCtxMu.Unlock()
|
||||
|
||||
@@ -1261,9 +1261,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
}()
|
||||
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
// Remove the listener immediately.
|
||||
p.resetDNS()
|
||||
mainLog.Load().Debug().Msg("DNS reset completed")
|
||||
|
||||
// Initialize OS resolver regardless of upstream recovery.
|
||||
mainLog.Load().Debug().Msg("initializing OS resolver")
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
@@ -1272,18 +1274,38 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
|
||||
// start leaking queries immediately
|
||||
if networkChange {
|
||||
// set all upstreams to failed and provide to performLeakingQuery
|
||||
failedUpstreams := make(map[string]*ctrld.UpstreamConfig)
|
||||
// Iterate over both key and upstream to ensure that we have a fallback key
|
||||
for key, upstream := range p.cfg.Upstream {
|
||||
mainLog.Load().Debug().Msgf("network change upstream checking: %v, key: %q", upstream, key)
|
||||
mapKey := upstreamPrefix + key
|
||||
failedUpstreams[mapKey] = upstream
|
||||
// If we're already waiting on a recovery from a previous network change,
|
||||
// cancel that wait to avoid stale recovery.
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change")
|
||||
p.recoveryCancel()
|
||||
p.recoveryCancel = nil
|
||||
}
|
||||
go p.performLeakingQuery(failedUpstreams, "all")
|
||||
// Create a new context (with a timeout) for this recovery wait.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
p.recoveryCancel = cancel
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Launch a goroutine that monitors the non-OS upstreams.
|
||||
go func() {
|
||||
recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener")
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; reattaching DNS", recoveredUpstream)
|
||||
p.setDNS()
|
||||
p.logInterfacesState()
|
||||
|
||||
// Clear the recovery cancel func as recovery has been achieved.
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}()
|
||||
|
||||
// Optionally flush DNS caches (if needed).
|
||||
if err := FlushDNSCache(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
|
||||
}
|
||||
@@ -1291,13 +1313,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
// delay putting back the ctrld listener to allow for captive portal to trigger
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
} else {
|
||||
// For non-network-change cases, immediately re-enable the listener.
|
||||
p.setDNS()
|
||||
p.logInterfacesState()
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("setting DNS configuration")
|
||||
p.setDNS()
|
||||
mainLog.Load().Debug().Msg("DNS configuration set successfully")
|
||||
p.logInterfacesState()
|
||||
|
||||
}
|
||||
|
||||
// FlushDNSCache flushes the DNS cache on macOS.
|
||||
@@ -1488,3 +1508,92 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkUpstreamOnce sends a test query to the specified upstream.
|
||||
// Returns nil if the upstream responds successfully.
|
||||
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
||||
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
return err
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
timeout := 1000 * time.Millisecond
|
||||
if uc.Timeout > 0 {
|
||||
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
uc.ReBootstrap()
|
||||
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
|
||||
start := time.Now()
|
||||
_, err = resolver.Resolve(ctx, msg)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream
|
||||
// and returns when the first one recovers.
|
||||
func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) {
|
||||
recoveredCh := make(chan string, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Loop over your upstream configuration; skip the OS resolver.
|
||||
for k, uc := range p.cfg.Upstream {
|
||||
if uc.Type == ctrld.ResolverTypeOS {
|
||||
continue
|
||||
}
|
||||
|
||||
upstreamName := upstreamPrefix + k
|
||||
mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName)
|
||||
wg.Add(1)
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name)
|
||||
return
|
||||
default:
|
||||
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name)
|
||||
select {
|
||||
case recoveredCh <- name:
|
||||
default:
|
||||
}
|
||||
return
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name)
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
}(upstreamName, uc)
|
||||
}
|
||||
|
||||
var recovered string
|
||||
select {
|
||||
case recovered = <-recoveredCh:
|
||||
mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered)
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return recovered, nil
|
||||
}
|
||||
|
||||
@@ -124,6 +124,9 @@ type prog struct {
|
||||
|
||||
resetCtxMu sync.Mutex
|
||||
|
||||
recoveryCancelMu sync.Mutex
|
||||
recoveryCancel context.CancelFunc
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
onStarted []func()
|
||||
|
||||
Reference in New Issue
Block a user