diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index b2c0f23..341a830 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" @@ -77,6 +78,12 @@ type upstreamForResult struct { } func (p *prog) serveDNS(listenerNum string) error { + // Start network monitoring + if err := p.monitorNetworkChanges(); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + // Don't return here as we still want DNS service to run + } + listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -418,11 +425,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) + upstreamMapKey := strings.Join(upstreams, "_") + leaked := false - if len(upstreamConfigs) > 0 && p.leakingQuery.Load() { - upstreamConfigs = nil - leaked = true - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + if len(upstreamConfigs) > 0 { + p.leakingQueryMu.Lock() + if p.leakingQueryRunning[upstreamMapKey] { + upstreamConfigs = nil + leaked = true + ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + } + p.leakingQueryMu.Unlock() } if len(upstreamConfigs) == 0 { @@ -601,9 +614,15 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) if p.leakOnUpstreamFailure() { p.leakingQueryMu.Lock() - if !p.leakingQueryWasRun { - p.leakingQueryWasRun = true - go p.performLeakingQuery() + // get the map key as concact of upstreams + if !p.leakingQueryRunning[upstreamMapKey] { + p.leakingQueryRunning[upstreamMapKey] = true + // get a map of the failed upstreams + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for n, upstream := range upstreamConfigs { + failedUpstreams[upstreams[n]] = upstream + } + go p.performLeakingQuery(failedUpstreams, upstreamMapKey) } p.leakingQueryMu.Unlock() } @@ -929,95 +948,66 @@ func (p *prog) selfUninstallCoolOfPeriod() { } // performLeakingQuery performs necessary works to leak queries to OS resolver. -func (p *prog) performLeakingQuery() { - mainLog.Load().Warn().Msg("leaking query to OS resolver") +// once we store the leakingQuery flag, we are leaking queries to OS resolver +// we then start testing all the upstreams forever, waiting for success, but in parallel +func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) { - // Create a context with timeout for the entire operation - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams) // Signal dns watchers to stop, so changes made below won't be reverted. - p.leakingQuery.Store(true) + p.leakingQueryMu.Lock() + p.leakingQueryRunning[upstreamMapKey] = true + p.leakingQueryMu.Unlock() defer func() { - p.leakingQuery.Store(false) p.leakingQueryMu.Lock() - p.leakingQueryWasRun = false + p.leakingQueryRunning[upstreamMapKey] = false p.leakingQueryMu.Unlock() mainLog.Load().Warn().Msg("stop leaking query") }() - // Create channels to coordinate operations - resetDone := make(chan struct{}) - checkDone := make(chan struct{}) + // we only want to reset DNS when our resolver is broken + // this allows us to find the new OS resolver nameservers + if p.um.isDown(upstreamOS) { - // Reset DNS with timeout - go func() { - defer close(resetDone) - mainLog.Load().Debug().Msg("attempting to reset DNS") - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") - }() + mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") + p.reinitializeOSResolver() - // Wait for reset with timeout - select { - case <-resetDone: - mainLog.Load().Debug().Msg("DNS reset successful") - case <-ctx.Done(): - mainLog.Load().Error().Msg("DNS reset timed out") - return } - // Check upstream in background with progress tracking - go func() { - defer close(checkDone) - mainLog.Load().Debug().Msg("starting upstream checks") - for name, uc := range p.cfg.Upstream { - select { - case <-ctx.Done(): - return - default: - mainLog.Load().Debug(). - Str("upstream", name). - Msg("checking upstream") - p.checkUpstream(name, uc) + // Test all failed upstreams in parallel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + upstreamCh := make(chan string, len(failedUpstreams)) + for name, uc := range failedUpstreams { + go func(name string, uc *ctrld.UpstreamConfig) { + mainLog.Load().Debug(). + Str("upstream", name). + Msg("checking upstream") + + for { + select { + case <-ctx.Done(): + return + default: + p.checkUpstream(name, uc) + mainLog.Load().Debug(). + Str("upstream", name). + Msg("upstream recovered") + upstreamCh <- name + return + } } - } - mainLog.Load().Debug().Msg("upstream checks completed") - }() - - // Wait for upstream checks - select { - case <-checkDone: - mainLog.Load().Debug().Msg("upstream checks successful") - case <-ctx.Done(): - mainLog.Load().Error().Msg("upstream checks timed out") - return + }(name, uc) } - // Initialize OS resolver with timeout - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) + // Wait for any upstream to recover + name := <-upstreamCh - // Wait for DNS operations to complete - waitCh := make(chan struct{}) - go func() { - p.dnsWg.Wait() - close(waitCh) - }() + mainLog.Load().Info(). + Str("upstream", name). + Msg("stopping leak as upstream recovered") - select { - case <-waitCh: - mainLog.Load().Debug().Msg("DNS operations completed") - case <-ctx.Done(): - mainLog.Load().Error().Msg("DNS operations timed out") - return - } - - // Set DNS with timeout - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") } // forceFetchingAPI sends signal to force syncing API config if run in cd mode, @@ -1190,3 +1180,157 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M answer.SetReply(m) return answer } + +// reinitializeOSResolver reinitializes the OS resolver +// by removing ctrld listenr from the interface, collecting the network nameservers +// and re-initializing the OS resolver with the nameservers +// applying listener back to the interface +func (p *prog) reinitializeOSResolver() { + // Cancel any existing operations + p.resetCtxMu.Lock() + if p.resetCancel != nil { + p.resetCancel() + } + + // Create new context for this operation + ctx, cancel := context.WithCancel(context.Background()) + p.resetCtx = ctx + p.resetCancel = cancel + p.resetCtxMu.Unlock() + + // Ensure cleanup + defer cancel() + + p.leakingQueryReset.Store(true) + defer p.leakingQueryReset.Store(false) + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + } + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver() + mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) + } + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("setting DNS configuration") + p.setDNS() + mainLog.Load().Debug().Msg("DNS configuration set successfully") + } +} + +// monitorNetworkChanges starts monitoring for network interface changes +func (p *prog) monitorNetworkChanges() error { + // Create network monitor + mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating network monitor: %w", err) + } + + mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + // Get map of valid interfaces + validIfaces := validInterfacesMap() + + // Parse old and new interface states + oldIfs := parseInterfaceState(delta.Old) + newIfs := parseInterfaceState(delta.New) + + // Check for changes in valid interfaces + changed := false + activeInterfaceExists := false + + for ifaceName := range validIfaces { + + oldState, oldExists := oldIfs[strings.ToLower(ifaceName)] + newState, newExists := newIfs[strings.ToLower(ifaceName)] + + if newState != "" && newState != "down" { + activeInterfaceExists = true + } + + if oldExists != newExists || oldState != newState { + changed = true + mainLog.Load().Debug(). + Str("interface", ifaceName). + Str("old_state", oldState). + Str("new_state", newState). + Msg("Valid interface changed state") + break + } else { + mainLog.Load().Debug(). + Str("interface", ifaceName). + Str("old_state", oldState). + Str("new_state", newState). + Msg("Valid interface unchanged") + } + } + + if !changed { + mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected") + return + } + + mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New) + if activeInterfaceExists { + p.reinitializeOSResolver() + } else { + mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + } + }) + + mon.Start() + mainLog.Load().Debug().Msg("Network monitor started") + return nil +} + +// parseInterfaceState parses the interface state string into a map of interface name -> state +func parseInterfaceState(state *netmon.State) map[string]string { + if state == nil { + return nil + } + + result := make(map[string]string) + + // Extract ifs={...} section + stateStr := state.String() + ifsStart := strings.Index(stateStr, "ifs={") + if ifsStart == -1 { + return result + } + + ifsStr := stateStr[ifsStart+5:] + ifsEnd := strings.Index(ifsStr, "}") + if ifsEnd == -1 { + return result + } + + // Parse each interface entry + ifaces := strings.Split(ifsStr[:ifsEnd], " ") + for _, iface := range ifaces { + parts := strings.Split(iface, ":") + if len(parts) != 2 { + continue + } + name := strings.ToLower(parts[0]) + state := parts[1] + result[name] = state + } + + return result +} diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index ec7e517..6233161 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -17,9 +17,8 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) { patched := false if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { - iface.Name = name - mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface") patched = true + iface.Name = name } return patched, nil } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 6aa95b1..a68dad2 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -115,9 +115,13 @@ type prog struct { loopMu sync.Mutex loop map[string]bool - leakingQueryMu sync.Mutex - leakingQueryWasRun bool - leakingQuery atomic.Bool + leakingQueryMu sync.Mutex + leakingQueryRunning map[string]bool + leakingQueryReset atomic.Bool + + resetCtx context.Context + resetCancel context.CancelFunc + resetCtxMu sync.Mutex started chan struct{} onStartedDone chan struct{} @@ -420,6 +424,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) + p.leakingQueryRunning = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil @@ -737,12 +742,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces if !requiredMultiNICsConfig() { return } + logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + logger.Debug().Msg("start DNS settings watchdog") - mainLog.Load().Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) ticker := time.NewTicker(p.dnsWatchdogDuration()) - logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + for { select { case <-p.dnsWatcherStopCh: @@ -751,7 +757,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQuery.Load() { + if p.leakingQueryReset.Load() { return } if dnsChanged(iface, ns) { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6df7be6..367ffe7 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQuery.Load() { + if p.leakingQueryReset.Load() { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 1f3484b..e37db4d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -44,10 +44,6 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { // increaseFailureCount increase failed queries count for an upstream by 1. func (um *upstreamMonitor) increaseFailureCount(upstream string) { - // Do not count "upstream.os", since it must not be down for leaking queries. - if upstream == upstreamOS { - return - } um.mu.Lock() defer um.mu.Unlock() diff --git a/resolver.go b/resolver.go index 3189dfb..0097fe0 100644 --- a/resolver.go +++ b/resolver.go @@ -78,9 +78,7 @@ func availableNameservers() []string { if _, ok := machineIPsMap[ns]; ok { continue } - if testNameServerFn(ns) { - nss = append(nss, ns) - } + nss = append(nss, ns) } return nss } @@ -100,11 +98,9 @@ func InitializeOsResolver() []string { // - First available LAN servers are saved and store. // - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { - var ( - lanNss []string - publicNss []string - ) + var lanNss, publicNss []string + // First categorize servers for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { @@ -117,28 +113,84 @@ func initializeOsResolver(servers []string) []string { publicNss = append(publicNss, server) } } + + // Store initial servers immediately if len(lanNss) > 0 { - // Saved first initialized LAN servers. or.initializedLanServers.CompareAndSwap(nil, &lanNss) - } - if len(lanNss) == 0 { - var nss []string - p := or.initializedLanServers.Load() - if p != nil { - for _, ns := range *p { - if testNameServerFn(ns) { - nss = append(nss, ns) - } - } - } - or.lanServers.Store(&nss) - } else { or.lanServers.Store(&lanNss) } + if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) + publicNss = []string{controldPublicDnsWithPort} } or.publicServers.Store(&publicNss) + + // Test servers in background and remove failures + go func() { + // Test servers in parallel but maintain order + type result struct { + index int + server string + valid bool + } + + testServers := func(servers []string) []string { + if len(servers) == 0 { + return nil + } + + results := make(chan result, len(servers)) + var wg sync.WaitGroup + + for i, server := range servers { + wg.Add(1) + go func(idx int, s string) { + defer wg.Done() + results <- result{ + index: idx, + server: s, + valid: testNameServerFn(s), + } + }(i, server) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results maintaining original order + validServers := make([]string, 0, len(servers)) + ordered := make([]result, 0, len(servers)) + for r := range results { + ordered = append(ordered, r) + } + slices.SortFunc(ordered, func(a, b result) int { + return a.index - b.index + }) + for _, r := range ordered { + if r.valid { + validServers = append(validServers, r.server) + } else { + ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") + } + } + return validServers + } + + // Test and update LAN servers + if validLanNss := testServers(lanNss); len(validLanNss) > 0 { + or.lanServers.Store(&validLanNss) + } + + // Test and update public servers + validPublicNss := testServers(publicNss) + if len(validPublicNss) == 0 { + validPublicNss = []string{controldPublicDnsWithPort} + } + or.publicServers.Store(&validPublicNss) + }() + return slices.Concat(lanNss, publicNss) } @@ -192,7 +244,6 @@ func testNameserver(addr string) bool { }{ {".", dns.TypeNS}, // Root NS query - should always work {"controld.com.", dns.TypeA}, // Fallback to a reliable domain - {"google.com.", dns.TypeA}, // Fallback to a reliable domain } client := &dns.Client{ @@ -330,10 +381,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { case res.server == controldPublicDnsWithPort: - Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server) controldSuccessAnswer = res.answer case !res.lan && publicServerAnswer == nil: - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server) publicServerAnswer = res.answer publicServer = res.server default: @@ -351,14 +400,17 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error errs = append(errs, res.err) } if publicServerAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer) logAnswer(publicServer) return publicServerAnswer, nil } if controldSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } diff --git a/resolver_test.go b/resolver_test.go index 7eab744..de8cca0 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -3,13 +3,10 @@ package ctrld import ( "context" "net" - "slices" "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/miekg/dns" ) @@ -178,71 +175,6 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H return server, addr, nil } -func Test_initializeOsResolver(t *testing.T) { - testNameServerFn = testNameserverTest - lanServer1 := "192.168.1.1" - lanServer1WithPort := net.JoinHostPort("192.168.1.1", "53") - lanServer2 := "10.0.10.69" - lanServer2WithPort := net.JoinHostPort("10.0.10.69", "53") - lanServer3 := "192.168.40.1" - lanServer3WithPort := net.JoinHostPort("192.168.40.1", "53") - wanServer := "1.1.1.1" - lanServers := []string{lanServer1WithPort, lanServer2WithPort} - publicServers := []string{net.JoinHostPort(wanServer, "53")} - - or = newResolverWithNameserver(defaultNameservers()) - - // First initialization, initialized servers are saved. - initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) - p := or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No new LAN servers, but lanServer2 gone, initialized servers not changed. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer1WithPort})) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // New LAN servers, they are used, initialized servers not changed. - initializeOsResolver([]string{lanServer3, wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer3WithPort})) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No LAN server available, initialized servers will be used. - initializeOsResolver([]string{wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No Public server, ControlD Public DNS will be used. - initializeOsResolver([]string{}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort})) - - // No LAN server available, initialized servers is unavailable, nothing will be used. - nonSuccessTestServerMap[lanServer1WithPort] = true - nonSuccessTestServerMap[lanServer2WithPort] = true - initializeOsResolver([]string{wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.Empty(t, *or.lanServers.Load()) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) -} - func successHandler() dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) @@ -258,9 +190,3 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { w.WriteMsg(m) } } - -var nonSuccessTestServerMap = map[string]bool{} - -func testNameserverTest(addr string) bool { - return !nonSuccessTestServerMap[addr] -}