From 4d810261a4bb26322e0ec84e071a05670bc9527a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 9 Feb 2024 12:35:43 +0700 Subject: [PATCH] cmd/cli: only save/restore static DNS The save/restore DNS functionality always perform its job, even though the DNS is not static, aka set by DHCP. That may lead to confusion to users. Since DHCP settings was changed to static settings, even though the namesers set are the same. To fix this, ctrld should save/restore only there's actual static DNS set. For DHCP, thing should work as-is like we are doing. --- cmd/cli/cli.go | 8 ++--- cmd/cli/os_darwin.go | 24 +++++++++++++- cmd/cli/os_freebsd.go | 5 +++ cmd/cli/os_linux.go | 5 +++ cmd/cli/os_windows.go | 76 ++++++++++++++++++++++++++++++++++++++----- cmd/cli/prog.go | 30 ++++++++++------- 6 files changed, 122 insertions(+), 26 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index de2e93c..e97b53c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -270,10 +270,7 @@ func initCLI() { {func() error { return ensureUninstall(s) }, false}, {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { - saveCurrentDNS(i) - return nil - }) + withEachPhysicalInterfaces("", "save DNS settings", saveCurrentStaticDNS) return nil }, false}, {s.Install, false}, @@ -2197,7 +2194,8 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M // powershell runs the given powershell command. func powershell(cmd string) ([]byte, error) { - return exec.Command("powershell", "-Command", cmd).CombinedOutput() + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err } // windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 1c61efd..aa39094 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -1,6 +1,8 @@ package cli import ( + "bufio" + "bytes" "net" "os/exec" @@ -44,7 +46,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { - if ns := savedNameservers(iface); len(ns) > 0 { + if ns := savedStaticNameservers(iface); len(ns) > 0 { if err := setDNS(iface, ns); err == nil { return nil } @@ -62,3 +64,23 @@ func resetDNS(iface *net.Interface) error { func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } + +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + cmd := "networksetup" + args := []string{"-getdnsservers", iface.Name} + out, err := exec.Command(cmd, args...).Output() + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get current static DNS") + return nil + } + scanner := bufio.NewScanner(bytes.NewReader(out)) + var ns []string + for scanner.Scan() { + line := scanner.Text() + if ip := net.ParseIP(line); ip != nil { + ns = append(ns, ip.String()) + } + } + return ns +} diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index a6d6dde..a8de0c6 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -66,3 +66,8 @@ func resetDNS(iface *net.Interface) error { func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } + +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + return currentDNS(iface) +} diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 3036d03..c7661f0 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -203,6 +203,11 @@ func currentDNS(iface *net.Interface) []string { return nil } +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + return currentDNS(iface) +} + func getDNSByResolvectl(iface string) []string { b, err := exec.Command("resolvectl", "dns", "-i", iface).Output() if err != nil { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 4bcbc1b..bfb8cba 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -15,7 +15,11 @@ import ( ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) -const forwardersFilename = ".forwarders.txt" +const ( + forwardersFilename = ".forwarders.txt" + v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` +) var ( setDNSOnce sync.Once @@ -47,7 +51,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { } }) primaryDNS := nameservers[0] - if err := setPrimaryDNS(iface, primaryDNS); err != nil { + if err := setPrimaryDNS(iface, primaryDNS, true); err != nil { return err } if len(nameservers) > 1 { @@ -76,25 +80,48 @@ func resetDNS(iface *net.Interface) error { } }) - if ns := savedNameservers(iface); len(ns) > 0 { - if err := setDNS(iface, ns); err == nil { - return nil - } - } + // Restoring ipv6 first. if ctrldnet.SupportsIPv6ListenLocal() { if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil { mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output)) } } + // Restoring ipv4 DHCP. output, err := netsh("interface", "ipv4", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp") if err != nil { mainLog.Load().Error().Err(err).Msgf("failed to reset ipv4 DNS: %s", string(output)) return err } + // If there's static DNS saved, restoring it. + if nss := savedStaticNameservers(iface); len(nss) > 0 { + v4ns := make([]string, 0, 2) + v6ns := make([]string, 0, 2) + for _, ns := range nss { + if ctrldnet.IsIPv6(ns) { + v6ns = append(v6ns, ns) + } else { + v4ns = append(v4ns, ns) + } + } + + for _, ns := range [][]string{v4ns, v6ns} { + if len(ns) == 0 { + continue + } + primaryDNS := ns[0] + if err := setPrimaryDNS(iface, primaryDNS, false); err != nil { + return err + } + if len(ns) > 1 { + secondaryDNS := ns[1] + _ = addSecondaryDNS(iface, secondaryDNS) + } + } + } return nil } -func setPrimaryDNS(iface *net.Interface, dns string) error { +func setPrimaryDNS(iface *net.Interface, dns string, disablev6 bool) error { ipVer := "ipv4" if ctrldnet.IsIPv6(dns) { ipVer = "ipv6" @@ -105,7 +132,7 @@ func setPrimaryDNS(iface *net.Interface, dns string) error { mainLog.Load().Error().Err(err).Msgf("failed to set primary DNS: %s", string(output)) return err } - if ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() { + if disablev6 && ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() { // Disable IPv6 DNS, so the query will be fallback to IPv4. _, _ = netsh("interface", "ipv6", "set", "dnsserver", idx, "static", "::1", "primary") } @@ -147,6 +174,37 @@ func currentDNS(iface *net.Interface) []string { return ns } +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get interface LUID") + return nil + } + guid, err := luid.GUID() + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get interface GUID") + return nil + } + var ns []string + for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { + interfaceKeyPath := path + guid.String() + found := false + for _, key := range []string{"NameServer", "ProfileNameServer"} { + if found { + continue + } + cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) + out, err := powershell(cmd) + if err == nil && len(out) > 0 { + found = true + ns = append(ns, strings.Split(string(out), ",")...) + } + } + } + return ns +} + // addDnsServerForwarders adds given nameservers to DNS server forwarders list. func addDnsServerForwarders(nameservers []string) error { for _, ns := range nameservers { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2f2bd0b..3042248 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -725,32 +725,40 @@ func requiredMultiNICsConfig() bool { } } -// saveCurrentDNS saves the current DNS settings for restoring later. +// saveCurrentStaticDNS saves the current static DNS settings for restoring later. // Only works on Windows and Mac. -func saveCurrentDNS(iface *net.Interface) { +func saveCurrentStaticDNS(iface *net.Interface) error { switch runtime.GOOS { case "windows", "darwin": default: - return + return nil } - ns := currentDNS(iface) + file := savedStaticDnsSettingsFilePath(iface) + if err := os.Remove(file); err != nil && !errors.Is(err, os.ErrNotExist) { + mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file") + } + ns := currentStaticDNS(iface) if len(ns) == 0 { - return + return nil } - file := savedDnsSettingsFilePath(iface) + mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name) if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil { mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name) + return err } + return nil } -// savedDnsSettingsFilePath returns the path to saved DNS settings of the given interface. -func savedDnsSettingsFilePath(iface *net.Interface) string { +// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface. +func savedStaticDnsSettingsFilePath(iface *net.Interface) string { return absHomeDir(".dns_" + iface.Name) } -// savedNameservers returns the static DNS nameservers of the given interface. -func savedNameservers(iface *net.Interface) []string { - file := savedDnsSettingsFilePath(iface) +// savedStaticNameservers returns the static DNS nameservers of the given interface. +// +//lint:ignore U1000 use in os_windows.go and os_darwin.go +func savedStaticNameservers(iface *net.Interface) []string { + file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { return strings.Split(string(data), ",") }