diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 0aa22df..9b6af8a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -304,14 +304,14 @@ func initCLI() { // Report any error if occurred. if err != nil { _, _ = mainLog.Load().Write(marker) - msg := fmt.Sprintf("An error happened when performing test query: %s", err) + msg := fmt.Sprintf("An error occurred while performing test query: %s", err) mainLog.Load().Write([]byte(msg)) } // If ctrld service is running but selfCheckStatus failed, it could be related // to user's system firewall configuration, notice users about it. if status == service.StatusRunning { _, _ = mainLog.Load().Write(marker) - mainLog.Load().Write([]byte(`ctrld service was running, but somehow DNS query could not be sent to its listener`)) + mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) } @@ -1731,10 +1731,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" + // For Windows server with local Dns server running, we can only try on random local IP. + hasLocalDnsServer := windowsHasLocalDnsServerRunning() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" + if hasLocalDnsServer { + // Windows Server lies to us that we could listen on 0.0.0.0:53 + // even there's a process already done that, stick to local IP only. + listener.IP = "127.0.0.1" + } lcc[n].IP = true } if listener.Port == 0 { @@ -1743,9 +1750,15 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata } // In cd mode, we always try to pick an ip:port pair to work. // Same if nextdns resolver is used. + // + // Except on Windows Server with local Dns running, + // we could only listen on random local IP port 53. if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true + if hasLocalDnsServer { + lcc[n].Port = false + } } updated = updated || lcc[n].IP || lcc[n].Port } @@ -1831,6 +1844,11 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata tryAllPort53 := true tryOldIPPort5354 := true tryPort5354 := true + if hasLocalDnsServer { + tryAllPort53 = false + tryOldIPPort5354 = false + tryPort5354 = false + } attempts := 0 maxAttempts := 10 for { @@ -2168,3 +2186,34 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M defer cancel() return c.ExchangeContext(ctx, msg, addr) } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + return exec.Command("powershell", "-Command", cmd).CombinedOutput() +} + +// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func windowsHasLocalDnsServerRunning() bool { + if runtime.GOOS == "windows" { + out, _ := powershell("Get-WindowsFeature -Name DNS") + if !bytes.Contains(bytes.ToLower(out), []byte("installed")) { + return false + } + + _, err := powershell("Get-Process -Name DNS") + return err == nil + } + return false +} + +// absHomeDir returns the absolute path to given filename using home directory as root dir. +func absHomeDir(filename string) string { + if homedir != "" { + return filepath.Join(homedir, filename) + } + dir, err := userHomeDir() + if err != nil { + return filename + } + return filepath.Join(dir, filename) +} diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index a58411e..694643f 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -2,19 +2,43 @@ package cli import ( "errors" + "fmt" "net" + "os" "os/exec" "strconv" + "strings" + "sync" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +const forwardersFilename = ".forwarders.txt" + +var ( + setDNSOnce sync.Once + resetDNSOnce sync.Once +) + func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { return errors.New("empty DNS nameservers") } + setDNSOnce.Do(func() { + // If there's a Dns server running, that means we are on AD with Dns feature enabled. + // Configuring the Dns server to forward queries to ctrld instead. + if windowsHasLocalDnsServerRunning() { + file := absHomeDir(forwardersFilename) + if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") + } + if err := addDnsServerForwarders(nameservers); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") + } + } + }) primaryDNS := nameservers[0] if err := setPrimaryDNS(iface, primaryDNS); err != nil { return err @@ -28,6 +52,23 @@ func setDNS(iface *net.Interface, nameservers []string) error { // TODO(cuonglm): should we use system API? func resetDNS(iface *net.Interface) error { + resetDNSOnce.Do(func() { + // See corresponding comment in setDNS. + if windowsHasLocalDnsServerRunning() { + file := absHomeDir(forwardersFilename) + content, err := os.ReadFile(file) + if err != nil { + mainLog.Load().Error().Err(err).Msg("could not read forwarders settings") + return + } + nameservers := strings.Split(string(content), ",") + if err := removeDnsServerForwarders(nameservers); err != nil { + mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings") + return + } + } + }) + if ctrldnet.SupportsIPv6ListenLocal() { if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil { mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output)) @@ -93,3 +134,25 @@ func currentDNS(iface *net.Interface) []string { } return ns } + +// addDnsServerForwarders adds given nameservers to DNS server forwarders list. +func addDnsServerForwarders(nameservers []string) error { + for _, ns := range nameservers { + cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns) + if out, err := powershell(cmd); err != nil { + return fmt.Errorf("%w: %s", err, string(out)) + } + } + return nil +} + +// removeDnsServerForwarders removes given nameservers from DNS server forwarders list. +func removeDnsServerForwarders(nameservers []string) error { + for _, ns := range nameservers { + cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns) + if out, err := powershell(cmd); err != nil { + return fmt.Errorf("%w: %s", err, string(out)) + } + } + return nil +}