mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
So it can be run regardless of ctrld current status. This prevents a racy behavior when reset DNS task restores DNS settings of the system, but current running ctrld process may revert it immediately.
219 lines
6.7 KiB
Go
219 lines
6.7 KiB
Go
package cli
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
const (
|
|
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
|
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
|
)
|
|
|
|
var (
|
|
setDNSOnce sync.Once
|
|
resetDNSOnce sync.Once
|
|
)
|
|
|
|
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
|
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
|
return setDNS(iface, nameservers)
|
|
}
|
|
|
|
func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string {
|
|
nss := make([]string, 0, len(nameservers))
|
|
for _, ns := range nameservers {
|
|
nss = append(nss, strconv.Quote(ns))
|
|
}
|
|
return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ","))
|
|
}
|
|
|
|
// setDNS sets the dns server for the provided network interface
|
|
func setDNS(iface *net.Interface, nameservers []string) error {
|
|
if len(nameservers) == 0 {
|
|
return errors.New("empty DNS nameservers")
|
|
}
|
|
setDNSOnce.Do(func() {
|
|
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
|
// Configuring the Dns server to forward queries to ctrld instead.
|
|
if windowsHasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
oldForwardersContent, _ := os.ReadFile(file)
|
|
hasLocalIPv6Listener := needLocalIPv6Listener()
|
|
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
|
if !hasLocalIPv6Listener {
|
|
return false
|
|
}
|
|
return s == "::1"
|
|
})
|
|
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
|
}
|
|
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
|
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
|
}
|
|
}
|
|
})
|
|
out, err := powershell(setDnsPowershellCmd(iface, nameservers))
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
|
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
|
return resetDNS(iface)
|
|
}
|
|
|
|
// TODO(cuonglm): should we use system API?
|
|
func resetDNS(iface *net.Interface) error {
|
|
resetDNSOnce.Do(func() {
|
|
// See corresponding comment in setDNS.
|
|
if windowsHasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
content, err := os.ReadFile(file)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
|
return
|
|
}
|
|
nameservers := strings.Split(string(content), ",")
|
|
if err := removeDnsServerForwarders(nameservers); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
// Restoring DHCP settings.
|
|
cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index)
|
|
out, err := powershell(cmd)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
|
|
// If there's static DNS saved, restoring it.
|
|
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
|
v4ns := make([]string, 0, 2)
|
|
v6ns := make([]string, 0, 2)
|
|
for _, ns := range nss {
|
|
if ctrldnet.IsIPv6(ns) {
|
|
v6ns = append(v6ns, ns)
|
|
} else {
|
|
v4ns = append(v4ns, ns)
|
|
}
|
|
}
|
|
|
|
for _, ns := range [][]string{v4ns, v6ns} {
|
|
if len(ns) == 0 {
|
|
continue
|
|
}
|
|
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
|
if err := setDNS(iface, ns); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func currentDNS(iface *net.Interface) []string {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
|
return nil
|
|
}
|
|
nameservers, err := luid.DNS()
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
|
return nil
|
|
}
|
|
ns := make([]string, 0, len(nameservers))
|
|
for _, nameserver := range nameservers {
|
|
ns = append(ns, nameserver.String())
|
|
}
|
|
return ns
|
|
}
|
|
|
|
// currentStaticDNS returns the current static DNS settings of given interface.
|
|
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
guid, err := luid.GUID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ns []string
|
|
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
|
interfaceKeyPath := path + guid.String()
|
|
found := false
|
|
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
|
if found {
|
|
continue
|
|
}
|
|
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
|
out, err := powershell(cmd)
|
|
if err == nil && len(out) > 0 {
|
|
found = true
|
|
for _, e := range strings.Split(string(out), ",") {
|
|
ns = append(ns, strings.TrimRight(e, "\x00"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ns, nil
|
|
}
|
|
|
|
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
|
// and also removing old forwarders if provided.
|
|
func addDnsServerForwarders(nameservers, old []string) error {
|
|
newForwardersMap := make(map[string]struct{})
|
|
newForwarders := make([]string, len(nameservers))
|
|
for i := range nameservers {
|
|
newForwardersMap[nameservers[i]] = struct{}{}
|
|
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
|
}
|
|
oldForwarders := old[:0]
|
|
for _, fwd := range old {
|
|
if _, ok := newForwardersMap[fwd]; !ok {
|
|
oldForwarders = append(oldForwarders, fwd)
|
|
}
|
|
}
|
|
// NOTE: It is important to add new forwarder before removing old one.
|
|
// Testing on Windows Server 2022 shows that removing forwarder1
|
|
// then adding forwarder2 sometimes ends up adding both of them
|
|
// to the forwarders list.
|
|
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
|
if len(oldForwarders) > 0 {
|
|
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
|
}
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
|
func removeDnsServerForwarders(nameservers []string) error {
|
|
for _, ns := range nameservers {
|
|
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
}
|
|
return nil
|
|
}
|