all: implement reload command

This commit adds reload command to ctrld for re-fetch new config from
ContorlD API or re-read the current config on disk.
This commit is contained in:
Cuong Manh Le
2023-10-24 00:22:03 +07:00
committed by Cuong Manh Le
parent 712b23a4bb
commit 58a00ea24a
10 changed files with 417 additions and 119 deletions
+138 -28
View File
@@ -2,6 +2,7 @@ package cli
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
@@ -16,6 +17,7 @@ import (
"syscall"
"github.com/kardianos/service"
"github.com/spf13/viper"
"tailscale.com/net/interfaces"
"github.com/Control-D-Inc/ctrld"
@@ -45,11 +47,13 @@ var svcConfig = &service.Config{
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
logConn net.Conn
cs *controlServer
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
logConn net.Conn
cs *controlServer
cfg *ctrld.Config
appCallback *AppCallback
@@ -69,11 +73,90 @@ type prog struct {
}
func (p *prog) Start(s service.Service) error {
p.cfg = &cfg
go p.run()
go p.runWait()
return nil
}
// runWait runs ctrld components, waiting for signal to reload.
func (p *prog) runWait() {
p.mu.Lock()
p.cfg = &cfg
p.mu.Unlock()
reloadSigCh := make(chan os.Signal, 1)
notifyReloadSigCh(reloadSigCh)
reload := false
logger := mainLog.Load()
for {
reloadCh := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
p.run(reload, reloadCh)
reload = true
}()
select {
case sig := <-reloadSigCh:
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
case <-p.reloadCh:
logger.Notice().Msg("reloading...")
case <-p.stopCh:
close(reloadCh)
return
}
waitOldRunDone := func() {
close(reloadCh)
<-done
}
newCfg := &ctrld.Config{}
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
}
}
waitOldRunDone()
_, ok := tryUpdateListenerConfig(newCfg, false)
if !ok {
logger.Error().Msg("could not update listener config")
continue
}
if err := validateConfig(newCfg); err != nil {
logger.Err(err).Msg("invalid config")
continue
}
p.mu.Lock()
*p.cfg = *newCfg
p.mu.Unlock()
logger.Notice().Msg("reloading config successfully")
select {
case p.reloadDoneCh <- struct{}{}:
default:
}
}
}
func (p *prog) preRun() {
if !service.Interactive() {
p.setDNS()
@@ -87,7 +170,15 @@ func (p *prog) preRun() {
}
}
func (p *prog) run() {
// run runs the ctrld main components.
//
// The reload boolean indicates that the function is run when ctrld first start
// or when ctrld receive reloading signal. Platform specifics setup is only done
// on started, mean reload is "false".
//
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
// so all listeners could be terminated and re-spawned again.
func (p *prog) run(reload bool, reloadCh chan struct{}) {
// Wait the caller to signal that we can do our logic.
<-p.waitCh
p.preRun()
@@ -146,19 +237,29 @@ func (p *prog) run() {
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
// context for managing spawn goroutines.
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
// Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() {
wg.Add(1)
go func() {
defer wg.Done()
p.ciTable.Init()
p.ciTable.RefreshLoop(p.stopCh)
p.ciTable.RefreshLoop(ctx)
}()
go p.watchLinkState()
go p.watchLinkState(ctx)
}
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
go func(listenerNum string) {
defer wg.Done()
defer func() {
cancelFunc()
wg.Done()
}()
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
@@ -166,35 +267,44 @@ func (p *prog) run() {
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
}(listenerNum)
}
for i := 0; i < numListeners; i++ {
<-p.started
}
for _, f := range p.onStarted {
f()
if !reload {
for i := 0; i < numListeners; i++ {
<-p.started
}
for _, f := range p.onStarted {
f()
}
}
// Check for possible DNS loop.
p.checkDnsLoop()
close(p.onStartedDone)
// Start check DNS loop ticker.
go p.checkDnsLoopTicker()
wg.Add(1)
go func() {
defer wg.Done()
p.checkDnsLoopTicker(ctx)
}()
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
if p.cs != nil {
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
if p.cs != nil {
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
}
}
wg.Wait()