set new dialer on every request

debugging

debugging

debugging

debugging

use default route interface IP for OS resolver queries

remove retries

fix resolv.conf clobbering on MacOS, set custom local addr for os resolver queries

remove the client info discovery logic on network change, this was overkill just for the IP, and was causing service failure after switching networks many times rapidly

handle ipv6 local addresses

guard ciTable from nil pointer

debugging failure count
This commit is contained in:
Alex
2025-02-05 01:41:16 -05:00
committed by Cuong Manh Le
parent 60686f55ff
commit cf6d16b439
7 changed files with 317 additions and 35 deletions

View File

@@ -99,6 +99,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
}
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
mainLog.Load().Debug().Msgf("serveDNS handler called")
p.sema.acquire()
defer p.sema.release()
if len(m.Question) == 0 {
@@ -1238,7 +1239,10 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
defer p.resetCtxMu.Unlock()
p.leakingQueryReset.Store(true)
defer p.leakingQueryReset.Store(false)
defer func() {
time.Sleep(time.Second)
p.leakingQueryReset.Store(false)
}()
mainLog.Load().Debug().Msg("attempting to reset DNS")
p.resetDNS()
@@ -1260,7 +1264,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
if err := FlushDNSCache(); err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
}
if runtime.GOOS == "darwin" {
// delay putting back the ctrld listener to allow for captive portal to trigger
time.Sleep(5 * time.Second)
@@ -1316,21 +1319,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
oldIfs := parseInterfaceState(delta.Old)
newIfs := parseInterfaceState(delta.New)
// Client info discover only run on non-mobile platforms.
if !isMobile() {
// If this is major change, re-init client info table if its self IP changes.
if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) {
selfIP := defaultRouteIP()
if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" {
p.stopClientInfoDiscover()
p.setupClientInfoDiscover(selfIP)
p.runClientInfoDiscover(ctx)
}
}
}
// Check for changes in valid interfaces
changed := false
var changedIface, changedIfaceState string
activeInterfaceExists := false
for ifaceName := range validIfaces {
@@ -1343,7 +1334,14 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
// Compare states directly
if oldExists != newExists || oldState != newState {
changed = true
// If the interface is up, we need to reinitialize the OS resolver
if newState != "" && !strings.Contains(newState, "down") {
changed = true
changedIface = ifaceName
changedIfaceState = newState
}
mainLog.Load().Warn().
Str("interface", ifaceName).
Str("old_state", oldState).
@@ -1364,11 +1362,33 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
return
}
if activeInterfaceExists {
p.reinitializeOSResolver(true)
} else {
if !activeInterfaceExists {
mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization")
return
}
// Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta.
selfIP := defaultRouteIP()
if selfIP == "" && changedIface != "" {
selfIP = extractIPv4FromState(changedIfaceState)
mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP)
}
// Extract IPv6 from the changed interface state.
ipv6 := extractIPv6FromState(changedIfaceState)
if ip := net.ParseIP(selfIP); ip != nil {
ctrld.SetDefaultLocalIPv4(ip)
// if we have a new IP, set the client info to the new IP
if !isMobile() && p.ciTable != nil {
p.ciTable.SetSelfIP(selfIP)
}
}
if ip := net.ParseIP(ipv6); ip != nil {
ctrld.SetDefaultLocalIPv6(ip)
}
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
p.reinitializeOSResolver(true)
})
mon.Start()
@@ -1423,3 +1443,33 @@ func parseInterfaceState(state *netmon.State) map[string]string {
return result
}
// extractIPv4FromState extracts an IPv4 address from an interface state string.
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
// If no valid IP can be found, it returns an empty string.
func extractIPv4FromState(state string) string {
trimmed := strings.Trim(state, "[]")
parts := strings.Fields(trimmed)
for _, part := range parts {
ipPart := strings.Split(part, "/")[0]
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil {
return ipPart
}
}
return ""
}
// extractIPv6FromState extracts an IPv6 address from an interface state string.
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
// If no valid IP can be found, it returns an empty string.
func extractIPv6FromState(state string) string {
trimmed := strings.Trim(state, "[]")
parts := strings.Fields(trimmed)
for _, part := range parts {
ipPart := strings.Split(part, "/")[0]
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil {
return ipPart
}
}
return ""
}

View File

@@ -504,6 +504,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if err := p.serveDNS(ctx, listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
}(listenerNum)
}
go func() {

View File

@@ -3,11 +3,38 @@ package cli
import (
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"
"github.com/fsnotify/fsnotify"
)
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
// Returns nil if no nameservers are found.
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Parse the file for "nameserver" lines
var currentNS []string
lines := strings.Split(string(content), "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "nameserver") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
currentNS = append(currentNS, parts[1])
}
}
}
return currentNS, nil
}
// watchResolvConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld.
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
continue
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
if err := watcher.Remove(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
// Convert expected nameservers to strings for comparison
expectedNS := make([]string, len(ns))
for i, addr := range ns {
expectedNS[i] = addr.String()
}
if err := setDnsFn(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
var foundNS []string
var err error
maxRetries := 1
for retry := 0; retry < maxRetries; retry++ {
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
break
}
// If we found nameservers, break out of retry loop
if len(foundNS) > 0 {
break
}
// Only retry if we found no nameservers
if retry < maxRetries-1 {
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
select {
case <-p.stopCh:
return
case <-p.dnsWatcherStopCh:
return
case <-time.After(2 * time.Second):
continue
}
} else {
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
}
}
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
// If we found nameservers, check if they match what we expect
if len(foundNS) > 0 {
// Check if the nameservers match exactly what we expect
matches := len(foundNS) == len(expectedNS)
if matches {
for i := range foundNS {
if foundNS[i] != expectedNS[i] {
matches = false
break
}
}
}
mainLog.Load().Debug().
Strs("found", foundNS).
Strs("expected", expectedNS).
Bool("matches", matches).
Msg("checking nameservers")
// Only revert if the nameservers don't match
if !matches {
if err := watcher.Remove(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
}
if err := setDnsFn(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
}
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
}
}
}
}
case err, ok := <-watcher.Errors:

View File

@@ -42,14 +42,24 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
return um
}
// increaseFailureCount increase failed queries count for an upstream by 1.
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
um.mu.Lock()
defer um.mu.Unlock()
um.failureReq[upstream] += 1
failedCount := um.failureReq[upstream]
um.down[upstream] = failedCount >= maxFailureRequest
// Log the updated failure count
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
// Check if the failure count has reached the threshold to mark the upstream as down.
if failedCount >= maxFailureRequest {
um.down[upstream] = true
mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount)
} else {
um.down[upstream] = false
}
}
// isDown reports whether the given upstream is being marked as down.

View File

@@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
if uc.rebootstrap.CompareAndSwap(false, true) {
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
}
return true, nil
})

View File

@@ -93,6 +93,7 @@ type Table struct {
quitCh chan struct{}
stopCh chan struct{}
selfIP string
selfIPLock sync.RWMutex
cdUID string
ptrNameservers []string
}
@@ -160,10 +161,20 @@ func (t *Table) Stop() {
<-t.quitCh
}
// SelfIP returns the selfIP value of the Table in a thread-safe manner.
func (t *Table) SelfIP() string {
t.selfIPLock.RLock()
defer t.selfIPLock.RUnlock()
return t.selfIP
}
// SetSelfIP sets the selfIP value of the Table in a thread-safe manner.
func (t *Table) SetSelfIP(ip string) {
t.selfIPLock.Lock()
defer t.selfIPLock.Unlock()
t.selfIP = ip
}
func (t *Table) init() {
// Custom client ID presents, use it as the only source.
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"net/netip"
"runtime"
"slices"
"sync"
"sync/atomic"
@@ -50,8 +51,10 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
var localResolver = newLocalResolver()
var (
resolverMutex sync.Mutex
or *osResolver
resolverMutex sync.Mutex
or *osResolver
defaultLocalIPv4 atomic.Value // holds net.IP (IPv4)
defaultLocalIPv6 atomic.Value // holds net.IP (IPv6)
)
func newLocalResolver() Resolver {
@@ -216,6 +219,108 @@ type publicResponse struct {
server string
}
// SetDefaultLocalIPv4 updates the stored local IPv4.
func SetDefaultLocalIPv4(ip net.IP) {
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
defaultLocalIPv4.Store(ip)
}
// SetDefaultLocalIPv6 updates the stored local IPv6.
func SetDefaultLocalIPv6(ip net.IP) {
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
defaultLocalIPv6.Store(ip)
}
// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none.
func GetDefaultLocalIPv4() net.IP {
if v := defaultLocalIPv4.Load(); v != nil {
return v.(net.IP)
}
return nil
}
// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none.
func GetDefaultLocalIPv6() net.IP {
if v := defaultLocalIPv6.Load(); v != nil {
return v.(net.IP)
}
return nil
}
// debugDialer is a helper type that wraps a net.Dialer and logs
// the local IP address used when dialing out.
type debugDialer struct {
*net.Dialer
}
// DialContext wraps the underlying DialContext and logs the local address of the connection.
func (d *debugDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, addr)
if err != nil {
// Log the error even before a connection is established.
if d.Dialer.LocalAddr != nil {
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v (local addr: %v)", addr, err, d.Dialer.LocalAddr)
} else {
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v", addr, err)
}
return nil, err
}
// Log the local address (source IP) used for this connection.
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s succeeded; local address: %s",
addr, conn.LocalAddr().String())
return conn, nil
}
// customDNSExchange wraps the DNS exchange to use our debug dialer.
// It uses dns.ExchangeWithConn so that our custom dialer is used directly.
func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, error) {
baseDialer := &net.Dialer{
Timeout: 3 * time.Second,
Resolver: &net.Resolver{PreferGo: true},
}
if desiredLocalIP != nil {
baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0}
}
dd := &debugDialer{Dialer: baseDialer}
// Attempt UDP first.
udpConn, err := dd.DialContext(ctx, "udp", server)
if err != nil {
return nil, err
}
defer udpConn.Close()
udpDnsConn := &dns.Conn{Conn: udpConn}
if err = udpDnsConn.WriteMsg(msg); err != nil {
return nil, err
}
reply, err := udpDnsConn.ReadMsg()
if err != nil {
return nil, err
}
// If the UDP reply is not truncated, return it.
if !reply.Truncated {
return reply, nil
}
// If truncated, retry over TCP once.
Log(ctx, ProxyLogger.Load().Debug(), "UDP response truncated, switching to TCP (1 retry)")
tcpConn, err := dd.DialContext(ctx, "tcp", server)
if err != nil {
return reply, nil // fallback to UDP reply if TCP dial fails.
}
defer tcpConn.Close()
tcpDnsConn := &dns.Conn{Conn: tcpConn}
if err = tcpDnsConn.WriteMsg(msg); err != nil {
return reply, nil // fallback if TCP write fails.
}
tcpReply, err := tcpDnsConn.ReadMsg()
if err != nil {
return reply, nil // fallback if TCP read fails.
}
return tcpReply, nil
}
// Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first
// success response will be returned.
@@ -237,7 +342,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
ctx, cancel := context.WithCancel(ctx)
defer cancel()
dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second}
ch := make(chan *osResolverResult, numServers)
wg := &sync.WaitGroup{}
wg.Add(numServers)
@@ -250,7 +354,22 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
for _, server := range servers {
go func(server string) {
defer wg.Done()
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
var answer *dns.Msg
var err error
var localOSResolverIP net.IP
if runtime.GOOS == "darwin" {
host, _, err := net.SplitHostPort(server)
if err == nil {
ip := net.ParseIP(host)
if ip != nil && ip.To4() == nil {
// IPv6 nameserver; use default IPv6 address (if set)
localOSResolverIP = GetDefaultLocalIPv6()
} else {
localOSResolverIP = GetDefaultLocalIPv4()
}
}
}
answer, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP)
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
}(server)
}