mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
remove leaking timeout, fix blocking upstreams checks, leaking is per listener, OS resolvers are tested in parallel, reset is only done is os is down
fix test use upstreamIS var init map, fix watcher flag attempt to detect network changes attempt to detect network changes cancel and rerun reinitializeOSResolver cancel and rerun reinitializeOSResolver cancel and rerun reinitializeOSResolver ignore invalid inferaces ignore invalid inferaces allow OS resolver upstream to fail dont wait for dnsWait group on reinit, check for active interfaces to trigger reinit fix unused var simpler active iface check, debug logs dont spam network service name patching on Mac dont wait for os resolver nameserver testing remove test for osresovlers for now async nameserver testing remove unused test
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/logger"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
@@ -77,6 +78,12 @@ type upstreamForResult struct {
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
// Start network monitoring
|
||||
if err := p.monitorNetworkChanges(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
|
||||
// Don't return here as we still want DNS service to run
|
||||
}
|
||||
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
@@ -418,11 +425,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
|
||||
upstreamMapKey := strings.Join(upstreams, "_")
|
||||
|
||||
leaked := false
|
||||
if len(upstreamConfigs) > 0 && p.leakingQuery.Load() {
|
||||
upstreamConfigs = nil
|
||||
leaked = true
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
|
||||
if len(upstreamConfigs) > 0 {
|
||||
p.leakingQueryMu.Lock()
|
||||
if p.leakingQueryRunning[upstreamMapKey] {
|
||||
upstreamConfigs = nil
|
||||
leaked = true
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
|
||||
}
|
||||
p.leakingQueryMu.Unlock()
|
||||
}
|
||||
|
||||
if len(upstreamConfigs) == 0 {
|
||||
@@ -601,9 +614,15 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
if p.leakOnUpstreamFailure() {
|
||||
p.leakingQueryMu.Lock()
|
||||
if !p.leakingQueryWasRun {
|
||||
p.leakingQueryWasRun = true
|
||||
go p.performLeakingQuery()
|
||||
// get the map key as concact of upstreams
|
||||
if !p.leakingQueryRunning[upstreamMapKey] {
|
||||
p.leakingQueryRunning[upstreamMapKey] = true
|
||||
// get a map of the failed upstreams
|
||||
failedUpstreams := make(map[string]*ctrld.UpstreamConfig)
|
||||
for n, upstream := range upstreamConfigs {
|
||||
failedUpstreams[upstreams[n]] = upstream
|
||||
}
|
||||
go p.performLeakingQuery(failedUpstreams, upstreamMapKey)
|
||||
}
|
||||
p.leakingQueryMu.Unlock()
|
||||
}
|
||||
@@ -929,95 +948,66 @@ func (p *prog) selfUninstallCoolOfPeriod() {
|
||||
}
|
||||
|
||||
// performLeakingQuery performs necessary works to leak queries to OS resolver.
|
||||
func (p *prog) performLeakingQuery() {
|
||||
mainLog.Load().Warn().Msg("leaking query to OS resolver")
|
||||
// once we store the leakingQuery flag, we are leaking queries to OS resolver
|
||||
// we then start testing all the upstreams forever, waiting for success, but in parallel
|
||||
func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) {
|
||||
|
||||
// Create a context with timeout for the entire operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams)
|
||||
|
||||
// Signal dns watchers to stop, so changes made below won't be reverted.
|
||||
p.leakingQuery.Store(true)
|
||||
p.leakingQueryMu.Lock()
|
||||
p.leakingQueryRunning[upstreamMapKey] = true
|
||||
p.leakingQueryMu.Unlock()
|
||||
defer func() {
|
||||
p.leakingQuery.Store(false)
|
||||
p.leakingQueryMu.Lock()
|
||||
p.leakingQueryWasRun = false
|
||||
p.leakingQueryRunning[upstreamMapKey] = false
|
||||
p.leakingQueryMu.Unlock()
|
||||
mainLog.Load().Warn().Msg("stop leaking query")
|
||||
}()
|
||||
|
||||
// Create channels to coordinate operations
|
||||
resetDone := make(chan struct{})
|
||||
checkDone := make(chan struct{})
|
||||
// we only want to reset DNS when our resolver is broken
|
||||
// this allows us to find the new OS resolver nameservers
|
||||
if p.um.isDown(upstreamOS) {
|
||||
|
||||
// Reset DNS with timeout
|
||||
go func() {
|
||||
defer close(resetDone)
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
p.resetDNS()
|
||||
mainLog.Load().Debug().Msg("DNS reset completed")
|
||||
}()
|
||||
mainLog.Load().Debug().Msg("OS resolver is down, reinitializing")
|
||||
p.reinitializeOSResolver()
|
||||
|
||||
// Wait for reset with timeout
|
||||
select {
|
||||
case <-resetDone:
|
||||
mainLog.Load().Debug().Msg("DNS reset successful")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("DNS reset timed out")
|
||||
return
|
||||
}
|
||||
|
||||
// Check upstream in background with progress tracking
|
||||
go func() {
|
||||
defer close(checkDone)
|
||||
mainLog.Load().Debug().Msg("starting upstream checks")
|
||||
for name, uc := range p.cfg.Upstream {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
mainLog.Load().Debug().
|
||||
Str("upstream", name).
|
||||
Msg("checking upstream")
|
||||
p.checkUpstream(name, uc)
|
||||
// Test all failed upstreams in parallel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
upstreamCh := make(chan string, len(failedUpstreams))
|
||||
for name, uc := range failedUpstreams {
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
mainLog.Load().Debug().
|
||||
Str("upstream", name).
|
||||
Msg("checking upstream")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
p.checkUpstream(name, uc)
|
||||
mainLog.Load().Debug().
|
||||
Str("upstream", name).
|
||||
Msg("upstream recovered")
|
||||
upstreamCh <- name
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("upstream checks completed")
|
||||
}()
|
||||
|
||||
// Wait for upstream checks
|
||||
select {
|
||||
case <-checkDone:
|
||||
mainLog.Load().Debug().Msg("upstream checks successful")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("upstream checks timed out")
|
||||
return
|
||||
}(name, uc)
|
||||
}
|
||||
|
||||
// Initialize OS resolver with timeout
|
||||
mainLog.Load().Debug().Msg("initializing OS resolver")
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
// Wait for any upstream to recover
|
||||
name := <-upstreamCh
|
||||
|
||||
// Wait for DNS operations to complete
|
||||
waitCh := make(chan struct{})
|
||||
go func() {
|
||||
p.dnsWg.Wait()
|
||||
close(waitCh)
|
||||
}()
|
||||
mainLog.Load().Info().
|
||||
Str("upstream", name).
|
||||
Msg("stopping leak as upstream recovered")
|
||||
|
||||
select {
|
||||
case <-waitCh:
|
||||
mainLog.Load().Debug().Msg("DNS operations completed")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("DNS operations timed out")
|
||||
return
|
||||
}
|
||||
|
||||
// Set DNS with timeout
|
||||
mainLog.Load().Debug().Msg("setting DNS configuration")
|
||||
p.setDNS()
|
||||
mainLog.Load().Debug().Msg("DNS configuration set successfully")
|
||||
}
|
||||
|
||||
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
||||
@@ -1190,3 +1180,157 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
|
||||
answer.SetReply(m)
|
||||
return answer
|
||||
}
|
||||
|
||||
// reinitializeOSResolver reinitializes the OS resolver
|
||||
// by removing ctrld listenr from the interface, collecting the network nameservers
|
||||
// and re-initializing the OS resolver with the nameservers
|
||||
// applying listener back to the interface
|
||||
func (p *prog) reinitializeOSResolver() {
|
||||
// Cancel any existing operations
|
||||
p.resetCtxMu.Lock()
|
||||
if p.resetCancel != nil {
|
||||
p.resetCancel()
|
||||
}
|
||||
|
||||
// Create new context for this operation
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.resetCtx = ctx
|
||||
p.resetCancel = cancel
|
||||
p.resetCtxMu.Unlock()
|
||||
|
||||
// Ensure cleanup
|
||||
defer cancel()
|
||||
|
||||
p.leakingQueryReset.Store(true)
|
||||
defer p.leakingQueryReset.Store(false)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
|
||||
return
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
p.resetDNS()
|
||||
mainLog.Load().Debug().Msg("DNS reset completed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
|
||||
return
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("initializing OS resolver")
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
|
||||
return
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("setting DNS configuration")
|
||||
p.setDNS()
|
||||
mainLog.Load().Debug().Msg("DNS configuration set successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// monitorNetworkChanges starts monitoring for network interface changes
|
||||
func (p *prog) monitorNetworkChanges() error {
|
||||
// Create network monitor
|
||||
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network monitor: %w", err)
|
||||
}
|
||||
|
||||
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
||||
// Get map of valid interfaces
|
||||
validIfaces := validInterfacesMap()
|
||||
|
||||
// Parse old and new interface states
|
||||
oldIfs := parseInterfaceState(delta.Old)
|
||||
newIfs := parseInterfaceState(delta.New)
|
||||
|
||||
// Check for changes in valid interfaces
|
||||
changed := false
|
||||
activeInterfaceExists := false
|
||||
|
||||
for ifaceName := range validIfaces {
|
||||
|
||||
oldState, oldExists := oldIfs[strings.ToLower(ifaceName)]
|
||||
newState, newExists := newIfs[strings.ToLower(ifaceName)]
|
||||
|
||||
if newState != "" && newState != "down" {
|
||||
activeInterfaceExists = true
|
||||
}
|
||||
|
||||
if oldExists != newExists || oldState != newState {
|
||||
changed = true
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Str("old_state", oldState).
|
||||
Str("new_state", newState).
|
||||
Msg("Valid interface changed state")
|
||||
break
|
||||
} else {
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Str("old_state", oldState).
|
||||
Str("new_state", newState).
|
||||
Msg("Valid interface unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected")
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New)
|
||||
if activeInterfaceExists {
|
||||
p.reinitializeOSResolver()
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
|
||||
}
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
mainLog.Load().Debug().Msg("Network monitor started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseInterfaceState parses the interface state string into a map of interface name -> state
|
||||
func parseInterfaceState(state *netmon.State) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
|
||||
// Extract ifs={...} section
|
||||
stateStr := state.String()
|
||||
ifsStart := strings.Index(stateStr, "ifs={")
|
||||
if ifsStart == -1 {
|
||||
return result
|
||||
}
|
||||
|
||||
ifsStr := stateStr[ifsStart+5:]
|
||||
ifsEnd := strings.Index(ifsStr, "}")
|
||||
if ifsEnd == -1 {
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse each interface entry
|
||||
ifaces := strings.Split(ifsStr[:ifsEnd], " ")
|
||||
for _, iface := range ifaces {
|
||||
parts := strings.Split(iface, ":")
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(parts[0])
|
||||
state := parts[1]
|
||||
result[name] = state
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -17,9 +17,8 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
patched = true
|
||||
iface.Name = name
|
||||
}
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
@@ -115,9 +115,13 @@ type prog struct {
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
|
||||
leakingQueryMu sync.Mutex
|
||||
leakingQueryWasRun bool
|
||||
leakingQuery atomic.Bool
|
||||
leakingQueryMu sync.Mutex
|
||||
leakingQueryRunning map[string]bool
|
||||
leakingQueryReset atomic.Bool
|
||||
|
||||
resetCtx context.Context
|
||||
resetCancel context.CancelFunc
|
||||
resetCtxMu sync.Mutex
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
@@ -420,6 +424,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
p.onStartedDone = make(chan struct{})
|
||||
p.loop = make(map[string]bool)
|
||||
p.leakingQueryRunning = make(map[string]bool)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
p.cacheFlushDomainsMap = nil
|
||||
@@ -737,12 +742,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
if !requiredMultiNICsConfig() {
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
logger.Debug().Msg("start DNS settings watchdog")
|
||||
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
@@ -751,7 +757,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.leakingQuery.Load() {
|
||||
if p.leakingQueryReset.Load() {
|
||||
return
|
||||
}
|
||||
if dnsChanged(iface, ns) {
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.leakingQuery.Load() {
|
||||
if p.leakingQueryReset.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
|
||||
@@ -44,10 +44,6 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
// Do not count "upstream.os", since it must not be down for leaking queries.
|
||||
if upstream == upstreamOS {
|
||||
return
|
||||
}
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
|
||||
102
resolver.go
102
resolver.go
@@ -78,9 +78,7 @@ func availableNameservers() []string {
|
||||
if _, ok := machineIPsMap[ns]; ok {
|
||||
continue
|
||||
}
|
||||
if testNameServerFn(ns) {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
return nss
|
||||
}
|
||||
@@ -100,11 +98,9 @@ func InitializeOsResolver() []string {
|
||||
// - First available LAN servers are saved and store.
|
||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||
func initializeOsResolver(servers []string) []string {
|
||||
var (
|
||||
lanNss []string
|
||||
publicNss []string
|
||||
)
|
||||
var lanNss, publicNss []string
|
||||
|
||||
// First categorize servers
|
||||
for _, ns := range servers {
|
||||
addr, err := netip.ParseAddr(ns)
|
||||
if err != nil {
|
||||
@@ -117,28 +113,84 @@ func initializeOsResolver(servers []string) []string {
|
||||
publicNss = append(publicNss, server)
|
||||
}
|
||||
}
|
||||
|
||||
// Store initial servers immediately
|
||||
if len(lanNss) > 0 {
|
||||
// Saved first initialized LAN servers.
|
||||
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
|
||||
}
|
||||
if len(lanNss) == 0 {
|
||||
var nss []string
|
||||
p := or.initializedLanServers.Load()
|
||||
if p != nil {
|
||||
for _, ns := range *p {
|
||||
if testNameServerFn(ns) {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
or.lanServers.Store(&nss)
|
||||
} else {
|
||||
or.lanServers.Store(&lanNss)
|
||||
}
|
||||
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = append(publicNss, controldPublicDnsWithPort)
|
||||
publicNss = []string{controldPublicDnsWithPort}
|
||||
}
|
||||
or.publicServers.Store(&publicNss)
|
||||
|
||||
// Test servers in background and remove failures
|
||||
go func() {
|
||||
// Test servers in parallel but maintain order
|
||||
type result struct {
|
||||
index int
|
||||
server string
|
||||
valid bool
|
||||
}
|
||||
|
||||
testServers := func(servers []string) []string {
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
results := make(chan result, len(servers))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, server := range servers {
|
||||
wg.Add(1)
|
||||
go func(idx int, s string) {
|
||||
defer wg.Done()
|
||||
results <- result{
|
||||
index: idx,
|
||||
server: s,
|
||||
valid: testNameServerFn(s),
|
||||
}
|
||||
}(i, server)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results maintaining original order
|
||||
validServers := make([]string, 0, len(servers))
|
||||
ordered := make([]result, 0, len(servers))
|
||||
for r := range results {
|
||||
ordered = append(ordered, r)
|
||||
}
|
||||
slices.SortFunc(ordered, func(a, b result) int {
|
||||
return a.index - b.index
|
||||
})
|
||||
for _, r := range ordered {
|
||||
if r.valid {
|
||||
validServers = append(validServers, r.server)
|
||||
} else {
|
||||
ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
|
||||
}
|
||||
}
|
||||
return validServers
|
||||
}
|
||||
|
||||
// Test and update LAN servers
|
||||
if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
|
||||
or.lanServers.Store(&validLanNss)
|
||||
}
|
||||
|
||||
// Test and update public servers
|
||||
validPublicNss := testServers(publicNss)
|
||||
if len(validPublicNss) == 0 {
|
||||
validPublicNss = []string{controldPublicDnsWithPort}
|
||||
}
|
||||
or.publicServers.Store(&validPublicNss)
|
||||
}()
|
||||
|
||||
return slices.Concat(lanNss, publicNss)
|
||||
}
|
||||
|
||||
@@ -192,7 +244,6 @@ func testNameserver(addr string) bool {
|
||||
}{
|
||||
{".", dns.TypeNS}, // Root NS query - should always work
|
||||
{"controld.com.", dns.TypeA}, // Fallback to a reliable domain
|
||||
{"google.com.", dns.TypeA}, // Fallback to a reliable domain
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
@@ -330,10 +381,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||
switch {
|
||||
case res.server == controldPublicDnsWithPort:
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server)
|
||||
controldSuccessAnswer = res.answer
|
||||
case !res.lan && publicServerAnswer == nil:
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server)
|
||||
publicServerAnswer = res.answer
|
||||
publicServer = res.server
|
||||
default:
|
||||
@@ -351,14 +400,17 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
if publicServerAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer)
|
||||
logAnswer(publicServer)
|
||||
return publicServerAnswer, nil
|
||||
}
|
||||
if controldSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
logAnswer(controldPublicDnsWithPort)
|
||||
return controldSuccessAnswer, nil
|
||||
}
|
||||
if nonSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
|
||||
logAnswer(nonSuccessServer)
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
|
||||
@@ -3,13 +3,10 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -178,71 +175,6 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
|
||||
return server, addr, nil
|
||||
}
|
||||
|
||||
func Test_initializeOsResolver(t *testing.T) {
|
||||
testNameServerFn = testNameserverTest
|
||||
lanServer1 := "192.168.1.1"
|
||||
lanServer1WithPort := net.JoinHostPort("192.168.1.1", "53")
|
||||
lanServer2 := "10.0.10.69"
|
||||
lanServer2WithPort := net.JoinHostPort("10.0.10.69", "53")
|
||||
lanServer3 := "192.168.40.1"
|
||||
lanServer3WithPort := net.JoinHostPort("192.168.40.1", "53")
|
||||
wanServer := "1.1.1.1"
|
||||
lanServers := []string{lanServer1WithPort, lanServer2WithPort}
|
||||
publicServers := []string{net.JoinHostPort(wanServer, "53")}
|
||||
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
|
||||
// First initialization, initialized servers are saved.
|
||||
initializeOsResolver([]string{lanServer1, lanServer2, wanServer})
|
||||
p := or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// No new LAN servers, but lanServer2 gone, initialized servers not changed.
|
||||
initializeOsResolver([]string{lanServer1, wanServer})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer1WithPort}))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// New LAN servers, they are used, initialized servers not changed.
|
||||
initializeOsResolver([]string{lanServer3, wanServer})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer3WithPort}))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// No LAN server available, initialized servers will be used.
|
||||
initializeOsResolver([]string{wanServer})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// No Public server, ControlD Public DNS will be used.
|
||||
initializeOsResolver([]string{})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort}))
|
||||
|
||||
// No LAN server available, initialized servers is unavailable, nothing will be used.
|
||||
nonSuccessTestServerMap[lanServer1WithPort] = true
|
||||
nonSuccessTestServerMap[lanServer2WithPort] = true
|
||||
initializeOsResolver([]string{wanServer})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.Empty(t, *or.lanServers.Load())
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
}
|
||||
|
||||
func successHandler() dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
@@ -258,9 +190,3 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
var nonSuccessTestServerMap = map[string]bool{}
|
||||
|
||||
func testNameserverTest(addr string) bool {
|
||||
return !nonSuccessTestServerMap[addr]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user