cmd/cli: check local DNS using Windows API

This commit is contained in:
Cuong Manh Le
2024-11-21 16:52:59 +07:00
committed by Cuong Manh Le
parent a56711796f
commit 71e327653a
5 changed files with 62 additions and 19 deletions

View File

@@ -709,7 +709,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
return nil
})
// Windows forwarders file.
if windowsHasLocalDnsServerRunning() {
if hasLocalDnsServerRunning() {
files = append(files, absHomeDir(windowsForwardersFilename))
}
// Binary itself.
@@ -2107,7 +2107,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
cdMode := cdUID != ""
nextdnsMode := nextdns != ""
// For Windows server with local Dns server running, we can only try on random local IP.
hasLocalDnsServer := windowsHasLocalDnsServerRunning()
hasLocalDnsServer := hasLocalDnsServerRunning()
for n, listener := range cfg.Listener {
lcc[n] = &listenerConfigCheck{}
if listener.IP == "" {
@@ -2614,21 +2614,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M
return c.ExchangeContext(ctx, msg, addr)
}
// powershell runs the given powershell command.
func powershell(cmd string) ([]byte, error) {
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
return bytes.TrimSpace(out), err
}
// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func windowsHasLocalDnsServerRunning() bool {
if runtime.GOOS == "windows" {
_, err := powershell("Get-Process -Name DNS")
return err == nil
}
return false
}
// absHomeDir returns the absolute path to given filename using home directory as root dir.
func absHomeDir(filename string) string {
if homedir != "" {

View File

@@ -1,11 +1,13 @@
package cli
import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"slices"
"strings"
"sync"
@@ -39,7 +41,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
setDNSOnce.Do(func() {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if windowsHasLocalDnsServerRunning() {
if hasLocalDnsServerRunning() {
file := absHomeDir(windowsForwardersFilename)
oldForwardersContent, _ := os.ReadFile(file)
hasLocalIPv6Listener := needLocalIPv6Listener()
@@ -101,7 +103,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
func resetDNS(iface *net.Interface) error {
resetDNSOnce.Do(func() {
// See corresponding comment in setDNS.
if windowsHasLocalDnsServerRunning() {
if hasLocalDnsServerRunning() {
file := absHomeDir(windowsForwardersFilename)
content, err := os.ReadFile(file)
if err != nil {
@@ -241,3 +243,9 @@ func removeDnsServerForwarders(nameservers []string) error {
}
return nil
}
// powershell runs the given powershell command.
func powershell(cmd string) ([]byte, error) {
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
return bytes.TrimSpace(out), err
}

View File

@@ -13,3 +13,6 @@ func hasElevatedPrivilege() (bool, error) {
func openLogFile(path string, flags int) (*os.File, error) {
return os.OpenFile(path, flags, os.FileMode(0o600))
}
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool { return false }

View File

@@ -2,7 +2,9 @@ package cli
import (
"os"
"strings"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
@@ -79,3 +81,23 @@ func openLogFile(path string, mode int) (*os.File, error) {
return os.NewFile(uintptr(handle), path), nil
}
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool {
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
if e != nil {
return false
}
p := windows.ProcessEntry32{Size: processEntrySize}
for {
e := windows.Process32Next(h, &p)
if e != nil {
return false
}
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
return true
}
}
}

View File

@@ -0,0 +1,25 @@
package cli
import (
"testing"
"time"
)
func Test_hasLocalDnsServerRunning(t *testing.T) {
start := time.Now()
hasDns := hasLocalDnsServerRunning()
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
if hasDns != hasDnsPowershell {
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
}
}
func hasLocalDnsServerRunningPowershell() bool {
_, err := powershell("Get-Process -Name DNS")
return err == nil
}