diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 6e7ecbe..008c34a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1031,14 +1031,22 @@ func uninstall(p *prog, s service.Service) { // restore static DNS settings or DHCP p.resetDNS(false, true) - // if present restore the original DNS settings - if netIface, err := netInterface(p.runningIface); err == nil { - if err := restoreDNS(netIface); err != nil { - mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") - } else { - mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + // Iterate over all physical interfaces and restore DNS if a saved static config exists. + withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + if err := restoreDNS(i); err != nil { + mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + } else { + mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + err = os.Remove(file) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name) + } + } } - } + return nil + }) if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 2bfe71e..d8eabeb 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -242,6 +242,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c os.Exit(deactivationPinInvalidExitCode) } currentIface = runningIface(s) + mainLog.Load().Debug().Msgf("current interface on start: %s", currentIface.Name) } ctx, cancel := context.WithCancel(context.Background()) @@ -339,13 +340,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } + // if already running, dont restart + if isCtrldRunning { + mainLog.Load().Notice().Msg("service is already running") + return + } + initInteractiveLogging() tasks := []task{ - {s.Stop, false, "Stop"}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { return err } @@ -424,10 +429,10 @@ NOTE: running "ctrld start" without any arguments will start already installed c {s.Stop, false, "Stop"}, {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), + //resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { return err } @@ -611,14 +616,18 @@ func initStopCmd() *cobra.Command { // restore static DNS settings or DHCP p.resetDNS(false, true) - // restore DNS settings - if netIface, err := netInterface(p.runningIface); err == nil { - if err := restoreDNS(netIface); err != nil { - mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") - } else { - mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + // Iterate over all physical interfaces and restore static DNS if a saved static config exists. + withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + if err := restoreDNS(i); err != nil { + mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + } else { + mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + } } - } + return nil + }) if router.WaitProcessExited() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) @@ -1046,9 +1055,16 @@ func initInterfacesCmd() *cobra.Command { Short: "List network interfaces of the host", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error { fmt.Printf("Index : %d\n", i.Index) fmt.Printf("Name : %s\n", i.Name) + var status string + if i.Flags&net.FlagUp != 0 { + status = "Up" + } else { + status = "Down" + } + fmt.Printf("Status: %s\n", status) addrs, _ := i.Addrs() for i, ipaddr := range addrs { if i == 0 { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 990cc57..4866267 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -147,15 +147,32 @@ func restoreDNS(iface *net.Interface) (err error) { } } - for _, ns := range [][]string{v4ns, v6ns} { - if len(ns) == 0 { - continue - } - mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) - err = setDNS(iface, ns) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("restoreDNS: %w", err) + } - if err != nil { - return err + if len(v4ns) > 0 { + mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns) + if err := setDNS(iface, v4ns); err != nil { + return fmt.Errorf("restoreDNS (IPv4): %w", err) + } + } else { + mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name) + if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { + return fmt.Errorf("restoreDNS (IPv4 clear): %w", err) + } + } + + if len(v6ns) > 0 { + mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns) + if err := setDNS(iface, v6ns); err != nil { + return fmt.Errorf("restoreDNS (IPv6): %w", err) + } + } else { + mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name) + if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { + return fmt.Errorf("restoreDNS (IPv6 clear): %w", err) } } } @@ -180,43 +197,65 @@ func currentDNS(iface *net.Interface) []string { return ns } -// currentStaticDNS returns the current static DNS settings of given interface. +// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys +// like "NameServer" and "ProfileNameServer". func currentStaticDNS(iface *net.Interface) ([]string, error) { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err) + return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err) } guid, err := luid.GUID() if err != nil { - return nil, fmt.Errorf("luid.GUID: %w", err) + return nil, fmt.Errorf("fallback luid.GUID: %w", err) } + var ns []string - for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { - found := false + keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} + for _, path := range keyPaths { interfaceKeyPath := path + guid.String() k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) if err != nil { - return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err) + mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name) + continue } - for _, key := range []string{"NameServer", "ProfileNameServer"} { - if found { - continue - } - value, _, err := k.GetStringValue(key) - if err != nil && !errors.Is(err, registry.ErrNotExist) { - return nil, fmt.Errorf("%s: %w", key, err) - } - if len(value) > 0 { - found = true - for _, e := range strings.Split(value, ",") { - ns = append(ns, strings.TrimRight(e, "\x00")) + func() { + defer k.Close() + for _, keyName := range []string{"NameServer", "ProfileNameServer"} { + value, _, err := k.GetStringValue(keyName) + if err != nil && !errors.Is(err, registry.ErrNotExist) { + mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName) + continue + } + if len(value) > 0 { + mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value) + parsed := parseDNSServers(value) + ns = append(ns, parsed...) } } - } + }() + } + if len(ns) == 0 { + mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name) } return ns, nil } +// parseDNSServers splits a DNS server string that may be comma- or space-separated, +// and trims any extraneous whitespace or null characters. +func parseDNSServers(val string) []string { + fields := strings.FieldsFunc(val, func(r rune) bool { + return r == ' ' || r == ',' + }) + var servers []string + for _, f := range fields { + trimmed := strings.TrimSpace(f) + if len(trimmed) > 0 { + servers = append(servers, trimmed) + } + } + return servers +} + // addDnsServerForwarders adds given nameservers to DNS server forwarders list, // and also removing old forwarders if provided. func addDnsServerForwarders(nameservers, old []string) error { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 60f138b..25547f3 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -276,7 +276,7 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { - p.resetDNS(false, true) + p.resetDNS(false, false) ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() @@ -788,6 +788,24 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces } if dnsChanged(iface, ns) { logger.Debug().Msg("DNS settings were changed, re-applying settings") + // Check if the interface already has static DNS servers configured. + // currentStaticDNS is an OS-dependent helper that returns the current static DNS. + staticDNS, err := currentStaticDNS(iface) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) + } else if len(staticDNS) > 0 { + //filter out loopback addresses + staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + // if we have a static config and no saved IPs already, save them + if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 { + // Save these static DNS values so that they can be restored later. + if err := saveCurrentStaticDNS(iface); err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) + } + } + } if err := setDNS(iface, ns); err != nil { mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") } @@ -795,6 +813,26 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces if allIfaces { withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error { if dnsChanged(i, ns) { + + // Check if the interface already has static DNS servers configured. + // currentStaticDNS is an OS-dependent helper that returns the current static DNS. + staticDNS, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) + } else if len(staticDNS) > 0 { + //filter out loopback addresses + staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + // if we have a static config and no saved IPs already, save them + if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 { + // Save these static DNS values so that they can be restored later. + if err := saveCurrentStaticDNS(i); err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) + } + } + } + if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") } else { @@ -841,7 +879,7 @@ func (p *prog) resetDNS(isStart bool, restoreStatic bool) { // If any static DNS value is not our own listener, assume an admin override. hasManualConfig := false for _, ns := range current { - if ns != "127.0.0.1" { + if ns != "127.0.0.1" && ns != "::1" { hasManualConfig = true break } @@ -1221,7 +1259,7 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. // TODO: investigate whether we should report this error? if err := f(netIface); err == nil { if context != "" { - mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name) + mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name) } } else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) { mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name) @@ -1250,13 +1288,28 @@ func saveCurrentStaticDNS(iface *net.Interface) error { return errSaveCurrentStaticDNSNotSupported } file := savedStaticDnsSettingsFilePath(iface) - ns, _ := currentStaticDNS(iface) + ns, err := currentStaticDNS(iface) + if err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name) + return err + } if len(ns) == 0 { + mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name) _ = os.Remove(file) // removing old static DNS settings return nil } + //filter out loopback addresses + ns = slices.DeleteFunc(ns, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + //if we now have no static DNS settings and the file already exists + // return and do not save the file + if len(ns) == 0 { + mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name) + return nil + } if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file") + mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file) } nss := strings.Join(ns, ",") mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss) diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index af204ad..1f761e6 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -1,7 +1,13 @@ package main -import "github.com/Control-D-Inc/ctrld/cmd/cli" +import ( + "os" + + "github.com/Control-D-Inc/ctrld/cmd/cli" +) func main() { cli.Main() + // make sure we exit with 0 if there are no errors + os.Exit(0) } diff --git a/resolver.go b/resolver.go index 3c6a0a7..677738b 100644 --- a/resolver.go +++ b/resolver.go @@ -199,9 +199,12 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeDOQ: return &doqResolver{uc: uc}, nil case ResolverTypeOS: + resolverMutex.Lock() if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver") or = newResolverWithNameserver(defaultNameservers()) } + resolverMutex.Unlock() return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil @@ -473,9 +476,13 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { + resolverMutex.Lock() if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize OS resolver in lookupIP") or = newResolverWithNameserver(defaultNameservers()) } + resolverMutex.Unlock() + nss := *or.lanServers.Load() nss = append(nss, *or.publicServers.Load()...) if withBootstrapDNS {