diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index f2b9906..bf6803f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -239,7 +239,7 @@ func initCLI() { if nextdns != "" { removeNextDNSFromArgs(sc) generateNextDNSConfig() - updateListenerConfig() + updateListenerConfig(&cfg) if err := writeConfigFile(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver") } @@ -383,6 +383,10 @@ func initCLI() { mainLog.Load().Error().Msg(err.Error()) return } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } initLogging() tasks := []task{ @@ -404,6 +408,50 @@ func initCLI() { }, } + reloadCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := userHomeDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("Service reloaded") + case http.StatusCreated: + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") + mainLog.Load().Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return + } + restartCmd.Run(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") + } + mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) + } + }, + } statusCmd := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -519,9 +567,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Short: "Manage ctrld service", Args: cobra.OnlyValidArgs, ValidArgs: []string{ - statusCmd.Use, + startCmd.Use, stopCmd.Use, restartCmd.Use, + reloadCmd.Use, statusCmd.Use, uninstallCmd.Use, interfacesCmd.Use, @@ -530,6 +579,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, serviceCmd.AddCommand(startCmd) serviceCmd.AddCommand(stopCmd) serviceCmd.AddCommand(restartCmd) + serviceCmd.AddCommand(reloadCmd) serviceCmd.AddCommand(statusCmd) serviceCmd.AddCommand(uninstallCmd) serviceCmd.AddCommand(interfacesCmd) @@ -584,6 +634,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } rootCmd.AddCommand(restartCmdAlias) + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + reloadCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + statusCmdAlias := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -716,10 +779,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } waitCh := make(chan struct{}) p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - cfg: &cfg, - appCallback: appCallback, + waitCh: waitCh, + stopCh: stopCh, + reloadCh: make(chan struct{}), + reloadDoneCh: make(chan struct{}), + cfg: &cfg, + appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { @@ -757,9 +822,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { readBase64Config(configBase64) processNoConfigFlags(noConfigStart) + p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } + p.mu.Unlock() processLogAndCacheFlags() @@ -795,14 +862,46 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if cdUID != "" { validateCdUpstreamProtocol() - err := processCDFlags() - if err != nil { - appCallback.Exit(err.Error()) - return + if err := processCDFlags(&cfg); err != nil { + if isMobile() { + appCallback.Exit(err.Error()) + return + } + + uninstallIfInvalidCdUID := func() { + cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { + s, err := newService(&prog{}, svcConfig) + if err != nil { + cdLogger.Warn().Err(err).Msg("failed to create new service") + return + } + if netIface, _ := netInterface(iface); netIface != nil { + if err := restoreNetworkManager(); err != nil { + cdLogger.Error().Err(err).Msg("could not restore NetworkManager") + return + } + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") + if err := resetDNS(netIface); err != nil { + cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS") + } else { + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") + } + } + + tasks := []task{{s.Uninstall, true}} + if doTasks(tasks) { + cdLogger.Info().Msg("uninstalled service") + } + cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config") + return + } + } + uninstallIfInvalidCdUID() } } - updated := updateListenerConfig() + updated := updateListenerConfig(&cfg) if cdUID != "" { processLogAndCacheFlags() @@ -830,7 +929,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { initLoggingWithBackup(false) } - validateConfig(&cfg) + if err := validateConfig(&cfg); err != nil { + os.Exit(1) + } initCache() if daemon { @@ -943,7 +1044,7 @@ func readConfigFile(writeDefaultConfig bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - _ = updateListenerConfig() + _ = updateListenerConfig(&cfg) if err := writeConfigFile(); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { @@ -1033,7 +1134,7 @@ func processNoConfigFlags(noConfigStart bool) { v.Set("upstream", upstream) } -func processCDFlags() error { +func processCDFlags(cfg *ctrld.Config) error { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -1049,47 +1150,17 @@ func processCDFlags() error { } break } - if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { - s, err := newService(&prog{}, svcConfig) - if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") - return nil - } - if netIface, _ := netInterface(iface); netIface != nil { - if err := restoreNetworkManager(); err != nil { - logger.Error().Err(err).Msg("could not restore NetworkManager") - return nil - } - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - logger.Warn().Err(err).Msg("something went wrong while restoring DNS") - } else { - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") - } - } - - tasks := []task{{s.Uninstall, true}} - if doTasks(tasks) { - logger.Info().Msg("uninstalled service") - } - event := logger.Fatal() - if isMobile() { - event = logger.Warn() - } - event.Err(uer).Msg("failed to fetch resolver config") - return uer - } if err != nil { if isMobile() { return errors.New("could not fetch resolver config") } logger.Warn().Err(err).Msg("could not fetch resolver config") - return nil + return err } logger.Info().Msg("generating ctrld config from Control-D configuration") - cfg = ctrld.Config{} + *cfg = ctrld.Config{} // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") @@ -1427,18 +1498,17 @@ func uninstall(p *prog, s service.Service) { } } -func validateConfig(cfg *ctrld.Config) { - err := ctrld.ValidateConfig(validator.New(), cfg) - if err == nil { - return - } - var ve validator.ValidationErrors - if errors.As(err, &ve) { - for _, fe := range ve { - mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) +func validateConfig(cfg *ctrld.Config) error { + if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil { + var ve validator.ValidationErrors + if errors.As(err, &ve) { + for _, fe := range ve { + mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) + } } + return err } - os.Exit(1) + return nil } // NOTE: Add more case here once new validation tag is used in ctrld.Config struct. @@ -1509,7 +1579,16 @@ func mobileListenerPort() int { // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. -func updateListenerConfig() (updated bool) { +func updateListenerConfig(cfg *ctrld.Config) bool { + updated, _ := tryUpdateListenerConfig(cfg, true) + return updated +} + +// tryUpdateListenerConfig tries updating listener config with a working one. +// If fatal is true, and there's listen address conflicted, the function do +// fatal error. +func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { + ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" @@ -1622,7 +1701,11 @@ func updateListenerConfig() (updated bool) { break } if !check.IP && !check.Port { - logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + } + ok = false + break } if tryAllPort53 { tryAllPort53 = false @@ -1683,12 +1766,19 @@ func updateListenerConfig() (updated bool) { listener.Port = oldPort } if listener.IP == oldIP && listener.Port == oldPort { - logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + } + ok = false + break } logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr) attempts++ } } + if !ok { + return + } // Specific case for systemd-resolved. if useSystemdResolved { diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 5f5ac51..80bc1ab 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -8,12 +8,15 @@ import ( "os" "sort" "time" + + "github.com/Control-D-Inc/ctrld" ) const ( contentTypeJson = "application/json" listClientsPath = "/clients" startedPath = "/started" + reloadPath = "/reload" ) type controlServer struct { @@ -75,6 +78,39 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusRequestTimeout) } })) + p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + listeners := make(map[string]*ctrld.ListenerConfig) + p.mu.Lock() + for k, v := range p.cfg.Listener { + listeners[k] = &ctrld.ListenerConfig{ + IP: v.IP, + Port: v.Port, + } + } + p.mu.Unlock() + if err := p.sendReloadSignal(); err != nil { + mainLog.Load().Err(err).Msg("could not send reload signal") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + select { + case <-p.reloadDoneCh: + case <-time.After(5 * time.Second): + http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError) + return + } + + p.mu.Lock() + defer p.mu.Unlock() + for k, v := range p.cfg.Listener { + l := listeners[k] + if l == nil || l.IP != v.IP || l.Port != v.Port { + w.WriteHeader(http.StatusCreated) + return + } + } + w.WriteHeader(http.StatusOK) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index c5271a9..be0b731 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -37,7 +38,9 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } -func (p *prog) serveDNS(listenerNum string) error { +var errReload = errors.New("reload") + +func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -78,6 +81,12 @@ func (p *prog) serveDNS(listenerNum string) error { }) g, ctx := errgroup.WithContext(context.Background()) + // When receiving reload signal, return a non-nil error so other + // goroutines in errgroup.Group could be terminated. + g.Go(func() error { + <-reloadCh + return errReload + }) for _, proto := range []string{"udp", "tcp"} { proto := proto if needLocalIPv6Listener() { @@ -121,11 +130,13 @@ func (p *prog) serveDNS(listenerNum string) error { addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - select { - case err := <-errCh: - return err - case <-time.After(5 * time.Second): - p.started <- struct{}{} + if !reload { + select { + case err := <-errCh: + return err + case <-time.After(5 * time.Second): + p.started <- struct{}{} + } } select { case <-p.stopCh: @@ -136,7 +147,11 @@ func (p *prog) serveDNS(listenerNum string) error { return nil }) } - return g.Wait() + err := g.Wait() + if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error. + return nil + } + return err } // upstreamFor returns the list of upstreams for resolving the given domain, diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 87dabf8..5e6d911 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -79,13 +79,15 @@ func (p *prog) checkDnsLoop() { } // checkDnsLoopTicker performs p.checkDnsLoop every minute. -func (p *prog) checkDnsLoopTicker() { +func (p *prog) checkDnsLoopTicker(ctx context.Context) { timer := time.NewTicker(time.Minute) defer timer.Stop() for { select { case <-p.stopCh: return + case <-ctx.Done(): + return case <-timer.C: p.checkDnsLoop() } diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index 0faae84..d757f8b 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -1,11 +1,13 @@ package cli import ( + "context" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -func (p *prog) watchLinkState() { +func (p *prog) watchLinkState(ctx context.Context) { ch := make(chan netlink.LinkUpdate) done := make(chan struct{}) defer close(done) @@ -13,14 +15,19 @@ func (p *prog) watchLinkState() { mainLog.Load().Warn().Err(err).Msg("could not subscribe link") return } - for lu := range ch { - if lu.Change == 0xFFFFFFFF { - continue - } - if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") - for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + for { + select { + case <-ctx.Done(): + return + case lu := <-ch: + if lu.Change == 0xFFFFFFFF { + continue + } + if lu.Change&unix.IFF_UP != 0 { + mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + for _, uc := range p.cfg.Upstream { + uc.ReBootstrap() + } } } } diff --git a/cmd/cli/netlink_others.go b/cmd/cli/netlink_others.go index f0afd21..5a298b9 100644 --- a/cmd/cli/netlink_others.go +++ b/cmd/cli/netlink_others.go @@ -2,4 +2,6 @@ package cli -func (p *prog) watchLinkState() {} +import "context" + +func (p *prog) watchLinkState(ctx context.Context) {} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index e30a03d..a475a77 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -2,6 +2,7 @@ package cli import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -16,6 +17,7 @@ import ( "syscall" "github.com/kardianos/service" + "github.com/spf13/viper" "tailscale.com/net/interfaces" "github.com/Control-D-Inc/ctrld" @@ -45,11 +47,13 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - logConn net.Conn - cs *controlServer + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + logConn net.Conn + cs *controlServer cfg *ctrld.Config appCallback *AppCallback @@ -69,11 +73,90 @@ type prog struct { } func (p *prog) Start(s service.Service) error { - p.cfg = &cfg - go p.run() + go p.runWait() return nil } +// runWait runs ctrld components, waiting for signal to reload. +func (p *prog) runWait() { + p.mu.Lock() + p.cfg = &cfg + p.mu.Unlock() + reloadSigCh := make(chan os.Signal, 1) + notifyReloadSigCh(reloadSigCh) + + reload := false + logger := mainLog.Load() + for { + reloadCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + p.run(reload, reloadCh) + reload = true + }() + select { + case sig := <-reloadSigCh: + logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + case <-p.reloadCh: + logger.Notice().Msg("reloading...") + case <-p.stopCh: + close(reloadCh) + return + } + + waitOldRunDone := func() { + close(reloadCh) + <-done + } + newCfg := &ctrld.Config{} + v := viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.InitConfig(v, "ctrld") + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + logger.Err(err).Msg("could not read new config") + waitOldRunDone() + continue + } + if err := v.Unmarshal(&newCfg); err != nil { + logger.Err(err).Msg("could not unmarshal new config") + waitOldRunDone() + continue + } + if cdUID != "" { + if err := processCDFlags(newCfg); err != nil { + logger.Err(err).Msg("could not fetch ControlD config") + waitOldRunDone() + continue + } + } + + waitOldRunDone() + + _, ok := tryUpdateListenerConfig(newCfg, false) + if !ok { + logger.Error().Msg("could not update listener config") + continue + } + if err := validateConfig(newCfg); err != nil { + logger.Err(err).Msg("invalid config") + continue + } + + p.mu.Lock() + *p.cfg = *newCfg + p.mu.Unlock() + + logger.Notice().Msg("reloading config successfully") + select { + case p.reloadDoneCh <- struct{}{}: + default: + } + } +} + func (p *prog) preRun() { if !service.Interactive() { p.setDNS() @@ -87,7 +170,15 @@ func (p *prog) preRun() { } } -func (p *prog) run() { +// run runs the ctrld main components. +// +// The reload boolean indicates that the function is run when ctrld first start +// or when ctrld receive reloading signal. Platform specifics setup is only done +// on started, mean reload is "false". +// +// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded, +// so all listeners could be terminated and re-spawned again. +func (p *prog) run(reload bool, reloadCh chan struct{}) { // Wait the caller to signal that we can do our logic. <-p.waitCh p.preRun() @@ -146,19 +237,29 @@ func (p *prog) run() { format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } + + // context for managing spawn goroutines. + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + // Newer versions of android and iOS denies permission which breaks connectivity. if !isMobile() { + wg.Add(1) go func() { + defer wg.Done() p.ciTable.Init() - p.ciTable.RefreshLoop(p.stopCh) + p.ciTable.RefreshLoop(ctx) }() - go p.watchLinkState() + go p.watchLinkState(ctx) } for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() go func(listenerNum string) { - defer wg.Done() + defer func() { + cancelFunc() + wg.Done() + }() listenerConfig := p.cfg.Listener[listenerNum] upstreamConfig := p.cfg.Upstream[listenerNum] if upstreamConfig == nil { @@ -166,35 +267,44 @@ func (p *prog) run() { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { + if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } }(listenerNum) } - for i := 0; i < numListeners; i++ { - <-p.started - } - for _, f := range p.onStarted { - f() + if !reload { + for i := 0; i < numListeners; i++ { + <-p.started + } + for _, f := range p.onStarted { + f() + } } + // Check for possible DNS loop. p.checkDnsLoop() close(p.onStartedDone) // Start check DNS loop ticker. - go p.checkDnsLoopTicker() + wg.Add(1) + go func() { + defer wg.Done() + p.checkDnsLoopTicker(ctx) + }() - // Stop writing log to unix socket. - consoleWriter.Out = os.Stdout - initLoggingWithBackup(false) - if p.logConn != nil { - _ = p.logConn.Close() - } - if p.cs != nil { - p.registerControlServerHandler() - if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + if !reload { + // Stop writing log to unix socket. + consoleWriter.Out = os.Stdout + initLoggingWithBackup(false) + if p.logConn != nil { + _ = p.logConn.Close() + } + if p.cs != nil { + p.registerControlServerHandler() + if err := p.cs.start(); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not start control server") + } } } wg.Wait() diff --git a/cmd/cli/reload_others.go b/cmd/cli/reload_others.go new file mode 100644 index 0000000..0977af9 --- /dev/null +++ b/cmd/cli/reload_others.go @@ -0,0 +1,17 @@ +//go:build !windows + +package cli + +import ( + "os" + "os/signal" + "syscall" +) + +func notifyReloadSigCh(ch chan os.Signal) { + signal.Notify(ch, syscall.SIGUSR1) +} + +func (p *prog) sendReloadSignal() error { + return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) +} diff --git a/cmd/cli/reload_windows.go b/cmd/cli/reload_windows.go new file mode 100644 index 0000000..0e817e4 --- /dev/null +++ b/cmd/cli/reload_windows.go @@ -0,0 +1,18 @@ +package cli + +import ( + "errors" + "os" + "time" +) + +func notifyReloadSigCh(ch chan os.Signal) {} + +func (p *prog) sendReloadSignal() error { + select { + case p.reloadCh <- struct{}{}: + return nil + case <-time.After(5 * time.Second): + } + return errors.New("timeout while sending reload signal") +} diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 3e92fd1..6d6cbf9 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -1,6 +1,7 @@ package clientinfo import ( + "context" "fmt" "net/netip" "strings" @@ -75,7 +76,7 @@ type Table struct { mdns *mdns hf *hostsFile vni *virtualNetworkIface - cfg *ctrld.Config + svcCfg ctrld.ServiceConfig quitCh chan struct{} selfIP string cdUID string @@ -83,7 +84,7 @@ type Table struct { func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table { return &Table{ - cfg: cfg, + svcCfg: cfg.Service, quitCh: make(chan struct{}), selfIP: selfIP, cdUID: cdUID, @@ -97,7 +98,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { clientInfoFiles[name] = format } -func (t *Table) RefreshLoop(stopCh chan struct{}) { +func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Minute * 5) defer timer.Stop() for { @@ -106,7 +107,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) { for _, r := range t.refreshers { _ = r.refresh() } - case <-stopCh: + case <-ctx.Done(): close(t.quitCh) return } @@ -339,38 +340,38 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { } func (t *Table) discoverDHCP() bool { - if t.cfg.Service.DiscoverDHCP == nil { + if t.svcCfg.DiscoverDHCP == nil { return true } - return *t.cfg.Service.DiscoverDHCP + return *t.svcCfg.DiscoverDHCP } func (t *Table) discoverARP() bool { - if t.cfg.Service.DiscoverARP == nil { + if t.svcCfg.DiscoverARP == nil { return true } - return *t.cfg.Service.DiscoverARP + return *t.svcCfg.DiscoverARP } func (t *Table) discoverMDNS() bool { - if t.cfg.Service.DiscoverMDNS == nil { + if t.svcCfg.DiscoverMDNS == nil { return true } - return *t.cfg.Service.DiscoverMDNS + return *t.svcCfg.DiscoverMDNS } func (t *Table) discoverPTR() bool { - if t.cfg.Service.DiscoverPtr == nil { + if t.svcCfg.DiscoverPtr == nil { return true } - return *t.cfg.Service.DiscoverPtr + return *t.svcCfg.DiscoverPtr } func (t *Table) discoverHosts() bool { - if t.cfg.Service.DiscoverHosts == nil { + if t.svcCfg.DiscoverHosts == nil { return true } - return *t.cfg.Service.DiscoverHosts + return *t.svcCfg.DiscoverHosts } // normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.