diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index d7eb28a..0d67e88 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -450,6 +450,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } if p.isAdDomainQuery(req.msg) { + ctrld.Log(ctx, mainLog.Load().Debug(), + "AD domain query detected for %s in domain %s", + req.msg.Question[0].Name, p.adDomain) upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} upstreams = []string{upstreamOS} } @@ -566,14 +569,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n]) + logger := mainLog.Load().Debug(). + Str("upstream", upstreamConfig.String()). + Str("query", req.msg.Question[0].Name). + Bool("is_ad_query", p.isAdDomainQuery(req.msg)). + Bool("is_lan_query", isLanOrPtrQuery) if p.isLoop(upstreamConfig) { - mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String()) + logger.Msg("DNS loop detected") continue } if p.um.isDown(upstreams[n]) { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s is down", upstreams[n]) + logger. + Bool("is_os_resolver", upstreams[n] == upstreamOS). + Msg("Upstream is down") continue } answer := resolve(n, upstreamConfig, req.msg) @@ -1257,7 +1266,6 @@ func (p *prog) reinitializeOSResolver() { // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges() error { - // Create network monitor mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1267,6 +1275,12 @@ func (p *prog) monitorNetworkChanges() error { // Get map of valid interfaces validIfaces := validInterfacesMap() + // log the delta for debugging + mainLog.Load().Debug(). + Interface("old_state", delta.Old). + Interface("new_state", delta.New). + Msg("Network change detected") + // Parse old and new interface states oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) @@ -1276,14 +1290,14 @@ func (p *prog) monitorNetworkChanges() error { activeInterfaceExists := false for ifaceName := range validIfaces { - oldState, oldExists := oldIfs[strings.ToLower(ifaceName)] newState, newExists := newIfs[strings.ToLower(ifaceName)] - if newState != "" && newState != "down" { + if newState != "" && !strings.Contains(newState, "down") { activeInterfaceExists = true } + // Compare states directly if oldExists != newExists || oldState != newState { changed = true mainLog.Load().Debug(). @@ -1302,11 +1316,10 @@ func (p *prog) monitorNetworkChanges() error { } if !changed { - mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected") + mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") return } - mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New) if activeInterfaceExists { p.reinitializeOSResolver() } else { @@ -1326,9 +1339,10 @@ func parseInterfaceState(state *netmon.State) map[string]string { } result := make(map[string]string) - - // Extract ifs={...} section + stateStr := state.String() + + // Extract interface information ifsStart := strings.Index(stateStr, "ifs={") if ifsStart == -1 { return result @@ -1340,17 +1354,28 @@ func parseInterfaceState(state *netmon.State) map[string]string { return result } - // Parse each interface entry - ifaces := strings.Split(ifsStr[:ifsEnd], " ") - for _, iface := range ifaces { - parts := strings.Split(iface, ":") + // Get the content between ifs={ } + ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) + + // Split on "] " to get each interface entry + entries := strings.Split(ifsContent, "] ") + + for _, entry := range entries { + if entry == "" { + continue + } + + // Split on ":[" + parts := strings.Split(entry, ":[") if len(parts) != 2 { continue } - name := strings.ToLower(parts[0]) - state := parts[1] - result[name] = state + + name := strings.TrimSpace(parts[0]) + state := "[" + strings.TrimSuffix(parts[1], "]") + "]" + + result[strings.ToLower(name)] = state } return result -} +} \ No newline at end of file diff --git a/go.mod b/go.mod index a86557e..8e9a8f7 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( require ( aead.dev/minisign v0.2.0 // indirect + github.com/StackExchange/wmi v1.2.1 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index 3eb268a..fcf2ac7 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= +github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= @@ -93,6 +95,7 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= @@ -449,6 +452,7 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/nameservers_windows.go b/nameservers_windows.go index 150f252..a8c5191 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -1,44 +1,364 @@ package ctrld import ( + "context" + "fmt" + "net" + "strings" "syscall" + "time" + "unsafe" + "io" + "os" + "github.com/rs/zerolog" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "github.com/StackExchange/wmi" ) +const ( + maxRetries = 3 + retryDelay = 500 * time.Millisecond + defaultTimeout = 5 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + NetSetupUnknown uint32 = 0 + NetSetupWorkgroup uint32 = 1 + NetSetupDomain uint32 = 2 + NetSetupCloudDomain uint32 = 3 + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 +) + +type DomainControllerInfo struct { + DomainControllerName *uint16 + DomainControllerAddress *uint16 + DomainControllerAddressType uint32 + DomainGuid windows.GUID + DomainName *uint16 + DnsForestName *uint16 + Flags uint32 + DcSiteName *uint16 + ClientSiteName *uint16 +} + func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } func dnsFromAdapter() []string { - aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) - if err != nil { - return nil + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + var ns []string + var err error + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } + + for i := 0; i < maxRetries; i++ { + if ctx.Err() != nil { + Log(context.Background(), logger.Debug(), + "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + return nil + } + + ns, err = getDNSServers(ctx) + if err == nil && len(ns) >= minDNSServers { + if i > 0 { + Log(context.Background(), logger.Debug(), + "Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) + } + return ns + } + + // Log the specific failure reason + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get DNS servers, attempt %d: %v", i+1, err) + } else { + Log(context.Background(), logger.Debug(), + "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(retryDelay): + } + } + + Log(context.Background(), logger.Debug(), + "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries) + return ns // Return whatever we got, even if insufficient +} + +func getDNSServers(ctx context.Context) ([]string, error) { + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + // Check context before making the call + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // Get DNS servers from adapters (existing method) + flags := winipcfg.GAAFlagIncludeGateways | + winipcfg.GAAFlagIncludePrefix + + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags) + if err != nil { + return nil, fmt.Errorf("getting adapters: %w", err) + } + + Log(context.Background(), logger.Debug(), + "Found network adapters, count=%d", len(aas)) + + // Try to get domain controller info if domain-joined + var dcServers []string + isDomain := checkDomainJoined() + if isDomain { + + domainName, err := getLocalADDomain() + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get local AD domain: %v", err) + + } else { + + // Load netapi32.dll + netapi32 := windows.NewLazySystemDLL("netapi32.dll") + dsDcName := netapi32.NewProc("DsGetDcNameW") + + var info *DomainControllerInfo + + flags := uint32(DS_RETURN_DNS_NAME | + DS_IP_REQUIRED | + DS_IS_DNS_NAME) + + // Convert domain name to UTF16 pointer + domainUTF16, err := windows.UTF16PtrFromString(domainName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to convert domain name to UTF16: %v", err) + } else { + Log(context.Background(), logger.Debug(), + "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + + // Call DsGetDcNameW with domain name + ret, _, err := dsDcName.Call( + 0, // ComputerName - can be NULL + uintptr(unsafe.Pointer(domainUTF16)), // DomainName + 0, // DomainGuid - not needed + 0, // SiteName - not needed + uintptr(flags), // Flags + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + + if ret != 0 { + switch ret { + case 1355: // ERROR_NO_SUCH_DOMAIN + Log(context.Background(), logger.Debug(), + "Domain not found: %s (%d)", domainName, ret) + case 1311: // ERROR_NO_LOGON_SERVERS + Log(context.Background(), logger.Debug(), + "No logon servers available for domain: %s (%d)", domainName, ret) + case 1004: // ERROR_DC_NOT_FOUND + Log(context.Background(), logger.Debug(), + "Domain controller not found for domain: %s (%d)", domainName, ret) + case 1722: // RPC_S_SERVER_UNAVAILABLE + Log(context.Background(), logger.Debug(), + "RPC server unavailable for domain: %s (%d)", domainName, ret) + default: + Log(context.Background(), logger.Debug(), + "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + } + } else if info != nil { + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) + + // Get DC address + if info.DomainControllerAddress != nil { + dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) + dcAddr = strings.TrimPrefix(dcAddr, "\\\\") + + Log(context.Background(), logger.Debug(), + "Found domain controller address: %s", dcAddr) + + // Try to resolve DC + if ip := net.ParseIP(dcAddr); ip != nil { + dcServers = append(dcServers, ip.String()) + Log(context.Background(), logger.Debug(), + "Added domain controller DNS servers: %v", dcServers) + } + } else { + Log(context.Background(), logger.Debug(), + "No domain controller address found") + } + } + } + + } + } + + // Continue with existing adapter DNS collection ns := make([]string, 0, len(aas)*2) seen := make(map[string]bool) addressMap := make(map[string]struct{}) + + // Collect all local IPs for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + Log(context.Background(), logger.Debug(), + "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + continue + } + + Log(context.Background(), logger.Debug(), + "Processing adapter %s", aa.FriendlyName()) + for a := aa.FirstUnicastAddress; a != nil; a = a.Next { - addressMap[a.Address.IP().String()] = struct{}{} + ip := a.Address.IP().String() + addressMap[ip] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } + + // Collect DNS servers for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + continue + } + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() - if ip == nil || ip.IsLoopback() || seen[ip.String()] { + if ip == nil { + Log(context.Background(), logger.Debug(), + "Skipping nil IP from adapter %s", aa.FriendlyName()) continue } - if _, ok := addressMap[ip.String()]; ok { + + ipStr := ip.String() + logger := logger.Debug(). + Str("ip", ipStr). + Str("adapter", aa.FriendlyName()) + + if ip.IsLoopback() { + logger.Msg("Skipping loopback IP") continue } - seen[ip.String()] = true - ns = append(ns, ip.String()) + + if seen[ipStr] { + logger.Msg("Skipping duplicate IP") + continue + } + + if _, ok := addressMap[ipStr]; ok { + logger.Msg("Skipping local interface IP") + continue + } + + seen[ipStr] = true + ns = append(ns, ipStr) + logger.Msg("Added DNS server") } } - return ns + + // Add DC servers if they're not already in the list + for _, dcServer := range dcServers { + if !seen[dcServer] { + seen[dcServer] = true + ns = append(ns, dcServer) + Log(context.Background(), logger.Debug(), + "Added additional domain controller DNS server: %s", dcServer) + } + } + + if len(ns) == 0 { + return nil, fmt.Errorf("no valid DNS servers found") + } + + Log(context.Background(), logger.Debug(), + "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", + len(ns), ns, len(dcServers)) + return ns, nil } func nameserversFromResolvconf() []string { return nil } + +// checkDomainJoined checks if the machine is joined to an Active Directory domain +// Returns whether it's domain joined and the domain name if available +func checkDomainJoined() bool { + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + var domain *uint16 + var status uint32 + + err := windows.NetGetJoinInformation(nil, &domain, &status) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get domain join status: %v", err) + return false + } + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + + domainName := windows.UTF16PtrToString(domain) + Log(context.Background(), logger.Debug(), + "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status) + + // Consider both traditional and cloud domains as valid domain joins + isDomain := status == NetSetupDomain || status == NetSetupCloudDomain + Log(context.Background(), logger.Debug(), + "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", + status, + status == NetSetupDomain, + status == NetSetupCloudDomain, + isDomain) + + return isDomain +} + +// Win32_ComputerSystem is the minimal struct for WMI query +type Win32_ComputerSystem struct { + Domain string +} + +// getLocalADDomain tries to detect the AD domain in two ways: +// 1) USERDNSDOMAIN env var (often set in AD logon sessions) +// 2) WMI Win32_ComputerSystem.Domain +func getLocalADDomain() (string, error) { + // 1) Check environment variable + envDomain := os.Getenv("USERDNSDOMAIN") + if envDomain != "" { + return strings.TrimSpace(envDomain), nil + } + + // 2) Check WMI (requires Windows + admin privileges or sufficient access) + var result []Win32_ComputerSystem + err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + if err != nil { + return "", fmt.Errorf("WMI query failed: %v", err) + } + if len(result) == 0 { + return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") + } + + domain := strings.TrimSpace(result[0].Domain) + if domain == "" { + return "", fmt.Errorf("machine does not appear to have a domain set") + } + return domain, nil +} diff --git a/resolver.go b/resolver.go index 7dc76b0..e82b763 100644 --- a/resolver.go +++ b/resolver.go @@ -11,7 +11,9 @@ import ( "sync" "sync/atomic" "time" + "io" + "github.com/rs/zerolog" "github.com/miekg/dns" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" @@ -83,15 +85,40 @@ func availableNameservers() []string { // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - machineIPsMap[v.String()] = struct{}{} + + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } - for _, ns := range nameservers() { + Log(context.Background(), logger.Debug(), + "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + machineIPsMap[ipStr] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP to OS resolverexclusion map: %s", ipStr) + } + + systemNameservers := nameservers() + Log(context.Background(), logger.Debug(), + "Got system nameservers: %v", systemNameservers) + + for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { + Log(context.Background(), logger.Debug(), + "Skipping local nameserver: %s", ns) continue } nss = append(nss, ns) + Log(context.Background(), logger.Debug(), + "Added non-local nameserver: %s", ns) } + + Log(context.Background(), logger.Debug(), + "Final available nameservers: %v", nss) return nss } @@ -138,156 +165,159 @@ func initializeOsResolver(servers []string) []string { } or.publicServers.Store(&publicNss) - // Test servers in background and remove failures - go func() { - // Test servers in parallel but maintain order - type result struct { - index int - server string - valid bool - } + // no longer testing servers in the background + // if DCHP nameservers are not working, this is outside of our control - testServers := func(servers []string) []string { - if len(servers) == 0 { - return nil - } + // // Test servers in background and remove failures + // go func() { + // // Test servers in parallel but maintain order + // type result struct { + // index int + // server string + // valid bool + // } - results := make(chan result, len(servers)) - var wg sync.WaitGroup + // testServers := func(servers []string) []string { + // if len(servers) == 0 { + // return nil + // } - for i, server := range servers { - wg.Add(1) - go func(idx int, s string) { - defer wg.Done() - results <- result{ - index: idx, - server: s, - valid: testNameServerFn(s), - } - }(i, server) - } + // results := make(chan result, len(servers)) + // var wg sync.WaitGroup - go func() { - wg.Wait() - close(results) - }() + // for i, server := range servers { + // wg.Add(1) + // go func(idx int, s string) { + // defer wg.Done() + // results <- result{ + // index: idx, + // server: s, + // valid: testNameServerFn(s), + // } + // }(i, server) + // } - // Collect results maintaining original order - validServers := make([]string, 0, len(servers)) - ordered := make([]result, 0, len(servers)) - for r := range results { - ordered = append(ordered, r) - } - slices.SortFunc(ordered, func(a, b result) int { - return a.index - b.index - }) - for _, r := range ordered { - if r.valid { - validServers = append(validServers, r.server) - } else { - ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") - } - } - return validServers - } + // go func() { + // wg.Wait() + // close(results) + // }() - // Test and update LAN servers - if validLanNss := testServers(lanNss); len(validLanNss) > 0 { - or.lanServers.Store(&validLanNss) - } + // // Collect results maintaining original order + // validServers := make([]string, 0, len(servers)) + // ordered := make([]result, 0, len(servers)) + // for r := range results { + // ordered = append(ordered, r) + // } + // slices.SortFunc(ordered, func(a, b result) int { + // return a.index - b.index + // }) + // for _, r := range ordered { + // if r.valid { + // validServers = append(validServers, r.server) + // } else { + // ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") + // } + // } + // return validServers + // } - // Test and update public servers - validPublicNss := testServers(publicNss) - if len(validPublicNss) == 0 { - validPublicNss = []string{controldPublicDnsWithPort} - } - or.publicServers.Store(&validPublicNss) - }() + // // Test and update LAN servers + // if validLanNss := testServers(lanNss); len(validLanNss) > 0 { + // or.lanServers.Store(&validLanNss) + // } + + // // Test and update public servers + // validPublicNss := testServers(publicNss) + // if len(validPublicNss) == 0 { + // validPublicNss = []string{controldPublicDnsWithPort} + // } + // or.publicServers.Store(&validPublicNss) + // }() return slices.Concat(lanNss, publicNss) } -// testNameserverFn sends a test query to DNS nameserver to check if the server is available. -var testNameServerFn = testNameserver +// // testNameserverFn sends a test query to DNS nameserver to check if the server is available. +// var testNameServerFn = testNameserver -// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. -func testNameserver(addr string) bool { - // Skip link-local addresses without scope IDs and deprecated site-local addresses - if ip, err := netip.ParseAddr(addr); err == nil { - if ip.Is6() { - if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { - ProxyLogger.Load().Debug(). - Str("nameserver", addr). - Msg("skipping link-local IPv6 address without scope ID") - return false - } - // Skip deprecated site-local addresses (fec0::/10) - if strings.HasPrefix(ip.String(), "fec0:") { - ProxyLogger.Load().Debug(). - Str("nameserver", addr). - Msg("skipping deprecated site-local IPv6 address") - return false - } - } - } +// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. +// func testNameserver(addr string) bool { +// // Skip link-local addresses without scope IDs and deprecated site-local addresses +// if ip, err := netip.ParseAddr(addr); err == nil { +// if ip.Is6() { +// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { +// ProxyLogger.Load().Debug(). +// Str("nameserver", addr). +// Msg("skipping link-local IPv6 address without scope ID") +// return false +// } +// // Skip deprecated site-local addresses (fec0::/10) +// if strings.HasPrefix(ip.String(), "fec0:") { +// ProxyLogger.Load().Debug(). +// Str("nameserver", addr). +// Msg("skipping deprecated site-local IPv6 address") +// return false +// } +// } +// } - ProxyLogger.Load().Debug(). - Str("input_addr", addr). - Msg("testing nameserver") +// ProxyLogger.Load().Debug(). +// Str("input_addr", addr). +// Msg("testing nameserver") - // Handle both IPv4 and IPv6 addresses - serverAddr := addr - host, port, err := net.SplitHostPort(addr) - if err != nil { - // No port in address, add default port 53 - serverAddr = net.JoinHostPort(addr, "53") - } else if port == "" { - // Has split markers but empty port - serverAddr = net.JoinHostPort(host, "53") - } +// // Handle both IPv4 and IPv6 addresses +// serverAddr := addr +// host, port, err := net.SplitHostPort(addr) +// if err != nil { +// // No port in address, add default port 53 +// serverAddr = net.JoinHostPort(addr, "53") +// } else if port == "" { +// // Has split markers but empty port +// serverAddr = net.JoinHostPort(host, "53") +// } - ProxyLogger.Load().Debug(). - Str("server_addr", serverAddr). - Msg("using server address") +// ProxyLogger.Load().Debug(). +// Str("server_addr", serverAddr). +// Msg("using server address") - // Test domains that are likely to exist and respond quickly - testDomains := []struct { - name string - qtype uint16 - }{ - {".", dns.TypeNS}, // Root NS query - should always work - {"controld.com.", dns.TypeA}, // Fallback to a reliable domain - } +// // Test domains that are likely to exist and respond quickly +// testDomains := []struct { +// name string +// qtype uint16 +// }{ +// {".", dns.TypeNS}, // Root NS query - should always work +// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain +// } - client := &dns.Client{ - Timeout: 2 * time.Second, - Net: "udp", - } +// client := &dns.Client{ +// Timeout: 2 * time.Second, +// Net: "udp", +// } - // Try each test query until one succeeds - for _, test := range testDomains { - msg := new(dns.Msg) - msg.SetQuestion(test.name, test.qtype) - msg.RecursionDesired = true +// // Try each test query until one succeeds +// for _, test := range testDomains { +// msg := new(dns.Msg) +// msg.SetQuestion(test.name, test.qtype) +// msg.RecursionDesired = true - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) - cancel() +// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) +// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) +// cancel() - if err == nil && resp != nil { - return true - } +// if err == nil && resp != nil { +// return true +// } - ProxyLogger.Load().Error(). - Err(err). - Str("nameserver", serverAddr). - Str("test_domain", test.name). - Str("query_type", dns.TypeToString[test.qtype]). - Msg("DNS availability test failed") - } +// ProxyLogger.Load().Error(). +// Err(err). +// Str("nameserver", serverAddr). +// Str("test_domain", test.name). +// Str("query_type", dns.TypeToString[test.qtype]). +// Msg("DNS availability test failed") +// } - return false -} +// return false +// } // Resolver is the interface that wraps the basic DNS operations. //