Files
ctrld/nameservers_windows.go
2025-12-18 17:10:39 +07:00

430 lines
13 KiB
Go

package ctrld
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"syscall"
"time"
"unsafe"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/netmon"
"github.com/Control-D-Inc/ctrld/internal/system"
)
const (
maxDNSAdapterRetries = 5
retryDelayDNSAdapter = 1 * time.Second
defaultDNSAdapterTimeout = 10 * time.Second
minDNSServers = 1 // Minimum number of DNS servers we want to find
DS_FORCE_REDISCOVERY = 0x00000001
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
DS_BACKGROUND_ONLY = 0x00000100
DS_IP_REQUIRED = 0x00000200
DS_IS_DNS_NAME = 0x00020000
DS_RETURN_DNS_NAME = 0x40000000
)
type DomainControllerInfo struct {
DomainControllerName *uint16
DomainControllerAddress *uint16
DomainControllerAddressType uint32
DomainGuid windows.GUID
DomainName *uint16
DnsForestName *uint16
Flags uint32
DcSiteName *uint16
ClientSiteName *uint16
}
func dnsFns() []dnsFn {
return []dnsFn{dnsFromAdapter}
}
func dnsFromAdapter() []string {
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
defer cancel()
var ns []string
var err error
logger := *ProxyLogger.Load()
for i := 0; i < maxDNSAdapterRetries; i++ {
if ctx.Err() != nil {
Log(context.Background(), logger.Debug(),
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
return nil
}
ns, err = getDNSServers(ctx)
if err == nil && len(ns) >= minDNSServers {
if i > 0 {
Log(context.Background(), logger.Debug(),
"Successfully got DNS servers after %d attempts, found %d servers",
i+1, len(ns))
}
return ns
}
// if osResolver is not initialized, this is likely a command line run
// and ctrld is already on the interface, abort retries
if or == nil {
return ns
}
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get DNS servers, attempt %d: %v", i+1, err)
} else {
Log(context.Background(), logger.Debug(),
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
}
select {
case <-ctx.Done():
return nil
case <-time.After(retryDelayDNSAdapter):
}
}
Log(context.Background(), logger.Debug(),
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
return ns
}
func getDNSServers(ctx context.Context) ([]string, error) {
logger := *ProxyLogger.Load()
// Check context before making the call
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Get DNS servers from adapters (existing method)
flags := winipcfg.GAAFlagIncludeGateways |
winipcfg.GAAFlagIncludePrefix
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
if err != nil {
return nil, fmt.Errorf("getting adapters: %w", err)
}
Log(context.Background(), logger.Debug(),
"Found network adapters, count=%d", len(aas))
// Try to get domain controller info if domain-joined
var dcServers []string
isDomain := checkDomainJoined()
if isDomain {
domainName, err := system.GetActiveDirectoryDomain()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get local AD domain: %v", err)
} else {
// Load netapi32.dll
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
dsDcName := netapi32.NewProc("DsGetDcNameW")
var info *DomainControllerInfo
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
domainUTF16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to convert domain name to UTF16: %v", err)
} else {
Log(context.Background(), logger.Debug(),
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
// Call DsGetDcNameW with domain name
ret, _, err := dsDcName.Call(
0, // ComputerName - can be NULL
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
0, // DomainGuid - not needed
0, // SiteName - not needed
uintptr(flags), // Flags
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
if ret != 0 {
switch ret {
case 1355: // ERROR_NO_SUCH_DOMAIN
Log(context.Background(), logger.Debug(),
"Domain not found: %s (%d)", domainName, ret)
case 1311: // ERROR_NO_LOGON_SERVERS
Log(context.Background(), logger.Debug(),
"No logon servers available for domain: %s (%d)", domainName, ret)
case 1004: // ERROR_DC_NOT_FOUND
Log(context.Background(), logger.Debug(),
"Domain controller not found for domain: %s (%d)", domainName, ret)
case 1722: // RPC_S_SERVER_UNAVAILABLE
Log(context.Background(), logger.Debug(),
"RPC server unavailable for domain: %s (%d)", domainName, ret)
default:
Log(context.Background(), logger.Debug(),
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
}
} else if info != nil {
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
if info.DomainControllerAddress != nil {
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
Log(context.Background(), logger.Debug(),
"Found domain controller address: %s", dcAddr)
if ip := net.ParseIP(dcAddr); ip != nil {
dcServers = append(dcServers, ip.String())
Log(context.Background(), logger.Debug(),
"Added domain controller DNS servers: %v", dcServers)
}
} else {
Log(context.Background(), logger.Debug(),
"No domain controller address found")
}
}
}
}
}
// Continue with existing adapter DNS collection
ns := make([]string, 0, len(aas)*2)
seen := make(map[string]bool)
addressMap := make(map[string]struct{})
// Collect all local IPs
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
Log(context.Background(), logger.Debug(),
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
continue
}
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
continue
}
Log(context.Background(), logger.Debug(),
"Processing adapter %s", aa.FriendlyName())
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
ip := a.Address.IP().String()
addressMap[ip] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
}
}
validInterfacesMap := validInterfaces()
// Collect DNS servers
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
continue
}
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
continue
}
// if not in the validInterfacesMap, skip
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
Log(context.Background(), logger.Debug(),
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
continue
}
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
ip := dns.Address.IP()
if ip == nil {
Log(context.Background(), logger.Debug(),
"Skipping nil IP from adapter %s", aa.FriendlyName())
continue
}
ipStr := ip.String()
l := logger.Debug().
Str("ip", ipStr).
Str("adapter", aa.FriendlyName())
if ip.IsLoopback() {
l.Msg("Skipping loopback IP")
continue
}
if seen[ipStr] {
l.Msg("Skipping duplicate IP")
continue
}
if _, ok := addressMap[ipStr]; ok {
l.Msg("Skipping local interface IP")
continue
}
seen[ipStr] = true
ns = append(ns, ipStr)
l.Msg("Added DNS server")
}
}
// Add DC servers if they're not already in the list
for _, dcServer := range dcServers {
if !seen[dcServer] {
seen[dcServer] = true
ns = append(ns, dcServer)
Log(context.Background(), logger.Debug(),
"Added additional domain controller DNS server: %s", dcServer)
}
}
// if we have static DNS servers saved for the current default route, we should add them to the list
drIfaceName, err := netmon.DefaultRouteInterface()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get default route interface: %v", err)
} else {
drIface, err := net.InterfaceByName(drIfaceName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get interface by name %s: %v", drIfaceName, err)
} else {
staticNs, file := SavedStaticNameservers(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {
Log(context.Background(), logger.Debug(),
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
ns = append(ns, staticNs...)
}
}
}
if len(ns) == 0 {
return nil, fmt.Errorf("no valid DNS servers found")
}
Log(context.Background(), logger.Debug(),
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
len(ns), ns, len(dcServers))
return ns, nil
}
// currentNameserversFromResolvconf returns a nil slice of strings.
func currentNameserversFromResolvconf() []string {
return nil
}
// checkDomainJoined checks if the machine is joined to an Active Directory domain
// Returns whether it's domain joined and the domain name if available
func checkDomainJoined() bool {
logger := *ProxyLogger.Load()
status, err := system.DomainJoinedStatus()
if err != nil {
logger.Debug().Msgf("Failed to get domain joined status: %v", err)
return false
}
isDomain := status == syscall.NetSetupDomainName
logger.Debug().Msg("Domain join status: (UnknownStatus=0, Unjoined=1, WorkgroupName=2, DomainName=3)")
logger.Debug().Msgf("Is domain joined? status=%d, result=%v", status, isDomain)
return isDomain
}
// validInterfaces returns a list of all physical interfaces.
// this is a duplicate of what is in net_windows.go, we should
// clean this up so there is only one version
func validInterfaces() map[string]struct{} {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
//load the logger
logger := *ProxyLogger.Load()
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get wmi network adapter: %v", err)
return nil
}
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get network adapter: %v", err)
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get interface name: %v", err)
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
Log(context.Background(), logger.Debug(),
"failed to get network adapter connector present property: %v", err)
continue
}
if !physical {
Log(context.Background(), logger.Debug(),
"skipping non-physical adapter: %s", name)
continue
}
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
Log(context.Background(), logger.Debug(),
"failed to get network adapter hardware interface property: %v", err)
continue
}
if !hardware {
Log(context.Background(), logger.Debug(),
"skipping non-hardware interface: %s", name)
continue
}
adapters = append(adapters, name)
}
m := make(map[string]struct{})
for _, ifaceName := range adapters {
m[ifaceName] = struct{}{}
}
return m
}