mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
cmd/cli: decouple reset DNS task from ctrld status
So it can be run regardless of ctrld current status. This prevents a racy behavior when reset DNS task restores DNS settings of the system, but current running ctrld process may revert it immediately.
This commit is contained in:
committed by
Cuong Manh Le
parent
8c661c4401
commit
5a88a7c22c
@@ -194,11 +194,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
|
||||
// Get current running iface, if any.
|
||||
var currentIface string
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
currentIface = runningIface(s)
|
||||
}
|
||||
|
||||
if !startOnly {
|
||||
@@ -213,12 +217,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
|
||||
initLogging()
|
||||
tasks := []task{
|
||||
resetDnsTask(p, s),
|
||||
{s.Stop, false},
|
||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
|
||||
return saveCurrentStaticDNS(i)
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false},
|
||||
@@ -334,14 +341,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
}
|
||||
|
||||
tasks := []task{
|
||||
resetDnsTask(p, s),
|
||||
{s.Stop, false},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
|
||||
{func() error { return ensureUninstall(s) }, false},
|
||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
|
||||
return saveCurrentStaticDNS(i)
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false},
|
||||
@@ -1340,9 +1350,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
|
||||
// Wait goroutines which watches/manipulates DNS settings terminated,
|
||||
// ensuring that changes to DNS since here won't be reverted.
|
||||
p.dnsWg.Wait()
|
||||
p.stopDnsWatchers()
|
||||
for _, f := range p.onStopped {
|
||||
f()
|
||||
}
|
||||
@@ -2642,17 +2650,20 @@ func runningIface(s service.Service) string {
|
||||
|
||||
// resetDnsNoLog performs resetting DNS with logging disable.
|
||||
func resetDnsNoLog(p *prog) {
|
||||
lvl := zerolog.GlobalLevel()
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
// Normally, disable log to prevent annoying users.
|
||||
if verbose < 3 {
|
||||
lvl := zerolog.GlobalLevel()
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
p.resetDNS()
|
||||
zerolog.SetGlobalLevel(lvl)
|
||||
return
|
||||
}
|
||||
// For debugging purpose, still emit log.
|
||||
p.resetDNS()
|
||||
zerolog.SetGlobalLevel(lvl)
|
||||
}
|
||||
|
||||
// resetDnsTask returns a task which perform reset DNS operation.
|
||||
func resetDnsTask(p *prog, s service.Service) task {
|
||||
status, err := s.Status()
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunningIface string) task {
|
||||
return task{func() error {
|
||||
if iface == "" {
|
||||
return nil
|
||||
@@ -2662,11 +2673,14 @@ func resetDnsTask(p *prog, s service.Service) task {
|
||||
// process to reset what setDNS has done properly.
|
||||
oldIface := iface
|
||||
iface = "auto"
|
||||
if isCtrldRunning {
|
||||
iface = runningIface(s)
|
||||
if currentRunningIface != "" {
|
||||
iface = currentRunningIface
|
||||
}
|
||||
if isCtrldInstalled {
|
||||
mainLog.Load().Debug().Msg("restore system DNS settings")
|
||||
if status, _ := s.Status(); status == service.StatusRunning {
|
||||
mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe")
|
||||
}
|
||||
resetDnsNoLog(p)
|
||||
}
|
||||
iface = oldIface
|
||||
|
||||
@@ -915,6 +915,8 @@ func (p *prog) performCaptivePortalDetection() {
|
||||
if found {
|
||||
resetDnsOnce.Do(func() {
|
||||
mainLog.Load().Warn().Msg("found captive portal, leaking query to OS resolver")
|
||||
// Store the result once here, so changes made below won't be reverted by DNS watchers.
|
||||
p.captivePortalDetected.Store(found)
|
||||
p.resetDNS()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -119,6 +119,7 @@ func resetDNS(iface *net.Interface) error {
|
||||
if len(ns) == 0 {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
124
cmd/cli/prog.go
124
cmd/cli/prog.go
@@ -69,21 +69,21 @@ var svcConfig = &service.Config{
|
||||
var useSystemdResolved = false
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
apiReloadCh chan *ctrld.Config
|
||||
apiForceReloadCh chan struct{}
|
||||
apiForceReloadGroup singleflight.Group
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
csSetDnsDone chan struct{}
|
||||
csSetDnsOk bool
|
||||
dnsWatchDogOnce sync.Once
|
||||
dnsWg sync.WaitGroup
|
||||
dnsWatcherStopCh chan struct{}
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
apiReloadCh chan *ctrld.Config
|
||||
apiForceReloadCh chan struct{}
|
||||
apiForceReloadGroup singleflight.Group
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
csSetDnsDone chan struct{}
|
||||
csSetDnsOk bool
|
||||
dnsWg sync.WaitGroup
|
||||
dnsWatcherClosedOnce sync.Once
|
||||
dnsWatcherStopCh chan struct{}
|
||||
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
@@ -512,6 +512,8 @@ func (p *prog) metricsEnabled() bool {
|
||||
}
|
||||
|
||||
func (p *prog) Stop(s service.Service) error {
|
||||
p.stopDnsWatchers()
|
||||
mainLog.Load().Debug().Msg("dns watchers stopped")
|
||||
mainLog.Load().Info().Msg("Service stopped")
|
||||
close(p.stopCh)
|
||||
if err := p.deAllocateIP(); err != nil {
|
||||
@@ -521,6 +523,15 @@ func (p *prog) Stop(s service.Service) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) stopDnsWatchers() {
|
||||
// Ensure all DNS watchers goroutine are terminated,
|
||||
// so it won't mess up with other DNS changes.
|
||||
p.dnsWatcherClosedOnce.Do(func() {
|
||||
close(p.dnsWatcherStopCh)
|
||||
})
|
||||
p.dnsWg.Wait()
|
||||
}
|
||||
|
||||
func (p *prog) allocateIP(ip string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -611,6 +622,11 @@ func (p *prog) setDNS() {
|
||||
}
|
||||
setDnsOK = true
|
||||
logger.Debug().Msg("setting DNS successfully")
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||
})
|
||||
}
|
||||
if shouldWatchResolvconf() {
|
||||
servers := make([]netip.Addr, len(nameservers))
|
||||
for i := range nameservers {
|
||||
@@ -622,11 +638,6 @@ func (p *prog) setDNS() {
|
||||
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||
}()
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||
})
|
||||
}
|
||||
if p.dnsWatchdogEnabled() {
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
@@ -661,41 +672,42 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
return
|
||||
}
|
||||
|
||||
p.dnsWatchDogOnce.Do(func() {
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.captivePortalDetected.Load() {
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if dnsChanged(iface, ns) {
|
||||
logger.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
}
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
|
||||
if dnsChanged(i, ns) {
|
||||
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if dnsChanged(iface, ns) {
|
||||
logger.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
}
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
|
||||
if dnsChanged(i, ns) {
|
||||
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
@@ -965,11 +977,13 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
|
||||
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file")
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name)
|
||||
if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil {
|
||||
nss := strings.Join(ns, ",")
|
||||
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
|
||||
if err := os.WriteFile(file, []byte(nss), 0600); err != nil {
|
||||
mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name)
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1005,9 +1019,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
|
||||
var uer *controld.UtilityErrorResponse
|
||||
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
// Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall.
|
||||
close(p.dnsWatcherStopCh)
|
||||
p.dnsWg.Wait()
|
||||
p.stopDnsWatchers()
|
||||
|
||||
// Perform self-uninstall now.
|
||||
selfUninstall(p, logger)
|
||||
|
||||
@@ -40,6 +40,9 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.captivePortalDetected.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user