From 2d3779ec27ce59e2b3fc07c4aa039a59111b06ed Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 4 Feb 2025 18:38:48 -0500 Subject: [PATCH] fix MacOS nameserver detection, fix not installed errors for commands copy fix get valid ifaces in nameservers_bsd nameservers on MacOS can be found in resolv.conf reliably nameservers on MacOS can be found in resolv.conf reliably exclude local IPs from MacOS resolve conf check use scutil for MacOS, simplify reinit logic to prevent duplicate calls add more dns server fetching options never skip OS resolver in IsDown check split dsb and darwin nameserver methods, add delay for setting DNS on interface on network change. increase delay to 5s but only on MacOS --- cmd/cli/commands.go | 56 ++++++++++ cmd/cli/dns_proxy.go | 86 ++++++--------- cmd/cli/log_writer.go | 4 +- cmd/cli/prog.go | 4 +- nameservers_bsd.go | 19 +--- nameservers_darwin.go | 243 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 337 insertions(+), 75 deletions(-) create mode 100644 nameservers_darwin.go diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 8396c19..3dae547 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -43,6 +43,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -82,6 +96,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -765,6 +793,20 @@ func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { Short: "Reload the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -1045,6 +1087,20 @@ func initClientsCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 18ac373..646bafb 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -585,10 +585,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } if p.um.isDown(upstreams[n]) { - logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS) - ctrld.Log(ctx, logger, "Upstream is down") - continue + // never skip the OS resolver, since we usually query this resolver when we + // have no other upstreams to query + if upstreams[n] != upstreamOS { + logger. + Bool("is_os_resolver", upstreams[n] == upstreamOS) + ctrld.Log(ctx, logger, "Upstream is down") + continue + } } answer := resolve(n, upstreamConfig, req.msg) if answer == nil { @@ -1231,67 +1235,43 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M func (p *prog) reinitializeOSResolver(networkChange bool) { // Cancel any existing operations p.resetCtxMu.Lock() - if p.resetCancel != nil { - p.resetCancel() - } - - // Create new context for this operation - ctx, cancel := context.WithCancel(context.Background()) - p.resetCtx = ctx - p.resetCancel = cancel - p.resetCtxMu.Unlock() - - // Ensure cleanup - defer cancel() + defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) defer p.leakingQueryReset.Store(false) - defer func() { - // start leaking queries immediately - if networkChange { - // set all upstreams to failed and provide to performLeakingQuery - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for _, upstream := range p.cfg.Upstream { - failedUpstreams[upstream.Name] = upstream - } - go p.performLeakingQuery(failedUpstreams, "all") + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver() + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + + // start leaking queries immediately// start leaking queries immediately + if networkChange { + // set all upstreams to failed and provide to performLeakingQuery + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for _, upstream := range p.cfg.Upstream { + failedUpstreams[upstream.Name] = upstream } + go p.performLeakingQuery(failedUpstreams, "all") + if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } - }() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("attempting to reset DNS") - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") + if runtime.GOOS == "darwin" { + // delay putting back the ctrld listener to allow for captive portal to trigger + time.Sleep(5 * time.Second) + } } - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } + mainLog.Load().Debug().Msg("setting DNS configuration") + p.setDNS() + mainLog.Load().Debug().Msg("DNS configuration set successfully") + p.logInterfacesState() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") - p.logInterfacesState() - } } // FlushDNSCache flushes the DNS cache on macOS. diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 1bcfe3d..339d984 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -20,9 +20,9 @@ const ( logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute - logStartEndMarker = "\n\n=== START_END ===\n\n" + logStartEndMarker = "\n\n=== INIT_END ===\n\n" logLogEndMarker = "\n\n=== LOG_END ===\n\n" - logWarnEndMarker = "\n\n=== WARN END ===\n\n" + logWarnEndMarker = "\n\n=== WARN_END ===\n\n" ) type logViewResponse struct { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 4c9270c..41dc2c4 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -123,9 +123,7 @@ type prog struct { leakingQueryRunning map[string]bool leakingQueryReset atomic.Bool - resetCtx context.Context - resetCancel context.CancelFunc - resetCtxMu sync.Mutex + resetCtxMu sync.Mutex started chan struct{} onStartedDone chan struct{} diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 2beebd0..b835060 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -1,19 +1,16 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package ctrld import ( "net" - "os/exec" - "runtime" - "strings" "syscall" "golang.org/x/net/route" ) func dnsFns() []dnsFn { - return []dnsFn{dnsFromRIB, dnsFromIPConfig} + return []dnsFn{dnsFromRIB} } func dnsFromRIB() []string { @@ -49,18 +46,6 @@ func dnsFromRIB() []string { return dns } -func dnsFromIPConfig() []string { - if runtime.GOOS != "darwin" { - return nil - } - cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server") - out, _ := cmd.Output() - if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil { - return []string{ip.String()} - } - return nil -} - func toNetIP(addr route.Addr) net.IP { switch t := addr.(type) { case *route.Inet4Addr: diff --git a/nameservers_darwin.go b/nameservers_darwin.go new file mode 100644 index 0000000..bec6ce4 --- /dev/null +++ b/nameservers_darwin.go @@ -0,0 +1,243 @@ +//go:build darwin + +package ctrld + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + "tailscale.com/net/netmon" +) + +func dnsFns() []dnsFn { + return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} +} + +// dnsFromResolvConf reads nameservers from /etc/resolv.conf +func dnsFromResolvConf() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var dns []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + file, err := os.Open("/etc/resolv.conf") + if err != nil { + Log(context.Background(), logger.Error(), "failed to open /etc/resolv.conf (attempt %d/%d)", attempt+1, maxRetries) + continue + } + defer file.Close() + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "nameserver" { + continue + } + if ip := net.ParseIP(fields[1]); ip != nil { + // skip loopback IPs + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + continue + } + } + if !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error reading /etc/resolv.conf (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the file and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return dns +} + +func getDNSFromScutil() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var nameservers []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + cmd := exec.Command("scutil", "--dns") + output, err := cmd.Output() + if err != nil { + Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "nameserver[") { + parts := strings.Split(line, ":") + if len(parts) == 2 { + ns := strings.TrimSpace(parts[1]) + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + isLocal = true + break + } + } + if !isLocal && !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the output and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return nameservers +} + +func getDHCPNameservers(iface string) ([]string, error) { + // Run the ipconfig command for the given interface. + cmd := exec.Command("ipconfig", "getpacket", iface) + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("error running ipconfig: %v", err) + } + + // Look for a line like: + // domain_name_servers = 192.168.1.1 8.8.8.8; + re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`) + matches := re.FindStringSubmatch(string(output)) + if len(matches) < 2 { + return nil, fmt.Errorf("no DHCP nameservers found") + } + + // Split the nameservers by whitespace. + nameservers := strings.Fields(matches[1]) + return nameservers, nil +} + +func getAllDHCPNameservers() []string { + interfaces, err := net.Interfaces() + if err != nil { + return nil + } + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var allNameservers []string + seen := make(map[string]bool) + + for _, iface := range interfaces { + // Skip interfaces that are: + // - down + // - loopback + // - not physical (virtual) + // - point-to-point (like VPN interfaces) + // - without MAC address (non-physical) + if iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagPointToPoint != 0 || + (iface.Flags&net.FlagBroadcast == 0 && + iface.Flags&net.FlagMulticast == 0) || + len(iface.HardwareAddr) == 0 || + strings.HasPrefix(iface.Name, "utun") || + strings.HasPrefix(iface.Name, "llw") || + strings.HasPrefix(iface.Name, "awdl") { + continue + } + + // Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48) + if len(iface.HardwareAddr) != 6 { + continue + } + + nameservers, err := getDHCPNameservers(iface.Name) + if err != nil { + continue + } + + // Add unique nameservers to the result, skipping local IPs + for _, ns := range nameservers { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback and local IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + if ip.String() == v.String() { + isLocal = true + break + } + } + if !isLocal && !seen[ns] { + seen[ns] = true + allNameservers = append(allNameservers, ns) + } + } + } + } + + return allNameservers +}