Don't automatically restore saved DNS settings when switching networks

smol tweaks to nameserver test queries

fix restoreDNS errors

add some debugging information

fix wront type in log msg

set send logs command timeout to 5 mins

when the runningIface is no longer up, attempt to find a new interface

prefer default route, ignore non physical interfaces

prefer default route, ignore non physical interfaces

add max context timeout on performLeakingQuery with more debug logs
This commit is contained in:
Alex Paguis
2025-01-15 17:31:10 -05:00
committed by Cuong Manh Le
parent e9e63b0983
commit 7833132917
12 changed files with 387 additions and 36 deletions

View File

@@ -1029,6 +1029,16 @@ func uninstall(p *prog, s service.Service) {
return
}
p.resetDNS()
// if present restore the original DNS settings
if netIface, err := netInterface(p.runningIface); err == nil {
if err := restoreDNS(netIface); err != nil {
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
} else {
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
}
}
if router.Name() != "" {
mainLog.Load().Debug().Msg("Router cleanup")
}

View File

@@ -541,6 +541,16 @@ func initStopCmd() *cobra.Command {
if doTasks([]task{{s.Stop, true}}) {
p.router.Cleanup()
p.resetDNS()
// restore DNS settings
if netIface, err := netInterface(p.runningIface); err == nil {
if err := restoreDNS(netIface); err != nil {
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
} else {
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
}
}
if router.WaitProcessExited() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

View File

@@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient {
}
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
// for log/send, set the timeout to 5 minutes
if path == sendLogsPath {
c.c.Timeout = time.Minute * 5
}
return c.c.Post("http://unix"+path, contentTypeJson, data)
}

View File

@@ -27,8 +27,8 @@ const (
deactivationPath = "/deactivation"
cdPath = "/cd"
ifacePath = "/iface"
viewLogsPath = "/logs/view"
sendLogsPath = "/logs/send"
viewLogsPath = "/log/view"
sendLogsPath = "/log/send"
)
type ifaceResponse struct {

View File

@@ -542,8 +542,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
if upstreamConfig == nil {
continue
}
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n])
if p.isLoop(upstreamConfig) {
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String())
continue
}
if p.um.isDown(upstreams[n]) {
@@ -929,6 +931,11 @@ func (p *prog) selfUninstallCoolOfPeriod() {
// performLeakingQuery performs necessary works to leak queries to OS resolver.
func (p *prog) performLeakingQuery() {
mainLog.Load().Warn().Msg("leaking query to OS resolver")
// Create a context with timeout for the entire operation
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Signal dns watchers to stop, so changes made below won't be reverted.
p.leakingQuery.Store(true)
defer func() {
@@ -936,20 +943,81 @@ func (p *prog) performLeakingQuery() {
p.leakingQueryMu.Lock()
p.leakingQueryWasRun = false
p.leakingQueryMu.Unlock()
mainLog.Load().Warn().Msg("stop leaking query")
}()
// Reset DNS, so queries are forwarded to OS resolver normally.
p.resetDNS()
// Check remote upstream in background, so ctrld could be back to normal
// operation as long as the network is back online.
for name, uc := range p.cfg.Upstream {
p.checkUpstream(name, uc)
// Create channels to coordinate operations
resetDone := make(chan struct{})
checkDone := make(chan struct{})
// Reset DNS with timeout
go func() {
defer close(resetDone)
mainLog.Load().Debug().Msg("attempting to reset DNS")
p.resetDNS()
mainLog.Load().Debug().Msg("DNS reset completed")
}()
// Wait for reset with timeout
select {
case <-resetDone:
mainLog.Load().Debug().Msg("DNS reset successful")
case <-ctx.Done():
mainLog.Load().Error().Msg("DNS reset timed out")
return
}
// After all upstream back, re-initializing OS resolver.
// Check upstream in background with progress tracking
go func() {
defer close(checkDone)
mainLog.Load().Debug().Msg("starting upstream checks")
for name, uc := range p.cfg.Upstream {
select {
case <-ctx.Done():
return
default:
mainLog.Load().Debug().
Str("upstream", name).
Msg("checking upstream")
p.checkUpstream(name, uc)
}
}
mainLog.Load().Debug().Msg("upstream checks completed")
}()
// Wait for upstream checks
select {
case <-checkDone:
mainLog.Load().Debug().Msg("upstream checks successful")
case <-ctx.Done():
mainLog.Load().Error().Msg("upstream checks timed out")
return
}
// Initialize OS resolver with timeout
mainLog.Load().Debug().Msg("initializing OS resolver")
ns := ctrld.InitializeOsResolver()
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
p.dnsWg.Wait()
// Wait for DNS operations to complete
waitCh := make(chan struct{})
go func() {
p.dnsWg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
mainLog.Load().Debug().Msg("DNS operations completed")
case <-ctx.Done():
mainLog.Load().Error().Msg("DNS operations timed out")
return
}
// Set DNS with timeout
mainLog.Load().Debug().Msg("setting DNS configuration")
p.setDNS()
mainLog.Load().Warn().Msg("stop leaking query")
mainLog.Load().Debug().Msg("DNS configuration set successfully")
}
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,

View File

@@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
// TODO(cuonglm): use system API
func resetDNS(iface *net.Interface) error {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
if err := setDNS(iface, ns); err == nil {
return nil
}
}
cmd := "networksetup"
args := []string{"-setdnsservers", iface.Name, "empty"}
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
@@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error {
return nil
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
err = setDNS(iface, ns)
}
return err
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
}

View File

@@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error {
return nil
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
return err
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
}

View File

@@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) {
})
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
return err
}
func currentDNS(iface *net.Interface) []string {
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
if ns := fn(iface.Name); len(ns) > 0 {

View File

@@ -130,8 +130,12 @@ func resetDNS(iface *net.Interface) error {
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
return fmt.Errorf("could not reset DNS ipv6: %w", err)
}
return nil
}
// If there's static DNS saved, restoring it.
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if nss := savedStaticNameservers(iface); len(nss) > 0 {
v4ns := make([]string, 0, 2)
v6ns := make([]string, 0, 2)
@@ -148,12 +152,14 @@ func resetDNS(iface *net.Interface) error {
continue
}
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
if err := setDNS(iface, ns); err != nil {
err = setDNS(iface, ns)
if err != nil {
return err
}
}
}
return nil
return err
}
func currentDNS(iface *net.Interface) []string {

View File

@@ -626,9 +626,31 @@ func (p *prog) setDNS() {
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
netIface, err := netInterface(p.runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
const maxDNSRetryAttempts = 3
const retryDelay = 1 * time.Second
var netIface *net.Interface
var err error
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
netIface, err = netInterface(p.runningIface)
if err == nil {
break
}
if attempt < maxDNSRetryAttempts {
// Try to find a different working interface
newIface := findWorkingInterface(p.runningIface)
if newIface != p.runningIface {
p.runningIface = newIface
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
logger.Info().Msg("switched to new interface")
continue
}
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
time.Sleep(retryDelay)
continue
}
logger.Error().Err(err).Msg("could not get interface after all attempts")
return
}
if err := setupNetworkManager(); err != nil {
@@ -766,6 +788,7 @@ func (p *prog) resetDNS() {
logger.Error().Err(err).Msg("could not get interface")
return
}
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return
@@ -781,6 +804,131 @@ func (p *prog) resetDNS() {
}
}
// findWorkingInterface looks for a network interface with a valid IP configuration
func findWorkingInterface(currentIface string) string {
// Helper to check if IP is valid (not link-local)
isValidIP := func(ip net.IP) bool {
return ip != nil &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!ip.IsLoopback() &&
!ip.IsUnspecified()
}
// Helper to check if interface has valid IP configuration
hasValidIPConfig := func(iface *net.Interface) bool {
if iface == nil || iface.Flags&net.FlagUp == 0 {
return false
}
addrs, err := iface.Addrs()
if err != nil {
mainLog.Load().Debug().
Str("interface", iface.Name).
Err(err).
Msg("failed to get interface addresses")
return false
}
for _, addr := range addrs {
// Check for IP network
if ipNet, ok := addr.(*net.IPNet); ok {
if isValidIP(ipNet.IP) {
return true
}
}
}
return false
}
// Get default route interface
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
mainLog.Load().Debug().
Err(err).
Msg("failed to get default route")
} else {
mainLog.Load().Debug().
Str("default_route_iface", defaultRoute.InterfaceName).
Msg("found default route")
}
// Get all interfaces
ifaces, err := net.Interfaces()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
return currentIface // Return current interface as fallback
}
var firstWorkingIface string
var currentIfaceValid bool
// Single pass through interfaces
for _, iface := range ifaces {
// Must be physical (has MAC address)
if len(iface.HardwareAddr) == 0 {
continue
}
// Skip interfaces that are:
// - Loopback
// - Not up
// - Point-to-point (like VPN tunnels)
if iface.Flags&net.FlagLoopback != 0 ||
iface.Flags&net.FlagUp == 0 ||
iface.Flags&net.FlagPointToPoint != 0 {
continue
}
if !hasValidIPConfig(&iface) {
continue
}
// Found working physical interface
if err == nil && defaultRoute.InterfaceName == iface.Name {
// Found interface with default route - use it immediately
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", iface.Name).
Msg("switching to interface with default route")
return iface.Name
}
// Keep track of first working interface as fallback
if firstWorkingIface == "" {
firstWorkingIface = iface.Name
}
// Check if this is our current interface
if iface.Name == currentIface {
currentIfaceValid = true
}
}
// Return interfaces in order of preference:
// 1. Current interface if it's still valid
if currentIfaceValid {
mainLog.Load().Debug().
Str("interface", currentIface).
Msg("keeping current interface")
return currentIface
}
// 2. First working interface found
if firstWorkingIface != "" {
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", firstWorkingIface).
Msg("switching to first working physical interface")
return firstWorkingIface
}
// 3. Fall back to current interface if nothing else works
mainLog.Load().Warn().
Str("current_iface", currentIface).
Msg("no working physical interface found, keeping current")
return currentIface
}
// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
@@ -1049,7 +1197,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
func savedStaticNameservers(iface *net.Interface) []string {
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
return strings.Split(string(data), ",")
saveValues := strings.Split(string(data), ",")
returnValues := []string{}
// check each one, if its in loopback range, remove it
for _, v := range saveValues {
if net.ParseIP(v).IsLoopback() {
continue
}
returnValues = append(returnValues, v)
}
return returnValues
}
return nil
}

View File

@@ -886,3 +886,12 @@ func upstreamUID() string {
return hex.EncodeToString(b)
}
}
// String returns a string representation of the UpstreamConfig for logging.
func (uc *UpstreamConfig) String() string {
if uc == nil {
return "<nil>"
}
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
}

View File

@@ -147,16 +147,82 @@ var testNameServerFn = testNameserver
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
func testNameserver(addr string) bool {
msg := new(dns.Msg)
msg.SetQuestion("controld.com.", dns.TypeNS)
client := new(dns.Client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
if err != nil {
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
// Skip link-local addresses without scope IDs and deprecated site-local addresses
if ip, err := netip.ParseAddr(addr); err == nil {
if ip.Is6() {
if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
ProxyLogger.Load().Debug().
Str("nameserver", addr).
Msg("skipping link-local IPv6 address without scope ID")
return false
}
// Skip deprecated site-local addresses (fec0::/10)
if strings.HasPrefix(ip.String(), "fec0:") {
ProxyLogger.Load().Debug().
Str("nameserver", addr).
Msg("skipping deprecated site-local IPv6 address")
return false
}
}
}
return err == nil
ProxyLogger.Load().Debug().
Str("input_addr", addr).
Msg("testing nameserver")
// Handle both IPv4 and IPv6 addresses
serverAddr := addr
host, port, err := net.SplitHostPort(addr)
if err != nil {
// No port in address, add default port 53
serverAddr = net.JoinHostPort(addr, "53")
} else if port == "" {
// Has split markers but empty port
serverAddr = net.JoinHostPort(host, "53")
}
ProxyLogger.Load().Debug().
Str("server_addr", serverAddr).
Msg("using server address")
// Test domains that are likely to exist and respond quickly
testDomains := []struct {
name string
qtype uint16
}{
{".", dns.TypeNS}, // Root NS query - should always work
{"controld.com.", dns.TypeA}, // Fallback to a reliable domain
{"google.com.", dns.TypeA}, // Fallback to a reliable domain
}
client := &dns.Client{
Timeout: 2 * time.Second,
Net: "udp",
}
// Try each test query until one succeeds
for _, test := range testDomains {
msg := new(dns.Msg)
msg.SetQuestion(test.name, test.qtype)
msg.RecursionDesired = true
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
cancel()
if err == nil && resp != nil {
return true
}
ProxyLogger.Load().Error().
Err(err).
Str("nameserver", serverAddr).
Str("test_domain", test.name).
Str("query_type", dns.TypeToString[test.qtype]).
Msg("DNS availability test failed")
}
return false
}
// Resolver is the interface that wraps the basic DNS operations.
@@ -222,7 +288,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
ctx, cancel := context.WithCancel(ctx)
defer cancel()
dnsClient := &dns.Client{Net: "udp"}
dnsClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
ch := make(chan *osResolverResult, numServers)
wg := &sync.WaitGroup{}
wg.Add(numServers)
@@ -264,11 +330,14 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
switch {
case res.server == controldPublicDnsWithPort:
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server)
controldSuccessAnswer = res.answer
case !res.lan && publicServerAnswer == nil:
publicServerAnswer = res.answer // use public DNS answer after LAN server..
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server)
publicServerAnswer = res.answer
publicServer = res.server
default:
Log(ctx, ProxyLogger.Load().Debug(), "got LAN answer from: %s", res.server)
cancel()
logAnswer(res.server)
return res.answer, nil
@@ -276,6 +345,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
case res.answer != nil:
nonSuccessAnswer = res.answer
nonSuccessServer = res.server
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
res.server, res.answer.Rcode)
}
errs = append(errs, res.err)
}