From 891b7cb2c6f572201c05ac59f53bc1cbb029513d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 7 Feb 2024 13:01:21 +0700 Subject: [PATCH] cmd/cli: integrating with Windows Server DNS feature Windows Server which is running Active Directory will have its own DNS server running. For typical setup, this DNS server will listen on all interfaces, and receiving queries from others to be able to resolve computer name in domain. That would make ctrld default setup never works, since ctrld can listen on port 53, but requests are never be routed to its listeners. To integrate ctrld in this case, we need to listen on a local IP address, then configure this IP as a Forwarder of local DNS server. With this setup, computer name on domain can still be resolved, and other queries can still be resolved by ctrld upstream as usual. --- cmd/cli/cli.go | 53 ++++++++++++++++++++++++++++++++++-- cmd/cli/os_windows.go | 63 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 0aa22df..9b6af8a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -304,14 +304,14 @@ func initCLI() { // Report any error if occurred. if err != nil { _, _ = mainLog.Load().Write(marker) - msg := fmt.Sprintf("An error happened when performing test query: %s", err) + msg := fmt.Sprintf("An error occurred while performing test query: %s", err) mainLog.Load().Write([]byte(msg)) } // If ctrld service is running but selfCheckStatus failed, it could be related // to user's system firewall configuration, notice users about it. if status == service.StatusRunning { _, _ = mainLog.Load().Write(marker) - mainLog.Load().Write([]byte(`ctrld service was running, but somehow DNS query could not be sent to its listener`)) + mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) } @@ -1731,10 +1731,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" + // For Windows server with local Dns server running, we can only try on random local IP. + hasLocalDnsServer := windowsHasLocalDnsServerRunning() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" + if hasLocalDnsServer { + // Windows Server lies to us that we could listen on 0.0.0.0:53 + // even there's a process already done that, stick to local IP only. + listener.IP = "127.0.0.1" + } lcc[n].IP = true } if listener.Port == 0 { @@ -1743,9 +1750,15 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata } // In cd mode, we always try to pick an ip:port pair to work. // Same if nextdns resolver is used. + // + // Except on Windows Server with local Dns running, + // we could only listen on random local IP port 53. if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true + if hasLocalDnsServer { + lcc[n].Port = false + } } updated = updated || lcc[n].IP || lcc[n].Port } @@ -1831,6 +1844,11 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata tryAllPort53 := true tryOldIPPort5354 := true tryPort5354 := true + if hasLocalDnsServer { + tryAllPort53 = false + tryOldIPPort5354 = false + tryPort5354 = false + } attempts := 0 maxAttempts := 10 for { @@ -2168,3 +2186,34 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M defer cancel() return c.ExchangeContext(ctx, msg, addr) } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + return exec.Command("powershell", "-Command", cmd).CombinedOutput() +} + +// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func windowsHasLocalDnsServerRunning() bool { + if runtime.GOOS == "windows" { + out, _ := powershell("Get-WindowsFeature -Name DNS") + if !bytes.Contains(bytes.ToLower(out), []byte("installed")) { + return false + } + + _, err := powershell("Get-Process -Name DNS") + return err == nil + } + return false +} + +// absHomeDir returns the absolute path to given filename using home directory as root dir. +func absHomeDir(filename string) string { + if homedir != "" { + return filepath.Join(homedir, filename) + } + dir, err := userHomeDir() + if err != nil { + return filename + } + return filepath.Join(dir, filename) +} diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index a58411e..694643f 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -2,19 +2,43 @@ package cli import ( "errors" + "fmt" "net" + "os" "os/exec" "strconv" + "strings" + "sync" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +const forwardersFilename = ".forwarders.txt" + +var ( + setDNSOnce sync.Once + resetDNSOnce sync.Once +) + func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { return errors.New("empty DNS nameservers") } + setDNSOnce.Do(func() { + // If there's a Dns server running, that means we are on AD with Dns feature enabled. + // Configuring the Dns server to forward queries to ctrld instead. + if windowsHasLocalDnsServerRunning() { + file := absHomeDir(forwardersFilename) + if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") + } + if err := addDnsServerForwarders(nameservers); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") + } + } + }) primaryDNS := nameservers[0] if err := setPrimaryDNS(iface, primaryDNS); err != nil { return err @@ -28,6 +52,23 @@ func setDNS(iface *net.Interface, nameservers []string) error { // TODO(cuonglm): should we use system API? func resetDNS(iface *net.Interface) error { + resetDNSOnce.Do(func() { + // See corresponding comment in setDNS. + if windowsHasLocalDnsServerRunning() { + file := absHomeDir(forwardersFilename) + content, err := os.ReadFile(file) + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not read forwarders settings") + return + } + nameservers := strings.Split(string(content), ",") + if err := removeDnsServerForwarders(nameservers); err != nil { + mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings") + return + } + } + }) + if ctrldnet.SupportsIPv6ListenLocal() { if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil { mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output)) @@ -93,3 +134,25 @@ func currentDNS(iface *net.Interface) []string { } return ns } + +// addDnsServerForwarders adds given nameservers to DNS server forwarders list. +func addDnsServerForwarders(nameservers []string) error { + for _, ns := range nameservers { + cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns) + if out, err := powershell(cmd); err != nil { + return fmt.Errorf("%w: %s", err, string(out)) + } + } + return nil +} + +// removeDnsServerForwarders removes given nameservers from DNS server forwarders list. +func removeDnsServerForwarders(nameservers []string) error { + for _, ns := range nameservers { + cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns) + if out, err := powershell(cmd); err != nil { + return fmt.Errorf("%w: %s", err, string(out)) + } + } + return nil +}