diff --git a/nameservers_windows.go b/nameservers_windows.go index 6817393..e02b1f5 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -17,6 +17,7 @@ import ( "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/constant" "github.com/microsoft/wmi/pkg/hardware/network/netadapter" + "github.com/miekg/dns" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/net/netmon" @@ -128,13 +129,15 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Try to get domain controller info if domain-joined var dcServers []string + var adDomain string isDomain := checkDomainJoined() if isDomain { domainName, err := system.GetActiveDirectoryDomain() if err != nil { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Failed to get local AD domain: %v", err) } else { + adDomain = domainName // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") dsDcName := netapi32.NewProc("DsGetDcNameW") @@ -144,10 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to convert domain name to UTF16: %v", err) + Log(ctx, logger.Debug(), "Failed to convert domain name to UTF16: %v", err) } else { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) // Call DsGetDcNameW with domain name @@ -162,19 +164,19 @@ func getDNSServers(ctx context.Context) ([]string, error) { if ret != 0 { switch ret { case 1355: // ERROR_NO_SUCH_DOMAIN - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Domain not found: %s (%d)", domainName, ret) case 1311: // ERROR_NO_LOGON_SERVERS - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "No logon servers available for domain: %s (%d)", domainName, ret) case 1004: // ERROR_DC_NOT_FOUND - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Domain controller not found for domain: %s (%d)", domainName, ret) case 1722: // RPC_S_SERVER_UNAVAILABLE - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "RPC server unavailable for domain: %s (%d)", domainName, ret) default: - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) } } else if info != nil { @@ -183,17 +185,16 @@ func getDNSServers(ctx context.Context) ([]string, error) { if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Found domain controller address: %s", dcAddr) if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Added domain controller DNS servers: %v", dcServers) } } else { - Log(context.Background(), logger.Debug(), - "No domain controller address found") + Log(ctx, logger.Debug(), "No domain controller address found") } } } @@ -208,7 +209,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Collect all local IPs for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { - Log(context.Background(), logger.Debug(), + Log(ctx, logger.Debug(), "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) continue } @@ -216,24 +217,25 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName()) continue } - Log(context.Background(), logger.Debug(), - "Processing adapter %s", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Processing adapter %s", aa.FriendlyName()) for a := aa.FirstUnicastAddress; a != nil; a = a.Next { ip := a.Address.IP().String() addressMap[ip] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP %s from adapter %s", ip, aa.FriendlyName()) + Log(ctx, logger.Debug(), "Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } validInterfacesMap := validInterfaces() + if isDomain && adDomain == "" { + Log(ctx, logger.Warn(), "The machine is joined domain, but domain name is empty") + } + checkDnsSuffix := isDomain && adDomain != "" // Collect DNS servers for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { @@ -243,23 +245,33 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping %s (software loopback)", aa.FriendlyName()) continue } - // if not in the validInterfacesMap, skip - if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { - Log(context.Background(), logger.Debug(), - "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + _, valid := validInterfacesMap[aa.FriendlyName()] + if !valid && checkDnsSuffix { + for suffix := aa.FirstDNSSuffix; suffix != nil; suffix = suffix.Next { + // For non-physical adapters but have the DNS suffix that matches the domain name, + // (or vice versa) consider it valid. This can happen when remote VPN machines. + ds := strings.TrimSpace(suffix.String()) + if dns.IsSubDomain(adDomain, ds) || dns.IsSubDomain(ds, adDomain) { + Log(ctx, logger.Debug(), "Found valid interface %s with DNS suffix %s", aa.FriendlyName(), suffix.String()) + valid = true + break + } + } + } + // if not a valid interface, skip it + if !valid { + Log(ctx, logger.Debug(), "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) continue } for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { - Log(context.Background(), logger.Debug(), - "Skipping nil IP from adapter %s", aa.FriendlyName()) + Log(ctx, logger.Debug(), "Skipping nil IP from adapter %s", aa.FriendlyName()) continue } @@ -292,28 +304,23 @@ func getDNSServers(ctx context.Context) ([]string, error) { if !seen[dcServer] { seen[dcServer] = true ns = append(ns, dcServer) - Log(context.Background(), logger.Debug(), - "Added additional domain controller DNS server: %s", dcServer) + Log(ctx, logger.Debug(), "Added additional domain controller DNS server: %s", dcServer) } } // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get default route interface: %v", err) + Log(ctx, logger.Debug(), "Failed to get default route interface: %v", err) } else { drIface, err := net.InterfaceByName(drIfaceName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get interface by name %s: %v", drIfaceName, err) + Log(ctx, logger.Debug(), "Failed to get interface by name %s: %v", drIfaceName, err) } else { staticNs, file := SavedStaticNameservers(drIface) - Log(context.Background(), logger.Debug(), - "static dns servers from %s: %v", file, staticNs) + Log(ctx, logger.Debug(), "static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { - Log(context.Background(), logger.Debug(), - "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + Log(ctx, logger.Debug(), "Adding static DNS servers from %s: %v", drIfaceName, staticNs) ns = append(ns, staticNs...) } } @@ -323,8 +330,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("no valid DNS servers found") } - Log(context.Background(), logger.Debug(), - "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", + Log(ctx, logger.Debug(), "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers)) return ns, nil }