mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
set new dialer on every request
debugging debugging debugging debugging use default route interface IP for OS resolver queries remove retries fix resolv.conf clobbering on MacOS, set custom local addr for os resolver queries remove the client info discovery logic on network change, this was overkill just for the IP, and was causing service failure after switching networks many times rapidly handle ipv6 local addresses guard ciTable from nil pointer debugging failure count
This commit is contained in:
@@ -99,6 +99,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
}
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
mainLog.Load().Debug().Msgf("serveDNS handler called")
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
if len(m.Question) == 0 {
|
||||
@@ -1238,7 +1239,10 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
defer p.resetCtxMu.Unlock()
|
||||
|
||||
p.leakingQueryReset.Store(true)
|
||||
defer p.leakingQueryReset.Store(false)
|
||||
defer func() {
|
||||
time.Sleep(time.Second)
|
||||
p.leakingQueryReset.Store(false)
|
||||
}()
|
||||
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
p.resetDNS()
|
||||
@@ -1260,7 +1264,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
if err := FlushDNSCache(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// delay putting back the ctrld listener to allow for captive portal to trigger
|
||||
time.Sleep(5 * time.Second)
|
||||
@@ -1316,21 +1319,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
oldIfs := parseInterfaceState(delta.Old)
|
||||
newIfs := parseInterfaceState(delta.New)
|
||||
|
||||
// Client info discover only run on non-mobile platforms.
|
||||
if !isMobile() {
|
||||
// If this is major change, re-init client info table if its self IP changes.
|
||||
if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) {
|
||||
selfIP := defaultRouteIP()
|
||||
if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" {
|
||||
p.stopClientInfoDiscover()
|
||||
p.setupClientInfoDiscover(selfIP)
|
||||
p.runClientInfoDiscover(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for changes in valid interfaces
|
||||
changed := false
|
||||
var changedIface, changedIfaceState string
|
||||
activeInterfaceExists := false
|
||||
|
||||
for ifaceName := range validIfaces {
|
||||
@@ -1343,7 +1334,14 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
// Compare states directly
|
||||
if oldExists != newExists || oldState != newState {
|
||||
changed = true
|
||||
|
||||
// If the interface is up, we need to reinitialize the OS resolver
|
||||
if newState != "" && !strings.Contains(newState, "down") {
|
||||
changed = true
|
||||
changedIface = ifaceName
|
||||
changedIfaceState = newState
|
||||
}
|
||||
|
||||
mainLog.Load().Warn().
|
||||
Str("interface", ifaceName).
|
||||
Str("old_state", oldState).
|
||||
@@ -1364,11 +1362,33 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
return
|
||||
}
|
||||
|
||||
if activeInterfaceExists {
|
||||
p.reinitializeOSResolver(true)
|
||||
} else {
|
||||
if !activeInterfaceExists {
|
||||
mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization")
|
||||
return
|
||||
}
|
||||
|
||||
// Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta.
|
||||
selfIP := defaultRouteIP()
|
||||
if selfIP == "" && changedIface != "" {
|
||||
selfIP = extractIPv4FromState(changedIfaceState)
|
||||
mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP)
|
||||
}
|
||||
|
||||
// Extract IPv6 from the changed interface state.
|
||||
ipv6 := extractIPv6FromState(changedIfaceState)
|
||||
|
||||
if ip := net.ParseIP(selfIP); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
// if we have a new IP, set the client info to the new IP
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
p.reinitializeOSResolver(true)
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
@@ -1423,3 +1443,33 @@ func parseInterfaceState(state *netmon.State) map[string]string {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractIPv4FromState extracts an IPv4 address from an interface state string.
|
||||
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
|
||||
// If no valid IP can be found, it returns an empty string.
|
||||
func extractIPv4FromState(state string) string {
|
||||
trimmed := strings.Trim(state, "[]")
|
||||
parts := strings.Fields(trimmed)
|
||||
for _, part := range parts {
|
||||
ipPart := strings.Split(part, "/")[0]
|
||||
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil {
|
||||
return ipPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractIPv6FromState extracts an IPv6 address from an interface state string.
|
||||
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
|
||||
// If no valid IP can be found, it returns an empty string.
|
||||
func extractIPv6FromState(state string) string {
|
||||
trimmed := strings.Trim(state, "[]")
|
||||
parts := strings.Fields(trimmed)
|
||||
for _, part := range parts {
|
||||
ipPart := strings.Split(part, "/")[0]
|
||||
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil {
|
||||
return ipPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -504,6 +504,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if err := p.serveDNS(ctx, listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
|
||||
@@ -3,11 +3,38 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
|
||||
@@ -42,14 +42,24 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// Check if the failure count has reached the threshold to mark the upstream as down.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount)
|
||||
} else {
|
||||
um.down[upstream] = false
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
|
||||
@@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
@@ -93,6 +93,7 @@ type Table struct {
|
||||
quitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
selfIP string
|
||||
selfIPLock sync.RWMutex
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
@@ -160,10 +161,20 @@ func (t *Table) Stop() {
|
||||
<-t.quitCh
|
||||
}
|
||||
|
||||
// SelfIP returns the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SelfIP() string {
|
||||
t.selfIPLock.RLock()
|
||||
defer t.selfIPLock.RUnlock()
|
||||
return t.selfIP
|
||||
}
|
||||
|
||||
// SetSelfIP sets the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SetSelfIP(ip string) {
|
||||
t.selfIPLock.Lock()
|
||||
defer t.selfIPLock.Unlock()
|
||||
t.selfIP = ip
|
||||
}
|
||||
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
|
||||
127
resolver.go
127
resolver.go
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -50,8 +51,10 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
var localResolver = newLocalResolver()
|
||||
|
||||
var (
|
||||
resolverMutex sync.Mutex
|
||||
or *osResolver
|
||||
resolverMutex sync.Mutex
|
||||
or *osResolver
|
||||
defaultLocalIPv4 atomic.Value // holds net.IP (IPv4)
|
||||
defaultLocalIPv6 atomic.Value // holds net.IP (IPv6)
|
||||
)
|
||||
|
||||
func newLocalResolver() Resolver {
|
||||
@@ -216,6 +219,108 @@ type publicResponse struct {
|
||||
server string
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv4 updates the stored local IPv4.
|
||||
func SetDefaultLocalIPv4(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
|
||||
defaultLocalIPv4.Store(ip)
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv6 updates the stored local IPv6.
|
||||
func SetDefaultLocalIPv6(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
|
||||
defaultLocalIPv6.Store(ip)
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none.
|
||||
func GetDefaultLocalIPv4() net.IP {
|
||||
if v := defaultLocalIPv4.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none.
|
||||
func GetDefaultLocalIPv6() net.IP {
|
||||
if v := defaultLocalIPv6.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugDialer is a helper type that wraps a net.Dialer and logs
|
||||
// the local IP address used when dialing out.
|
||||
type debugDialer struct {
|
||||
*net.Dialer
|
||||
}
|
||||
|
||||
// DialContext wraps the underlying DialContext and logs the local address of the connection.
|
||||
func (d *debugDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
// Log the error even before a connection is established.
|
||||
if d.Dialer.LocalAddr != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v (local addr: %v)", addr, err, d.Dialer.LocalAddr)
|
||||
} else {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v", addr, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// Log the local address (source IP) used for this connection.
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s succeeded; local address: %s",
|
||||
addr, conn.LocalAddr().String())
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// customDNSExchange wraps the DNS exchange to use our debug dialer.
|
||||
// It uses dns.ExchangeWithConn so that our custom dialer is used directly.
|
||||
func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, error) {
|
||||
baseDialer := &net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
Resolver: &net.Resolver{PreferGo: true},
|
||||
}
|
||||
if desiredLocalIP != nil {
|
||||
baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0}
|
||||
}
|
||||
dd := &debugDialer{Dialer: baseDialer}
|
||||
|
||||
// Attempt UDP first.
|
||||
udpConn, err := dd.DialContext(ctx, "udp", server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
udpDnsConn := &dns.Conn{Conn: udpConn}
|
||||
if err = udpDnsConn.WriteMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := udpDnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the UDP reply is not truncated, return it.
|
||||
if !reply.Truncated {
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// If truncated, retry over TCP once.
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "UDP response truncated, switching to TCP (1 retry)")
|
||||
tcpConn, err := dd.DialContext(ctx, "tcp", server)
|
||||
if err != nil {
|
||||
return reply, nil // fallback to UDP reply if TCP dial fails.
|
||||
}
|
||||
defer tcpConn.Close()
|
||||
tcpDnsConn := &dns.Conn{Conn: tcpConn}
|
||||
if err = tcpDnsConn.WriteMsg(msg); err != nil {
|
||||
return reply, nil // fallback if TCP write fails.
|
||||
}
|
||||
tcpReply, err := tcpDnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return reply, nil // fallback if TCP read fails.
|
||||
}
|
||||
return tcpReply, nil
|
||||
}
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
@@ -237,7 +342,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second}
|
||||
ch := make(chan *osResolverResult, numServers)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(numServers)
|
||||
@@ -250,7 +354,22 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
for _, server := range servers {
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
var answer *dns.Msg
|
||||
var err error
|
||||
var localOSResolverIP net.IP
|
||||
if runtime.GOOS == "darwin" {
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
if err == nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
// IPv6 nameserver; use default IPv6 address (if set)
|
||||
localOSResolverIP = GetDefaultLocalIPv6()
|
||||
} else {
|
||||
localOSResolverIP = GetDefaultLocalIPv4()
|
||||
}
|
||||
}
|
||||
}
|
||||
answer, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP)
|
||||
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
|
||||
}(server)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user