diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 646bafb..01e1673 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -99,6 +99,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + mainLog.Load().Debug().Msgf("serveDNS handler called") p.sema.acquire() defer p.sema.release() if len(m.Question) == 0 { @@ -1238,7 +1239,10 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) - defer p.leakingQueryReset.Store(false) + defer func() { + time.Sleep(time.Second) + p.leakingQueryReset.Store(false) + }() mainLog.Load().Debug().Msg("attempting to reset DNS") p.resetDNS() @@ -1260,7 +1264,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } - if runtime.GOOS == "darwin" { // delay putting back the ctrld listener to allow for captive portal to trigger time.Sleep(5 * time.Second) @@ -1316,21 +1319,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) - // Client info discover only run on non-mobile platforms. - if !isMobile() { - // If this is major change, re-init client info table if its self IP changes. - if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) { - selfIP := defaultRouteIP() - if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" { - p.stopClientInfoDiscover() - p.setupClientInfoDiscover(selfIP) - p.runClientInfoDiscover(ctx) - } - } - } - // Check for changes in valid interfaces changed := false + var changedIface, changedIfaceState string activeInterfaceExists := false for ifaceName := range validIfaces { @@ -1343,7 +1334,14 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Compare states directly if oldExists != newExists || oldState != newState { - changed = true + + // If the interface is up, we need to reinitialize the OS resolver + if newState != "" && !strings.Contains(newState, "down") { + changed = true + changedIface = ifaceName + changedIfaceState = newState + } + mainLog.Load().Warn(). Str("interface", ifaceName). Str("old_state", oldState). @@ -1364,11 +1362,33 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { return } - if activeInterfaceExists { - p.reinitializeOSResolver(true) - } else { + if !activeInterfaceExists { mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization") + return } + + // Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta. + selfIP := defaultRouteIP() + if selfIP == "" && changedIface != "" { + selfIP = extractIPv4FromState(changedIfaceState) + mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP) + } + + // Extract IPv6 from the changed interface state. + ipv6 := extractIPv6FromState(changedIfaceState) + + if ip := net.ParseIP(selfIP); ip != nil { + ctrld.SetDefaultLocalIPv4(ip) + // if we have a new IP, set the client info to the new IP + if !isMobile() && p.ciTable != nil { + p.ciTable.SetSelfIP(selfIP) + } + } + if ip := net.ParseIP(ipv6); ip != nil { + ctrld.SetDefaultLocalIPv6(ip) + } + mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + p.reinitializeOSResolver(true) }) mon.Start() @@ -1423,3 +1443,33 @@ func parseInterfaceState(state *netmon.State) map[string]string { return result } + +// extractIPv4FromState extracts an IPv4 address from an interface state string. +// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". +// If no valid IP can be found, it returns an empty string. +func extractIPv4FromState(state string) string { + trimmed := strings.Trim(state, "[]") + parts := strings.Fields(trimmed) + for _, part := range parts { + ipPart := strings.Split(part, "/")[0] + if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil { + return ipPart + } + } + return "" +} + +// extractIPv6FromState extracts an IPv6 address from an interface state string. +// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". +// If no valid IP can be found, it returns an empty string. +func extractIPv6FromState(state string) string { + trimmed := strings.Trim(state, "[]") + parts := strings.Fields(trimmed) + for _, part := range parts { + ipPart := strings.Split(part, "/")[0] + if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil { + return ipPart + } + } + return "" +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 41dc2c4..c7eba13 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -504,6 +504,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if err := p.serveDNS(ctx, listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } + mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 367ffe7..9d37d68 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -3,11 +3,38 @@ package cli import ( "net" "net/netip" + "os" "path/filepath" + "strings" + "time" "github.com/fsnotify/fsnotify" ) +// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found. +// Returns nil if no nameservers are found. +func (p *prog) parseResolvConfNameservers(path string) ([]string, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // Parse the file for "nameserver" lines + var currentNS []string + lines := strings.Split(string(content), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "nameserver") { + parts := strings.Fields(trimmed) + if len(parts) >= 2 { + currentNS = append(currentNS, parts[1]) + } + } + } + + return currentNS, nil +} + // watchResolvConf watches any changes to /etc/resolv.conf file, // and reverting to the original config set by ctrld. func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { @@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting") - if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") - continue + mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + + // Convert expected nameservers to strings for comparison + expectedNS := make([]string, len(ns)) + for i, addr := range ns { + expectedNS[i] = addr.String() } - if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + + var foundNS []string + var err error + + maxRetries := 1 + for retry := 0; retry < maxRetries; retry++ { + foundNS, err = p.parseResolvConfNameservers(resolvConfPath) + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + break + } + + // If we found nameservers, break out of retry loop + if len(foundNS) > 0 { + break + } + + // Only retry if we found no nameservers + if retry < maxRetries-1 { + mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + select { + case <-p.stopCh: + return + case <-p.dnsWatcherStopCh: + return + case <-time.After(2 * time.Second): + continue + } + } else { + mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + } } - if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") - return + + // If we found nameservers, check if they match what we expect + if len(foundNS) > 0 { + // Check if the nameservers match exactly what we expect + matches := len(foundNS) == len(expectedNS) + if matches { + for i := range foundNS { + if foundNS[i] != expectedNS[i] { + matches = false + break + } + } + } + + mainLog.Load().Debug(). + Strs("found", foundNS). + Strs("expected", expectedNS). + Bool("matches", matches). + Msg("checking nameservers") + + // Only revert if the nameservers don't match + if !matches { + if err := watcher.Remove(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + continue + } + + if err := setDnsFn(iface, ns); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + } + + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + return + } + } } } case err, ok := <-watcher.Errors: diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index e37db4d..fc5d65d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -42,14 +42,24 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { return um } -// increaseFailureCount increase failed queries count for an upstream by 1. +// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information. func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.mu.Lock() defer um.mu.Unlock() um.failureReq[upstream] += 1 failedCount := um.failureReq[upstream] - um.down[upstream] = failedCount >= maxFailureRequest + + // Log the updated failure count + mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + + // Check if the failure count has reached the threshold to mark the upstream as down. + if failedCount >= maxFailureRequest { + um.down[upstream] = true + mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount) + } else { + um.down[upstream] = false + } } // isDown reports whether the given upstream is being marked as down. diff --git a/config.go b/config.go index 099f75b..e1454f9 100644 --- a/config.go +++ b/config.go @@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") + ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil }) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 04ec4c3..e6bda79 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -93,6 +93,7 @@ type Table struct { quitCh chan struct{} stopCh chan struct{} selfIP string + selfIPLock sync.RWMutex cdUID string ptrNameservers []string } @@ -160,10 +161,20 @@ func (t *Table) Stop() { <-t.quitCh } +// SelfIP returns the selfIP value of the Table in a thread-safe manner. func (t *Table) SelfIP() string { + t.selfIPLock.RLock() + defer t.selfIPLock.RUnlock() return t.selfIP } +// SetSelfIP sets the selfIP value of the Table in a thread-safe manner. +func (t *Table) SetSelfIP(ip string) { + t.selfIPLock.Lock() + defer t.selfIPLock.Unlock() + t.selfIP = ip +} + func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { diff --git a/resolver.go b/resolver.go index 34a6cdd..01348dc 100644 --- a/resolver.go +++ b/resolver.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/netip" + "runtime" "slices" "sync" "sync/atomic" @@ -50,8 +51,10 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") var localResolver = newLocalResolver() var ( - resolverMutex sync.Mutex - or *osResolver + resolverMutex sync.Mutex + or *osResolver + defaultLocalIPv4 atomic.Value // holds net.IP (IPv4) + defaultLocalIPv6 atomic.Value // holds net.IP (IPv6) ) func newLocalResolver() Resolver { @@ -216,6 +219,108 @@ type publicResponse struct { server string } +// SetDefaultLocalIPv4 updates the stored local IPv4. +func SetDefaultLocalIPv4(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) + defaultLocalIPv4.Store(ip) +} + +// SetDefaultLocalIPv6 updates the stored local IPv6. +func SetDefaultLocalIPv6(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) + defaultLocalIPv6.Store(ip) +} + +// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none. +func GetDefaultLocalIPv4() net.IP { + if v := defaultLocalIPv4.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none. +func GetDefaultLocalIPv6() net.IP { + if v := defaultLocalIPv6.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// debugDialer is a helper type that wraps a net.Dialer and logs +// the local IP address used when dialing out. +type debugDialer struct { + *net.Dialer +} + +// DialContext wraps the underlying DialContext and logs the local address of the connection. +func (d *debugDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, addr) + if err != nil { + // Log the error even before a connection is established. + if d.Dialer.LocalAddr != nil { + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v (local addr: %v)", addr, err, d.Dialer.LocalAddr) + } else { + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v", addr, err) + } + return nil, err + } + // Log the local address (source IP) used for this connection. + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s succeeded; local address: %s", + addr, conn.LocalAddr().String()) + return conn, nil +} + +// customDNSExchange wraps the DNS exchange to use our debug dialer. +// It uses dns.ExchangeWithConn so that our custom dialer is used directly. +func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, error) { + baseDialer := &net.Dialer{ + Timeout: 3 * time.Second, + Resolver: &net.Resolver{PreferGo: true}, + } + if desiredLocalIP != nil { + baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0} + } + dd := &debugDialer{Dialer: baseDialer} + + // Attempt UDP first. + udpConn, err := dd.DialContext(ctx, "udp", server) + if err != nil { + return nil, err + } + defer udpConn.Close() + udpDnsConn := &dns.Conn{Conn: udpConn} + if err = udpDnsConn.WriteMsg(msg); err != nil { + return nil, err + } + reply, err := udpDnsConn.ReadMsg() + if err != nil { + return nil, err + } + + // If the UDP reply is not truncated, return it. + if !reply.Truncated { + return reply, nil + } + + // If truncated, retry over TCP once. + Log(ctx, ProxyLogger.Load().Debug(), "UDP response truncated, switching to TCP (1 retry)") + tcpConn, err := dd.DialContext(ctx, "tcp", server) + if err != nil { + return reply, nil // fallback to UDP reply if TCP dial fails. + } + defer tcpConn.Close() + tcpDnsConn := &dns.Conn{Conn: tcpConn} + if err = tcpDnsConn.WriteMsg(msg); err != nil { + return reply, nil // fallback if TCP write fails. + } + tcpReply, err := tcpDnsConn.ReadMsg() + if err != nil { + return reply, nil // fallback if TCP read fails. + } + return tcpReply, nil +} + // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. @@ -237,7 +342,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) @@ -250,7 +354,22 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error for _, server := range servers { go func(server string) { defer wg.Done() - answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) + var answer *dns.Msg + var err error + var localOSResolverIP net.IP + if runtime.GOOS == "darwin" { + host, _, err := net.SplitHostPort(server) + if err == nil { + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { + // IPv6 nameserver; use default IPv6 address (if set) + localOSResolverIP = GetDefaultLocalIPv6() + } else { + localOSResolverIP = GetDefaultLocalIPv4() + } + } + } + answer, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP) ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan} }(server) }