diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index b99c48f..f1439e0 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -435,7 +435,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.resetDNS(false, true) // Iterate over all physical interfaces and restore static DNS if a saved static config exists. withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) + file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) @@ -1077,7 +1077,7 @@ func uninstall(p *prog, s service.Service) { // Iterate over all physical interfaces and restore DNS if a saved static config exists. withEachPhysicalInterfaces(p.runningIface, "restore static DNS", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) + file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 048212a..18cf00b 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -977,7 +977,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } // Static DNS settings files. withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) + file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { files = append(files, file) } diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 4c358b0..ada1755 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -8,6 +8,7 @@ import ( "os/exec" "strings" + "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) @@ -84,7 +85,7 @@ func resetDNS(iface *net.Interface) error { // restoreDNS restores the DNS settings of the given interface. // this should only be executed upon turning off the ctrld service. func restoreDNS(iface *net.Interface) (err error) { - if ns := savedStaticNameservers(iface); len(ns) > 0 { + if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 { err = setDNS(iface, ns) } return err diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 7ebc54a..68c5107 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "github.com/Control-D-Inc/ctrld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) @@ -161,7 +162,7 @@ func resetDNS(iface *net.Interface) error { // restoreDNS restores the DNS settings of the given interface. // this should only be executed upon turning off the ctrld service. func restoreDNS(iface *net.Interface) (err error) { - if nss := savedStaticNameservers(iface); len(nss) > 0 { + if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 { v4ns := make([]string, 0, 2) v6ns := make([]string, 0, 2) for _, ns := range nss { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index dd8de9f..3b159ee 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -868,7 +868,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { return net.ParseIP(s).IsLoopback() }) // if we have a static config and no saved IPs already, save them - if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 { + if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(iface); err != nil { mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) @@ -898,7 +898,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { return net.ParseIP(s).IsLoopback() }) // if we have a static config and no saved IPs already, save them - if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 { + if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(i); err != nil { mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) @@ -976,7 +976,7 @@ func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runnin } // Default logic: if there is a saved static DNS configuration, restore it. - saved := savedStaticNameservers(netIface) + saved := ctrld.SavedStaticNameservers(netIface) if len(saved) > 0 && restoreStatic { logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved) if err := setDNS(netIface, saved); err != nil { @@ -1373,7 +1373,7 @@ func saveCurrentStaticDNS(iface *net.Interface) error { default: return errSaveCurrentStaticDNSNotSupported } - file := savedStaticDnsSettingsFilePath(iface) + file := ctrld.SavedStaticDnsSettingsFilePath(iface) ns, err := currentStaticDNS(iface) if err != nil { mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name) @@ -1407,38 +1407,6 @@ func saveCurrentStaticDNS(iface *net.Interface) error { return nil } -// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface. -func savedStaticDnsSettingsFilePath(iface *net.Interface) string { - if iface == nil { - return "" - } - return absHomeDir(".dns_" + iface.Name) -} - -// savedStaticNameservers returns the static DNS nameservers of the given interface. -// -//lint:ignore U1000 use in os_windows.go and os_darwin.go -func savedStaticNameservers(iface *net.Interface) []string { - if iface == nil { - mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface") - return nil - } - file := savedStaticDnsSettingsFilePath(iface) - if data, _ := os.ReadFile(file); len(data) > 0 { - saveValues := strings.Split(string(data), ",") - returnValues := []string{} - // check each one, if its in loopback range, remove it - for _, v := range saveValues { - if net.ParseIP(v).IsLoopback() { - continue - } - returnValues = append(returnValues, v) - } - return returnValues - } - return nil -} - // dnsChanged reports whether DNS settings for given interface was changed. // It returns false for a nil iface. // diff --git a/nameservers_darwin.go b/nameservers_darwin.go index 1bf4574..c8fa78d 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -186,7 +186,7 @@ func getAllDHCPNameservers() []string { Log(context.Background(), logger.Debug(), "Failed to patch interface name %s: %v", drIfaceName, err) } - staticNs, file := SavedStaticNameservers(drIface) + staticNs, file := SavedStaticNameserversAndPath(drIface) Log(context.Background(), logger.Debug(), "static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { diff --git a/nameservers_windows.go b/nameservers_windows.go index eb4f2b5..4f6ca8e 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -158,7 +158,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { 0, // DomainGuid - not needed 0, // SiteName - not needed uintptr(flags), // Flags - uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output if ret != 0 { switch ret { @@ -309,7 +309,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { Log(context.Background(), logger.Debug(), "Failed to get interface by name %s: %v", drIfaceName, err) } else { - staticNs, file := SavedStaticNameservers(drIface) + staticNs, file := SavedStaticNameserversAndPath(drIface) Log(context.Background(), logger.Debug(), "static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { diff --git a/staticdns.go b/staticdns.go index 1bfd556..ce24fe8 100644 --- a/staticdns.go +++ b/staticdns.go @@ -54,13 +54,18 @@ func userHomeDir() (string, error) { // SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings // for the provided interface are saved. +// +// The caller must ensure iface is non-nil. func SavedStaticDnsSettingsFilePath(iface *net.Interface) string { // The file is stored in the user home directory under a hidden file. return absHomeDir(".dns_" + iface.Name) } -// SavedStaticNameservers returns the stored static nameservers for the given interface. -func SavedStaticNameservers(iface *net.Interface) ([]string, string) { +// SavedStaticNameserversAndPath returns the stored static nameservers for the given interface, +// and the absolute path to file that stored the settings. +// +// The caller must ensure iface is non-nil. +func SavedStaticNameserversAndPath(iface *net.Interface) ([]string, string) { file := SavedStaticDnsSettingsFilePath(iface) data, err := os.ReadFile(file) if err != nil || len(data) == 0 { @@ -77,3 +82,9 @@ func SavedStaticNameservers(iface *net.Interface) ([]string, string) { } return ns, file } + +// SavedStaticNameservers is like SavedStaticNameserversAndPath, but only returns the static nameservers. +func SavedStaticNameservers(iface *net.Interface) []string { + nss, _ := SavedStaticNameserversAndPath(iface) + return nss +}