diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index ac808db..25e3e53 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "net" + "net/http" "net/netip" + "os/exec" "runtime" "slices" "strconv" @@ -42,7 +44,7 @@ const ( var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: ctrld.ResolverTypeOS, - Timeout: 2000, + Timeout: 3000, } var privateUpstreamConfig = &ctrld.UpstreamConfig{ @@ -436,10 +438,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { leaked := false if len(upstreamConfigs) > 0 { p.leakingQueryMu.Lock() - if p.leakingQueryRunning[upstreamMapKey] { + if p.leakingQueryRunning[upstreamMapKey] || p.leakingQueryRunning["all"] { upstreamConfigs = nil leaked = true - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + if p.leakingQueryRunning["all"] { + ctrld.Log(ctx, mainLog.Load().Debug(), "all upstreams marked down for network change, leaking query to OS resolver") + } else { + ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + } } p.leakingQueryMu.Unlock() } @@ -576,13 +582,13 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { Bool("is_lan_query", isLanOrPtrQuery) if p.isLoop(upstreamConfig) { - logger.Msg("DNS loop detected") + ctrld.Log(ctx, logger, "DNS loop detected") continue } if p.um.isDown(upstreams[n]) { logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS). - Msg("Upstream is down") + Bool("is_os_resolver", upstreams[n] == upstreamOS) + ctrld.Log(ctx, logger, "Upstream is down") continue } answer := resolve(n, upstreamConfig, req.msg) @@ -995,10 +1001,11 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon // we only want to reset DNS when our resolver is broken // this allows us to find the new OS resolver nameservers - if p.um.isDown(upstreamOS) { + // we skip the all upstream lock key to prevent duplicate calls + if p.um.isDown(upstreamOS) && upstreamMapKey != "all" { mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") - p.reinitializeOSResolver() + p.reinitializeOSResolver(false) } @@ -1006,6 +1013,15 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // if a network change, delay upstream checks by 1s + // this is to ensure we actually leak queries to OS resolver + // We have observed some captive portals leak queries to public upstreams + // This can cause the captive portal on MacOS to not trigger a popup + if upstreamMapKey != "all" { + mainLog.Load().Debug().Msg("network change leaking queries, delaying upstream checks by 1s") + time.Sleep(1 * time.Second) + } + upstreamCh := make(chan string, len(failedUpstreams)) for name, uc := range failedUpstreams { go func(name string, uc *ctrld.UpstreamConfig) { @@ -1213,7 +1229,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M // by removing ctrld listenr from the interface, collecting the network nameservers // and re-initializing the OS resolver with the nameservers // applying listener back to the interface -func (p *prog) reinitializeOSResolver() { +func (p *prog) reinitializeOSResolver(networkChange bool) { // Cancel any existing operations p.resetCtxMu.Lock() if p.resetCancel != nil { @@ -1232,6 +1248,21 @@ func (p *prog) reinitializeOSResolver() { p.leakingQueryReset.Store(true) defer p.leakingQueryReset.Store(false) + defer func() { + // start leaking queries immediately + if networkChange { + // set all upstreams to fialed and provide to performLeakingQuery + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for _, upstream := range p.cfg.Upstream { + failedUpstreams[upstream.Name] = upstream + } + go p.performLeakingQuery(failedUpstreams, "all") + } + if err := FlushDNSCache(); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") + } + }() + select { case <-ctx.Done(): mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") @@ -1264,6 +1295,51 @@ func (p *prog) reinitializeOSResolver() { } } +func triggerCaptiveCheck() { + // Wait for a short period to ensure DNS reinitialization is complete. + time.Sleep(2 * time.Second) + + // if not Mac OS, return + if runtime.GOOS != "darwin" { + return + } + + // Trigger a lookup for captive.apple.com. + // This can be done either via a DNS query or an HTTP GET. + // Here we use a simple HTTP GET which is what macOS CaptiveNetworkAssistant uses. + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Get("http://captive.apple.com/generate_204") + if err != nil { + mainLog.Load().Debug().Msg("failed to trigger captive portal check") + return + } + resp.Body.Close() + mainLog.Load().Debug().Msg("triggered captive portal check by querying captive.apple.com") +} + +// FlushDNSCache flushes the DNS cache on macOS. +func FlushDNSCache() error { + // if not Mac OS, return + if runtime.GOOS != "darwin" { + return nil + } + + // Flush the DNS cache via mDNSResponder. + // This is typically needed on modern macOS systems. + if err := exec.Command("sudo", "killall", "-HUP", "mDNSResponder").Run(); err != nil { + return fmt.Errorf("failed to flush mDNSResponder: %w", err) + } + + // Optionally, flush the directory services cache. + if err := exec.Command("sudo", "dscacheutil", "-flushcache").Run(); err != nil { + return fmt.Errorf("failed to flush dscacheutil: %w", err) + } + + return nil +} + // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges() error { mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) @@ -1321,7 +1397,7 @@ func (p *prog) monitorNetworkChanges() error { } if activeInterfaceExists { - p.reinitializeOSResolver() + p.reinitializeOSResolver(true) } else { mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") } diff --git a/resolver.go b/resolver.go index f036967..34a6cdd 100644 --- a/resolver.go +++ b/resolver.go @@ -237,7 +237,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} + dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers)