remove leaking timeout, fix blocking upstreams checks, leaking is per listener, OS resolvers are tested in parallel, reset is only done is os is down

fix test

use upstreamIS var

init map, fix watcher flag

attempt to detect network changes

attempt to detect network changes

cancel and rerun reinitializeOSResolver

cancel and rerun reinitializeOSResolver

cancel and rerun reinitializeOSResolver

ignore invalid inferaces

ignore invalid inferaces

allow OS resolver upstream to fail

dont wait for dnsWait group on reinit, check for active interfaces to trigger reinit

fix unused var

simpler active iface check, debug logs

dont spam network service name patching on Mac

dont wait for os resolver nameserver testing

remove test for osresovlers for now

async nameserver testing

remove unused test
This commit is contained in:
Alex
2025-01-16 19:27:24 -05:00
committed by Cuong Manh Le
parent 2d9c60dea1
commit 2687a4a018
7 changed files with 313 additions and 190 deletions

View File

@@ -19,6 +19,7 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
@@ -77,6 +78,12 @@ type upstreamForResult struct {
}
func (p *prog) serveDNS(listenerNum string) error {
// Start network monitoring
if err := p.monitorNetworkChanges(); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
// Don't return here as we still want DNS service to run
}
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
@@ -418,11 +425,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
upstreamMapKey := strings.Join(upstreams, "_")
leaked := false
if len(upstreamConfigs) > 0 && p.leakingQuery.Load() {
upstreamConfigs = nil
leaked = true
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
if len(upstreamConfigs) > 0 {
p.leakingQueryMu.Lock()
if p.leakingQueryRunning[upstreamMapKey] {
upstreamConfigs = nil
leaked = true
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
}
p.leakingQueryMu.Unlock()
}
if len(upstreamConfigs) == 0 {
@@ -601,9 +614,15 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
if p.leakOnUpstreamFailure() {
p.leakingQueryMu.Lock()
if !p.leakingQueryWasRun {
p.leakingQueryWasRun = true
go p.performLeakingQuery()
// get the map key as concact of upstreams
if !p.leakingQueryRunning[upstreamMapKey] {
p.leakingQueryRunning[upstreamMapKey] = true
// get a map of the failed upstreams
failedUpstreams := make(map[string]*ctrld.UpstreamConfig)
for n, upstream := range upstreamConfigs {
failedUpstreams[upstreams[n]] = upstream
}
go p.performLeakingQuery(failedUpstreams, upstreamMapKey)
}
p.leakingQueryMu.Unlock()
}
@@ -929,95 +948,66 @@ func (p *prog) selfUninstallCoolOfPeriod() {
}
// performLeakingQuery performs necessary works to leak queries to OS resolver.
func (p *prog) performLeakingQuery() {
mainLog.Load().Warn().Msg("leaking query to OS resolver")
// once we store the leakingQuery flag, we are leaking queries to OS resolver
// we then start testing all the upstreams forever, waiting for success, but in parallel
func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) {
// Create a context with timeout for the entire operation
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams)
// Signal dns watchers to stop, so changes made below won't be reverted.
p.leakingQuery.Store(true)
p.leakingQueryMu.Lock()
p.leakingQueryRunning[upstreamMapKey] = true
p.leakingQueryMu.Unlock()
defer func() {
p.leakingQuery.Store(false)
p.leakingQueryMu.Lock()
p.leakingQueryWasRun = false
p.leakingQueryRunning[upstreamMapKey] = false
p.leakingQueryMu.Unlock()
mainLog.Load().Warn().Msg("stop leaking query")
}()
// Create channels to coordinate operations
resetDone := make(chan struct{})
checkDone := make(chan struct{})
// we only want to reset DNS when our resolver is broken
// this allows us to find the new OS resolver nameservers
if p.um.isDown(upstreamOS) {
// Reset DNS with timeout
go func() {
defer close(resetDone)
mainLog.Load().Debug().Msg("attempting to reset DNS")
p.resetDNS()
mainLog.Load().Debug().Msg("DNS reset completed")
}()
mainLog.Load().Debug().Msg("OS resolver is down, reinitializing")
p.reinitializeOSResolver()
// Wait for reset with timeout
select {
case <-resetDone:
mainLog.Load().Debug().Msg("DNS reset successful")
case <-ctx.Done():
mainLog.Load().Error().Msg("DNS reset timed out")
return
}
// Check upstream in background with progress tracking
go func() {
defer close(checkDone)
mainLog.Load().Debug().Msg("starting upstream checks")
for name, uc := range p.cfg.Upstream {
select {
case <-ctx.Done():
return
default:
mainLog.Load().Debug().
Str("upstream", name).
Msg("checking upstream")
p.checkUpstream(name, uc)
// Test all failed upstreams in parallel
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
upstreamCh := make(chan string, len(failedUpstreams))
for name, uc := range failedUpstreams {
go func(name string, uc *ctrld.UpstreamConfig) {
mainLog.Load().Debug().
Str("upstream", name).
Msg("checking upstream")
for {
select {
case <-ctx.Done():
return
default:
p.checkUpstream(name, uc)
mainLog.Load().Debug().
Str("upstream", name).
Msg("upstream recovered")
upstreamCh <- name
return
}
}
}
mainLog.Load().Debug().Msg("upstream checks completed")
}()
// Wait for upstream checks
select {
case <-checkDone:
mainLog.Load().Debug().Msg("upstream checks successful")
case <-ctx.Done():
mainLog.Load().Error().Msg("upstream checks timed out")
return
}(name, uc)
}
// Initialize OS resolver with timeout
mainLog.Load().Debug().Msg("initializing OS resolver")
ns := ctrld.InitializeOsResolver()
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
// Wait for any upstream to recover
name := <-upstreamCh
// Wait for DNS operations to complete
waitCh := make(chan struct{})
go func() {
p.dnsWg.Wait()
close(waitCh)
}()
mainLog.Load().Info().
Str("upstream", name).
Msg("stopping leak as upstream recovered")
select {
case <-waitCh:
mainLog.Load().Debug().Msg("DNS operations completed")
case <-ctx.Done():
mainLog.Load().Error().Msg("DNS operations timed out")
return
}
// Set DNS with timeout
mainLog.Load().Debug().Msg("setting DNS configuration")
p.setDNS()
mainLog.Load().Debug().Msg("DNS configuration set successfully")
}
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
@@ -1190,3 +1180,157 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
answer.SetReply(m)
return answer
}
// reinitializeOSResolver reinitializes the OS resolver
// by removing ctrld listenr from the interface, collecting the network nameservers
// and re-initializing the OS resolver with the nameservers
// applying listener back to the interface
func (p *prog) reinitializeOSResolver() {
// Cancel any existing operations
p.resetCtxMu.Lock()
if p.resetCancel != nil {
p.resetCancel()
}
// Create new context for this operation
ctx, cancel := context.WithCancel(context.Background())
p.resetCtx = ctx
p.resetCancel = cancel
p.resetCtxMu.Unlock()
// Ensure cleanup
defer cancel()
p.leakingQueryReset.Store(true)
defer p.leakingQueryReset.Store(false)
select {
case <-ctx.Done():
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
return
default:
mainLog.Load().Debug().Msg("attempting to reset DNS")
p.resetDNS()
mainLog.Load().Debug().Msg("DNS reset completed")
}
select {
case <-ctx.Done():
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
return
default:
mainLog.Load().Debug().Msg("initializing OS resolver")
ns := ctrld.InitializeOsResolver()
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
}
select {
case <-ctx.Done():
mainLog.Load().Debug().Msg("DNS reset cancelled by new network change")
return
default:
mainLog.Load().Debug().Msg("setting DNS configuration")
p.setDNS()
mainLog.Load().Debug().Msg("DNS configuration set successfully")
}
}
// monitorNetworkChanges starts monitoring for network interface changes
func (p *prog) monitorNetworkChanges() error {
// Create network monitor
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
if err != nil {
return fmt.Errorf("creating network monitor: %w", err)
}
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
// Get map of valid interfaces
validIfaces := validInterfacesMap()
// Parse old and new interface states
oldIfs := parseInterfaceState(delta.Old)
newIfs := parseInterfaceState(delta.New)
// Check for changes in valid interfaces
changed := false
activeInterfaceExists := false
for ifaceName := range validIfaces {
oldState, oldExists := oldIfs[strings.ToLower(ifaceName)]
newState, newExists := newIfs[strings.ToLower(ifaceName)]
if newState != "" && newState != "down" {
activeInterfaceExists = true
}
if oldExists != newExists || oldState != newState {
changed = true
mainLog.Load().Debug().
Str("interface", ifaceName).
Str("old_state", oldState).
Str("new_state", newState).
Msg("Valid interface changed state")
break
} else {
mainLog.Load().Debug().
Str("interface", ifaceName).
Str("old_state", oldState).
Str("new_state", newState).
Msg("Valid interface unchanged")
}
}
if !changed {
mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected")
return
}
mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New)
if activeInterfaceExists {
p.reinitializeOSResolver()
} else {
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
}
})
mon.Start()
mainLog.Load().Debug().Msg("Network monitor started")
return nil
}
// parseInterfaceState parses the interface state string into a map of interface name -> state
func parseInterfaceState(state *netmon.State) map[string]string {
if state == nil {
return nil
}
result := make(map[string]string)
// Extract ifs={...} section
stateStr := state.String()
ifsStart := strings.Index(stateStr, "ifs={")
if ifsStart == -1 {
return result
}
ifsStr := stateStr[ifsStart+5:]
ifsEnd := strings.Index(ifsStr, "}")
if ifsEnd == -1 {
return result
}
// Parse each interface entry
ifaces := strings.Split(ifsStr[:ifsEnd], " ")
for _, iface := range ifaces {
parts := strings.Split(iface, ":")
if len(parts) != 2 {
continue
}
name := strings.ToLower(parts[0])
state := parts[1]
result[name] = state
}
return result
}

View File

@@ -17,9 +17,8 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) {
patched := false
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
iface.Name = name
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
patched = true
iface.Name = name
}
return patched, nil
}

View File

@@ -115,9 +115,13 @@ type prog struct {
loopMu sync.Mutex
loop map[string]bool
leakingQueryMu sync.Mutex
leakingQueryWasRun bool
leakingQuery atomic.Bool
leakingQueryMu sync.Mutex
leakingQueryRunning map[string]bool
leakingQueryReset atomic.Bool
resetCtx context.Context
resetCancel context.CancelFunc
resetCtxMu sync.Mutex
started chan struct{}
onStartedDone chan struct{}
@@ -420,6 +424,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
p.leakingQueryRunning = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
@@ -737,12 +742,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
if !requiredMultiNICsConfig() {
return
}
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
logger.Debug().Msg("start DNS settings watchdog")
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
for {
select {
case <-p.dnsWatcherStopCh:
@@ -751,7 +757,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.leakingQuery.Load() {
if p.leakingQueryReset.Load() {
return
}
if dnsChanged(iface, ns) {

View File

@@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if p.leakingQuery.Load() {
if p.leakingQueryReset.Load() {
return
}
if !ok {

View File

@@ -44,10 +44,6 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
// increaseFailureCount increase failed queries count for an upstream by 1.
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
// Do not count "upstream.os", since it must not be down for leaking queries.
if upstream == upstreamOS {
return
}
um.mu.Lock()
defer um.mu.Unlock()

View File

@@ -78,9 +78,7 @@ func availableNameservers() []string {
if _, ok := machineIPsMap[ns]; ok {
continue
}
if testNameServerFn(ns) {
nss = append(nss, ns)
}
nss = append(nss, ns)
}
return nss
}
@@ -100,11 +98,9 @@ func InitializeOsResolver() []string {
// - First available LAN servers are saved and store.
// - Later calls, if no LAN servers available, the saved servers above will be used.
func initializeOsResolver(servers []string) []string {
var (
lanNss []string
publicNss []string
)
var lanNss, publicNss []string
// First categorize servers
for _, ns := range servers {
addr, err := netip.ParseAddr(ns)
if err != nil {
@@ -117,28 +113,84 @@ func initializeOsResolver(servers []string) []string {
publicNss = append(publicNss, server)
}
}
// Store initial servers immediately
if len(lanNss) > 0 {
// Saved first initialized LAN servers.
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
}
if len(lanNss) == 0 {
var nss []string
p := or.initializedLanServers.Load()
if p != nil {
for _, ns := range *p {
if testNameServerFn(ns) {
nss = append(nss, ns)
}
}
}
or.lanServers.Store(&nss)
} else {
or.lanServers.Store(&lanNss)
}
if len(publicNss) == 0 {
publicNss = append(publicNss, controldPublicDnsWithPort)
publicNss = []string{controldPublicDnsWithPort}
}
or.publicServers.Store(&publicNss)
// Test servers in background and remove failures
go func() {
// Test servers in parallel but maintain order
type result struct {
index int
server string
valid bool
}
testServers := func(servers []string) []string {
if len(servers) == 0 {
return nil
}
results := make(chan result, len(servers))
var wg sync.WaitGroup
for i, server := range servers {
wg.Add(1)
go func(idx int, s string) {
defer wg.Done()
results <- result{
index: idx,
server: s,
valid: testNameServerFn(s),
}
}(i, server)
}
go func() {
wg.Wait()
close(results)
}()
// Collect results maintaining original order
validServers := make([]string, 0, len(servers))
ordered := make([]result, 0, len(servers))
for r := range results {
ordered = append(ordered, r)
}
slices.SortFunc(ordered, func(a, b result) int {
return a.index - b.index
})
for _, r := range ordered {
if r.valid {
validServers = append(validServers, r.server)
} else {
ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
}
}
return validServers
}
// Test and update LAN servers
if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
or.lanServers.Store(&validLanNss)
}
// Test and update public servers
validPublicNss := testServers(publicNss)
if len(validPublicNss) == 0 {
validPublicNss = []string{controldPublicDnsWithPort}
}
or.publicServers.Store(&validPublicNss)
}()
return slices.Concat(lanNss, publicNss)
}
@@ -192,7 +244,6 @@ func testNameserver(addr string) bool {
}{
{".", dns.TypeNS}, // Root NS query - should always work
{"controld.com.", dns.TypeA}, // Fallback to a reliable domain
{"google.com.", dns.TypeA}, // Fallback to a reliable domain
}
client := &dns.Client{
@@ -330,10 +381,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
switch {
case res.server == controldPublicDnsWithPort:
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server)
controldSuccessAnswer = res.answer
case !res.lan && publicServerAnswer == nil:
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server)
publicServerAnswer = res.answer
publicServer = res.server
default:
@@ -351,14 +400,17 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
errs = append(errs, res.err)
}
if publicServerAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer)
logAnswer(publicServer)
return publicServerAnswer, nil
}
if controldSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
logAnswer(controldPublicDnsWithPort)
return controldSuccessAnswer, nil
}
if nonSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
logAnswer(nonSuccessServer)
return nonSuccessAnswer, nil
}

View File

@@ -3,13 +3,10 @@ package ctrld
import (
"context"
"net"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/miekg/dns"
)
@@ -178,71 +175,6 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
return server, addr, nil
}
func Test_initializeOsResolver(t *testing.T) {
testNameServerFn = testNameserverTest
lanServer1 := "192.168.1.1"
lanServer1WithPort := net.JoinHostPort("192.168.1.1", "53")
lanServer2 := "10.0.10.69"
lanServer2WithPort := net.JoinHostPort("10.0.10.69", "53")
lanServer3 := "192.168.40.1"
lanServer3WithPort := net.JoinHostPort("192.168.40.1", "53")
wanServer := "1.1.1.1"
lanServers := []string{lanServer1WithPort, lanServer2WithPort}
publicServers := []string{net.JoinHostPort(wanServer, "53")}
or = newResolverWithNameserver(defaultNameservers())
// First initialization, initialized servers are saved.
initializeOsResolver([]string{lanServer1, lanServer2, wanServer})
p := or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// No new LAN servers, but lanServer2 gone, initialized servers not changed.
initializeOsResolver([]string{lanServer1, wanServer})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer1WithPort}))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// New LAN servers, they are used, initialized servers not changed.
initializeOsResolver([]string{lanServer3, wanServer})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer3WithPort}))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// No LAN server available, initialized servers will be used.
initializeOsResolver([]string{wanServer})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// No Public server, ControlD Public DNS will be used.
initializeOsResolver([]string{})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort}))
// No LAN server available, initialized servers is unavailable, nothing will be used.
nonSuccessTestServerMap[lanServer1WithPort] = true
nonSuccessTestServerMap[lanServer2WithPort] = true
initializeOsResolver([]string{wanServer})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.Empty(t, *or.lanServers.Load())
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
}
func successHandler() dns.HandlerFunc {
return func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
@@ -258,9 +190,3 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
w.WriteMsg(m)
}
}
var nonSuccessTestServerMap = map[string]bool{}
func testNameserverTest(addr string) bool {
return !nonSuccessTestServerMap[addr]
}