cmd/cli: decouple reset DNS task from ctrld status

So it can be run regardless of ctrld current status. This prevents a
racy behavior when reset DNS task restores DNS settings of the system,
but current running ctrld process may revert it immediately.
This commit is contained in:
Cuong Manh Le
2024-09-20 21:33:44 +07:00
committed by Cuong Manh Le
parent 8c661c4401
commit 5a88a7c22c
5 changed files with 106 additions and 74 deletions

View File

@@ -194,11 +194,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
isCtrldRunning := status == service.StatusRunning
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
// Get current running iface, if any.
var currentIface string
// If pin code was set, do not allow running start command.
if isCtrldRunning {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
currentIface = runningIface(s)
}
if !startOnly {
@@ -213,12 +217,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
initLogging()
tasks := []task{
resetDnsTask(p, s),
{s.Stop, false},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
return saveCurrentStaticDNS(i)
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
return nil
})
return nil
}, false},
@@ -334,14 +341,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c
}
tasks := []task{
resetDnsTask(p, s),
{s.Stop, false},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
{func() error { return ensureUninstall(s) }, false},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
return saveCurrentStaticDNS(i)
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
return nil
})
return nil
}, false},
@@ -1340,9 +1350,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
close(waitCh)
<-stopCh
// Wait goroutines which watches/manipulates DNS settings terminated,
// ensuring that changes to DNS since here won't be reverted.
p.dnsWg.Wait()
p.stopDnsWatchers()
for _, f := range p.onStopped {
f()
}
@@ -2642,17 +2650,20 @@ func runningIface(s service.Service) string {
// resetDnsNoLog performs resetting DNS with logging disable.
func resetDnsNoLog(p *prog) {
lvl := zerolog.GlobalLevel()
zerolog.SetGlobalLevel(zerolog.Disabled)
// Normally, disable log to prevent annoying users.
if verbose < 3 {
lvl := zerolog.GlobalLevel()
zerolog.SetGlobalLevel(zerolog.Disabled)
p.resetDNS()
zerolog.SetGlobalLevel(lvl)
return
}
// For debugging purpose, still emit log.
p.resetDNS()
zerolog.SetGlobalLevel(lvl)
}
// resetDnsTask returns a task which perform reset DNS operation.
func resetDnsTask(p *prog, s service.Service) task {
status, err := s.Status()
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
isCtrldRunning := status == service.StatusRunning
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunningIface string) task {
return task{func() error {
if iface == "" {
return nil
@@ -2662,11 +2673,14 @@ func resetDnsTask(p *prog, s service.Service) task {
// process to reset what setDNS has done properly.
oldIface := iface
iface = "auto"
if isCtrldRunning {
iface = runningIface(s)
if currentRunningIface != "" {
iface = currentRunningIface
}
if isCtrldInstalled {
mainLog.Load().Debug().Msg("restore system DNS settings")
if status, _ := s.Status(); status == service.StatusRunning {
mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe")
}
resetDnsNoLog(p)
}
iface = oldIface

View File

@@ -915,6 +915,8 @@ func (p *prog) performCaptivePortalDetection() {
if found {
resetDnsOnce.Do(func() {
mainLog.Load().Warn().Msg("found captive portal, leaking query to OS resolver")
// Store the result once here, so changes made below won't be reverted by DNS watchers.
p.captivePortalDetected.Store(found)
p.resetDNS()
})
}

View File

@@ -119,6 +119,7 @@ func resetDNS(iface *net.Interface) error {
if len(ns) == 0 {
continue
}
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
if err := setDNS(iface, ns); err != nil {
return err
}

View File

@@ -69,21 +69,21 @@ var svcConfig = &service.Config{
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
apiReloadCh chan *ctrld.Config
apiForceReloadCh chan struct{}
apiForceReloadGroup singleflight.Group
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
dnsWatchDogOnce sync.Once
dnsWg sync.WaitGroup
dnsWatcherStopCh chan struct{}
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
apiReloadCh chan *ctrld.Config
apiForceReloadCh chan struct{}
apiForceReloadGroup singleflight.Group
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
dnsWg sync.WaitGroup
dnsWatcherClosedOnce sync.Once
dnsWatcherStopCh chan struct{}
cfg *ctrld.Config
localUpstreams []string
@@ -512,6 +512,8 @@ func (p *prog) metricsEnabled() bool {
}
func (p *prog) Stop(s service.Service) error {
p.stopDnsWatchers()
mainLog.Load().Debug().Msg("dns watchers stopped")
mainLog.Load().Info().Msg("Service stopped")
close(p.stopCh)
if err := p.deAllocateIP(); err != nil {
@@ -521,6 +523,15 @@ func (p *prog) Stop(s service.Service) error {
return nil
}
func (p *prog) stopDnsWatchers() {
// Ensure all DNS watchers goroutine are terminated,
// so it won't mess up with other DNS changes.
p.dnsWatcherClosedOnce.Do(func() {
close(p.dnsWatcherStopCh)
})
p.dnsWg.Wait()
}
func (p *prog) allocateIP(ip string) error {
p.mu.Lock()
defer p.mu.Unlock()
@@ -611,6 +622,11 @@ func (p *prog) setDNS() {
}
setDnsOK = true
logger.Debug().Msg("setting DNS successfully")
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
if shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
for i := range nameservers {
@@ -622,11 +638,6 @@ func (p *prog) setDNS() {
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
@@ -661,41 +672,42 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
return
}
p.dnsWatchDogOnce.Do(func() {
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
for {
select {
case <-p.dnsWatcherStopCh:
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
for {
select {
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.captivePortalDetected.Load() {
return
case <-p.stopCh:
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if dnsChanged(iface, ns) {
logger.Debug().Msg("DNS settings were changed, re-applying settings")
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if allIfaces {
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
if dnsChanged(i, ns) {
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
}
}
return nil
})
}
if dnsChanged(iface, ns) {
logger.Debug().Msg("DNS settings were changed, re-applying settings")
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if allIfaces {
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
if dnsChanged(i, ns) {
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
}
}
return nil
})
}
}
})
}
}
func (p *prog) resetDNS() {
@@ -965,11 +977,13 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file")
}
mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name)
if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil {
nss := strings.Join(ns, ",")
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
if err := os.WriteFile(file, []byte(nss), 0600); err != nil {
mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name)
return err
}
mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name)
return nil
}
@@ -1005,9 +1019,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool {
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
var uer *controld.UtilityErrorResponse
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
// Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall.
close(p.dnsWatcherStopCh)
p.dnsWg.Wait()
p.stopDnsWatchers()
// Perform self-uninstall now.
selfUninstall(p, logger)

View File

@@ -40,6 +40,9 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if p.captivePortalDetected.Load() {
return
}
if !ok {
return
}