mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
smol tweaks to nameserver test queries fix restoreDNS errors add some debugging information fix wront type in log msg set send logs command timeout to 5 mins when the runningIface is no longer up, attempt to find a new interface prefer default route, ignore non physical interfaces prefer default route, ignore non physical interfaces add max context timeout on performLeakingQuery with more debug logs
265 lines
7.9 KiB
Go
265 lines
7.9 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"os/exec"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
const (
|
|
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
|
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
|
)
|
|
|
|
var (
|
|
setDNSOnce sync.Once
|
|
resetDNSOnce sync.Once
|
|
)
|
|
|
|
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
|
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
|
return setDNS(iface, nameservers)
|
|
}
|
|
|
|
// setDNS sets the dns server for the provided network interface
|
|
func setDNS(iface *net.Interface, nameservers []string) error {
|
|
if len(nameservers) == 0 {
|
|
return errors.New("empty DNS nameservers")
|
|
}
|
|
setDNSOnce.Do(func() {
|
|
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
|
// Configuring the Dns server to forward queries to ctrld instead.
|
|
if hasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
oldForwardersContent, _ := os.ReadFile(file)
|
|
hasLocalIPv6Listener := needLocalIPv6Listener()
|
|
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
|
if !hasLocalIPv6Listener {
|
|
return false
|
|
}
|
|
return s == "::1"
|
|
})
|
|
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
|
}
|
|
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
|
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
|
}
|
|
}
|
|
})
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return fmt.Errorf("setDNS: %w", err)
|
|
}
|
|
var (
|
|
serversV4 []netip.Addr
|
|
serversV6 []netip.Addr
|
|
)
|
|
for _, ns := range nameservers {
|
|
if addr, err := netip.ParseAddr(ns); err == nil {
|
|
if addr.Is4() {
|
|
serversV4 = append(serversV4, addr)
|
|
} else {
|
|
serversV6 = append(serversV6, addr)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(serversV4) == 0 && len(serversV6) == 0 {
|
|
return errors.New("invalid DNS nameservers")
|
|
}
|
|
if len(serversV4) > 0 {
|
|
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
|
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
|
}
|
|
}
|
|
if len(serversV6) > 0 {
|
|
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
|
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
|
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
|
return resetDNS(iface)
|
|
}
|
|
|
|
// TODO(cuonglm): should we use system API?
|
|
func resetDNS(iface *net.Interface) error {
|
|
resetDNSOnce.Do(func() {
|
|
// See corresponding comment in setDNS.
|
|
if hasLocalDnsServerRunning() {
|
|
file := absHomeDir(windowsForwardersFilename)
|
|
content, err := os.ReadFile(file)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
|
return
|
|
}
|
|
nameservers := strings.Split(string(content), ",")
|
|
if err := removeDnsServerForwarders(nameservers); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return fmt.Errorf("resetDNS: %w", err)
|
|
}
|
|
// Restoring DHCP settings.
|
|
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
|
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
|
}
|
|
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
|
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// restoreDNS restores the DNS settings of the given interface.
|
|
// this should only be executed upon turning off the ctrld service.
|
|
func restoreDNS(iface *net.Interface) (err error) {
|
|
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
|
v4ns := make([]string, 0, 2)
|
|
v6ns := make([]string, 0, 2)
|
|
for _, ns := range nss {
|
|
if ctrldnet.IsIPv6(ns) {
|
|
v6ns = append(v6ns, ns)
|
|
} else {
|
|
v4ns = append(v4ns, ns)
|
|
}
|
|
}
|
|
|
|
for _, ns := range [][]string{v4ns, v6ns} {
|
|
if len(ns) == 0 {
|
|
continue
|
|
}
|
|
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
|
err = setDNS(iface, ns)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func currentDNS(iface *net.Interface) []string {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
|
return nil
|
|
}
|
|
nameservers, err := luid.DNS()
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
|
return nil
|
|
}
|
|
ns := make([]string, 0, len(nameservers))
|
|
for _, nameserver := range nameservers {
|
|
ns = append(ns, nameserver.String())
|
|
}
|
|
return ns
|
|
}
|
|
|
|
// currentStaticDNS returns the current static DNS settings of given interface.
|
|
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
|
|
}
|
|
guid, err := luid.GUID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("luid.GUID: %w", err)
|
|
}
|
|
var ns []string
|
|
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
|
found := false
|
|
interfaceKeyPath := path + guid.String()
|
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
|
|
}
|
|
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
|
if found {
|
|
continue
|
|
}
|
|
value, _, err := k.GetStringValue(key)
|
|
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
|
return nil, fmt.Errorf("%s: %w", key, err)
|
|
}
|
|
if len(value) > 0 {
|
|
found = true
|
|
for _, e := range strings.Split(value, ",") {
|
|
ns = append(ns, strings.TrimRight(e, "\x00"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ns, nil
|
|
}
|
|
|
|
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
|
// and also removing old forwarders if provided.
|
|
func addDnsServerForwarders(nameservers, old []string) error {
|
|
newForwardersMap := make(map[string]struct{})
|
|
newForwarders := make([]string, len(nameservers))
|
|
for i := range nameservers {
|
|
newForwardersMap[nameservers[i]] = struct{}{}
|
|
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
|
}
|
|
oldForwarders := old[:0]
|
|
for _, fwd := range old {
|
|
if _, ok := newForwardersMap[fwd]; !ok {
|
|
oldForwarders = append(oldForwarders, fwd)
|
|
}
|
|
}
|
|
// NOTE: It is important to add new forwarder before removing old one.
|
|
// Testing on Windows Server 2022 shows that removing forwarder1
|
|
// then adding forwarder2 sometimes ends up adding both of them
|
|
// to the forwarders list.
|
|
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
|
if len(oldForwarders) > 0 {
|
|
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
|
}
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
|
func removeDnsServerForwarders(nameservers []string) error {
|
|
for _, ns := range nameservers {
|
|
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
|
if out, err := powershell(cmd); err != nil {
|
|
return fmt.Errorf("%w: %s", err, string(out))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// powershell runs the given powershell command.
|
|
func powershell(cmd string) ([]byte, error) {
|
|
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
|
return bytes.TrimSpace(out), err
|
|
}
|