mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-15 00:50:25 +02:00
set new dialer on every request
debugging debugging debugging debugging use default route interface IP for OS resolver queries remove retries fix resolv.conf clobbering on MacOS, set custom local addr for os resolver queries remove the client info discovery logic on network change, this was overkill just for the IP, and was causing service failure after switching networks many times rapidly handle ipv6 local addresses guard ciTable from nil pointer debugging failure count
This commit is contained in:
+69
-19
@@ -99,6 +99,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
}
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
mainLog.Load().Debug().Msgf("serveDNS handler called")
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
if len(m.Question) == 0 {
|
||||
@@ -1238,7 +1239,10 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
defer p.resetCtxMu.Unlock()
|
||||
|
||||
p.leakingQueryReset.Store(true)
|
||||
defer p.leakingQueryReset.Store(false)
|
||||
defer func() {
|
||||
time.Sleep(time.Second)
|
||||
p.leakingQueryReset.Store(false)
|
||||
}()
|
||||
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
p.resetDNS()
|
||||
@@ -1260,7 +1264,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) {
|
||||
if err := FlushDNSCache(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// delay putting back the ctrld listener to allow for captive portal to trigger
|
||||
time.Sleep(5 * time.Second)
|
||||
@@ -1316,21 +1319,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
oldIfs := parseInterfaceState(delta.Old)
|
||||
newIfs := parseInterfaceState(delta.New)
|
||||
|
||||
// Client info discover only run on non-mobile platforms.
|
||||
if !isMobile() {
|
||||
// If this is major change, re-init client info table if its self IP changes.
|
||||
if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) {
|
||||
selfIP := defaultRouteIP()
|
||||
if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" {
|
||||
p.stopClientInfoDiscover()
|
||||
p.setupClientInfoDiscover(selfIP)
|
||||
p.runClientInfoDiscover(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for changes in valid interfaces
|
||||
changed := false
|
||||
var changedIface, changedIfaceState string
|
||||
activeInterfaceExists := false
|
||||
|
||||
for ifaceName := range validIfaces {
|
||||
@@ -1343,7 +1334,14 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
// Compare states directly
|
||||
if oldExists != newExists || oldState != newState {
|
||||
changed = true
|
||||
|
||||
// If the interface is up, we need to reinitialize the OS resolver
|
||||
if newState != "" && !strings.Contains(newState, "down") {
|
||||
changed = true
|
||||
changedIface = ifaceName
|
||||
changedIfaceState = newState
|
||||
}
|
||||
|
||||
mainLog.Load().Warn().
|
||||
Str("interface", ifaceName).
|
||||
Str("old_state", oldState).
|
||||
@@ -1364,11 +1362,33 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
return
|
||||
}
|
||||
|
||||
if activeInterfaceExists {
|
||||
p.reinitializeOSResolver(true)
|
||||
} else {
|
||||
if !activeInterfaceExists {
|
||||
mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization")
|
||||
return
|
||||
}
|
||||
|
||||
// Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta.
|
||||
selfIP := defaultRouteIP()
|
||||
if selfIP == "" && changedIface != "" {
|
||||
selfIP = extractIPv4FromState(changedIfaceState)
|
||||
mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP)
|
||||
}
|
||||
|
||||
// Extract IPv6 from the changed interface state.
|
||||
ipv6 := extractIPv6FromState(changedIfaceState)
|
||||
|
||||
if ip := net.ParseIP(selfIP); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
// if we have a new IP, set the client info to the new IP
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
p.reinitializeOSResolver(true)
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
@@ -1423,3 +1443,33 @@ func parseInterfaceState(state *netmon.State) map[string]string {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractIPv4FromState extracts an IPv4 address from an interface state string.
|
||||
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
|
||||
// If no valid IP can be found, it returns an empty string.
|
||||
func extractIPv4FromState(state string) string {
|
||||
trimmed := strings.Trim(state, "[]")
|
||||
parts := strings.Fields(trimmed)
|
||||
for _, part := range parts {
|
||||
ipPart := strings.Split(part, "/")[0]
|
||||
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil {
|
||||
return ipPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractIPv6FromState extracts an IPv6 address from an interface state string.
|
||||
// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239".
|
||||
// If no valid IP can be found, it returns an empty string.
|
||||
func extractIPv6FromState(state string) string {
|
||||
trimmed := strings.Trim(state, "[]")
|
||||
parts := strings.Fields(trimmed)
|
||||
for _, part := range parts {
|
||||
ipPart := strings.Split(part, "/")[0]
|
||||
if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil {
|
||||
return ipPart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -504,6 +504,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if err := p.serveDNS(ctx, listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
|
||||
+100
-9
@@ -3,11 +3,38 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
|
||||
@@ -42,14 +42,24 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// Check if the failure count has reached the threshold to mark the upstream as down.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount)
|
||||
} else {
|
||||
um.down[upstream] = false
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
|
||||
Reference in New Issue
Block a user