From 1804e6db673731210ef48bf37e472b80fffc4cec Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 14 Jan 2026 17:17:55 +0700 Subject: [PATCH] fix(windows): improve DNS server discovery for domain-joined machines Add DNS suffix matching for non-physical adapters when domain-joined. This allows interfaces with matching DNS suffix to be considered valid even if not in validInterfacesMap, improving DNS server discovery for remote VPN scenarios. While at it, also replacing context.Background() with proper ctx parameter throughout the function for consistent context propagation. --- nameservers_windows.go | 86 ++++++++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/nameservers_windows.go b/nameservers_windows.go index 6817393..e02b1f5 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -17,6 +17,7 @@ import ( "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/constant" "github.com/microsoft/wmi/pkg/hardware/network/netadapter" + "github.com/miekg/dns" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/net/netmon" @@ -128,13 +129,15 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Try to get domain controller info if domain-joined var dcServers []string + var adDomain string isDomain := checkDomainJoined() if isDomain { domainName, err := system.GetActiveDirectoryDomain() if err != nil { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Failed to get local AD domain: %v", err) } else { + adDomain = domainName // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") dsDcName := netapi32.NewProc("DsGetDcNameW") @@ -144,10 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to convert domain name to UTF16: %v", err) + Log(ctx, logger.Debug(), "Failed to convert domain name to UTF16: %v", err) } else { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) // Call DsGetDcNameW with domain name @@ -162,19 +164,19 @@ func getDNSServers(ctx context.Context) ([]string, error) { if ret != 0 { switch ret { case 1355: // ERROR_NO_SUCH_DOMAIN - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Domain not found: %s (%d)", domainName, ret) case 1311: // ERROR_NO_LOGON_SERVERS - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "No logon servers available for domain: %s (%d)", domainName, ret) case 1004: // ERROR_DC_NOT_FOUND - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Domain controller not found for domain: %s (%d)", domainName, ret) case 1722: // RPC_S_SERVER_UNAVAILABLE - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "RPC server unavailable for domain: %s (%d)", domainName, ret) default: - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) } } else if info != nil { @@ -183,17 +185,16 @@ func getDNSServers(ctx context.Context) ([]string, error) { if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Found domain controller address: %s", dcAddr) if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Added domain controller DNS servers: %v", dcServers) } } else { - Log(context.Background(), logger.Debug(), - "No domain controller address found") + Log(ctx, logger.Debug(), "No domain controller address found") } } } @@ -208,7 +209,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Collect all local IPs for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) continue } @@ -216,24 +217,25 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName()) continue } - Log(context.Background(), logger.Debug(), - "Processing adapter %s", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Processing adapter %s", aa.FriendlyName()) for a := aa.FirstUnicastAddress; a != nil; a = a.Next { ip := a.Address.IP().String() addressMap[ip] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP %s from adapter %s", ip, aa.FriendlyName()) + Log(ctx, logger.Debug(), "Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } validInterfacesMap := validInterfaces() + if isDomain && adDomain == "" { + Log(ctx, logger.Warn(), "The machine is joined domain, but domain name is empty") + } + checkDnsSuffix := isDomain && adDomain != "" // Collect DNS servers for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { @@ -243,23 +245,33 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName()) continue } - // if not in the validInterfacesMap, skip - if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { - Log(context.Background(), logger.Debug(), - "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + _, valid := validInterfacesMap[aa.FriendlyName()] + if !valid && checkDnsSuffix { + for suffix := aa.FirstDNSSuffix; suffix != nil; suffix = suffix.Next { + // For non-physical adapters but have the DNS suffix that matches the domain name, + // (or vice versa) consider it valid. This can happen when remote VPN machines. + ds := strings.TrimSpace(suffix.String()) + if dns.IsSubDomain(adDomain, ds) || dns.IsSubDomain(ds, adDomain) { + Log(ctx, logger.Debug(), "Found valid interface %s with DNS suffix %s", aa.FriendlyName(), suffix.String()) + valid = true + break + } + } + } + // if not a valid interface, skip it + if !valid { + Log(ctx, logger.Debug(), "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) continue } for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { - Log(context.Background(), logger.Debug(), - "Skipping nil IP from adapter %s", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping nil IP from adapter %s", aa.FriendlyName()) continue } @@ -292,28 +304,23 @@ func getDNSServers(ctx context.Context) ([]string, error) { if !seen[dcServer] { seen[dcServer] = true ns = append(ns, dcServer) - Log(context.Background(), logger.Debug(), - "Added additional domain controller DNS server: %s", dcServer) + Log(ctx, logger.Debug(), "Added additional domain controller DNS server: %s", dcServer) } } // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get default route interface: %v", err) + Log(ctx, logger.Debug(), "Failed to get default route interface: %v", err) } else { drIface, err := net.InterfaceByName(drIfaceName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get interface by name %s: %v", drIfaceName, err) + Log(ctx, logger.Debug(), "Failed to get interface by name %s: %v", drIfaceName, err) } else { staticNs, file := SavedStaticNameservers(drIface) - Log(context.Background(), logger.Debug(), - "static dns servers from %s: %v", file, staticNs) + Log(ctx, logger.Debug(), "static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { - Log(context.Background(), logger.Debug(), - "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + Log(ctx, logger.Debug(), "Adding static DNS servers from %s: %v", drIfaceName, staticNs) ns = append(ns, staticNs...) } } @@ -323,8 +330,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("no valid DNS servers found") } - Log(context.Background(), logger.Debug(), - "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", + Log(ctx, logger.Debug(), "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers)) return ns, nil }