mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Don't automatically restore saved DNS settings when switching networks
smol tweaks to nameserver test queries fix restoreDNS errors add some debugging information fix wront type in log msg set send logs command timeout to 5 mins when the runningIface is no longer up, attempt to find a new interface prefer default route, ignore non physical interfaces prefer default route, ignore non physical interfaces add max context timeout on performLeakingQuery with more debug logs
This commit is contained in:
committed by
Cuong Manh Le
parent
e9e63b0983
commit
7833132917
@@ -1029,6 +1029,16 @@ func uninstall(p *prog, s service.Service) {
|
||||
return
|
||||
}
|
||||
p.resetDNS()
|
||||
|
||||
// if present restore the original DNS settings
|
||||
if netIface, err := netInterface(p.runningIface); err == nil {
|
||||
if err := restoreDNS(netIface); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
|
||||
}
|
||||
}
|
||||
|
||||
if router.Name() != "" {
|
||||
mainLog.Load().Debug().Msg("Router cleanup")
|
||||
}
|
||||
|
||||
@@ -541,6 +541,16 @@ func initStopCmd() *cobra.Command {
|
||||
if doTasks([]task{{s.Stop, true}}) {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
|
||||
// restore DNS settings
|
||||
if netIface, err := netInterface(p.runningIface); err == nil {
|
||||
if err := restoreDNS(netIface); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
|
||||
}
|
||||
}
|
||||
|
||||
if router.WaitProcessExited() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
@@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ const (
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/logs/view"
|
||||
sendLogsPath = "/logs/send"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
|
||||
@@ -542,8 +542,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n])
|
||||
|
||||
if p.isLoop(upstreamConfig) {
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String())
|
||||
continue
|
||||
}
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
@@ -929,6 +931,11 @@ func (p *prog) selfUninstallCoolOfPeriod() {
|
||||
// performLeakingQuery performs necessary works to leak queries to OS resolver.
|
||||
func (p *prog) performLeakingQuery() {
|
||||
mainLog.Load().Warn().Msg("leaking query to OS resolver")
|
||||
|
||||
// Create a context with timeout for the entire operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Signal dns watchers to stop, so changes made below won't be reverted.
|
||||
p.leakingQuery.Store(true)
|
||||
defer func() {
|
||||
@@ -936,20 +943,81 @@ func (p *prog) performLeakingQuery() {
|
||||
p.leakingQueryMu.Lock()
|
||||
p.leakingQueryWasRun = false
|
||||
p.leakingQueryMu.Unlock()
|
||||
mainLog.Load().Warn().Msg("stop leaking query")
|
||||
}()
|
||||
// Reset DNS, so queries are forwarded to OS resolver normally.
|
||||
p.resetDNS()
|
||||
// Check remote upstream in background, so ctrld could be back to normal
|
||||
// operation as long as the network is back online.
|
||||
for name, uc := range p.cfg.Upstream {
|
||||
p.checkUpstream(name, uc)
|
||||
|
||||
// Create channels to coordinate operations
|
||||
resetDone := make(chan struct{})
|
||||
checkDone := make(chan struct{})
|
||||
|
||||
// Reset DNS with timeout
|
||||
go func() {
|
||||
defer close(resetDone)
|
||||
mainLog.Load().Debug().Msg("attempting to reset DNS")
|
||||
p.resetDNS()
|
||||
mainLog.Load().Debug().Msg("DNS reset completed")
|
||||
}()
|
||||
|
||||
// Wait for reset with timeout
|
||||
select {
|
||||
case <-resetDone:
|
||||
mainLog.Load().Debug().Msg("DNS reset successful")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("DNS reset timed out")
|
||||
return
|
||||
}
|
||||
// After all upstream back, re-initializing OS resolver.
|
||||
|
||||
// Check upstream in background with progress tracking
|
||||
go func() {
|
||||
defer close(checkDone)
|
||||
mainLog.Load().Debug().Msg("starting upstream checks")
|
||||
for name, uc := range p.cfg.Upstream {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
mainLog.Load().Debug().
|
||||
Str("upstream", name).
|
||||
Msg("checking upstream")
|
||||
p.checkUpstream(name, uc)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("upstream checks completed")
|
||||
}()
|
||||
|
||||
// Wait for upstream checks
|
||||
select {
|
||||
case <-checkDone:
|
||||
mainLog.Load().Debug().Msg("upstream checks successful")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("upstream checks timed out")
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize OS resolver with timeout
|
||||
mainLog.Load().Debug().Msg("initializing OS resolver")
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
p.dnsWg.Wait()
|
||||
|
||||
// Wait for DNS operations to complete
|
||||
waitCh := make(chan struct{})
|
||||
go func() {
|
||||
p.dnsWg.Wait()
|
||||
close(waitCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-waitCh:
|
||||
mainLog.Load().Debug().Msg("DNS operations completed")
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("DNS operations timed out")
|
||||
return
|
||||
}
|
||||
|
||||
// Set DNS with timeout
|
||||
mainLog.Load().Debug().Msg("setting DNS configuration")
|
||||
p.setDNS()
|
||||
mainLog.Load().Warn().Msg("stop leaking query")
|
||||
mainLog.Load().Debug().Msg("DNS configuration set successfully")
|
||||
}
|
||||
|
||||
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
||||
|
||||
@@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
if err := setDNS(iface, ns); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
@@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
|
||||
@@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
|
||||
@@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
|
||||
@@ -130,8 +130,12 @@ func resetDNS(iface *net.Interface) error {
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there's static DNS saved, restoring it.
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
@@ -148,12 +152,14 @@ func resetDNS(iface *net.Interface) error {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
err = setDNS(iface, ns)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
|
||||
165
cmd/cli/prog.go
165
cmd/cli/prog.go
@@ -626,9 +626,31 @@ func (p *prog) setDNS() {
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
netIface, err := netInterface(p.runningIface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
|
||||
const maxDNSRetryAttempts = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
var netIface *net.Interface
|
||||
var err error
|
||||
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
|
||||
netIface, err = netInterface(p.runningIface)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if attempt < maxDNSRetryAttempts {
|
||||
// Try to find a different working interface
|
||||
newIface := findWorkingInterface(p.runningIface)
|
||||
if newIface != p.runningIface {
|
||||
p.runningIface = newIface
|
||||
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger.Info().Msg("switched to new interface")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
logger.Error().Err(err).Msg("could not get interface after all attempts")
|
||||
return
|
||||
}
|
||||
if err := setupNetworkManager(); err != nil {
|
||||
@@ -766,6 +788,7 @@ func (p *prog) resetDNS() {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
@@ -781,6 +804,131 @@ func (p *prog) resetDNS() {
|
||||
}
|
||||
}
|
||||
|
||||
// findWorkingInterface looks for a network interface with a valid IP configuration
|
||||
func findWorkingInterface(currentIface string) string {
|
||||
// Helper to check if IP is valid (not link-local)
|
||||
isValidIP := func(ip net.IP) bool {
|
||||
return ip != nil &&
|
||||
!ip.IsLinkLocalUnicast() &&
|
||||
!ip.IsLinkLocalMulticast() &&
|
||||
!ip.IsLoopback() &&
|
||||
!ip.IsUnspecified()
|
||||
}
|
||||
|
||||
// Helper to check if interface has valid IP configuration
|
||||
hasValidIPConfig := func(iface *net.Interface) bool {
|
||||
if iface == nil || iface.Flags&net.FlagUp == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", iface.Name).
|
||||
Err(err).
|
||||
Msg("failed to get interface addresses")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
// Check for IP network
|
||||
if ipNet, ok := addr.(*net.IPNet); ok {
|
||||
if isValidIP(ipNet.IP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Get default route interface
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Msg("failed to get default route")
|
||||
} else {
|
||||
mainLog.Load().Debug().
|
||||
Str("default_route_iface", defaultRoute.InterfaceName).
|
||||
Msg("found default route")
|
||||
}
|
||||
|
||||
// Get all interfaces
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
|
||||
return currentIface // Return current interface as fallback
|
||||
}
|
||||
|
||||
var firstWorkingIface string
|
||||
var currentIfaceValid bool
|
||||
|
||||
// Single pass through interfaces
|
||||
for _, iface := range ifaces {
|
||||
// Must be physical (has MAC address)
|
||||
if len(iface.HardwareAddr) == 0 {
|
||||
continue
|
||||
}
|
||||
// Skip interfaces that are:
|
||||
// - Loopback
|
||||
// - Not up
|
||||
// - Point-to-point (like VPN tunnels)
|
||||
if iface.Flags&net.FlagLoopback != 0 ||
|
||||
iface.Flags&net.FlagUp == 0 ||
|
||||
iface.Flags&net.FlagPointToPoint != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasValidIPConfig(&iface) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found working physical interface
|
||||
if err == nil && defaultRoute.InterfaceName == iface.Name {
|
||||
// Found interface with default route - use it immediately
|
||||
mainLog.Load().Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", iface.Name).
|
||||
Msg("switching to interface with default route")
|
||||
return iface.Name
|
||||
}
|
||||
|
||||
// Keep track of first working interface as fallback
|
||||
if firstWorkingIface == "" {
|
||||
firstWorkingIface = iface.Name
|
||||
}
|
||||
|
||||
// Check if this is our current interface
|
||||
if iface.Name == currentIface {
|
||||
currentIfaceValid = true
|
||||
}
|
||||
}
|
||||
|
||||
// Return interfaces in order of preference:
|
||||
// 1. Current interface if it's still valid
|
||||
if currentIfaceValid {
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", currentIface).
|
||||
Msg("keeping current interface")
|
||||
return currentIface
|
||||
}
|
||||
|
||||
// 2. First working interface found
|
||||
if firstWorkingIface != "" {
|
||||
mainLog.Load().Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", firstWorkingIface).
|
||||
Msg("switching to first working physical interface")
|
||||
return firstWorkingIface
|
||||
}
|
||||
|
||||
// 3. Fall back to current interface if nothing else works
|
||||
mainLog.Load().Warn().
|
||||
Str("current_iface", currentIface).
|
||||
Msg("no working physical interface found, keeping current")
|
||||
return currentIface
|
||||
}
|
||||
|
||||
// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams.
|
||||
func (p *prog) leakOnUpstreamFailure() bool {
|
||||
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
|
||||
@@ -1049,7 +1197,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
||||
func savedStaticNameservers(iface *net.Interface) []string {
|
||||
file := savedStaticDnsSettingsFilePath(iface)
|
||||
if data, _ := os.ReadFile(file); len(data) > 0 {
|
||||
return strings.Split(string(data), ",")
|
||||
saveValues := strings.Split(string(data), ",")
|
||||
returnValues := []string{}
|
||||
// check each one, if its in loopback range, remove it
|
||||
for _, v := range saveValues {
|
||||
if net.ParseIP(v).IsLoopback() {
|
||||
continue
|
||||
}
|
||||
returnValues = append(returnValues, v)
|
||||
}
|
||||
return returnValues
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user