fix(windows): improve DNS server discovery for domain-joined machines

Add DNS suffix matching for non-physical adapters when domain-joined.
This allows interfaces with matching DNS suffix to be considered valid
even if not in validInterfacesMap, improving DNS server discovery for
remote VPN scenarios.

While at it, also replacing context.Background() with proper ctx
parameter throughout the function for consistent context propagation.
This commit is contained in:
Cuong Manh Le
2026-01-14 17:17:55 +07:00
committed by Cuong Manh Le
parent d0341497d1
commit 1804e6db67

View File

@@ -17,6 +17,7 @@ import (
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
"github.com/miekg/dns"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/netmon"
@@ -128,13 +129,15 @@ func getDNSServers(ctx context.Context) ([]string, error) {
// Try to get domain controller info if domain-joined
var dcServers []string
var adDomain string
isDomain := checkDomainJoined()
if isDomain {
domainName, err := system.GetActiveDirectoryDomain()
if err != nil {
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Failed to get local AD domain: %v", err)
} else {
adDomain = domainName
// Load netapi32.dll
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
dsDcName := netapi32.NewProc("DsGetDcNameW")
@@ -144,10 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) {
domainUTF16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to convert domain name to UTF16: %v", err)
Log(ctx, logger.Debug(), "Failed to convert domain name to UTF16: %v", err)
} else {
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
// Call DsGetDcNameW with domain name
@@ -162,19 +164,19 @@ func getDNSServers(ctx context.Context) ([]string, error) {
if ret != 0 {
switch ret {
case 1355: // ERROR_NO_SUCH_DOMAIN
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Domain not found: %s (%d)", domainName, ret)
case 1311: // ERROR_NO_LOGON_SERVERS
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"No logon servers available for domain: %s (%d)", domainName, ret)
case 1004: // ERROR_DC_NOT_FOUND
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Domain controller not found for domain: %s (%d)", domainName, ret)
case 1722: // RPC_S_SERVER_UNAVAILABLE
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"RPC server unavailable for domain: %s (%d)", domainName, ret)
default:
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
}
} else if info != nil {
@@ -183,17 +185,16 @@ func getDNSServers(ctx context.Context) ([]string, error) {
if info.DomainControllerAddress != nil {
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Found domain controller address: %s", dcAddr)
if ip := net.ParseIP(dcAddr); ip != nil {
dcServers = append(dcServers, ip.String())
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Added domain controller DNS servers: %v", dcServers)
}
} else {
Log(context.Background(), logger.Debug(),
"No domain controller address found")
Log(ctx, logger.Debug(), "No domain controller address found")
}
}
}
@@ -208,7 +209,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
// Collect all local IPs
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
Log(context.Background(), logger.Debug(),
Log(ctx, logger.Debug(),
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
continue
}
@@ -216,24 +217,25 @@ func getDNSServers(ctx context.Context) ([]string, error) {
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName())
continue
}
Log(context.Background(), logger.Debug(),
"Processing adapter %s", aa.FriendlyName())
Log(ctx, logger.Debug(), "Processing adapter %s", aa.FriendlyName())
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
ip := a.Address.IP().String()
addressMap[ip] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
Log(ctx, logger.Debug(), "Added local IP %s from adapter %s", ip, aa.FriendlyName())
}
}
validInterfacesMap := validInterfaces()
if isDomain && adDomain == "" {
Log(ctx, logger.Warn(), "The machine is joined domain, but domain name is empty")
}
checkDnsSuffix := isDomain && adDomain != ""
// Collect DNS servers
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
@@ -243,23 +245,33 @@ func getDNSServers(ctx context.Context) ([]string, error) {
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName())
continue
}
// if not in the validInterfacesMap, skip
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
Log(context.Background(), logger.Debug(),
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
_, valid := validInterfacesMap[aa.FriendlyName()]
if !valid && checkDnsSuffix {
for suffix := aa.FirstDNSSuffix; suffix != nil; suffix = suffix.Next {
// For non-physical adapters but have the DNS suffix that matches the domain name,
// (or vice versa) consider it valid. This can happen when remote VPN machines.
ds := strings.TrimSpace(suffix.String())
if dns.IsSubDomain(adDomain, ds) || dns.IsSubDomain(ds, adDomain) {
Log(ctx, logger.Debug(), "Found valid interface %s with DNS suffix %s", aa.FriendlyName(), suffix.String())
valid = true
break
}
}
}
// if not a valid interface, skip it
if !valid {
Log(ctx, logger.Debug(), "Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
continue
}
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
ip := dns.Address.IP()
if ip == nil {
Log(context.Background(), logger.Debug(),
"Skipping nil IP from adapter %s", aa.FriendlyName())
Log(ctx, logger.Debug(), "Skipping nil IP from adapter %s", aa.FriendlyName())
continue
}
@@ -292,28 +304,23 @@ func getDNSServers(ctx context.Context) ([]string, error) {
if !seen[dcServer] {
seen[dcServer] = true
ns = append(ns, dcServer)
Log(context.Background(), logger.Debug(),
"Added additional domain controller DNS server: %s", dcServer)
Log(ctx, logger.Debug(), "Added additional domain controller DNS server: %s", dcServer)
}
}
// if we have static DNS servers saved for the current default route, we should add them to the list
drIfaceName, err := netmon.DefaultRouteInterface()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get default route interface: %v", err)
Log(ctx, logger.Debug(), "Failed to get default route interface: %v", err)
} else {
drIface, err := net.InterfaceByName(drIfaceName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get interface by name %s: %v", drIfaceName, err)
Log(ctx, logger.Debug(), "Failed to get interface by name %s: %v", drIfaceName, err)
} else {
staticNs, file := SavedStaticNameservers(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
Log(ctx, logger.Debug(), "static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {
Log(context.Background(), logger.Debug(),
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
Log(ctx, logger.Debug(), "Adding static DNS servers from %s: %v", drIfaceName, staticNs)
ns = append(ns, staticNs...)
}
}
@@ -323,8 +330,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("no valid DNS servers found")
}
Log(context.Background(), logger.Debug(),
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
Log(ctx, logger.Debug(), "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
len(ns), ns, len(dcServers))
return ns, nil
}