mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
all: implement reload command
This commit adds reload command to ctrld for re-fetch new config from ContorlD API or re-read the current config on disk.
This commit is contained in:
committed by
Cuong Manh Le
parent
712b23a4bb
commit
58a00ea24a
208
cmd/cli/cli.go
208
cmd/cli/cli.go
@@ -239,7 +239,7 @@ func initCLI() {
|
||||
if nextdns != "" {
|
||||
removeNextDNSFromArgs(sc)
|
||||
generateNextDNSConfig()
|
||||
updateListenerConfig()
|
||||
updateListenerConfig(&cfg)
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver")
|
||||
}
|
||||
@@ -383,6 +383,10 @@ func initCLI() {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("service not installed")
|
||||
return
|
||||
}
|
||||
initLogging()
|
||||
|
||||
tasks := []task{
|
||||
@@ -404,6 +408,50 @@ func initCLI() {
|
||||
},
|
||||
}
|
||||
|
||||
reloadCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
|
||||
}
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
resp, err := cc.post(reloadPath, nil)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
mainLog.Load().Notice().Msg("Service reloaded")
|
||||
case http.StatusCreated:
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.")
|
||||
mainLog.Load().Warn().Msg("Restarting service")
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return
|
||||
}
|
||||
restartCmd.Run(cmd, args)
|
||||
default:
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("could not read response from control server")
|
||||
}
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf))
|
||||
}
|
||||
},
|
||||
}
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
@@ -519,9 +567,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
statusCmd.Use,
|
||||
startCmd.Use,
|
||||
stopCmd.Use,
|
||||
restartCmd.Use,
|
||||
reloadCmd.Use,
|
||||
statusCmd.Use,
|
||||
uninstallCmd.Use,
|
||||
interfacesCmd.Use,
|
||||
@@ -530,6 +579,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(reloadCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
@@ -584,6 +634,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
}
|
||||
rootCmd.AddCommand(restartCmdAlias)
|
||||
|
||||
reloadCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
reloadCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(reloadCmdAlias)
|
||||
|
||||
statusCmdAlias := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
@@ -716,10 +779,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
waitCh := make(chan struct{})
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
reloadCh: make(chan struct{}),
|
||||
reloadDoneCh: make(chan struct{}),
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
}
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
@@ -757,9 +822,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
p.mu.Lock()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
processLogAndCacheFlags()
|
||||
|
||||
@@ -795,14 +862,46 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
if cdUID != "" {
|
||||
validateCdUpstreamProtocol()
|
||||
err := processCDFlags()
|
||||
if err != nil {
|
||||
appCallback.Exit(err.Error())
|
||||
return
|
||||
if err := processCDFlags(&cfg); err != nil {
|
||||
if isMobile() {
|
||||
appCallback.Exit(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
uninstallIfInvalidCdUID := func() {
|
||||
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
cdLogger.Info().Msg("uninstalled service")
|
||||
}
|
||||
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
return
|
||||
}
|
||||
}
|
||||
uninstallIfInvalidCdUID()
|
||||
}
|
||||
}
|
||||
|
||||
updated := updateListenerConfig()
|
||||
updated := updateListenerConfig(&cfg)
|
||||
|
||||
if cdUID != "" {
|
||||
processLogAndCacheFlags()
|
||||
@@ -830,7 +929,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
initLoggingWithBackup(false)
|
||||
}
|
||||
|
||||
validateConfig(&cfg)
|
||||
if err := validateConfig(&cfg); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
@@ -943,7 +1044,7 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
_ = updateListenerConfig()
|
||||
_ = updateListenerConfig(&cfg)
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
@@ -1033,7 +1134,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
v.Set("upstream", upstream)
|
||||
}
|
||||
|
||||
func processCDFlags() error {
|
||||
func processCDFlags(cfg *ctrld.Config) error {
|
||||
logger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
|
||||
@@ -1049,47 +1150,17 @@ func processCDFlags() error {
|
||||
}
|
||||
break
|
||||
}
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return nil
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return nil
|
||||
}
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
}
|
||||
event := logger.Fatal()
|
||||
if isMobile() {
|
||||
event = logger.Warn()
|
||||
}
|
||||
event.Err(uer).Msg("failed to fetch resolver config")
|
||||
return uer
|
||||
}
|
||||
if err != nil {
|
||||
if isMobile() {
|
||||
return errors.New("could not fetch resolver config")
|
||||
}
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
cfg = ctrld.Config{}
|
||||
|
||||
*cfg = ctrld.Config{}
|
||||
// Fetch config, unmarshal to cfg.
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
@@ -1427,18 +1498,17 @@ func uninstall(p *prog, s service.Service) {
|
||||
}
|
||||
}
|
||||
|
||||
func validateConfig(cfg *ctrld.Config) {
|
||||
err := ctrld.ValidateConfig(validator.New(), cfg)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
var ve validator.ValidationErrors
|
||||
if errors.As(err, &ve) {
|
||||
for _, fe := range ve {
|
||||
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
|
||||
func validateConfig(cfg *ctrld.Config) error {
|
||||
if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil {
|
||||
var ve validator.ValidationErrors
|
||||
if errors.As(err, &ve) {
|
||||
for _, fe := range ve {
|
||||
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: Add more case here once new validation tag is used in ctrld.Config struct.
|
||||
@@ -1509,7 +1579,16 @@ func mobileListenerPort() int {
|
||||
// updateListenerConfig updates the config for listeners if not defined,
|
||||
// or defined but invalid to be used, e.g: using loopback address other
|
||||
// than 127.0.0.1 with systemd-resolved.
|
||||
func updateListenerConfig() (updated bool) {
|
||||
func updateListenerConfig(cfg *ctrld.Config) bool {
|
||||
updated, _ := tryUpdateListenerConfig(cfg, true)
|
||||
return updated
|
||||
}
|
||||
|
||||
// tryUpdateListenerConfig tries updating listener config with a working one.
|
||||
// If fatal is true, and there's listen address conflicted, the function do
|
||||
// fatal error.
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) {
|
||||
ok = true
|
||||
lcc := make(map[string]*listenerConfigCheck)
|
||||
cdMode := cdUID != ""
|
||||
nextdnsMode := nextdns != ""
|
||||
@@ -1622,7 +1701,11 @@ func updateListenerConfig() (updated bool) {
|
||||
break
|
||||
}
|
||||
if !check.IP && !check.Port {
|
||||
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
|
||||
if fatal {
|
||||
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
if tryAllPort53 {
|
||||
tryAllPort53 = false
|
||||
@@ -1683,12 +1766,19 @@ func updateListenerConfig() (updated bool) {
|
||||
listener.Port = oldPort
|
||||
}
|
||||
if listener.IP == oldIP && listener.Port == oldPort {
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
if fatal {
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr)
|
||||
attempts++
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Specific case for systemd-resolved.
|
||||
if useSystemdResolved {
|
||||
|
||||
@@ -8,12 +8,15 @@ import (
|
||||
"os"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
)
|
||||
|
||||
type controlServer struct {
|
||||
@@ -75,6 +78,39 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -37,7 +38,9 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
var errReload = errors.New("reload")
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
@@ -78,6 +81,12 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
})
|
||||
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
// When receiving reload signal, return a non-nil error so other
|
||||
// goroutines in errgroup.Group could be terminated.
|
||||
g.Go(func() error {
|
||||
<-reloadCh
|
||||
return errReload
|
||||
})
|
||||
for _, proto := range []string{"udp", "tcp"} {
|
||||
proto := proto
|
||||
if needLocalIPv6Listener() {
|
||||
@@ -121,11 +130,13 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
s, errCh := runDNSServer(addr, proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-time.After(5 * time.Second):
|
||||
p.started <- struct{}{}
|
||||
if !reload {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-time.After(5 * time.Second):
|
||||
p.started <- struct{}{}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
@@ -136,7 +147,11 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
err := g.Wait()
|
||||
if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error.
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// upstreamFor returns the list of upstreams for resolving the given domain,
|
||||
|
||||
@@ -79,13 +79,15 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker() {
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
@@ -13,14 +15,19 @@ func (p *prog) watchLinkState() {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case lu := <-ch:
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,6 @@
|
||||
|
||||
package cli
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
import "context"
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {}
|
||||
|
||||
166
cmd/cli/prog.go
166
cmd/cli/prog.go
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/net/interfaces"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
@@ -45,11 +47,13 @@ var svcConfig = &service.Config{
|
||||
var useSystemdResolved = false
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
|
||||
cfg *ctrld.Config
|
||||
appCallback *AppCallback
|
||||
@@ -69,11 +73,90 @@ type prog struct {
|
||||
}
|
||||
|
||||
func (p *prog) Start(s service.Service) error {
|
||||
p.cfg = &cfg
|
||||
go p.run()
|
||||
go p.runWait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// runWait runs ctrld components, waiting for signal to reload.
|
||||
func (p *prog) runWait() {
|
||||
p.mu.Lock()
|
||||
p.cfg = &cfg
|
||||
p.mu.Unlock()
|
||||
reloadSigCh := make(chan os.Signal, 1)
|
||||
notifyReloadSigCh(reloadSigCh)
|
||||
|
||||
reload := false
|
||||
logger := mainLog.Load()
|
||||
for {
|
||||
reloadCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
p.run(reload, reloadCh)
|
||||
reload = true
|
||||
}()
|
||||
select {
|
||||
case sig := <-reloadSigCh:
|
||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
case <-p.reloadCh:
|
||||
logger.Notice().Msg("reloading...")
|
||||
case <-p.stopCh:
|
||||
close(reloadCh)
|
||||
return
|
||||
}
|
||||
|
||||
waitOldRunDone := func() {
|
||||
close(reloadCh)
|
||||
<-done
|
||||
}
|
||||
newCfg := &ctrld.Config{}
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
waitOldRunDone()
|
||||
|
||||
_, ok := tryUpdateListenerConfig(newCfg, false)
|
||||
if !ok {
|
||||
logger.Error().Msg("could not update listener config")
|
||||
continue
|
||||
}
|
||||
if err := validateConfig(newCfg); err != nil {
|
||||
logger.Err(err).Msg("invalid config")
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
*p.cfg = *newCfg
|
||||
p.mu.Unlock()
|
||||
|
||||
logger.Notice().Msg("reloading config successfully")
|
||||
select {
|
||||
case p.reloadDoneCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
@@ -87,7 +170,15 @@ func (p *prog) preRun() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) run() {
|
||||
// run runs the ctrld main components.
|
||||
//
|
||||
// The reload boolean indicates that the function is run when ctrld first start
|
||||
// or when ctrld receive reloading signal. Platform specifics setup is only done
|
||||
// on started, mean reload is "false".
|
||||
//
|
||||
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
|
||||
// so all listeners could be terminated and re-spawned again.
|
||||
func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
// Wait the caller to signal that we can do our logic.
|
||||
<-p.waitCh
|
||||
p.preRun()
|
||||
@@ -146,19 +237,29 @@ func (p *prog) run() {
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
|
||||
// context for managing spawn goroutines.
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
// Newer versions of android and iOS denies permission which breaks connectivity.
|
||||
if !isMobile() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(p.stopCh)
|
||||
p.ciTable.RefreshLoop(ctx)
|
||||
}()
|
||||
go p.watchLinkState()
|
||||
go p.watchLinkState(ctx)
|
||||
}
|
||||
|
||||
for listenerNum := range p.cfg.Listener {
|
||||
p.cfg.Listener[listenerNum].Init()
|
||||
go func(listenerNum string) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
wg.Done()
|
||||
}()
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
@@ -166,35 +267,44 @@ func (p *prog) run() {
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
}(listenerNum)
|
||||
}
|
||||
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
if !reload {
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
// Check for possible DNS loop.
|
||||
p.checkDnsLoop()
|
||||
close(p.onStartedDone)
|
||||
|
||||
// Start check DNS loop ticker.
|
||||
go p.checkDnsLoopTicker()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.checkDnsLoopTicker(ctx)
|
||||
}()
|
||||
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
if !reload {
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
17
cmd/cli/reload_others.go
Normal file
17
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
18
cmd/cli/reload_windows.go
Normal file
18
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -75,7 +76,7 @@ type Table struct {
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
cfg *ctrld.Config
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
@@ -83,7 +84,7 @@ type Table struct {
|
||||
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table {
|
||||
return &Table{
|
||||
cfg: cfg,
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
@@ -97,7 +98,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
|
||||
clientInfoFiles[name] = format
|
||||
}
|
||||
|
||||
func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
func (t *Table) RefreshLoop(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute * 5)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
@@ -106,7 +107,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
case <-stopCh:
|
||||
case <-ctx.Done():
|
||||
close(t.quitCh)
|
||||
return
|
||||
}
|
||||
@@ -339,38 +340,38 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
|
||||
}
|
||||
|
||||
func (t *Table) discoverDHCP() bool {
|
||||
if t.cfg.Service.DiscoverDHCP == nil {
|
||||
if t.svcCfg.DiscoverDHCP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverDHCP
|
||||
return *t.svcCfg.DiscoverDHCP
|
||||
}
|
||||
|
||||
func (t *Table) discoverARP() bool {
|
||||
if t.cfg.Service.DiscoverARP == nil {
|
||||
if t.svcCfg.DiscoverARP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverARP
|
||||
return *t.svcCfg.DiscoverARP
|
||||
}
|
||||
|
||||
func (t *Table) discoverMDNS() bool {
|
||||
if t.cfg.Service.DiscoverMDNS == nil {
|
||||
if t.svcCfg.DiscoverMDNS == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverMDNS
|
||||
return *t.svcCfg.DiscoverMDNS
|
||||
}
|
||||
|
||||
func (t *Table) discoverPTR() bool {
|
||||
if t.cfg.Service.DiscoverPtr == nil {
|
||||
if t.svcCfg.DiscoverPtr == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverPtr
|
||||
return *t.svcCfg.DiscoverPtr
|
||||
}
|
||||
|
||||
func (t *Table) discoverHosts() bool {
|
||||
if t.cfg.Service.DiscoverHosts == nil {
|
||||
if t.svcCfg.DiscoverHosts == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverHosts
|
||||
return *t.svcCfg.DiscoverHosts
|
||||
}
|
||||
|
||||
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.
|
||||
|
||||
Reference in New Issue
Block a user