diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 2317840..2f4916c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1135,13 +1135,14 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } waitCh := make(chan struct{}) p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - reloadCh: make(chan struct{}), - reloadDoneCh: make(chan struct{}), - apiReloadCh: make(chan *ctrld.Config), - cfg: &cfg, - appCallback: appCallback, + waitCh: waitCh, + stopCh: stopCh, + reloadCh: make(chan struct{}), + reloadDoneCh: make(chan struct{}), + dnsWatcherStopCh: make(chan struct{}), + apiReloadCh: make(chan *ctrld.Config), + cfg: &cfg, + appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { @@ -1232,7 +1233,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() - _ = uninstallIfInvalidCdUID(err, p, cdLogger) + // Performs self-uninstallation if the ControlD device does not exist. + var uer *controld.UtilityErrorResponse + if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { + _ = uninstallInvalidCdUID(p, cdLogger, false) + } cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") } } @@ -2696,23 +2701,23 @@ func doValidateCdRemoteConfig(cdUID string) { v = oldV } -// uninstallIfInvalidCdUID performs self-uninstallation if the ControlD device does not exist. -func uninstallIfInvalidCdUID(err error, p *prog, logger zerolog.Logger) bool { - var uer *controld.UtilityErrorResponse - if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { - s, err := newService(p, svcConfig) - if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") - return false - } +// uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. +func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { + s, err := newService(p, svcConfig) + if err != nil { + logger.Warn().Err(err).Msg("failed to create new service") + return false + } - p.resetDNS() + p.resetDNS() - tasks := []task{{s.Uninstall, true}} - if doTasks(tasks) { - logger.Info().Msg("uninstalled service") - return true + tasks := []task{{s.Uninstall, true}} + if doTasks(tasks) { + logger.Info().Msg("uninstalled service") + if doStop { + _ = s.Stop() } + return true } return false } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 33e5ebb..a7c62af 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -863,7 +863,7 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { p.checkingSelfUninstall = true _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") - selfUninstall(err, p, logger) + selfUninstallCheck(err, p, logger) if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d2b9ed3..82daa24 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -22,6 +22,7 @@ import ( "time" "github.com/kardianos/service" + "github.com/rs/zerolog" "github.com/spf13/viper" "tailscale.com/net/interfaces" "tailscale.com/net/tsaddr" @@ -67,18 +68,19 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - reloadCh chan struct{} // For Windows. - reloadDoneCh chan struct{} - apiReloadCh chan *ctrld.Config - logConn net.Conn - cs *controlServer - csSetDnsDone chan struct{} - csSetDnsOk bool - dnsWatchDogOnce sync.Once - dnsWg sync.WaitGroup + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + apiReloadCh chan *ctrld.Config + logConn net.Conn + cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool + dnsWatchDogOnce sync.Once + dnsWg sync.WaitGroup + dnsWatcherStopCh chan struct{} cfg *ctrld.Config localUpstreams []string @@ -261,7 +263,7 @@ func (p *prog) apiConfigReload() { select { case <-ticker.C: resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) - selfUninstall(err, p, logger) + selfUninstallCheck(err, p, logger) if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") continue @@ -650,6 +652,8 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces logger := mainLog.Load().With().Str("iface", iface.Name).Logger() for { select { + case <-p.dnsWatcherStopCh: + return case <-p.stopCh: mainLog.Load().Debug().Msg("stop dns watchdog") return @@ -975,3 +979,16 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool { slices.Sort(curNameservers) return !slices.Equal(curNameservers, nameservers) } + +// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then. +func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { + var uer *controld.UtilityErrorResponse + if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { + // Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall. + close(p.dnsWatcherStopCh) + p.dnsWg.Wait() + + // Perform self-uninstall now. + selfUninstall(p, logger) + } +} diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6196487..5be34fc 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -34,6 +34,8 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f for { select { + case <-p.dnsWatcherStopCh: + return case <-p.stopCh: mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return diff --git a/cmd/cli/self_kill_others.go b/cmd/cli/self_kill_others.go index 1fe9b8a..e9fb1f8 100644 --- a/cmd/cli/self_kill_others.go +++ b/cmd/cli/self_kill_others.go @@ -8,8 +8,8 @@ import ( "github.com/rs/zerolog" ) -func selfUninstall(err error, p *prog, logger zerolog.Logger) { - if uninstallIfInvalidCdUID(err, p, logger) { +func selfUninstall(p *prog, logger zerolog.Logger) { + if uninstallInvalidCdUID(p, logger, false) { logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) os.Exit(0) } diff --git a/cmd/cli/self_kill_unix.go b/cmd/cli/self_kill_unix.go index a7dc1f1..9e494b4 100644 --- a/cmd/cli/self_kill_unix.go +++ b/cmd/cli/self_kill_unix.go @@ -3,54 +3,43 @@ package cli import ( - "errors" "fmt" "os" "os/exec" "runtime" "syscall" - "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/rs/zerolog" ) -func selfUninstall(uninstallErr error, p *prog, logger zerolog.Logger) { - var uer *controld.UtilityErrorResponse - if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { - if runtime.GOOS == "linux" { - s, err := newService(p, svcConfig) - if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") - } else { - selfUninstallLinux(uninstallErr, p, logger) - _ = s.Stop() - os.Exit(0) - } - } +func selfUninstall(p *prog, logger zerolog.Logger) { + if runtime.GOOS == "linux" { + selfUninstallLinux(p, logger) + } - bin, err := os.Executable() - if err != nil { - logger.Fatal().Err(err).Msg("could not determine executable") - } - args := []string{"uninstall"} - if !deactivationPinNotSet() { - args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin)) - } - cmd := exec.Command(bin, args...) - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - if err := cmd.Start(); err != nil { - logger.Fatal().Err(err).Msg("could not start self uninstall command") - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + bin, err := os.Executable() + if err != nil { + logger.Fatal().Err(err).Msg("could not determine executable") + } + args := []string{"uninstall"} + if !deactivationPinNotSet() { + args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin)) + } + cmd := exec.Command(bin, args...) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + if err := cmd.Start(); err != nil { + logger.Fatal().Err(err).Msg("could not start self uninstall command") + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + _ = cmd.Wait() + os.Exit(0) +} + +func selfUninstallLinux(p *prog, logger zerolog.Logger) { + if uninstallInvalidCdUID(p, logger, true) { logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) - _ = cmd.Wait() os.Exit(0) } } - -func selfUninstallLinux(err error, p *prog, logger zerolog.Logger) { - if uninstallIfInvalidCdUID(err, p, logger) { - logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) - } -}