From 5a88a7c22c4259e77e51e4b95c35821b4a58d259 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 20 Sep 2024 21:33:44 +0700 Subject: [PATCH] cmd/cli: decouple reset DNS task from ctrld status So it can be run regardless of ctrld current status. This prevents a racy behavior when reset DNS task restores DNS settings of the system, but current running ctrld process may revert it immediately. --- cmd/cli/cli.go | 50 +++++++++++------ cmd/cli/dns_proxy.go | 2 + cmd/cli/os_windows.go | 1 + cmd/cli/prog.go | 124 +++++++++++++++++++++++------------------- cmd/cli/resolvconf.go | 3 + 5 files changed, 106 insertions(+), 74 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 007f45e..d07c145 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -194,11 +194,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c isCtrldRunning := status == service.StatusRunning isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + // Get current running iface, if any. + var currentIface string + // If pin code was set, do not allow running start command. if isCtrldRunning { if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } + currentIface = runningIface(s) } if !startOnly { @@ -213,12 +217,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c initLogging() tasks := []task{ - resetDnsTask(p, s), {s.Stop, false}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { - return saveCurrentStaticDNS(i) + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil }) return nil }, false}, @@ -334,14 +341,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c } tasks := []task{ - resetDnsTask(p, s), {s.Stop, false}, {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, {func() error { return ensureUninstall(s) }, false}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { - return saveCurrentStaticDNS(i) + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil }) return nil }, false}, @@ -1340,9 +1350,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { close(waitCh) <-stopCh - // Wait goroutines which watches/manipulates DNS settings terminated, - // ensuring that changes to DNS since here won't be reverted. - p.dnsWg.Wait() + p.stopDnsWatchers() for _, f := range p.onStopped { f() } @@ -2642,17 +2650,20 @@ func runningIface(s service.Service) string { // resetDnsNoLog performs resetting DNS with logging disable. func resetDnsNoLog(p *prog) { - lvl := zerolog.GlobalLevel() - zerolog.SetGlobalLevel(zerolog.Disabled) + // Normally, disable log to prevent annoying users. + if verbose < 3 { + lvl := zerolog.GlobalLevel() + zerolog.SetGlobalLevel(zerolog.Disabled) + p.resetDNS() + zerolog.SetGlobalLevel(lvl) + return + } + // For debugging purpose, still emit log. p.resetDNS() - zerolog.SetGlobalLevel(lvl) } // resetDnsTask returns a task which perform reset DNS operation. -func resetDnsTask(p *prog, s service.Service) task { - status, err := s.Status() - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) - isCtrldRunning := status == service.StatusRunning +func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunningIface string) task { return task{func() error { if iface == "" { return nil @@ -2662,11 +2673,14 @@ func resetDnsTask(p *prog, s service.Service) task { // process to reset what setDNS has done properly. oldIface := iface iface = "auto" - if isCtrldRunning { - iface = runningIface(s) + if currentRunningIface != "" { + iface = currentRunningIface } if isCtrldInstalled { mainLog.Load().Debug().Msg("restore system DNS settings") + if status, _ := s.Status(); status == service.StatusRunning { + mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe") + } resetDnsNoLog(p) } iface = oldIface diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index b9eb8f5..5652f07 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -915,6 +915,8 @@ func (p *prog) performCaptivePortalDetection() { if found { resetDnsOnce.Do(func() { mainLog.Load().Warn().Msg("found captive portal, leaking query to OS resolver") + // Store the result once here, so changes made below won't be reverted by DNS watchers. + p.captivePortalDetected.Store(found) p.resetDNS() }) } diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 234764f..b9412b6 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -119,6 +119,7 @@ func resetDNS(iface *net.Interface) error { if len(ns) == 0 { continue } + mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) if err := setDNS(iface, ns); err != nil { return err } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 781edd1..a87d7e8 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -69,21 +69,21 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - reloadCh chan struct{} // For Windows. - reloadDoneCh chan struct{} - apiReloadCh chan *ctrld.Config - apiForceReloadCh chan struct{} - apiForceReloadGroup singleflight.Group - logConn net.Conn - cs *controlServer - csSetDnsDone chan struct{} - csSetDnsOk bool - dnsWatchDogOnce sync.Once - dnsWg sync.WaitGroup - dnsWatcherStopCh chan struct{} + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + apiReloadCh chan *ctrld.Config + apiForceReloadCh chan struct{} + apiForceReloadGroup singleflight.Group + logConn net.Conn + cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool + dnsWg sync.WaitGroup + dnsWatcherClosedOnce sync.Once + dnsWatcherStopCh chan struct{} cfg *ctrld.Config localUpstreams []string @@ -512,6 +512,8 @@ func (p *prog) metricsEnabled() bool { } func (p *prog) Stop(s service.Service) error { + p.stopDnsWatchers() + mainLog.Load().Debug().Msg("dns watchers stopped") mainLog.Load().Info().Msg("Service stopped") close(p.stopCh) if err := p.deAllocateIP(); err != nil { @@ -521,6 +523,15 @@ func (p *prog) Stop(s service.Service) error { return nil } +func (p *prog) stopDnsWatchers() { + // Ensure all DNS watchers goroutine are terminated, + // so it won't mess up with other DNS changes. + p.dnsWatcherClosedOnce.Do(func() { + close(p.dnsWatcherStopCh) + }) + p.dnsWg.Wait() +} + func (p *prog) allocateIP(ip string) error { p.mu.Lock() defer p.mu.Unlock() @@ -611,6 +622,11 @@ func (p *prog) setDNS() { } setDnsOK = true logger.Debug().Msg("setting DNS successfully") + if allIfaces { + withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error { + return setDnsIgnoreUnusableInterface(i, nameservers) + }) + } if shouldWatchResolvconf() { servers := make([]netip.Addr, len(nameservers)) for i := range nameservers { @@ -622,11 +638,6 @@ func (p *prog) setDNS() { p.watchResolvConf(netIface, servers, setResolvConf) }() } - if allIfaces { - withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error { - return setDnsIgnoreUnusableInterface(i, nameservers) - }) - } if p.dnsWatchdogEnabled() { p.dnsWg.Add(1) go func() { @@ -661,41 +672,42 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces return } - p.dnsWatchDogOnce.Do(func() { - mainLog.Load().Debug().Msg("start DNS settings watchdog") - ns := nameservers - slices.Sort(ns) - ticker := time.NewTicker(p.dnsWatchdogDuration()) - logger := mainLog.Load().With().Str("iface", iface.Name).Logger() - for { - select { - case <-p.dnsWatcherStopCh: + mainLog.Load().Debug().Msg("start DNS settings watchdog") + ns := nameservers + slices.Sort(ns) + ticker := time.NewTicker(p.dnsWatchdogDuration()) + logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + for { + select { + case <-p.dnsWatcherStopCh: + return + case <-p.stopCh: + mainLog.Load().Debug().Msg("stop dns watchdog") + return + case <-ticker.C: + if p.captivePortalDetected.Load() { return - case <-p.stopCh: - mainLog.Load().Debug().Msg("stop dns watchdog") - return - case <-ticker.C: - if dnsChanged(iface, ns) { - logger.Debug().Msg("DNS settings were changed, re-applying settings") - if err := setDNS(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") - } - } - if allIfaces { - withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error { - if dnsChanged(i, ns) { - if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { - mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") - } else { - mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) - } - } - return nil - }) + } + if dnsChanged(iface, ns) { + logger.Debug().Msg("DNS settings were changed, re-applying settings") + if err := setDNS(iface, ns); err != nil { + mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") } } + if allIfaces { + withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error { + if dnsChanged(i, ns) { + if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { + mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") + } else { + mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) + } + } + return nil + }) + } } - }) + } } func (p *prog) resetDNS() { @@ -965,11 +977,13 @@ func saveCurrentStaticDNS(iface *net.Interface) error { if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file") } - mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name) - if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil { + nss := strings.Join(ns, ",") + mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss) + if err := os.WriteFile(file, []byte(nss), 0600); err != nil { mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name) return err } + mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name) return nil } @@ -1005,9 +1019,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool { func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { var uer *controld.UtilityErrorResponse if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { - // Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall. - close(p.dnsWatcherStopCh) - p.dnsWg.Wait() + p.stopDnsWatchers() // Perform self-uninstall now. selfUninstall(p, logger) diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 5be34fc..21e435d 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -40,6 +40,9 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: + if p.captivePortalDetected.Load() { + return + } if !ok { return }