mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
cmd/cli: check local DNS using Windows API
This commit is contained in:
committed by
Cuong Manh Le
parent
a56711796f
commit
71e327653a
@@ -709,7 +709,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
return nil
|
||||
})
|
||||
// Windows forwarders file.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
files = append(files, absHomeDir(windowsForwardersFilename))
|
||||
}
|
||||
// Binary itself.
|
||||
@@ -2107,7 +2107,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
cdMode := cdUID != ""
|
||||
nextdnsMode := nextdns != ""
|
||||
// For Windows server with local Dns server running, we can only try on random local IP.
|
||||
hasLocalDnsServer := windowsHasLocalDnsServerRunning()
|
||||
hasLocalDnsServer := hasLocalDnsServerRunning()
|
||||
for n, listener := range cfg.Listener {
|
||||
lcc[n] = &listenerConfigCheck{}
|
||||
if listener.IP == "" {
|
||||
@@ -2614,21 +2614,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M
|
||||
return c.ExchangeContext(ctx, msg, addr)
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
|
||||
// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func windowsHasLocalDnsServerRunning() bool {
|
||||
if runtime.GOOS == "windows" {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// absHomeDir returns the absolute path to given filename using home directory as root dir.
|
||||
func absHomeDir(filename string) string {
|
||||
if homedir != "" {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -39,7 +41,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
@@ -101,7 +103,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@@ -241,3 +243,9 @@ func removeDnsServerForwarders(nameservers []string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
|
||||
@@ -13,3 +13,6 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
@@ -2,7 +2,9 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
@@ -79,3 +81,23 @@ func openLogFile(path string, mode int) (*os.File, error) {
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool {
|
||||
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
cmd/cli/service_windows_test.go
Normal file
25
cmd/cli/service_windows_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_hasLocalDnsServerRunning(t *testing.T) {
|
||||
start := time.Now()
|
||||
hasDns := hasLocalDnsServerRunning()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if hasDns != hasDnsPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
|
||||
}
|
||||
}
|
||||
|
||||
func hasLocalDnsServerRunningPowershell() bool {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
Reference in New Issue
Block a user