diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 4934c5a..39a5977 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1029,6 +1029,16 @@ func uninstall(p *prog, s service.Service) { return } p.resetDNS() + + // if present restore the original DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") } diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index e5f655f..bae0cf1 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -541,6 +541,16 @@ func initStopCmd() *cobra.Command { if doTasks([]task{{s.Stop, true}}) { p.router.Cleanup() p.resetDNS() + + // restore DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + if router.WaitProcessExited() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() diff --git a/cmd/cli/control_client.go b/cmd/cli/control_client.go index 73002e8..7382d4e 100644 --- a/cmd/cli/control_client.go +++ b/cmd/cli/control_client.go @@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient { } func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) { + // for log/send, set the timeout to 5 minutes + if path == sendLogsPath { + c.c.Timeout = time.Minute * 5 + } return c.c.Post("http://unix"+path, contentTypeJson, data) } diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index d1daea3..36285e5 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -27,8 +27,8 @@ const ( deactivationPath = "/deactivation" cdPath = "/cd" ifacePath = "/iface" - viewLogsPath = "/logs/view" - sendLogsPath = "/logs/send" + viewLogsPath = "/log/view" + sendLogsPath = "/log/send" ) type ifaceResponse struct { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 4f4b980..b2c0f23 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -542,8 +542,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } + ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n]) + if p.isLoop(upstreamConfig) { - mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint) + mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String()) continue } if p.um.isDown(upstreams[n]) { @@ -929,6 +931,11 @@ func (p *prog) selfUninstallCoolOfPeriod() { // performLeakingQuery performs necessary works to leak queries to OS resolver. func (p *prog) performLeakingQuery() { mainLog.Load().Warn().Msg("leaking query to OS resolver") + + // Create a context with timeout for the entire operation + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Signal dns watchers to stop, so changes made below won't be reverted. p.leakingQuery.Store(true) defer func() { @@ -936,20 +943,81 @@ func (p *prog) performLeakingQuery() { p.leakingQueryMu.Lock() p.leakingQueryWasRun = false p.leakingQueryMu.Unlock() + mainLog.Load().Warn().Msg("stop leaking query") }() - // Reset DNS, so queries are forwarded to OS resolver normally. - p.resetDNS() - // Check remote upstream in background, so ctrld could be back to normal - // operation as long as the network is back online. - for name, uc := range p.cfg.Upstream { - p.checkUpstream(name, uc) + + // Create channels to coordinate operations + resetDone := make(chan struct{}) + checkDone := make(chan struct{}) + + // Reset DNS with timeout + go func() { + defer close(resetDone) + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + }() + + // Wait for reset with timeout + select { + case <-resetDone: + mainLog.Load().Debug().Msg("DNS reset successful") + case <-ctx.Done(): + mainLog.Load().Error().Msg("DNS reset timed out") + return } - // After all upstream back, re-initializing OS resolver. + + // Check upstream in background with progress tracking + go func() { + defer close(checkDone) + mainLog.Load().Debug().Msg("starting upstream checks") + for name, uc := range p.cfg.Upstream { + select { + case <-ctx.Done(): + return + default: + mainLog.Load().Debug(). + Str("upstream", name). + Msg("checking upstream") + p.checkUpstream(name, uc) + } + } + mainLog.Load().Debug().Msg("upstream checks completed") + }() + + // Wait for upstream checks + select { + case <-checkDone: + mainLog.Load().Debug().Msg("upstream checks successful") + case <-ctx.Done(): + mainLog.Load().Error().Msg("upstream checks timed out") + return + } + + // Initialize OS resolver with timeout + mainLog.Load().Debug().Msg("initializing OS resolver") ns := ctrld.InitializeOsResolver() mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) - p.dnsWg.Wait() + + // Wait for DNS operations to complete + waitCh := make(chan struct{}) + go func() { + p.dnsWg.Wait() + close(waitCh) + }() + + select { + case <-waitCh: + mainLog.Load().Debug().Msg("DNS operations completed") + case <-ctx.Done(): + mainLog.Load().Error().Msg("DNS operations timed out") + return + } + + // Set DNS with timeout + mainLog.Load().Debug().Msg("setting DNS configuration") p.setDNS() - mainLog.Load().Warn().Msg("stop leaking query") + mainLog.Load().Debug().Msg("DNS configuration set successfully") } // forceFetchingAPI sends signal to force syncing API config if run in cd mode, diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index f319056..841be76 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { - if ns := savedStaticNameservers(iface); len(ns) > 0 { - if err := setDNS(iface, ns); err == nil { - return nil - } - } cmd := "networksetup" args := []string{"-setdnsservers", iface.Name, "empty"} if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { @@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + if ns := savedStaticNameservers(iface); len(ns) > 0 { + err = setDNS(iface, ns) + } + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index bddffca..72da485 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index ade5881..3f815e8 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) { }) } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(iface *net.Interface) []string { for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} { if ns := fn(iface.Name); len(ns) > 0 { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 5ff9360..990cc57 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -130,8 +130,12 @@ func resetDNS(iface *net.Interface) error { if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { return fmt.Errorf("could not reset DNS ipv6: %w", err) } + return nil +} - // If there's static DNS saved, restoring it. +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { if nss := savedStaticNameservers(iface); len(nss) > 0 { v4ns := make([]string, 0, 2) v6ns := make([]string, 0, 2) @@ -148,12 +152,14 @@ func resetDNS(iface *net.Interface) error { continue } mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) - if err := setDNS(iface, ns); err != nil { + err = setDNS(iface, ns) + + if err != nil { return err } } } - return nil + return err } func currentDNS(iface *net.Interface) []string { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 29c1120..b1fb18b 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -626,9 +626,31 @@ func (p *prog) setDNS() { return } logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() - netIface, err := netInterface(p.runningIface) - if err != nil { - logger.Error().Err(err).Msg("could not get interface") + + const maxDNSRetryAttempts = 3 + const retryDelay = 1 * time.Second + var netIface *net.Interface + var err error + for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ { + netIface, err = netInterface(p.runningIface) + if err == nil { + break + } + if attempt < maxDNSRetryAttempts { + // Try to find a different working interface + newIface := findWorkingInterface(p.runningIface) + if newIface != p.runningIface { + p.runningIface = newIface + logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger.Info().Msg("switched to new interface") + continue + } + + logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...") + time.Sleep(retryDelay) + continue + } + logger.Error().Err(err).Msg("could not get interface after all attempts") return } if err := setupNetworkManager(); err != nil { @@ -766,6 +788,7 @@ func (p *prog) resetDNS() { logger.Error().Err(err).Msg("could not get interface") return } + if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return @@ -781,6 +804,131 @@ func (p *prog) resetDNS() { } } +// findWorkingInterface looks for a network interface with a valid IP configuration +func findWorkingInterface(currentIface string) string { + // Helper to check if IP is valid (not link-local) + isValidIP := func(ip net.IP) bool { + return ip != nil && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !ip.IsLoopback() && + !ip.IsUnspecified() + } + + // Helper to check if interface has valid IP configuration + hasValidIPConfig := func(iface *net.Interface) bool { + if iface == nil || iface.Flags&net.FlagUp == 0 { + return false + } + + addrs, err := iface.Addrs() + if err != nil { + mainLog.Load().Debug(). + Str("interface", iface.Name). + Err(err). + Msg("failed to get interface addresses") + return false + } + + for _, addr := range addrs { + // Check for IP network + if ipNet, ok := addr.(*net.IPNet); ok { + if isValidIP(ipNet.IP) { + return true + } + } + } + return false + } + + // Get default route interface + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + mainLog.Load().Debug(). + Err(err). + Msg("failed to get default route") + } else { + mainLog.Load().Debug(). + Str("default_route_iface", defaultRoute.InterfaceName). + Msg("found default route") + } + + // Get all interfaces + ifaces, err := net.Interfaces() + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + return currentIface // Return current interface as fallback + } + + var firstWorkingIface string + var currentIfaceValid bool + + // Single pass through interfaces + for _, iface := range ifaces { + // Must be physical (has MAC address) + if len(iface.HardwareAddr) == 0 { + continue + } + // Skip interfaces that are: + // - Loopback + // - Not up + // - Point-to-point (like VPN tunnels) + if iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagPointToPoint != 0 { + continue + } + + if !hasValidIPConfig(&iface) { + continue + } + + // Found working physical interface + if err == nil && defaultRoute.InterfaceName == iface.Name { + // Found interface with default route - use it immediately + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", iface.Name). + Msg("switching to interface with default route") + return iface.Name + } + + // Keep track of first working interface as fallback + if firstWorkingIface == "" { + firstWorkingIface = iface.Name + } + + // Check if this is our current interface + if iface.Name == currentIface { + currentIfaceValid = true + } + } + + // Return interfaces in order of preference: + // 1. Current interface if it's still valid + if currentIfaceValid { + mainLog.Load().Debug(). + Str("interface", currentIface). + Msg("keeping current interface") + return currentIface + } + + // 2. First working interface found + if firstWorkingIface != "" { + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", firstWorkingIface). + Msg("switching to first working physical interface") + return firstWorkingIface + } + + // 3. Fall back to current interface if nothing else works + mainLog.Load().Warn(). + Str("current_iface", currentIface). + Msg("no working physical interface found, keeping current") + return currentIface +} + // leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams. func (p *prog) leakOnUpstreamFailure() bool { if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { @@ -1049,7 +1197,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string { func savedStaticNameservers(iface *net.Interface) []string { file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { - return strings.Split(string(data), ",") + saveValues := strings.Split(string(data), ",") + returnValues := []string{} + // check each one, if its in loopback range, remove it + for _, v := range saveValues { + if net.ParseIP(v).IsLoopback() { + continue + } + returnValues = append(returnValues, v) + } + return returnValues } return nil } diff --git a/config.go b/config.go index 4302c5d..c88404c 100644 --- a/config.go +++ b/config.go @@ -886,3 +886,12 @@ func upstreamUID() string { return hex.EncodeToString(b) } } + +// String returns a string representation of the UpstreamConfig for logging. +func (uc *UpstreamConfig) String() string { + if uc == nil { + return "" + } + return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}", + uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack) +} diff --git a/resolver.go b/resolver.go index 82a395e..3189dfb 100644 --- a/resolver.go +++ b/resolver.go @@ -147,16 +147,82 @@ var testNameServerFn = testNameserver // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. func testNameserver(addr string) bool { - msg := new(dns.Msg) - msg.SetQuestion("controld.com.", dns.TypeNS) - client := new(dns.Client) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53")) - if err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr) + // Skip link-local addresses without scope IDs and deprecated site-local addresses + if ip, err := netip.ParseAddr(addr); err == nil { + if ip.Is6() { + if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { + ProxyLogger.Load().Debug(). + Str("nameserver", addr). + Msg("skipping link-local IPv6 address without scope ID") + return false + } + // Skip deprecated site-local addresses (fec0::/10) + if strings.HasPrefix(ip.String(), "fec0:") { + ProxyLogger.Load().Debug(). + Str("nameserver", addr). + Msg("skipping deprecated site-local IPv6 address") + return false + } + } } - return err == nil + + ProxyLogger.Load().Debug(). + Str("input_addr", addr). + Msg("testing nameserver") + + // Handle both IPv4 and IPv6 addresses + serverAddr := addr + host, port, err := net.SplitHostPort(addr) + if err != nil { + // No port in address, add default port 53 + serverAddr = net.JoinHostPort(addr, "53") + } else if port == "" { + // Has split markers but empty port + serverAddr = net.JoinHostPort(host, "53") + } + + ProxyLogger.Load().Debug(). + Str("server_addr", serverAddr). + Msg("using server address") + + // Test domains that are likely to exist and respond quickly + testDomains := []struct { + name string + qtype uint16 + }{ + {".", dns.TypeNS}, // Root NS query - should always work + {"controld.com.", dns.TypeA}, // Fallback to a reliable domain + {"google.com.", dns.TypeA}, // Fallback to a reliable domain + } + + client := &dns.Client{ + Timeout: 2 * time.Second, + Net: "udp", + } + + // Try each test query until one succeeds + for _, test := range testDomains { + msg := new(dns.Msg) + msg.SetQuestion(test.name, test.qtype) + msg.RecursionDesired = true + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) + cancel() + + if err == nil && resp != nil { + return true + } + + ProxyLogger.Load().Error(). + Err(err). + Str("nameserver", serverAddr). + Str("test_domain", test.name). + Str("query_type", dns.TypeToString[test.qtype]). + Msg("DNS availability test failed") + } + + return false } // Resolver is the interface that wraps the basic DNS operations. @@ -222,7 +288,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp"} + dnsClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) @@ -264,11 +330,14 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { case res.server == controldPublicDnsWithPort: - controldSuccessAnswer = res.answer // only use ControlD answer as last one. + Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server) + controldSuccessAnswer = res.answer case !res.lan && publicServerAnswer == nil: - publicServerAnswer = res.answer // use public DNS answer after LAN server.. + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server) + publicServerAnswer = res.answer publicServer = res.server default: + Log(ctx, ProxyLogger.Load().Debug(), "got LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -276,6 +345,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil: nonSuccessAnswer = res.answer nonSuccessServer = res.server + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + res.server, res.answer.Rcode) } errs = append(errs, res.err) }