much more debugging, improved nameserver detection, no more testing nameservers

fix logging

fix logging

try to enable nameserver logs

try to enable nameserver logs

handle flags in interface state changes

debugging

debugging

debugging

fix state detection, AD status fix

fix debugging line

more dc info

always log state changes

remove unused method

windows AD IP discovery

windows AD IP discovery

windows AD IP discovery
This commit is contained in:
Alex
2025-01-25 01:26:48 -05:00
committed by Cuong Manh Le
parent 0fbfd160c9
commit ce3281e70d
5 changed files with 538 additions and 158 deletions

View File

@@ -450,6 +450,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
}
if p.isAdDomainQuery(req.msg) {
ctrld.Log(ctx, mainLog.Load().Debug(),
"AD domain query detected for %s in domain %s",
req.msg.Question[0].Name, p.adDomain)
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
upstreams = []string{upstreamOS}
}
@@ -566,14 +569,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
if upstreamConfig == nil {
continue
}
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n])
logger := mainLog.Load().Debug().
Str("upstream", upstreamConfig.String()).
Str("query", req.msg.Question[0].Name).
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
Bool("is_lan_query", isLanOrPtrQuery)
if p.isLoop(upstreamConfig) {
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String())
logger.Msg("DNS loop detected")
continue
}
if p.um.isDown(upstreams[n]) {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s is down", upstreams[n])
logger.
Bool("is_os_resolver", upstreams[n] == upstreamOS).
Msg("Upstream is down")
continue
}
answer := resolve(n, upstreamConfig, req.msg)
@@ -1257,7 +1266,6 @@ func (p *prog) reinitializeOSResolver() {
// monitorNetworkChanges starts monitoring for network interface changes
func (p *prog) monitorNetworkChanges() error {
// Create network monitor
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
if err != nil {
return fmt.Errorf("creating network monitor: %w", err)
@@ -1267,6 +1275,12 @@ func (p *prog) monitorNetworkChanges() error {
// Get map of valid interfaces
validIfaces := validInterfacesMap()
// log the delta for debugging
mainLog.Load().Debug().
Interface("old_state", delta.Old).
Interface("new_state", delta.New).
Msg("Network change detected")
// Parse old and new interface states
oldIfs := parseInterfaceState(delta.Old)
newIfs := parseInterfaceState(delta.New)
@@ -1276,14 +1290,14 @@ func (p *prog) monitorNetworkChanges() error {
activeInterfaceExists := false
for ifaceName := range validIfaces {
oldState, oldExists := oldIfs[strings.ToLower(ifaceName)]
newState, newExists := newIfs[strings.ToLower(ifaceName)]
if newState != "" && newState != "down" {
if newState != "" && !strings.Contains(newState, "down") {
activeInterfaceExists = true
}
// Compare states directly
if oldExists != newExists || oldState != newState {
changed = true
mainLog.Load().Debug().
@@ -1302,11 +1316,10 @@ func (p *prog) monitorNetworkChanges() error {
}
if !changed {
mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected")
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
return
}
mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New)
if activeInterfaceExists {
p.reinitializeOSResolver()
} else {
@@ -1326,9 +1339,10 @@ func parseInterfaceState(state *netmon.State) map[string]string {
}
result := make(map[string]string)
// Extract ifs={...} section
stateStr := state.String()
// Extract interface information
ifsStart := strings.Index(stateStr, "ifs={")
if ifsStart == -1 {
return result
@@ -1340,17 +1354,28 @@ func parseInterfaceState(state *netmon.State) map[string]string {
return result
}
// Parse each interface entry
ifaces := strings.Split(ifsStr[:ifsEnd], " ")
for _, iface := range ifaces {
parts := strings.Split(iface, ":")
// Get the content between ifs={ }
ifsContent := strings.TrimSpace(ifsStr[:ifsEnd])
// Split on "] " to get each interface entry
entries := strings.Split(ifsContent, "] ")
for _, entry := range entries {
if entry == "" {
continue
}
// Split on ":["
parts := strings.Split(entry, ":[")
if len(parts) != 2 {
continue
}
name := strings.ToLower(parts[0])
state := parts[1]
result[name] = state
name := strings.TrimSpace(parts[0])
state := "[" + strings.TrimSuffix(parts[1], "]") + "]"
result[strings.ToLower(name)] = state
}
return result
}
}

1
go.mod
View File

@@ -45,6 +45,7 @@ require (
require (
aead.dev/minisign v0.2.0 // indirect
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect

4
go.sum
View File

@@ -42,6 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
@@ -93,6 +95,7 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
@@ -449,6 +452,7 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@@ -1,44 +1,364 @@
package ctrld
import (
"context"
"fmt"
"net"
"strings"
"syscall"
"time"
"unsafe"
"io"
"os"
"github.com/rs/zerolog"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/StackExchange/wmi"
)
const (
maxRetries = 3
retryDelay = 500 * time.Millisecond
defaultTimeout = 5 * time.Second
minDNSServers = 1 // Minimum number of DNS servers we want to find
NetSetupUnknown uint32 = 0
NetSetupWorkgroup uint32 = 1
NetSetupDomain uint32 = 2
NetSetupCloudDomain uint32 = 3
DS_FORCE_REDISCOVERY = 0x00000001
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
DS_BACKGROUND_ONLY = 0x00000100
DS_IP_REQUIRED = 0x00000200
DS_IS_DNS_NAME = 0x00020000
DS_RETURN_DNS_NAME = 0x40000000
)
type DomainControllerInfo struct {
DomainControllerName *uint16
DomainControllerAddress *uint16
DomainControllerAddressType uint32
DomainGuid windows.GUID
DomainName *uint16
DnsForestName *uint16
Flags uint32
DcSiteName *uint16
ClientSiteName *uint16
}
func dnsFns() []dnsFn {
return []dnsFn{dnsFromAdapter}
}
func dnsFromAdapter() []string {
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
if err != nil {
return nil
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
var ns []string
var err error
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
for i := 0; i < maxRetries; i++ {
if ctx.Err() != nil {
Log(context.Background(), logger.Debug(),
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
return nil
}
ns, err = getDNSServers(ctx)
if err == nil && len(ns) >= minDNSServers {
if i > 0 {
Log(context.Background(), logger.Debug(),
"Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
}
return ns
}
// Log the specific failure reason
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get DNS servers, attempt %d: %v", i+1, err)
} else {
Log(context.Background(), logger.Debug(),
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
}
select {
case <-ctx.Done():
return nil
case <-time.After(retryDelay):
}
}
Log(context.Background(), logger.Debug(),
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries)
return ns // Return whatever we got, even if insufficient
}
func getDNSServers(ctx context.Context) ([]string, error) {
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
// Check context before making the call
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Get DNS servers from adapters (existing method)
flags := winipcfg.GAAFlagIncludeGateways |
winipcfg.GAAFlagIncludePrefix
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
if err != nil {
return nil, fmt.Errorf("getting adapters: %w", err)
}
Log(context.Background(), logger.Debug(),
"Found network adapters, count=%d", len(aas))
// Try to get domain controller info if domain-joined
var dcServers []string
isDomain := checkDomainJoined()
if isDomain {
domainName, err := getLocalADDomain()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get local AD domain: %v", err)
} else {
// Load netapi32.dll
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
dsDcName := netapi32.NewProc("DsGetDcNameW")
var info *DomainControllerInfo
flags := uint32(DS_RETURN_DNS_NAME |
DS_IP_REQUIRED |
DS_IS_DNS_NAME)
// Convert domain name to UTF16 pointer
domainUTF16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to convert domain name to UTF16: %v", err)
} else {
Log(context.Background(), logger.Debug(),
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
// Call DsGetDcNameW with domain name
ret, _, err := dsDcName.Call(
0, // ComputerName - can be NULL
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
0, // DomainGuid - not needed
0, // SiteName - not needed
uintptr(flags), // Flags
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
if ret != 0 {
switch ret {
case 1355: // ERROR_NO_SUCH_DOMAIN
Log(context.Background(), logger.Debug(),
"Domain not found: %s (%d)", domainName, ret)
case 1311: // ERROR_NO_LOGON_SERVERS
Log(context.Background(), logger.Debug(),
"No logon servers available for domain: %s (%d)", domainName, ret)
case 1004: // ERROR_DC_NOT_FOUND
Log(context.Background(), logger.Debug(),
"Domain controller not found for domain: %s (%d)", domainName, ret)
case 1722: // RPC_S_SERVER_UNAVAILABLE
Log(context.Background(), logger.Debug(),
"RPC server unavailable for domain: %s (%d)", domainName, ret)
default:
Log(context.Background(), logger.Debug(),
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
}
} else if info != nil {
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
// Get DC address
if info.DomainControllerAddress != nil {
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
Log(context.Background(), logger.Debug(),
"Found domain controller address: %s", dcAddr)
// Try to resolve DC
if ip := net.ParseIP(dcAddr); ip != nil {
dcServers = append(dcServers, ip.String())
Log(context.Background(), logger.Debug(),
"Added domain controller DNS servers: %v", dcServers)
}
} else {
Log(context.Background(), logger.Debug(),
"No domain controller address found")
}
}
}
}
}
// Continue with existing adapter DNS collection
ns := make([]string, 0, len(aas)*2)
seen := make(map[string]bool)
addressMap := make(map[string]struct{})
// Collect all local IPs
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
Log(context.Background(), logger.Debug(),
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
continue
}
Log(context.Background(), logger.Debug(),
"Processing adapter %s", aa.FriendlyName())
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
addressMap[a.Address.IP().String()] = struct{}{}
ip := a.Address.IP().String()
addressMap[ip] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
}
}
// Collect DNS servers
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
continue
}
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
ip := dns.Address.IP()
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
if ip == nil {
Log(context.Background(), logger.Debug(),
"Skipping nil IP from adapter %s", aa.FriendlyName())
continue
}
if _, ok := addressMap[ip.String()]; ok {
ipStr := ip.String()
logger := logger.Debug().
Str("ip", ipStr).
Str("adapter", aa.FriendlyName())
if ip.IsLoopback() {
logger.Msg("Skipping loopback IP")
continue
}
seen[ip.String()] = true
ns = append(ns, ip.String())
if seen[ipStr] {
logger.Msg("Skipping duplicate IP")
continue
}
if _, ok := addressMap[ipStr]; ok {
logger.Msg("Skipping local interface IP")
continue
}
seen[ipStr] = true
ns = append(ns, ipStr)
logger.Msg("Added DNS server")
}
}
return ns
// Add DC servers if they're not already in the list
for _, dcServer := range dcServers {
if !seen[dcServer] {
seen[dcServer] = true
ns = append(ns, dcServer)
Log(context.Background(), logger.Debug(),
"Added additional domain controller DNS server: %s", dcServer)
}
}
if len(ns) == 0 {
return nil, fmt.Errorf("no valid DNS servers found")
}
Log(context.Background(), logger.Debug(),
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
len(ns), ns, len(dcServers))
return ns, nil
}
func nameserversFromResolvconf() []string {
return nil
}
// checkDomainJoined checks if the machine is joined to an Active Directory domain
// Returns whether it's domain joined and the domain name if available
func checkDomainJoined() bool {
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
var domain *uint16
var status uint32
err := windows.NetGetJoinInformation(nil, &domain, &status)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get domain join status: %v", err)
return false
}
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
domainName := windows.UTF16PtrToString(domain)
Log(context.Background(), logger.Debug(),
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status)
// Consider both traditional and cloud domains as valid domain joins
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
Log(context.Background(), logger.Debug(),
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
status,
status == NetSetupDomain,
status == NetSetupCloudDomain,
isDomain)
return isDomain
}
// Win32_ComputerSystem is the minimal struct for WMI query
type Win32_ComputerSystem struct {
Domain string
}
// getLocalADDomain tries to detect the AD domain in two ways:
// 1) USERDNSDOMAIN env var (often set in AD logon sessions)
// 2) WMI Win32_ComputerSystem.Domain
func getLocalADDomain() (string, error) {
// 1) Check environment variable
envDomain := os.Getenv("USERDNSDOMAIN")
if envDomain != "" {
return strings.TrimSpace(envDomain), nil
}
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
var result []Win32_ComputerSystem
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
if err != nil {
return "", fmt.Errorf("WMI query failed: %v", err)
}
if len(result) == 0 {
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
}
domain := strings.TrimSpace(result[0].Domain)
if domain == "" {
return "", fmt.Errorf("machine does not appear to have a domain set")
}
return domain, nil
}

View File

@@ -11,7 +11,9 @@ import (
"sync"
"sync/atomic"
"time"
"io"
"github.com/rs/zerolog"
"github.com/miekg/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
@@ -83,15 +85,40 @@ func availableNameservers() []string {
// Ignore local addresses to prevent loop.
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
machineIPsMap := make(map[string]struct{}, len(regularIPs))
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
machineIPsMap[v.String()] = struct{}{}
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
for _, ns := range nameservers() {
Log(context.Background(), logger.Debug(),
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
ipStr := v.String()
machineIPsMap[ipStr] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP to OS resolverexclusion map: %s", ipStr)
}
systemNameservers := nameservers()
Log(context.Background(), logger.Debug(),
"Got system nameservers: %v", systemNameservers)
for _, ns := range systemNameservers {
if _, ok := machineIPsMap[ns]; ok {
Log(context.Background(), logger.Debug(),
"Skipping local nameserver: %s", ns)
continue
}
nss = append(nss, ns)
Log(context.Background(), logger.Debug(),
"Added non-local nameserver: %s", ns)
}
Log(context.Background(), logger.Debug(),
"Final available nameservers: %v", nss)
return nss
}
@@ -138,156 +165,159 @@ func initializeOsResolver(servers []string) []string {
}
or.publicServers.Store(&publicNss)
// Test servers in background and remove failures
go func() {
// Test servers in parallel but maintain order
type result struct {
index int
server string
valid bool
}
// no longer testing servers in the background
// if DCHP nameservers are not working, this is outside of our control
testServers := func(servers []string) []string {
if len(servers) == 0 {
return nil
}
// // Test servers in background and remove failures
// go func() {
// // Test servers in parallel but maintain order
// type result struct {
// index int
// server string
// valid bool
// }
results := make(chan result, len(servers))
var wg sync.WaitGroup
// testServers := func(servers []string) []string {
// if len(servers) == 0 {
// return nil
// }
for i, server := range servers {
wg.Add(1)
go func(idx int, s string) {
defer wg.Done()
results <- result{
index: idx,
server: s,
valid: testNameServerFn(s),
}
}(i, server)
}
// results := make(chan result, len(servers))
// var wg sync.WaitGroup
go func() {
wg.Wait()
close(results)
}()
// for i, server := range servers {
// wg.Add(1)
// go func(idx int, s string) {
// defer wg.Done()
// results <- result{
// index: idx,
// server: s,
// valid: testNameServerFn(s),
// }
// }(i, server)
// }
// Collect results maintaining original order
validServers := make([]string, 0, len(servers))
ordered := make([]result, 0, len(servers))
for r := range results {
ordered = append(ordered, r)
}
slices.SortFunc(ordered, func(a, b result) int {
return a.index - b.index
})
for _, r := range ordered {
if r.valid {
validServers = append(validServers, r.server)
} else {
ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
}
}
return validServers
}
// go func() {
// wg.Wait()
// close(results)
// }()
// Test and update LAN servers
if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
or.lanServers.Store(&validLanNss)
}
// // Collect results maintaining original order
// validServers := make([]string, 0, len(servers))
// ordered := make([]result, 0, len(servers))
// for r := range results {
// ordered = append(ordered, r)
// }
// slices.SortFunc(ordered, func(a, b result) int {
// return a.index - b.index
// })
// for _, r := range ordered {
// if r.valid {
// validServers = append(validServers, r.server)
// } else {
// ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
// }
// }
// return validServers
// }
// Test and update public servers
validPublicNss := testServers(publicNss)
if len(validPublicNss) == 0 {
validPublicNss = []string{controldPublicDnsWithPort}
}
or.publicServers.Store(&validPublicNss)
}()
// // Test and update LAN servers
// if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
// or.lanServers.Store(&validLanNss)
// }
// // Test and update public servers
// validPublicNss := testServers(publicNss)
// if len(validPublicNss) == 0 {
// validPublicNss = []string{controldPublicDnsWithPort}
// }
// or.publicServers.Store(&validPublicNss)
// }()
return slices.Concat(lanNss, publicNss)
}
// testNameserverFn sends a test query to DNS nameserver to check if the server is available.
var testNameServerFn = testNameserver
// // testNameserverFn sends a test query to DNS nameserver to check if the server is available.
// var testNameServerFn = testNameserver
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
func testNameserver(addr string) bool {
// Skip link-local addresses without scope IDs and deprecated site-local addresses
if ip, err := netip.ParseAddr(addr); err == nil {
if ip.Is6() {
if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
ProxyLogger.Load().Debug().
Str("nameserver", addr).
Msg("skipping link-local IPv6 address without scope ID")
return false
}
// Skip deprecated site-local addresses (fec0::/10)
if strings.HasPrefix(ip.String(), "fec0:") {
ProxyLogger.Load().Debug().
Str("nameserver", addr).
Msg("skipping deprecated site-local IPv6 address")
return false
}
}
}
// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
// func testNameserver(addr string) bool {
// // Skip link-local addresses without scope IDs and deprecated site-local addresses
// if ip, err := netip.ParseAddr(addr); err == nil {
// if ip.Is6() {
// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
// ProxyLogger.Load().Debug().
// Str("nameserver", addr).
// Msg("skipping link-local IPv6 address without scope ID")
// return false
// }
// // Skip deprecated site-local addresses (fec0::/10)
// if strings.HasPrefix(ip.String(), "fec0:") {
// ProxyLogger.Load().Debug().
// Str("nameserver", addr).
// Msg("skipping deprecated site-local IPv6 address")
// return false
// }
// }
// }
ProxyLogger.Load().Debug().
Str("input_addr", addr).
Msg("testing nameserver")
// ProxyLogger.Load().Debug().
// Str("input_addr", addr).
// Msg("testing nameserver")
// Handle both IPv4 and IPv6 addresses
serverAddr := addr
host, port, err := net.SplitHostPort(addr)
if err != nil {
// No port in address, add default port 53
serverAddr = net.JoinHostPort(addr, "53")
} else if port == "" {
// Has split markers but empty port
serverAddr = net.JoinHostPort(host, "53")
}
// // Handle both IPv4 and IPv6 addresses
// serverAddr := addr
// host, port, err := net.SplitHostPort(addr)
// if err != nil {
// // No port in address, add default port 53
// serverAddr = net.JoinHostPort(addr, "53")
// } else if port == "" {
// // Has split markers but empty port
// serverAddr = net.JoinHostPort(host, "53")
// }
ProxyLogger.Load().Debug().
Str("server_addr", serverAddr).
Msg("using server address")
// ProxyLogger.Load().Debug().
// Str("server_addr", serverAddr).
// Msg("using server address")
// Test domains that are likely to exist and respond quickly
testDomains := []struct {
name string
qtype uint16
}{
{".", dns.TypeNS}, // Root NS query - should always work
{"controld.com.", dns.TypeA}, // Fallback to a reliable domain
}
// // Test domains that are likely to exist and respond quickly
// testDomains := []struct {
// name string
// qtype uint16
// }{
// {".", dns.TypeNS}, // Root NS query - should always work
// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain
// }
client := &dns.Client{
Timeout: 2 * time.Second,
Net: "udp",
}
// client := &dns.Client{
// Timeout: 2 * time.Second,
// Net: "udp",
// }
// Try each test query until one succeeds
for _, test := range testDomains {
msg := new(dns.Msg)
msg.SetQuestion(test.name, test.qtype)
msg.RecursionDesired = true
// // Try each test query until one succeeds
// for _, test := range testDomains {
// msg := new(dns.Msg)
// msg.SetQuestion(test.name, test.qtype)
// msg.RecursionDesired = true
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
cancel()
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
// cancel()
if err == nil && resp != nil {
return true
}
// if err == nil && resp != nil {
// return true
// }
ProxyLogger.Load().Error().
Err(err).
Str("nameserver", serverAddr).
Str("test_domain", test.name).
Str("query_type", dns.TypeToString[test.qtype]).
Msg("DNS availability test failed")
}
// ProxyLogger.Load().Error().
// Err(err).
// Str("nameserver", serverAddr).
// Str("test_domain", test.name).
// Str("query_type", dns.TypeToString[test.qtype]).
// Msg("DNS availability test failed")
// }
return false
}
// return false
// }
// Resolver is the interface that wraps the basic DNS operations.
//