diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index de2e93c..e97b53c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -270,10 +270,7 @@ func initCLI() { {func() error { return ensureUninstall(s) }, false}, {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { - saveCurrentDNS(i) - return nil - }) + withEachPhysicalInterfaces("", "save DNS settings", saveCurrentStaticDNS) return nil }, false}, {s.Install, false}, @@ -2197,7 +2194,8 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M // powershell runs the given powershell command. func powershell(cmd string) ([]byte, error) { - return exec.Command("powershell", "-Command", cmd).CombinedOutput() + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err } // windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 1c61efd..aa39094 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -1,6 +1,8 @@ package cli import ( + "bufio" + "bytes" "net" "os/exec" @@ -44,7 +46,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { - if ns := savedNameservers(iface); len(ns) > 0 { + if ns := savedStaticNameservers(iface); len(ns) > 0 { if err := setDNS(iface, ns); err == nil { return nil } @@ -62,3 +64,23 @@ func resetDNS(iface *net.Interface) error { func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } + +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + cmd := "networksetup" + args := []string{"-getdnsservers", iface.Name} + out, err := exec.Command(cmd, args...).Output() + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get current static DNS") + return nil + } + scanner := bufio.NewScanner(bytes.NewReader(out)) + var ns []string + for scanner.Scan() { + line := scanner.Text() + if ip := net.ParseIP(line); ip != nil { + ns = append(ns, ip.String()) + } + } + return ns +} diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index a6d6dde..a8de0c6 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -66,3 +66,8 @@ func resetDNS(iface *net.Interface) error { func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } + +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + return currentDNS(iface) +} diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 3036d03..c7661f0 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -203,6 +203,11 @@ func currentDNS(iface *net.Interface) []string { return nil } +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + return currentDNS(iface) +} + func getDNSByResolvectl(iface string) []string { b, err := exec.Command("resolvectl", "dns", "-i", iface).Output() if err != nil { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 4bcbc1b..bfb8cba 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -15,7 +15,11 @@ import ( ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) -const forwardersFilename = ".forwarders.txt" +const ( + forwardersFilename = ".forwarders.txt" + v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` +) var ( setDNSOnce sync.Once @@ -47,7 +51,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { } }) primaryDNS := nameservers[0] - if err := setPrimaryDNS(iface, primaryDNS); err != nil { + if err := setPrimaryDNS(iface, primaryDNS, true); err != nil { return err } if len(nameservers) > 1 { @@ -76,25 +80,48 @@ func resetDNS(iface *net.Interface) error { } }) - if ns := savedNameservers(iface); len(ns) > 0 { - if err := setDNS(iface, ns); err == nil { - return nil - } - } + // Restoring ipv6 first. if ctrldnet.SupportsIPv6ListenLocal() { if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil { mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output)) } } + // Restoring ipv4 DHCP. output, err := netsh("interface", "ipv4", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp") if err != nil { mainLog.Load().Error().Err(err).Msgf("failed to reset ipv4 DNS: %s", string(output)) return err } + // If there's static DNS saved, restoring it. + if nss := savedStaticNameservers(iface); len(nss) > 0 { + v4ns := make([]string, 0, 2) + v6ns := make([]string, 0, 2) + for _, ns := range nss { + if ctrldnet.IsIPv6(ns) { + v6ns = append(v6ns, ns) + } else { + v4ns = append(v4ns, ns) + } + } + + for _, ns := range [][]string{v4ns, v6ns} { + if len(ns) == 0 { + continue + } + primaryDNS := ns[0] + if err := setPrimaryDNS(iface, primaryDNS, false); err != nil { + return err + } + if len(ns) > 1 { + secondaryDNS := ns[1] + _ = addSecondaryDNS(iface, secondaryDNS) + } + } + } return nil } -func setPrimaryDNS(iface *net.Interface, dns string) error { +func setPrimaryDNS(iface *net.Interface, dns string, disablev6 bool) error { ipVer := "ipv4" if ctrldnet.IsIPv6(dns) { ipVer = "ipv6" @@ -105,7 +132,7 @@ func setPrimaryDNS(iface *net.Interface, dns string) error { mainLog.Load().Error().Err(err).Msgf("failed to set primary DNS: %s", string(output)) return err } - if ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() { + if disablev6 && ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() { // Disable IPv6 DNS, so the query will be fallback to IPv4. _, _ = netsh("interface", "ipv6", "set", "dnsserver", idx, "static", "::1", "primary") } @@ -147,6 +174,37 @@ func currentDNS(iface *net.Interface) []string { return ns } +// currentStaticDNS returns the current static DNS settings of given interface. +func currentStaticDNS(iface *net.Interface) []string { + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get interface LUID") + return nil + } + guid, err := luid.GUID() + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not get interface GUID") + return nil + } + var ns []string + for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { + interfaceKeyPath := path + guid.String() + found := false + for _, key := range []string{"NameServer", "ProfileNameServer"} { + if found { + continue + } + cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) + out, err := powershell(cmd) + if err == nil && len(out) > 0 { + found = true + ns = append(ns, strings.Split(string(out), ",")...) + } + } + } + return ns +} + // addDnsServerForwarders adds given nameservers to DNS server forwarders list. func addDnsServerForwarders(nameservers []string) error { for _, ns := range nameservers { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2f2bd0b..3042248 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -725,32 +725,40 @@ func requiredMultiNICsConfig() bool { } } -// saveCurrentDNS saves the current DNS settings for restoring later. +// saveCurrentStaticDNS saves the current static DNS settings for restoring later. // Only works on Windows and Mac. -func saveCurrentDNS(iface *net.Interface) { +func saveCurrentStaticDNS(iface *net.Interface) error { switch runtime.GOOS { case "windows", "darwin": default: - return + return nil } - ns := currentDNS(iface) + file := savedStaticDnsSettingsFilePath(iface) + if err := os.Remove(file); err != nil && !errors.Is(err, os.ErrNotExist) { + mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file") + } + ns := currentStaticDNS(iface) if len(ns) == 0 { - return + return nil } - file := savedDnsSettingsFilePath(iface) + mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name) if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil { mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name) + return err } + return nil } -// savedDnsSettingsFilePath returns the path to saved DNS settings of the given interface. -func savedDnsSettingsFilePath(iface *net.Interface) string { +// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface. +func savedStaticDnsSettingsFilePath(iface *net.Interface) string { return absHomeDir(".dns_" + iface.Name) } -// savedNameservers returns the static DNS nameservers of the given interface. -func savedNameservers(iface *net.Interface) []string { - file := savedDnsSettingsFilePath(iface) +// savedStaticNameservers returns the static DNS nameservers of the given interface. +// +//lint:ignore U1000 use in os_windows.go and os_darwin.go +func savedStaticNameservers(iface *net.Interface) []string { + file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { return strings.Split(string(data), ",") }