From 5641aab5bd96f5bcd146b284fd52128cb17bdc90 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 5 May 2025 23:28:49 +0700 Subject: [PATCH] all: unify handling user home directory logic --- cmd/cli/cli.go | 31 +------------------------------ cmd/cli/commands.go | 2 +- cmd/cli/os_windows.go | 4 ++-- staticdns.go | 16 ++++++---------- 4 files changed, 10 insertions(+), 43 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index f1439e0..3caa3bb 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -966,28 +966,11 @@ func userHomeDir() (string, error) { if dir != "" { return dir, nil } - // viper will expand for us. - if runtime.GOOS == "windows" { - // If we're on windows, use the install path for this. - exePath, err := os.Executable() - if err != nil { - return "", err - } - - return filepath.Dir(exePath), nil - } // Mobile platform should provide a rw dir path for this. if isMobile() { return homedir, nil } - dir = "/etc/controld" - if err := os.MkdirAll(dir, 0750); err != nil { - return os.UserHomeDir() // fallback to user home directory - } - if ok, _ := dirWritable(dir); !ok { - return os.UserHomeDir() - } - return dir, nil + return ctrld.UserHomeDir() } // socketDir returns directory that ctrld will create socket file for running controlServer. @@ -1754,18 +1737,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M return c.ExchangeContext(ctx, msg, addr) } -// absHomeDir returns the absolute path to given filename using home directory as root dir. -func absHomeDir(filename string) string { - if homedir != "" { - return filepath.Join(homedir, filename) - } - dir, err := userHomeDir() - if err != nil { - return filename - } - return filepath.Join(dir, filename) -} - // runInCdMode reports whether ctrld service is running in cd mode. func runInCdMode() bool { return curCdUID() != "" diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 18cf00b..d610463 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -985,7 +985,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, }) // Windows forwarders file. if hasLocalDnsServerRunning() { - files = append(files, absHomeDir(windowsForwardersFilename)) + files = append(files, ctrld.AbsHomeDir(windowsForwardersFilename)) } // Binary itself. bin, _ := os.Executable() diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 68c5107..c0cd787 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -46,7 +46,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { if hasLocalDnsServerRunning() { mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders") - file := absHomeDir(windowsForwardersFilename) + file := ctrld.AbsHomeDir(windowsForwardersFilename) mainLog.Load().Debug().Msgf("Using forwarders file: %s", file) oldForwardersContent, err := os.ReadFile(file) @@ -131,7 +131,7 @@ func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { // See corresponding comment in setDNS. if hasLocalDnsServerRunning() { - file := absHomeDir(windowsForwardersFilename) + file := ctrld.AbsHomeDir(windowsForwardersFilename) content, err := os.ReadFile(file) if err != nil { mainLog.Load().Error().Err(err).Msg("could not read forwarders settings") diff --git a/staticdns.go b/staticdns.go index ce24fe8..b1de8ec 100644 --- a/staticdns.go +++ b/staticdns.go @@ -8,14 +8,9 @@ import ( "strings" ) -var homedir string - -// absHomeDir returns the absolute path to given filename using home directory as root dir. -func absHomeDir(filename string) string { - if homedir != "" { - return filepath.Join(homedir, filename) - } - dir, err := userHomeDir() +// AbsHomeDir returns the absolute path to given filename using home directory as root dir. +func AbsHomeDir(filename string) string { + dir, err := UserHomeDir() if err != nil { return filename } @@ -31,7 +26,8 @@ func dirWritable(dir string) (bool, error) { return true, f.Close() } -func userHomeDir() (string, error) { +// UserHomeDir returns the home directory for user who is running ctrld. +func UserHomeDir() (string, error) { // viper will expand for us. if runtime.GOOS == "windows" { // If we're on windows, use the install path for this. @@ -58,7 +54,7 @@ func userHomeDir() (string, error) { // The caller must ensure iface is non-nil. func SavedStaticDnsSettingsFilePath(iface *net.Interface) string { // The file is stored in the user home directory under a hidden file. - return absHomeDir(".dns_" + iface.Name) + return AbsHomeDir(".dns_" + iface.Name) } // SavedStaticNameserversAndPath returns the stored static nameservers for the given interface,