diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 8396c19..3dae547 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -43,6 +43,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -82,6 +96,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -765,6 +793,20 @@ func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { Short: "Reload the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -1045,6 +1087,20 @@ func initClientsCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 18ac373..646bafb 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -585,10 +585,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } if p.um.isDown(upstreams[n]) { - logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS) - ctrld.Log(ctx, logger, "Upstream is down") - continue + // never skip the OS resolver, since we usually query this resolver when we + // have no other upstreams to query + if upstreams[n] != upstreamOS { + logger. + Bool("is_os_resolver", upstreams[n] == upstreamOS) + ctrld.Log(ctx, logger, "Upstream is down") + continue + } } answer := resolve(n, upstreamConfig, req.msg) if answer == nil { @@ -1231,67 +1235,43 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M func (p *prog) reinitializeOSResolver(networkChange bool) { // Cancel any existing operations p.resetCtxMu.Lock() - if p.resetCancel != nil { - p.resetCancel() - } - - // Create new context for this operation - ctx, cancel := context.WithCancel(context.Background()) - p.resetCtx = ctx - p.resetCancel = cancel - p.resetCtxMu.Unlock() - - // Ensure cleanup - defer cancel() + defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) defer p.leakingQueryReset.Store(false) - defer func() { - // start leaking queries immediately - if networkChange { - // set all upstreams to failed and provide to performLeakingQuery - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for _, upstream := range p.cfg.Upstream { - failedUpstreams[upstream.Name] = upstream - } - go p.performLeakingQuery(failedUpstreams, "all") + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver() + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + + // start leaking queries immediately// start leaking queries immediately + if networkChange { + // set all upstreams to failed and provide to performLeakingQuery + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for _, upstream := range p.cfg.Upstream { + failedUpstreams[upstream.Name] = upstream } + go p.performLeakingQuery(failedUpstreams, "all") + if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } - }() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("attempting to reset DNS") - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") + if runtime.GOOS == "darwin" { + // delay putting back the ctrld listener to allow for captive portal to trigger + time.Sleep(5 * time.Second) + } } - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } + mainLog.Load().Debug().Msg("setting DNS configuration") + p.setDNS() + mainLog.Load().Debug().Msg("DNS configuration set successfully") + p.logInterfacesState() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") - p.logInterfacesState() - } } // FlushDNSCache flushes the DNS cache on macOS. diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 1bcfe3d..339d984 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -20,9 +20,9 @@ const ( logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute - logStartEndMarker = "\n\n=== START_END ===\n\n" + logStartEndMarker = "\n\n=== INIT_END ===\n\n" logLogEndMarker = "\n\n=== LOG_END ===\n\n" - logWarnEndMarker = "\n\n=== WARN END ===\n\n" + logWarnEndMarker = "\n\n=== WARN_END ===\n\n" ) type logViewResponse struct { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 4c9270c..41dc2c4 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -123,9 +123,7 @@ type prog struct { leakingQueryRunning map[string]bool leakingQueryReset atomic.Bool - resetCtx context.Context - resetCancel context.CancelFunc - resetCtxMu sync.Mutex + resetCtxMu sync.Mutex started chan struct{} onStartedDone chan struct{} diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 2beebd0..b835060 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -1,19 +1,16 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package ctrld import ( "net" - "os/exec" - "runtime" - "strings" "syscall" "golang.org/x/net/route" ) func dnsFns() []dnsFn { - return []dnsFn{dnsFromRIB, dnsFromIPConfig} + return []dnsFn{dnsFromRIB} } func dnsFromRIB() []string { @@ -49,18 +46,6 @@ func dnsFromRIB() []string { return dns } -func dnsFromIPConfig() []string { - if runtime.GOOS != "darwin" { - return nil - } - cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server") - out, _ := cmd.Output() - if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil { - return []string{ip.String()} - } - return nil -} - func toNetIP(addr route.Addr) net.IP { switch t := addr.(type) { case *route.Inet4Addr: diff --git a/nameservers_darwin.go b/nameservers_darwin.go new file mode 100644 index 0000000..bec6ce4 --- /dev/null +++ b/nameservers_darwin.go @@ -0,0 +1,243 @@ +//go:build darwin + +package ctrld + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + "tailscale.com/net/netmon" +) + +func dnsFns() []dnsFn { + return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} +} + +// dnsFromResolvConf reads nameservers from /etc/resolv.conf +func dnsFromResolvConf() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var dns []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + file, err := os.Open("/etc/resolv.conf") + if err != nil { + Log(context.Background(), logger.Error(), "failed to open /etc/resolv.conf (attempt %d/%d)", attempt+1, maxRetries) + continue + } + defer file.Close() + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "nameserver" { + continue + } + if ip := net.ParseIP(fields[1]); ip != nil { + // skip loopback IPs + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + continue + } + } + if !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error reading /etc/resolv.conf (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the file and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return dns +} + +func getDNSFromScutil() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var nameservers []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + cmd := exec.Command("scutil", "--dns") + output, err := cmd.Output() + if err != nil { + Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "nameserver[") { + parts := strings.Split(line, ":") + if len(parts) == 2 { + ns := strings.TrimSpace(parts[1]) + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + isLocal = true + break + } + } + if !isLocal && !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the output and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return nameservers +} + +func getDHCPNameservers(iface string) ([]string, error) { + // Run the ipconfig command for the given interface. + cmd := exec.Command("ipconfig", "getpacket", iface) + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("error running ipconfig: %v", err) + } + + // Look for a line like: + // domain_name_servers = 192.168.1.1 8.8.8.8; + re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`) + matches := re.FindStringSubmatch(string(output)) + if len(matches) < 2 { + return nil, fmt.Errorf("no DHCP nameservers found") + } + + // Split the nameservers by whitespace. + nameservers := strings.Fields(matches[1]) + return nameservers, nil +} + +func getAllDHCPNameservers() []string { + interfaces, err := net.Interfaces() + if err != nil { + return nil + } + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var allNameservers []string + seen := make(map[string]bool) + + for _, iface := range interfaces { + // Skip interfaces that are: + // - down + // - loopback + // - not physical (virtual) + // - point-to-point (like VPN interfaces) + // - without MAC address (non-physical) + if iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagPointToPoint != 0 || + (iface.Flags&net.FlagBroadcast == 0 && + iface.Flags&net.FlagMulticast == 0) || + len(iface.HardwareAddr) == 0 || + strings.HasPrefix(iface.Name, "utun") || + strings.HasPrefix(iface.Name, "llw") || + strings.HasPrefix(iface.Name, "awdl") { + continue + } + + // Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48) + if len(iface.HardwareAddr) != 6 { + continue + } + + nameservers, err := getDHCPNameservers(iface.Name) + if err != nil { + continue + } + + // Add unique nameservers to the result, skipping local IPs + for _, ns := range nameservers { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback and local IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + if ip.String() == v.String() { + isLocal = true + break + } + } + if !isLocal && !seen[ns] { + seen[ns] = true + allNameservers = append(allNameservers, ns) + } + } + } + } + + return allNameservers +}