diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 2a9a4e9..c0b3fe1 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -199,6 +199,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p := &prog{ waitCh: waitCh, stopCh: stopCh, + pinCodeValidCh: make(chan struct{}, 1), reloadCh: make(chan struct{}), reloadDoneCh: make(chan struct{}), dnsWatcherStopCh: make(chan struct{}), @@ -421,19 +422,28 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := p.router.Cleanup(); err != nil { mainLog.Load().Error().Err(err).Msg("could not cleanup router") } - // restore static DNS settings or DHCP - p.resetDNS(false, true) }) } } + p.onStopped = append(p.onStopped, func() { + // restore static DNS settings or DHCP + p.resetDNS(false, true) + // Iterate over all physical interfaces and restore static DNS if a saved static config exists. + withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + if err := restoreDNS(i); err != nil { + mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + } else { + mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + } + } + return nil + }) + }) close(waitCh) <-stopCh - - p.stopDnsWatchers() - for _, f := range p.onStopped { - f() - } } func writeConfigFile(cfg *ctrld.Config) error { @@ -609,9 +619,9 @@ func init() { cdDeactivationPin.Store(defaultDeactivationPin) } -// deactivationPinNotSet reports whether cdDeactivationPin was not set by processCDFlags. -func deactivationPinNotSet() bool { - return cdDeactivationPin.Load() == defaultDeactivationPin +// deactivationPinSet indicates if cdDeactivationPin is non-default.. +func deactivationPinSet() bool { + return cdDeactivationPin.Load() != defaultDeactivationPin } func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 96e264b..048212a 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -629,23 +629,6 @@ func initStopCmd() *cobra.Command { os.Exit(deactivationPinInvalidExitCode) } if doTasks([]task{{s.Stop, true, "Stop"}}) { - p.router.Cleanup() - // restore static DNS settings or DHCP - p.resetDNS(false, true) - - // Iterate over all physical interfaces and restore static DNS if a saved static config exists. - withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) - if _, err := os.Stat(file); err == nil { - if err := restoreDNS(i); err != nil { - mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) - } else { - mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) - } - } - return nil - }) - if router.WaitProcessExited() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 17f585d..9281b90 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -228,7 +228,7 @@ func (p *prog) registerControlServerHandler() { } // If pin code not set, allowing deactivation. - if deactivationPinNotSet() { + if !deactivationPinSet() { w.WriteHeader(http.StatusOK) return } @@ -244,6 +244,10 @@ func (p *prog) registerControlServerHandler() { switch req.Pin { case cdDeactivationPin.Load(): code = http.StatusOK + select { + case p.pinCodeValidCh <- struct{}{}: + default: + } case defaultDeactivationPin: // If the pin code was set, but users do not provide --pin, return proper code to client. code = http.StatusBadRequest diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 9c2fb11..089bfd0 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -87,6 +87,7 @@ type prog struct { mu sync.Mutex waitCh chan struct{} stopCh chan struct{} + pinCodeValidCh chan struct{} reloadCh chan struct{} // For Windows. reloadDoneCh chan struct{} apiReloadCh chan *ctrld.Config @@ -268,13 +269,6 @@ func (p *prog) preRun() { p.requiredMultiNICsConfig = requiredMultiNICsConfig() } p.runningIface = iface - if runtime.GOOS == "darwin" { - p.onStopped = append(p.onStopped, func() { - if !service.Interactive() { - p.resetDNS(false, true) - } - }) - } } func (p *prog) postRun() { @@ -622,14 +616,41 @@ func (p *prog) metricsEnabled() bool { func (p *prog) Stop(s service.Service) error { p.stopDnsWatchers() mainLog.Load().Debug().Msg("dns watchers stopped") + for _, f := range p.onStopped { + f() + } + mainLog.Load().Debug().Msg("finish running onStopped functions") defer func() { mainLog.Load().Info().Msg("Service stopped") }() - close(p.stopCh) if err := p.deAllocateIP(); err != nil { mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") return err } + if deactivationPinSet() { + select { + case <-p.pinCodeValidCh: + // Allow stopping the service, pinCodeValidCh is only filled + // after control server did validate the pin code. + case <-time.After(time.Millisecond * 100): + // No valid pin code was checked, that mean we are stopping + // because of OS signal sent directly from someone else. + // In this case, restarting ctrld service by ourselves. + mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code") + mainLog.Load().Debug().Msgf("self restarting ctrld service") + if exe, err := os.Executable(); err == nil { + cmd := exec.Command(exe, "restart") + cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() + if err := cmd.Start(); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to run self restart command") + } + } else { + mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service") + } + os.Exit(deactivationPinInvalidExitCode) + } + } + close(p.stopCh) return nil } @@ -1471,7 +1492,7 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { return } cmd := exec.Command(exe, "upgrade", "prod", "-vv") - cmd.SysProcAttr = sysProcAttrForSelfUpgrade() + cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") return diff --git a/cmd/cli/self_kill_unix.go b/cmd/cli/self_kill_unix.go index 9a68e60..157425f 100644 --- a/cmd/cli/self_kill_unix.go +++ b/cmd/cli/self_kill_unix.go @@ -22,7 +22,7 @@ func selfUninstall(p *prog, logger zerolog.Logger) { logger.Fatal().Err(err).Msg("could not determine executable") } args := []string{"uninstall"} - if !deactivationPinNotSet() { + if deactivationPinSet() { args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load())) } cmd := exec.Command(bin, args...) diff --git a/cmd/cli/self_upgrade_others.go b/cmd/cli/self_upgrade_others.go index f1ff140..0250c0e 100644 --- a/cmd/cli/self_upgrade_others.go +++ b/cmd/cli/self_upgrade_others.go @@ -6,7 +6,7 @@ import ( "syscall" ) -// sysProcAttrForSelfUpgrade returns *syscall.SysProcAttr instance for running self-upgrade command. -func sysProcAttrForSelfUpgrade() *syscall.SysProcAttr { +// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command. +func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr { return &syscall.SysProcAttr{Setsid: true} } diff --git a/cmd/cli/self_upgrade_windows.go b/cmd/cli/self_upgrade_windows.go index 213aec9..a6f37be 100644 --- a/cmd/cli/self_upgrade_windows.go +++ b/cmd/cli/self_upgrade_windows.go @@ -9,8 +9,8 @@ import ( // SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window. const SYSCALL_CREATE_NO_WINDOW = 0x08000000 -// sysProcAttrForSelfUpgrade returns *syscall.SysProcAttr instance for running self-upgrade command. -func sysProcAttrForSelfUpgrade() *syscall.SysProcAttr { +// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command. +func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr { return &syscall.SysProcAttr{ CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW, HideWindow: true,