mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
fix(windows): improve DNS server discovery for domain-joined machines
Add DNS suffix matching for non-physical adapters when domain-joined. This allows interfaces with matching DNS suffix to be considered valid even if not in validInterfacesMap, improving DNS server discovery for remote VPN scenarios. While at it, also replacing context.Background() with proper ctx parameter throughout the function for consistent context propagation.
This commit is contained in:
committed by
Cuong Manh Le
parent
d0341497d1
commit
1804e6db67
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/net/netmon"
|
||||
@@ -128,13 +129,15 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
|
||||
// Try to get domain controller info if domain-joined
|
||||
var dcServers []string
|
||||
var adDomain string
|
||||
isDomain := checkDomainJoined()
|
||||
if isDomain {
|
||||
domainName, err := system.GetActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Failed to get local AD domain: %v", err)
|
||||
} else {
|
||||
adDomain = domainName
|
||||
// Load netapi32.dll
|
||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
||||
@@ -144,10 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
|
||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to convert domain name to UTF16: %v", err)
|
||||
Log(ctx, logger.Debug(), "Failed to convert domain name to UTF16: %v", err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
||||
|
||||
// Call DsGetDcNameW with domain name
|
||||
@@ -162,19 +164,19 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if ret != 0 {
|
||||
switch ret {
|
||||
case 1355: // ERROR_NO_SUCH_DOMAIN
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Domain not found: %s (%d)", domainName, ret)
|
||||
case 1311: // ERROR_NO_LOGON_SERVERS
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"No logon servers available for domain: %s (%d)", domainName, ret)
|
||||
case 1004: // ERROR_DC_NOT_FOUND
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
||||
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
||||
default:
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
||||
}
|
||||
} else if info != nil {
|
||||
@@ -183,17 +185,16 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if info.DomainControllerAddress != nil {
|
||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Found domain controller address: %s", dcAddr)
|
||||
|
||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||
dcServers = append(dcServers, ip.String())
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Added domain controller DNS servers: %v", dcServers)
|
||||
}
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No domain controller address found")
|
||||
Log(ctx, logger.Debug(), "No domain controller address found")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,7 +209,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
// Collect all local IPs
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
Log(ctx, logger.Debug(),
|
||||
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
||||
continue
|
||||
}
|
||||
@@ -216,24 +217,25 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Processing adapter %s", aa.FriendlyName())
|
||||
Log(ctx, logger.Debug(), "Processing adapter %s", aa.FriendlyName())
|
||||
|
||||
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||
ip := a.Address.IP().String()
|
||||
addressMap[ip] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
Log(ctx, logger.Debug(), "Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
}
|
||||
}
|
||||
|
||||
validInterfacesMap := validInterfaces()
|
||||
|
||||
if isDomain && adDomain == "" {
|
||||
Log(ctx, logger.Warn(), "The machine is joined domain, but domain name is empty")
|
||||
}
|
||||
checkDnsSuffix := isDomain && adDomain != ""
|
||||
// Collect DNS servers
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
@@ -243,23 +245,33 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
// if not in the validInterfacesMap, skip
|
||||
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
_, valid := validInterfacesMap[aa.FriendlyName()]
|
||||
if !valid && checkDnsSuffix {
|
||||
for suffix := aa.FirstDNSSuffix; suffix != nil; suffix = suffix.Next {
|
||||
// For non-physical adapters but have the DNS suffix that matches the domain name,
|
||||
// (or vice versa) consider it valid. This can happen when remote VPN machines.
|
||||
ds := strings.TrimSpace(suffix.String())
|
||||
if dns.IsSubDomain(adDomain, ds) || dns.IsSubDomain(ds, adDomain) {
|
||||
Log(ctx, logger.Debug(), "Found valid interface %s with DNS suffix %s", aa.FriendlyName(), suffix.String())
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// if not a valid interface, skip it
|
||||
if !valid {
|
||||
Log(ctx, logger.Debug(), "Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
ip := dns.Address.IP()
|
||||
if ip == nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
Log(ctx, logger.Debug(), "Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -292,28 +304,23 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if !seen[dcServer] {
|
||||
seen[dcServer] = true
|
||||
ns = append(ns, dcServer)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added additional domain controller DNS server: %s", dcServer)
|
||||
Log(ctx, logger.Debug(), "Added additional domain controller DNS server: %s", dcServer)
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
Log(ctx, logger.Debug(), "Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
Log(ctx, logger.Debug(), "Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else {
|
||||
staticNs, file := SavedStaticNameservers(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
Log(ctx, logger.Debug(), "static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
Log(ctx, logger.Debug(), "Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
ns = append(ns, staticNs...)
|
||||
}
|
||||
}
|
||||
@@ -323,8 +330,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
return nil, fmt.Errorf("no valid DNS servers found")
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
||||
Log(ctx, logger.Debug(), "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
||||
len(ns), ns, len(dcServers))
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user