remove leaking logic in favor of recovery logic.

This commit is contained in:
Alex
2025-02-07 15:25:19 -05:00
committed by Cuong Manh Le
parent af4b826b68
commit 98042d8dbd
4 changed files with 174 additions and 246 deletions

View File

@@ -432,23 +432,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
upstreamMapKey := strings.Join(upstreams, "_")
leaked := false
if len(upstreamConfigs) > 0 {
p.leakingQueryMu.Lock()
if p.leakingQueryRunning[upstreamMapKey] || p.leakingQueryRunning["all"] {
upstreamConfigs = nil
leaked = true
if p.leakingQueryRunning["all"] {
ctrld.Log(ctx, mainLog.Load().Debug(), "all upstreams marked down for network change, leaking query to OS resolver")
} else {
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
}
}
p.leakingQueryMu.Unlock()
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{upstreamOS}
@@ -472,11 +455,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
// 4. Try remote upstream.
isLanOrPtrQuery := false
if req.ufr.matched {
if leaked {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
}
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
switch {
case isSrvLookup(req.msg):
@@ -557,13 +536,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
isNetworkErr := errNetworkError(err)
if isNetworkErr {
p.um.increaseFailureCount(upstreams[n])
if p.um.isDown(upstreams[n]) {
p.um.mu.RLock()
if !p.um.checking[upstreams[n]] {
go p.checkUpstream(upstreams[n], upstreamConfig)
}
p.um.mu.RUnlock()
}
}
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
var e net.Error
@@ -594,16 +566,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
ctrld.Log(ctx, logger, "DNS loop detected")
continue
}
if p.um.isDown(upstreams[n]) {
// never skip the OS resolver, since we usually query this resolver when we
// have no other upstreams to query
if upstreams[n] != upstreamOS {
logger.
Bool("is_os_resolver", upstreams[n] == upstreamOS)
ctrld.Log(ctx, logger, "Upstream is down")
continue
}
}
answer := resolve(n, upstreamConfig, req.msg)
if answer == nil {
if serveStaleCache && staleAnswer != nil {
@@ -651,20 +613,29 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
return res
}
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
if p.leakOnUpstreamFailure() {
p.leakingQueryMu.Lock()
// get the map key as concact of upstreams
if !p.leakingQueryRunning[upstreamMapKey] {
p.leakingQueryRunning[upstreamMapKey] = true
// get a map of the failed upstreams
failedUpstreams := make(map[string]*ctrld.UpstreamConfig)
for n, upstream := range upstreamConfigs {
failedUpstreams[upstreams[n]] = upstream
// if we have no healthy upstreams, trigger recovery flow
if p.recoverOnUpstreamFailure() {
if p.um.countHealthy(upstreams) == 0 {
p.recoveryCancelMu.Lock()
if p.recoveryCancel == nil {
var reason RecoveryReason
if upstreams[0] == upstreamOS {
reason = RecoveryReasonOSFailure
} else {
reason = RecoveryReasonRegularFailure
}
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
go p.handleRecovery(reason)
} else {
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
}
go p.performLeakingQuery(failedUpstreams, upstreamMapKey)
p.recoveryCancelMu.Unlock()
} else {
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
}
p.leakingQueryMu.Unlock()
}
answer := new(dns.Msg)
answer.SetRcode(req.msg, dns.RcodeServerFailure)
res.answer = answer
@@ -994,86 +965,6 @@ func (p *prog) selfUninstallCoolOfPeriod() {
p.selfUninstallMu.Unlock()
}
// performLeakingQuery performs necessary works to leak queries to OS resolver.
// once we store the leakingQuery flag, we are leaking queries to OS resolver
// we then start testing all the upstreams forever, waiting for success, but in parallel
func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) {
mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams)
// Signal dns watchers to stop, so changes made below won't be reverted.
p.leakingQueryMu.Lock()
p.leakingQueryRunning[upstreamMapKey] = true
p.leakingQueryMu.Unlock()
defer func() {
p.leakingQueryMu.Lock()
p.leakingQueryRunning[upstreamMapKey] = false
p.leakingQueryMu.Unlock()
mainLog.Load().Warn().Msg("stop leaking query")
}()
// we only want to reset DNS when our resolver is broken
// this allows us to find the new OS resolver nameservers
// we skip the all upstream lock key to prevent duplicate calls
if p.um.isDown(upstreamOS) && upstreamMapKey != "all" {
mainLog.Load().Debug().Msg("OS resolver is down, reinitializing")
p.reinitializeOSResolver(false)
}
// Test all failed upstreams in parallel
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// if a network change, delay upstream checks by 1s
// this is to ensure we actually leak queries to OS resolver
// We have observed some captive portals leak queries to public upstreams
// This can cause the captive portal on MacOS to not trigger a popup
if upstreamMapKey != "all" {
mainLog.Load().Debug().Msg("network change leaking queries, delaying upstream checks by 1s")
time.Sleep(1 * time.Second)
}
upstreamCh := make(chan string, len(failedUpstreams))
for name, uc := range failedUpstreams {
go func(name string, uc *ctrld.UpstreamConfig) {
for {
select {
case <-ctx.Done():
return
default:
// make sure this upstream is not already being checked
p.um.mu.RLock()
if p.um.checking[name] {
p.um.mu.RUnlock()
continue
}
p.um.mu.RUnlock()
mainLog.Load().Debug().
Str("upstream", name).
Msg("checking upstream")
p.checkUpstream(name, uc)
mainLog.Load().Debug().
Str("upstream", name).
Msg("upstream recovered")
upstreamCh <- name
return
}
}
}(name, uc)
}
// Wait for any upstream to recover
name := <-upstreamCh
mainLog.Load().Info().
Str("upstream", name).
Msg("stopping leak as upstream recovered")
}
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
// and the domain == "cdUID.verify.controld.com"
func (p *prog) forceFetchingAPI(domain string) {
@@ -1245,85 +1136,6 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
return answer
}
// reinitializeOSResolver reinitializes the OS resolver
// by removing ctrld listenr from the interface, collecting the network nameservers
// and re-initializing the OS resolver with the nameservers
// applying listener back to the interface
func (p *prog) reinitializeOSResolver(networkChange bool) {
// Cancel any existing operations.
p.resetCtxMu.Lock()
defer p.resetCtxMu.Unlock()
p.leakingQueryReset.Store(true)
mainLog.Load().Debug().Msg("attempting to reset DNS")
// Remove the listener immediately.
p.resetDNS()
mainLog.Load().Debug().Msg("DNS reset completed")
if networkChange {
// If we're already waiting on a recovery from a previous network change,
// cancel that wait to avoid stale recovery.
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change")
p.recoveryCancel()
p.recoveryCancel = nil
}
ctx, cancel := context.WithCancel(context.Background())
p.recoveryCancel = cancel
p.recoveryCancelMu.Unlock()
// Launch a goroutine that monitors the non-OS upstreams.
go func() {
recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener")
return
}
mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; initializing OS resolver and attaching DNS listener", recoveredUpstream)
// Initialize OS resolver regardless of upstream recovery.
mainLog.Load().Debug().Msg("initializing OS resolver")
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values")
} else {
mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns)
}
p.setDNS()
p.logInterfacesState()
// allow watchers to reset changes
p.leakingQueryReset.Store(false)
// Clear the recovery cancel func as recovery has been achieved.
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
}()
// Optionally flush DNS caches (if needed).
if err := FlushDNSCache(); err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
}
if runtime.GOOS == "darwin" {
// delay putting back the ctrld listener to allow for captive portal to trigger
time.Sleep(5 * time.Second)
}
} else {
// For non-network-change cases, immediately re-enable the listener.
p.setDNS()
p.logInterfacesState()
// allow watchers to reset changes
p.leakingQueryReset.Store(false)
}
}
// FlushDNSCache flushes the DNS cache on macOS.
func FlushDNSCache() error {
// if not macOS, return
@@ -1457,7 +1269,10 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
ctrld.SetDefaultLocalIPv6(ip)
}
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
p.reinitializeOSResolver(true)
if p.recoverOnUpstreamFailure() {
p.handleRecovery(RecoveryReasonNetworkChange)
}
})
mon.Start()
@@ -1551,53 +1366,154 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro
return err
}
// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream
// and returns when the first one recovers.
func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) {
// handleRecovery performs a unified recovery by removing DNS settings,
// canceling existing recovery checks for network changes, but coalescing duplicate
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
// and then re-applying the DNS settings.
func (p *prog) handleRecovery(reason RecoveryReason) {
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
// For network changes, cancel any existing recovery check because the network state has changed.
if reason == RecoveryReasonNetworkChange {
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
p.recoveryCancel()
p.recoveryCancel = nil
}
p.recoveryCancelMu.Unlock()
} else {
// For upstream failures, if a recovery is already in progress, do nothing new.
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
p.recoveryCancelMu.Unlock()
return
}
p.recoveryCancelMu.Unlock()
}
// Create a new recovery context without a fixed timeout.
p.recoveryCancelMu.Lock()
recoveryCtx, cancel := context.WithCancel(context.Background())
p.recoveryCancel = cancel
p.recoveryCancelMu.Unlock()
// Immediately remove our DNS settings from the interface.
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
p.recoveryRunning.Store(true)
p.resetDNS()
// For an OS failure, reinitialize OS resolver nameservers immediately.
if reason == RecoveryReasonOSFailure {
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
} else {
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
}
}
// Build upstream map based on the recovery reason.
upstreams := p.buildRecoveryUpstreams(reason)
// Wait indefinitely until one of the upstreams recovers.
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
return
}
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
// For network changes we also reinitialize the OS resolver.
if reason == RecoveryReasonNetworkChange {
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
} else {
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
}
}
// Apply our DNS settings back and log the interface state.
p.setDNS()
p.logInterfacesState()
// allow watchdogs to put the listener back on the interface if its changed for any reason
p.recoveryRunning.Store(false)
// Clear the recovery cancellation for a clean slate.
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
}
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
// It returns the name of the recovered upstream or an error if the check times out.
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
recoveredCh := make(chan string, 1)
var wg sync.WaitGroup
// Loop over your upstream configuration; skip the OS resolver.
for k, uc := range p.cfg.Upstream {
if uc.Type == ctrld.ResolverTypeOS {
continue
}
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
upstreamName := upstreamPrefix + k
mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName)
for name, uc := range upstreams {
wg.Add(1)
go func(name string, uc *ctrld.UpstreamConfig) {
defer wg.Done()
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
for {
select {
case <-ctx.Done():
mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name)
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
return
default:
// checkUpstreamOnce will reset any failure counters on success.
if err := p.checkUpstreamOnce(name, uc); err == nil {
mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name)
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
select {
case recoveredCh <- name:
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
default:
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
}
return
} else {
mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name)
}
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
time.Sleep(checkUpstreamBackoffSleep)
}
}
}(upstreamName, uc)
}(name, uc)
}
var recovered string
select {
case recovered = <-recoveredCh:
mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered)
case <-ctx.Done():
return "", ctx.Err()
}
wg.Wait()
return recovered, nil
}
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
// For OS failures we supply the manual OS resolver upstream configuration.
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
upstreams := make(map[string]*ctrld.UpstreamConfig)
switch reason {
case RecoveryReasonOSFailure:
upstreams[upstreamOS] = osUpstreamConfig
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
// Use all configured upstreams except any OS type.
for k, uc := range p.cfg.Upstream {
if uc.Type != ctrld.ResolverTypeOS {
upstreams[upstreamPrefix+k] = uc
}
}
}
return upstreams
}

View File

@@ -48,6 +48,17 @@ const (
ctrldServiceName = "ctrld"
)
// RecoveryReason provides context for why we are waiting for recovery.
// recovery involves removing the listener IP from the interface and
// waiting for the upstreams to work before returning
type RecoveryReason int
const (
RecoveryReasonNetworkChange RecoveryReason = iota
RecoveryReasonRegularFailure
RecoveryReasonOSFailure
)
// ControlSocketName returns name for control unix socket.
func ControlSocketName() string {
if isMobile() {
@@ -118,14 +129,9 @@ type prog struct {
loopMu sync.Mutex
loop map[string]bool
leakingQueryMu sync.Mutex
leakingQueryRunning map[string]bool
leakingQueryReset atomic.Bool
resetCtxMu sync.Mutex
recoveryCancelMu sync.Mutex
recoveryCancel context.CancelFunc
recoveryRunning atomic.Bool
started chan struct{}
onStartedDone chan struct{}
@@ -429,7 +435,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
p.leakingQueryRunning = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
@@ -779,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.leakingQueryReset.Load() {
if p.recoveryRunning.Load() {
return
}
if dnsChanged(iface, ns) {
@@ -980,16 +985,10 @@ func findWorkingInterface(currentIface string) string {
return currentIface
}
// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
return *ptr
}
// Default is false on routers, since this leaking is only useful for devices that move between networks.
if router.Name() != "" {
return false
}
return true
// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure.
func (p *prog) recoverOnUpstreamFailure() bool {
// Default is false on routers, since this recovery flow is only useful for devices that move between networks.
return router.Name() == ""
}
func randomLocalIP() string {

View File

@@ -67,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if p.leakingQueryReset.Load() {
if p.recoveryRunning.Load() {
return
}
if !ok {

View File

@@ -145,3 +145,16 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
time.Sleep(checkUpstreamBackoffSleep)
}
}
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
var count int
um.mu.RLock()
defer um.mu.RUnlock()
for _, upstream := range upstreams {
if !um.isDown(upstream) {
count++
}
}
return count
}