cmd/cli: do not stop listener when reloading

We could not do a reload if the listener config changes, so do not turn
them off to try updating new listener config.
This commit is contained in:
Cuong Manh Le
2023-11-09 22:10:24 +07:00
committed by Cuong Manh Le
parent 294a90a807
commit d01f5c2777
3 changed files with 101 additions and 70 deletions

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/http"
"os"
"reflect"
"sort"
"time"
@@ -87,6 +88,7 @@ func (p *prog) registerControlServerHandler() {
Port: v.Port,
}
}
oldSvc := p.cfg.Service
p.mu.Unlock()
if err := p.sendReloadSignal(); err != nil {
mainLog.Load().Err(err).Msg("could not send reload signal")
@@ -102,6 +104,10 @@ func (p *prog) registerControlServerHandler() {
p.mu.Lock()
defer p.mu.Unlock()
// Checking for cases that we could not do a reload.
// 1. Listener config ip or port changes.
for k, v := range p.cfg.Listener {
l := listeners[k]
if l == nil || l.IP != v.IP || l.Port != v.Port {
@@ -109,6 +115,14 @@ func (p *prog) registerControlServerHandler() {
return
}
}
// 2. Service config changes.
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
w.WriteHeader(http.StatusCreated)
return
}
// Otherwise, reload is done.
w.WriteHeader(http.StatusOK)
}))
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
@@ -44,19 +43,14 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
Timeout: 2000,
}
var errReload = errors.New("reload")
func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error {
func (p *prog) serveDNS(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
return allocErr
}
var failoverRcodes []int
if listenerConfig.Policy != nil {
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
}
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
p.sema.acquire()
defer p.sema.release()
@@ -83,7 +77,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{})
answer = new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
} else {
answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci)
answer = p.proxy(ctx, upstreams, listenerConfig.Policy.FailoverRcodeNumbers, m, ci)
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
@@ -93,12 +87,6 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{})
})
g, ctx := errgroup.WithContext(context.Background())
// When receiving reload signal, return a non-nil error so other
// goroutines in errgroup.Group could be terminated.
g.Go(func() error {
<-reloadCh
return errReload
})
for _, proto := range []string{"udp", "tcp"} {
proto := proto
if needLocalIPv6Listener() {
@@ -142,13 +130,11 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{})
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
s, errCh := runDNSServer(addr, proto, handler)
defer s.Shutdown()
if !reload {
select {
case err := <-errCh:
return err
case <-time.After(5 * time.Second):
p.started <- struct{}{}
}
select {
case err := <-errCh:
return err
case <-time.After(5 * time.Second):
p.started <- struct{}{}
}
select {
case <-p.stopCh:
@@ -159,11 +145,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{})
return nil
})
}
err := g.Wait()
if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error.
return nil
}
return err
return g.Wait()
}
// upstreamFor returns the list of upstreams for resolving the given domain,

View File

@@ -135,16 +135,32 @@ func (p *prog) runWait() {
waitOldRunDone()
_, ok := tryUpdateListenerConfig(newCfg, nil, false)
if !ok {
logger.Error().Msg("could not update listener config")
continue
p.mu.Lock()
curListener := p.cfg.Listener
p.mu.Unlock()
for n, lc := range newCfg.Listener {
curLc := curListener[n]
if curLc == nil {
continue
}
if lc.IP == "" {
lc.IP = curLc.IP
}
if lc.Port == 0 {
lc.Port = curLc.Port
}
}
if err := validateConfig(newCfg); err != nil {
logger.Err(err).Msg("invalid config")
continue
}
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
setupUpstream(newCfg)
p.mu.Lock()
*p.cfg = *newCfg
p.mu.Unlock()
@@ -170,6 +186,21 @@ func (p *prog) preRun() {
}
}
func setupUpstream(cfg *ctrld.Config) {
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
}
}
// run runs the ctrld main components.
//
// The reload boolean indicates that the function is run when ctrld first start
@@ -183,7 +214,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
<-p.waitCh
p.preRun()
numListeners := len(p.cfg.Listener)
p.started = make(chan struct{}, numListeners)
if !reload {
p.started = make(chan struct{}, numListeners)
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
if p.cfg.Service.CacheEnable {
@@ -194,15 +227,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.cache = cacher
}
}
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
@@ -218,24 +243,24 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
p.um = newUpstreamMonitor(p.cfg)
for n := range p.cfg.Upstream {
uc := p.cfg.Upstream[n]
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
}
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
if !reload {
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
}
// context for managing spawn goroutines.
@@ -243,7 +268,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
defer cancelFunc()
// Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() {
if !isMobile() && !reload {
wg.Add(1)
go func() {
defer wg.Done()
@@ -255,22 +280,32 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
go func(listenerNum string) {
if !reload {
go func(listenerNum string) {
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
}(listenerNum)
}
go func() {
defer func() {
cancelFunc()
wg.Done()
}()
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
select {
case <-p.stopCh:
case <-ctx.Done():
case <-reloadCh:
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
}(listenerNum)
return
}()
}
if !reload {