all: unify code to handle static DNS file path

This commit is contained in:
Cuong Manh Le
2025-05-05 17:36:02 +07:00
committed by Cuong Manh Le
parent 51e58b64a5
commit 31517ce750
8 changed files with 27 additions and 46 deletions

View File

@@ -435,7 +435,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
p.resetDNS(false, true)
// Iterate over all physical interfaces and restore static DNS if a saved static config exists.
withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
file := ctrld.SavedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
if err := restoreDNS(i); err != nil {
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
@@ -1077,7 +1077,7 @@ func uninstall(p *prog, s service.Service) {
// Iterate over all physical interfaces and restore DNS if a saved static config exists.
withEachPhysicalInterfaces(p.runningIface, "restore static DNS", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
file := ctrld.SavedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
if err := restoreDNS(i); err != nil {
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)

View File

@@ -977,7 +977,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
}
// Static DNS settings files.
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
file := ctrld.SavedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
files = append(files, file)
}

View File

@@ -8,6 +8,7 @@ import (
"os/exec"
"strings"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
@@ -84,7 +85,7 @@ func resetDNS(iface *net.Interface) error {
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 {
err = setDNS(iface, ns)
}
return err

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/Control-D-Inc/ctrld"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
@@ -161,7 +162,7 @@ func resetDNS(iface *net.Interface) error {
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if nss := savedStaticNameservers(iface); len(nss) > 0 {
if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 {
v4ns := make([]string, 0, 2)
v6ns := make([]string, 0, 2)
for _, ns := range nss {

View File

@@ -868,7 +868,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 {
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(iface); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
@@ -898,7 +898,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 {
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(i); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
@@ -976,7 +976,7 @@ func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runnin
}
// Default logic: if there is a saved static DNS configuration, restore it.
saved := savedStaticNameservers(netIface)
saved := ctrld.SavedStaticNameservers(netIface)
if len(saved) > 0 && restoreStatic {
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
if err := setDNS(netIface, saved); err != nil {
@@ -1373,7 +1373,7 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
default:
return errSaveCurrentStaticDNSNotSupported
}
file := savedStaticDnsSettingsFilePath(iface)
file := ctrld.SavedStaticDnsSettingsFilePath(iface)
ns, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name)
@@ -1407,38 +1407,6 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
return nil
}
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
if iface == nil {
return ""
}
return absHomeDir(".dns_" + iface.Name)
}
// savedStaticNameservers returns the static DNS nameservers of the given interface.
//
//lint:ignore U1000 use in os_windows.go and os_darwin.go
func savedStaticNameservers(iface *net.Interface) []string {
if iface == nil {
mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface")
return nil
}
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
saveValues := strings.Split(string(data), ",")
returnValues := []string{}
// check each one, if its in loopback range, remove it
for _, v := range saveValues {
if net.ParseIP(v).IsLoopback() {
continue
}
returnValues = append(returnValues, v)
}
return returnValues
}
return nil
}
// dnsChanged reports whether DNS settings for given interface was changed.
// It returns false for a nil iface.
//

View File

@@ -186,7 +186,7 @@ func getAllDHCPNameservers() []string {
Log(context.Background(), logger.Debug(),
"Failed to patch interface name %s: %v", drIfaceName, err)
}
staticNs, file := SavedStaticNameservers(drIface)
staticNs, file := SavedStaticNameserversAndPath(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {

View File

@@ -158,7 +158,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
0, // DomainGuid - not needed
0, // SiteName - not needed
uintptr(flags), // Flags
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
if ret != 0 {
switch ret {
@@ -309,7 +309,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
Log(context.Background(), logger.Debug(),
"Failed to get interface by name %s: %v", drIfaceName, err)
} else {
staticNs, file := SavedStaticNameservers(drIface)
staticNs, file := SavedStaticNameserversAndPath(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {

View File

@@ -54,13 +54,18 @@ func userHomeDir() (string, error) {
// SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings
// for the provided interface are saved.
//
// The caller must ensure iface is non-nil.
func SavedStaticDnsSettingsFilePath(iface *net.Interface) string {
// The file is stored in the user home directory under a hidden file.
return absHomeDir(".dns_" + iface.Name)
}
// SavedStaticNameservers returns the stored static nameservers for the given interface.
func SavedStaticNameservers(iface *net.Interface) ([]string, string) {
// SavedStaticNameserversAndPath returns the stored static nameservers for the given interface,
// and the absolute path to file that stored the settings.
//
// The caller must ensure iface is non-nil.
func SavedStaticNameserversAndPath(iface *net.Interface) ([]string, string) {
file := SavedStaticDnsSettingsFilePath(iface)
data, err := os.ReadFile(file)
if err != nil || len(data) == 0 {
@@ -77,3 +82,9 @@ func SavedStaticNameservers(iface *net.Interface) ([]string, string) {
}
return ns, file
}
// SavedStaticNameservers is like SavedStaticNameserversAndPath, but only returns the static nameservers.
func SavedStaticNameservers(iface *net.Interface) []string {
nss, _ := SavedStaticNameserversAndPath(iface)
return nss
}