mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
postRun should not restore static settings put back validInterface check better debug logs for os resolver init, use mutex to prevent duplicate initializations use WMI instead of registry keys for static DNS data on Windows use WMI instead of registry keys for static DNS data on Windows use winipcfg DNS method use WMI with registry fallback go back to registry method restore saved static configs on stop and uninstall restore ipv6 DHCP if no saved static ipv6 addresses do not save loopback IPs for static configs handle watchdog interface changed for new interfaces dont overwrite static file on start when staticdns is set to loopback dont overwrite static file on start when staticdns is set to loopback dont overwrite static file on start when staticdns is set to loopback no need to resetDNS on start, uninstall already takes care of this
304 lines
9.4 KiB
Go
304 lines
9.4 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"os/exec"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
const (
|
|
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
|
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
|
)
|
|
|
|
var (
|
|
setDNSOnce sync.Once
|
|
resetDNSOnce sync.Once
|
|
)
|
|
|
|
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
|
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
|
return setDNS(iface, nameservers)
|
|
}
|
|
|
|
// setDNS sets the dns server for the provided network interface
|
|
func setDNS(iface *net.Interface, nameservers []string) error {
|
|
if len(nameservers) == 0 {
|
|
return errors.New("empty DNS nameservers")
|
|
}
|
|
setDNSOnce.Do(func() {
|
|
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
|
// Configuring the Dns server to forward queries to ctrld instead.
|
|
if hasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
oldForwardersContent, _ := os.ReadFile(file)
|
|
hasLocalIPv6Listener := needLocalIPv6Listener()
|
|
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
|
if !hasLocalIPv6Listener {
|
|
return false
|
|
}
|
|
return s == "::1"
|
|
})
|
|
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
|
}
|
|
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
|
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
|
}
|
|
}
|
|
})
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return fmt.Errorf("setDNS: %w", err)
|
|
}
|
|
var (
|
|
serversV4 []netip.Addr
|
|
serversV6 []netip.Addr
|
|
)
|
|
for _, ns := range nameservers {
|
|
if addr, err := netip.ParseAddr(ns); err == nil {
|
|
if addr.Is4() {
|
|
serversV4 = append(serversV4, addr)
|
|
} else {
|
|
serversV6 = append(serversV6, addr)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(serversV4) == 0 && len(serversV6) == 0 {
|
|
return errors.New("invalid DNS nameservers")
|
|
}
|
|
if len(serversV4) > 0 {
|
|
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
|
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
|
}
|
|
}
|
|
if len(serversV6) > 0 {
|
|
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
|
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
|
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
|
return resetDNS(iface)
|
|
}
|
|
|
|
// TODO(cuonglm): should we use system API?
|
|
func resetDNS(iface *net.Interface) error {
|
|
resetDNSOnce.Do(func() {
|
|
// See corresponding comment in setDNS.
|
|
if hasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
content, err := os.ReadFile(file)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
|
return
|
|
}
|
|
nameservers := strings.Split(string(content), ",")
|
|
if err := removeDnsServerForwarders(nameservers); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return fmt.Errorf("resetDNS: %w", err)
|
|
}
|
|
// Restoring DHCP settings.
|
|
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
|
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
|
}
|
|
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
|
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// restoreDNS restores the DNS settings of the given interface.
|
|
// this should only be executed upon turning off the ctrld service.
|
|
func restoreDNS(iface *net.Interface) (err error) {
|
|
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
|
v4ns := make([]string, 0, 2)
|
|
v6ns := make([]string, 0, 2)
|
|
for _, ns := range nss {
|
|
if ctrldnet.IsIPv6(ns) {
|
|
v6ns = append(v6ns, ns)
|
|
} else {
|
|
v4ns = append(v4ns, ns)
|
|
}
|
|
}
|
|
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return fmt.Errorf("restoreDNS: %w", err)
|
|
}
|
|
|
|
if len(v4ns) > 0 {
|
|
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
|
if err := setDNS(iface, v4ns); err != nil {
|
|
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
|
}
|
|
} else {
|
|
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
|
|
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
|
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
|
}
|
|
}
|
|
|
|
if len(v6ns) > 0 {
|
|
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
|
if err := setDNS(iface, v6ns); err != nil {
|
|
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
|
}
|
|
} else {
|
|
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
|
|
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
|
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func currentDNS(iface *net.Interface) []string {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
|
return nil
|
|
}
|
|
nameservers, err := luid.DNS()
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
|
return nil
|
|
}
|
|
ns := make([]string, 0, len(nameservers))
|
|
for _, nameserver := range nameservers {
|
|
ns = append(ns, nameserver.String())
|
|
}
|
|
return ns
|
|
}
|
|
|
|
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
|
// like "NameServer" and "ProfileNameServer".
|
|
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
|
}
|
|
guid, err := luid.GUID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
|
}
|
|
|
|
var ns []string
|
|
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
|
for _, path := range keyPaths {
|
|
interfaceKeyPath := path + guid.String()
|
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
|
if err != nil {
|
|
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
|
continue
|
|
}
|
|
func() {
|
|
defer k.Close()
|
|
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
|
value, _, err := k.GetStringValue(keyName)
|
|
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
|
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
|
|
continue
|
|
}
|
|
if len(value) > 0 {
|
|
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
|
parsed := parseDNSServers(value)
|
|
ns = append(ns, parsed...)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
|
|
}
|
|
return ns, nil
|
|
}
|
|
|
|
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
|
// and trims any extraneous whitespace or null characters.
|
|
func parseDNSServers(val string) []string {
|
|
fields := strings.FieldsFunc(val, func(r rune) bool {
|
|
return r == ' ' || r == ','
|
|
})
|
|
var servers []string
|
|
for _, f := range fields {
|
|
trimmed := strings.TrimSpace(f)
|
|
if len(trimmed) > 0 {
|
|
servers = append(servers, trimmed)
|
|
}
|
|
}
|
|
return servers
|
|
}
|
|
|
|
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
|
// and also removing old forwarders if provided.
|
|
func addDnsServerForwarders(nameservers, old []string) error {
|
|
newForwardersMap := make(map[string]struct{})
|
|
newForwarders := make([]string, len(nameservers))
|
|
for i := range nameservers {
|
|
newForwardersMap[nameservers[i]] = struct{}{}
|
|
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
|
}
|
|
oldForwarders := old[:0]
|
|
for _, fwd := range old {
|
|
if _, ok := newForwardersMap[fwd]; !ok {
|
|
oldForwarders = append(oldForwarders, fwd)
|
|
}
|
|
}
|
|
// NOTE: It is important to add new forwarder before removing old one.
|
|
// Testing on Windows Server 2022 shows that removing forwarder1
|
|
// then adding forwarder2 sometimes ends up adding both of them
|
|
// to the forwarders list.
|
|
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
|
if len(oldForwarders) > 0 {
|
|
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
|
}
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
|
func removeDnsServerForwarders(nameservers []string) error {
|
|
for _, ns := range nameservers {
|
|
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// powershell runs the given powershell command.
|
|
func powershell(cmd string) ([]byte, error) {
|
|
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
|
return bytes.TrimSpace(out), err
|
|
}
|