From 71e327653a65f239f124e75d9f3b94374a6350b4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 21 Nov 2024 16:52:59 +0700 Subject: [PATCH] cmd/cli: check local DNS using Windows API --- cmd/cli/cli.go | 19 ++----------------- cmd/cli/os_windows.go | 12 ++++++++++-- cmd/cli/service_others.go | 3 +++ cmd/cli/service_windows.go | 22 ++++++++++++++++++++++ cmd/cli/service_windows_test.go | 25 +++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 19 deletions(-) create mode 100644 cmd/cli/service_windows_test.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 502014e..b0ae022 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -709,7 +709,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, return nil }) // Windows forwarders file. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { files = append(files, absHomeDir(windowsForwardersFilename)) } // Binary itself. @@ -2107,7 +2107,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata cdMode := cdUID != "" nextdnsMode := nextdns != "" // For Windows server with local Dns server running, we can only try on random local IP. - hasLocalDnsServer := windowsHasLocalDnsServerRunning() + hasLocalDnsServer := hasLocalDnsServerRunning() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -2614,21 +2614,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M return c.ExchangeContext(ctx, msg, addr) } -// powershell runs the given powershell command. -func powershell(cmd string) ([]byte, error) { - out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() - return bytes.TrimSpace(out), err -} - -// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. -func windowsHasLocalDnsServerRunning() bool { - if runtime.GOOS == "windows" { - _, err := powershell("Get-Process -Name DNS") - return err == nil - } - return false -} - // absHomeDir returns the absolute path to given filename using home directory as root dir. func absHomeDir(filename string) string { if homedir != "" { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 1a22b0f..aa44418 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -1,11 +1,13 @@ package cli import ( + "bytes" "errors" "fmt" "net" "net/netip" "os" + "os/exec" "slices" "strings" "sync" @@ -39,7 +41,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { setDNSOnce.Do(func() { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) oldForwardersContent, _ := os.ReadFile(file) hasLocalIPv6Listener := needLocalIPv6Listener() @@ -101,7 +103,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { // See corresponding comment in setDNS. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) content, err := os.ReadFile(file) if err != nil { @@ -241,3 +243,9 @@ func removeDnsServerForwarders(nameservers []string) error { } return nil } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err +} diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index f4d73e5..2303e30 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -13,3 +13,6 @@ func hasElevatedPrivilege() (bool, error) { func openLogFile(path string, flags int) (*os.File, error) { return os.OpenFile(path, flags, os.FileMode(0o600)) } + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { return false } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index d4e2449..af4f317 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,7 +2,9 @@ package cli import ( "os" + "strings" "syscall" + "unsafe" "golang.org/x/sys/windows" ) @@ -79,3 +81,23 @@ func openLogFile(path string, mode int) (*os.File, error) { return os.NewFile(uintptr(handle), path), nil } + +const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{})) + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { + h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if e != nil { + return false + } + p := windows.ProcessEntry32{Size: processEntrySize} + for { + e := windows.Process32Next(h, &p) + if e != nil { + return false + } + if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" { + return true + } + } +} diff --git a/cmd/cli/service_windows_test.go b/cmd/cli/service_windows_test.go new file mode 100644 index 0000000..67c2725 --- /dev/null +++ b/cmd/cli/service_windows_test.go @@ -0,0 +1,25 @@ +package cli + +import ( + "testing" + "time" +) + +func Test_hasLocalDnsServerRunning(t *testing.T) { + start := time.Now() + hasDns := hasLocalDnsServerRunning() + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + hasDnsPowershell := hasLocalDnsServerRunningPowershell() + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if hasDns != hasDnsPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns) + } +} + +func hasLocalDnsServerRunningPowershell() bool { + _, err := powershell("Get-Process -Name DNS") + return err == nil +}