all: implement reload command

This commit adds reload command to ctrld for re-fetch new config from
ContorlD API or re-read the current config on disk.
This commit is contained in:
Cuong Manh Le
2023-10-24 00:22:03 +07:00
committed by Cuong Manh Le
parent 712b23a4bb
commit 58a00ea24a
10 changed files with 417 additions and 119 deletions

View File

@@ -239,7 +239,7 @@ func initCLI() {
if nextdns != "" {
removeNextDNSFromArgs(sc)
generateNextDNSConfig()
updateListenerConfig()
updateListenerConfig(&cfg)
if err := writeConfigFile(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver")
}
@@ -383,6 +383,10 @@ func initCLI() {
mainLog.Load().Error().Msg(err.Error())
return
}
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("service not installed")
return
}
initLogging()
tasks := []task{
@@ -404,6 +408,50 @@ func initCLI() {
},
}
reloadCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
Short: "Reload the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
dir, err := userHomeDir()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
resp, err := cc.post(reloadPath, nil)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld")
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
mainLog.Load().Notice().Msg("Service reloaded")
case http.StatusCreated:
s, err := newService(&prog{}, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.")
mainLog.Load().Warn().Msg("Restarting service")
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("Service not installed")
return
}
restartCmd.Run(cmd, args)
default:
buf, err := io.ReadAll(resp.Body)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("could not read response from control server")
}
mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf))
}
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
@@ -519,9 +567,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Short: "Manage ctrld service",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
statusCmd.Use,
startCmd.Use,
stopCmd.Use,
restartCmd.Use,
reloadCmd.Use,
statusCmd.Use,
uninstallCmd.Use,
interfacesCmd.Use,
@@ -530,6 +579,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
serviceCmd.AddCommand(startCmd)
serviceCmd.AddCommand(stopCmd)
serviceCmd.AddCommand(restartCmd)
serviceCmd.AddCommand(reloadCmd)
serviceCmd.AddCommand(statusCmd)
serviceCmd.AddCommand(uninstallCmd)
serviceCmd.AddCommand(interfacesCmd)
@@ -584,6 +634,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
}
rootCmd.AddCommand(restartCmdAlias)
reloadCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
Short: "Reload the ctrld service",
Run: func(cmd *cobra.Command, args []string) {
reloadCmd.Run(cmd, args)
},
}
rootCmd.AddCommand(reloadCmdAlias)
statusCmdAlias := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
@@ -716,10 +779,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
waitCh := make(chan struct{})
p := &prog{
waitCh: waitCh,
stopCh: stopCh,
cfg: &cfg,
appCallback: appCallback,
waitCh: waitCh,
stopCh: stopCh,
reloadCh: make(chan struct{}),
reloadDoneCh: make(chan struct{}),
cfg: &cfg,
appCallback: appCallback,
}
if homedir == "" {
if dir, err := userHomeDir(); err == nil {
@@ -757,9 +822,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
readBase64Config(configBase64)
processNoConfigFlags(noConfigStart)
p.mu.Lock()
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
}
p.mu.Unlock()
processLogAndCacheFlags()
@@ -795,14 +862,46 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
if cdUID != "" {
validateCdUpstreamProtocol()
err := processCDFlags()
if err != nil {
appCallback.Exit(err.Error())
return
if err := processCDFlags(&cfg); err != nil {
if isMobile() {
appCallback.Exit(err.Error())
return
}
uninstallIfInvalidCdUID := func() {
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := newService(&prog{}, svcConfig)
if err != nil {
cdLogger.Warn().Err(err).Msg("failed to create new service")
return
}
if netIface, _ := netInterface(iface); netIface != nil {
if err := restoreNetworkManager(); err != nil {
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
} else {
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
}
}
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
cdLogger.Info().Msg("uninstalled service")
}
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
return
}
}
uninstallIfInvalidCdUID()
}
}
updated := updateListenerConfig()
updated := updateListenerConfig(&cfg)
if cdUID != "" {
processLogAndCacheFlags()
@@ -830,7 +929,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
initLoggingWithBackup(false)
}
validateConfig(&cfg)
if err := validateConfig(&cfg); err != nil {
os.Exit(1)
}
initCache()
if daemon {
@@ -943,7 +1044,7 @@ func readConfigFile(writeDefaultConfig bool) bool {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
}
_ = updateListenerConfig()
_ = updateListenerConfig(&cfg)
if err := writeConfigFile(); err != nil {
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
} else {
@@ -1033,7 +1134,7 @@ func processNoConfigFlags(noConfigStart bool) {
v.Set("upstream", upstream)
}
func processCDFlags() error {
func processCDFlags(cfg *ctrld.Config) error {
logger := mainLog.Load().With().Str("mode", "cd").Logger()
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
@@ -1049,47 +1150,17 @@ func processCDFlags() error {
}
break
}
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := newService(&prog{}, svcConfig)
if err != nil {
logger.Warn().Err(err).Msg("failed to create new service")
return nil
}
if netIface, _ := netInterface(iface); netIface != nil {
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return nil
}
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
} else {
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
}
}
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
logger.Info().Msg("uninstalled service")
}
event := logger.Fatal()
if isMobile() {
event = logger.Warn()
}
event.Err(uer).Msg("failed to fetch resolver config")
return uer
}
if err != nil {
if isMobile() {
return errors.New("could not fetch resolver config")
}
logger.Warn().Err(err).Msg("could not fetch resolver config")
return nil
return err
}
logger.Info().Msg("generating ctrld config from Control-D configuration")
cfg = ctrld.Config{}
*cfg = ctrld.Config{}
// Fetch config, unmarshal to cfg.
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
@@ -1427,18 +1498,17 @@ func uninstall(p *prog, s service.Service) {
}
}
func validateConfig(cfg *ctrld.Config) {
err := ctrld.ValidateConfig(validator.New(), cfg)
if err == nil {
return
}
var ve validator.ValidationErrors
if errors.As(err, &ve) {
for _, fe := range ve {
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
func validateConfig(cfg *ctrld.Config) error {
if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil {
var ve validator.ValidationErrors
if errors.As(err, &ve) {
for _, fe := range ve {
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
}
}
return err
}
os.Exit(1)
return nil
}
// NOTE: Add more case here once new validation tag is used in ctrld.Config struct.
@@ -1509,7 +1579,16 @@ func mobileListenerPort() int {
// updateListenerConfig updates the config for listeners if not defined,
// or defined but invalid to be used, e.g: using loopback address other
// than 127.0.0.1 with systemd-resolved.
func updateListenerConfig() (updated bool) {
func updateListenerConfig(cfg *ctrld.Config) bool {
updated, _ := tryUpdateListenerConfig(cfg, true)
return updated
}
// tryUpdateListenerConfig tries updating listener config with a working one.
// If fatal is true, and there's listen address conflicted, the function do
// fatal error.
func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) {
ok = true
lcc := make(map[string]*listenerConfigCheck)
cdMode := cdUID != ""
nextdnsMode := nextdns != ""
@@ -1622,7 +1701,11 @@ func updateListenerConfig() (updated bool) {
break
}
if !check.IP && !check.Port {
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
if fatal {
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
}
ok = false
break
}
if tryAllPort53 {
tryAllPort53 = false
@@ -1683,12 +1766,19 @@ func updateListenerConfig() (updated bool) {
listener.Port = oldPort
}
if listener.IP == oldIP && listener.Port == oldPort {
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
if fatal {
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
}
ok = false
break
}
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr)
attempts++
}
}
if !ok {
return
}
// Specific case for systemd-resolved.
if useSystemdResolved {