diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 986e069..be8b5af 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -418,7 +418,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := p.router.Cleanup(); err != nil { mainLog.Load().Error().Err(err).Msg("could not cleanup router") } - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) }) } } @@ -1030,7 +1031,8 @@ func uninstall(p *prog, s service.Service) { mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error") return } - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) // if present restore the original DNS settings if netIface, err := netInterface(p.runningIface); err == nil { @@ -1779,12 +1781,14 @@ func resetDnsNoLog(p *prog) { if verbose < 3 { lvl := zerolog.GlobalLevel() zerolog.SetGlobalLevel(zerolog.Disabled) - p.resetDNS() + // This is startup so interface settings may have changed + p.resetDNS(true, true) zerolog.SetGlobalLevel(lvl) return } // For debugging purpose, still emit log. - p.resetDNS() + // This is startup so interface settings may have changed + p.resetDNS(true, true) } // resetDnsTask returns a task which perform reset DNS operation. @@ -1806,10 +1810,10 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe } p.runningIface = iface if isCtrldInstalled { - mainLog.Load().Debug().Msg("restore system DNS settings") if status, _ := s.Status(); status == service.StatusRunning { mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe") } + mainLog.Load().Debug().Msg("Start resetDNS") resetDnsNoLog(p) } iface = oldIface @@ -1868,8 +1872,8 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { logger.Warn().Err(err).Msg("failed to create new service") return false } - - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 49dfb8f..2bfe71e 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -355,7 +355,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c }, false, "Save current DNS"}, {func() error { return ConfigureWindowsServiceFailureActions(ctrldServiceName) - }, false, "Configure Windows service failure actions"}, + }, false, "Configure service failure actions"}, {s.Start, true, "Start"}, {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } @@ -608,7 +608,8 @@ func initStopCmd() *cobra.Command { } if doTasks([]task{{s.Stop, true, "Stop"}}) { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) // restore DNS settings if netIface, err := netInterface(p.runningIface); err == nil { @@ -714,7 +715,8 @@ func initRestartCmd() *cobra.Command { {s.Stop, true, "Stop"}, {func() error { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) return nil }, false, "Cleanup"}, {func() error { @@ -994,13 +996,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if os.IsNotExist(err) { continue } - mainLog.Load().Warn().Err(err).Msg("failed to remove file") + mainLog.Load().Warn().Err(err).Msgf("failed to remove file: %s", file) } else { mainLog.Load().Debug().Msgf("file removed: %s", file) } } if err := selfDeleteExe(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to remove file") + mainLog.Load().Warn().Err(err).Msg("failed to delete ctrld binary") } else { if !supportedSelfDelete { mainLog.Load().Debug().Msgf("file removed: %s", bin) @@ -1266,7 +1268,8 @@ func initUpgradeCmd() *cobra.Command { {s.Stop, true, "Stop"}, {func() error { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) return nil }, false, "Cleanup"}, {func() error { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 799dc58..694131d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -20,7 +20,6 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" @@ -1179,7 +1178,10 @@ func FlushDNSCache() error { // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges(ctx context.Context) error { - mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) + mon, err := netmon.New(func(format string, args ...any) { + // Always fetch the latest logger (and inject the prefix) + mainLog.Load().Printf("netmon: "+format, args...) + }) if err != nil { return fmt.Errorf("creating network monitor: %w", err) } @@ -1457,7 +1459,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // Immediately remove our DNS settings from the interface. // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface p.recoveryRunning.Store(true) - p.resetDNS() + // we do not want to restore any static DNS settings + // we must try to get the DHCP values, any static DNS settings + // will be appended to nameservers from the saved interface values + p.resetDNS(false, false) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index bafc8a4..df004ab 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -268,7 +268,7 @@ func (p *prog) preRun() { if runtime.GOOS == "darwin" { p.onStopped = append(p.onStopped, func() { if !service.Interactive() { - p.resetDNS() + p.resetDNS(false, true) } }) } @@ -276,7 +276,7 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { - p.resetDNS() + p.resetDNS(false, true) ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() @@ -809,7 +809,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces } } -func (p *prog) resetDNS() { +// resetDNS performs a DNS reset on the running interface. +// The parameter isStart indicates whether this is being called as part of a start (or restart) +// command. When true, we check if the current static DNS configuration already differs from the +// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's +// static DNS settings and we do not override them using the potentially out-of-date saved file. +// Otherwise, we restore the saved configuration (if any) or reset to DHCP. +func (p *prog) resetDNS(isStart bool, restoreStatic bool) { if p.runningIface == "" { mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") return @@ -822,17 +828,47 @@ func (p *prog) resetDNS() { logger.Error().Err(err).Msg("could not get interface") return } - if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return } - logger.Debug().Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - logger.Error().Err(err).Msgf("could not reset DNS") - return + + // If starting, check the current static DNS configuration. + if isStart { + current, err := currentStaticDNS(netIface) + if err != nil { + logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config") + } else if len(current) > 0 { + // If any static DNS value is not our own listener, assume an admin override. + hasManualConfig := false + for _, ns := range current { + if ns != "127.0.0.1" { + hasManualConfig = true + break + } + } + if hasManualConfig { + logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current) + return + } + } + } + + // Default logic: if there is a saved static DNS configuration, restore it. + saved := savedStaticNameservers(netIface) + if len(saved) > 0 && restoreStatic { + logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved) + if err := setDNS(netIface, saved); err != nil { + logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name) + return + } + } else { + logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name) + if err := resetDNS(netIface); err != nil { + logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name) + return + } } - logger.Debug().Msg("Restoring DNS successfully") if allIfaces { withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface) } diff --git a/nameservers_darwin.go b/nameservers_darwin.go index d536d78..3aef77d 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -155,6 +155,8 @@ func getDHCPNameservers(iface string) ([]string, error) { } func getAllDHCPNameservers() []string { + logger := *ProxyLogger.Load() + interfaces, err := net.Interfaces() if err != nil { return nil @@ -213,5 +215,32 @@ func getAllDHCPNameservers() []string { } } + // if we have static DNS servers saved for the current default route, we should add them to the list + drIfaceName, err := netmon.DefaultRouteInterface() + Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get default route interface: %v", err) + } else { + drIface, err := net.InterfaceByName(drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get interface by name %s: %v", drIfaceName, err) + } else if drIface != nil { + if _, err := patchNetIfaceName(drIface); err != nil { + Log(context.Background(), logger.Debug(), + "Failed to patch interface name %s: %v", drIfaceName, err) + } + staticNs, file := SavedStaticNameservers(drIface) + Log(context.Background(), logger.Debug(), + "static dns servers from %s: %v", file, staticNs) + if len(staticNs) > 0 { + Log(context.Background(), logger.Debug(), + "Adding static DNS servers from %s: %v", drIface.Name, staticNs) + allNameservers = append(allNameservers, staticNs...) + } + } + } + return allNameservers } diff --git a/nameservers_windows.go b/nameservers_windows.go index 54fb8b6..36d67fa 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -20,6 +20,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/net/netmon" ) const ( @@ -303,6 +304,28 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } + // if we have static DNS servers saved for the current default route, we should add them to the list + drIfaceName, err := netmon.DefaultRouteInterface() + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get default route interface: %v", err) + } else { + drIface, err := net.InterfaceByName(drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get interface by name %s: %v", drIfaceName, err) + } else { + staticNs, file := SavedStaticNameservers(drIface) + Log(context.Background(), logger.Debug(), + "static dns servers from %s: %v", file, staticNs) + if len(staticNs) > 0 { + Log(context.Background(), logger.Debug(), + "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + ns = append(ns, staticNs...) + } + } + } + if len(ns) == 0 { return nil, fmt.Errorf("no valid DNS servers found") } diff --git a/resolver.go b/resolver.go index e5abef2..df823fc 100644 --- a/resolver.go +++ b/resolver.go @@ -549,6 +549,11 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Gateway IP address (depends on OS). // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers) nss := defaultNameservers() nss = append([]string{controldPublicDnsWithPort}, nss...) for _, ns := range servers { @@ -565,6 +570,13 @@ func NewBootstrapResolver(servers ...string) Resolver { // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { + + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + Log(context.Background(), logger.Debug(), "NewPrivateResolver called") + nss := defaultNameservers() resolveConfNss := nameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() @@ -609,6 +621,11 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // newResolverWithNameserver returns an OS resolver from given nameservers list. // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers) r := &osResolver{} var publicNss []string var lanNss []string diff --git a/staticdns.go b/staticdns.go new file mode 100644 index 0000000..5cb1697 --- /dev/null +++ b/staticdns.go @@ -0,0 +1,118 @@ +package ctrld + +import ( + "bufio" + "bytes" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" +) + +var homedir string + +// absHomeDir returns the absolute path to given filename using home directory as root dir. +func absHomeDir(filename string) string { + if homedir != "" { + return filepath.Join(homedir, filename) + } + dir, err := userHomeDir() + if err != nil { + return filename + } + return filepath.Join(dir, filename) +} + +func dirWritable(dir string) (bool, error) { + f, err := os.CreateTemp(dir, "") + if err != nil { + return false, err + } + defer os.Remove(f.Name()) + return true, f.Close() +} + +func userHomeDir() (string, error) { + // viper will expand for us. + if runtime.GOOS == "windows" { + // If we're on windows, use the install path for this. + exePath, err := os.Executable() + if err != nil { + return "", err + } + + return filepath.Dir(exePath), nil + } + dir := "/etc/controld" + if err := os.MkdirAll(dir, 0750); err != nil { + return os.UserHomeDir() // fallback to user home directory + } + if ok, _ := dirWritable(dir); !ok { + return os.UserHomeDir() + } + return dir, nil +} + +// SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings +// for the provided interface are saved. +func SavedStaticDnsSettingsFilePath(iface *net.Interface) string { + // The file is stored in the user home directory under a hidden file. + return absHomeDir(".dns_" + iface.Name) +} + +// SavedStaticNameservers returns the stored static nameservers for the given interface. +func SavedStaticNameservers(iface *net.Interface) ([]string, string) { + file := SavedStaticDnsSettingsFilePath(iface) + data, err := os.ReadFile(file) + if err != nil || len(data) == 0 { + return nil, file + } + saveValues := strings.Split(string(data), ",") + var ns []string + for _, v := range saveValues { + // Skip any IP that is loopback + if ip := net.ParseIP(v); ip != nil && ip.IsLoopback() { + continue + } + ns = append(ns, v) + } + return ns, file +} + +func patchNetIfaceName(iface *net.Interface) (bool, error) { + b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output() + if err != nil { + return false, err + } + + patched := false + if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { + patched = true + iface.Name = name + } + return patched, nil +} + +func networkServiceName(ifaceName string, r io.Reader) string { + scanner := bufio.NewScanner(r) + prevLine := "" + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "*") { + // Network services is disabled. + continue + } + if !strings.Contains(line, "Device: "+ifaceName) { + prevLine = line + continue + } + parts := strings.SplitN(prevLine, " ", 2) + if len(parts) == 2 { + return strings.TrimSpace(parts[1]) + } + } + return "" +}