diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 80bc1ab..5ee7112 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "os" + "reflect" "sort" "time" @@ -87,6 +88,7 @@ func (p *prog) registerControlServerHandler() { Port: v.Port, } } + oldSvc := p.cfg.Service p.mu.Unlock() if err := p.sendReloadSignal(); err != nil { mainLog.Load().Err(err).Msg("could not send reload signal") @@ -102,6 +104,10 @@ func (p *prog) registerControlServerHandler() { p.mu.Lock() defer p.mu.Unlock() + + // Checking for cases that we could not do a reload. + + // 1. Listener config ip or port changes. for k, v := range p.cfg.Listener { l := listeners[k] if l == nil || l.IP != v.IP || l.Port != v.Port { @@ -109,6 +115,14 @@ func (p *prog) registerControlServerHandler() { return } } + + // 2. Service config changes. + if !reflect.DeepEqual(oldSvc, p.cfg.Service) { + w.WriteHeader(http.StatusCreated) + return + } + + // Otherwise, reload is done. w.WriteHeader(http.StatusOK) })) } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 666bf50..de8aef7 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "net" "net/netip" @@ -44,19 +43,14 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } -var errReload = errors.New("reload") - -func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error { +func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } - var failoverRcodes []int - if listenerConfig.Policy != nil { - failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers - } + handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() @@ -83,7 +77,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) } else { - answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci) + answer = p.proxy(ctx, upstreams, listenerConfig.Policy.FailoverRcodeNumbers, m, ci) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -93,12 +87,6 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) }) g, ctx := errgroup.WithContext(context.Background()) - // When receiving reload signal, return a non-nil error so other - // goroutines in errgroup.Group could be terminated. - g.Go(func() error { - <-reloadCh - return errReload - }) for _, proto := range []string{"udp", "tcp"} { proto := proto if needLocalIPv6Listener() { @@ -142,13 +130,11 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - if !reload { - select { - case err := <-errCh: - return err - case <-time.After(5 * time.Second): - p.started <- struct{}{} - } + select { + case err := <-errCh: + return err + case <-time.After(5 * time.Second): + p.started <- struct{}{} } select { case <-p.stopCh: @@ -159,11 +145,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) return nil }) } - err := g.Wait() - if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error. - return nil - } - return err + return g.Wait() } // upstreamFor returns the list of upstreams for resolving the given domain, diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d29f374..be50ea6 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -135,16 +135,32 @@ func (p *prog) runWait() { waitOldRunDone() - _, ok := tryUpdateListenerConfig(newCfg, nil, false) - if !ok { - logger.Error().Msg("could not update listener config") - continue + p.mu.Lock() + curListener := p.cfg.Listener + p.mu.Unlock() + + for n, lc := range newCfg.Listener { + curLc := curListener[n] + if curLc == nil { + continue + } + if lc.IP == "" { + lc.IP = curLc.IP + } + if lc.Port == 0 { + lc.Port = curLc.Port + } } if err := validateConfig(newCfg); err != nil { logger.Err(err).Msg("invalid config") continue } + // This needs to be done here, otherwise, the DNS handler may observe an invalid + // upstream config because its initialization function have not been called yet. + mainLog.Load().Debug().Msg("setup upstream with new config") + setupUpstream(newCfg) + p.mu.Lock() *p.cfg = *newCfg p.mu.Unlock() @@ -170,6 +186,21 @@ func (p *prog) preRun() { } } +func setupUpstream(cfg *ctrld.Config) { + for n := range cfg.Upstream { + uc := cfg.Upstream[n] + uc.Init() + if uc.BootstrapIP == "" { + uc.SetupBootstrapIP() + mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + } else { + mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + } + uc.SetCertPool(rootCertPool) + go uc.Ping() + } +} + // run runs the ctrld main components. // // The reload boolean indicates that the function is run when ctrld first start @@ -183,7 +214,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { <-p.waitCh p.preRun() numListeners := len(p.cfg.Listener) - p.started = make(chan struct{}, numListeners) + if !reload { + p.started = make(chan struct{}, numListeners) + } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) if p.cfg.Service.CacheEnable { @@ -194,15 +227,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } - p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} - if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { - n := *mcr - if n == 0 { - p.sema = &noopSemaphore{} - } else { - p.sema = &chanSemaphore{ready: make(chan struct{}, n)} - } - } + var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -218,24 +243,24 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.um = newUpstreamMonitor(p.cfg) - for n := range p.cfg.Upstream { - uc := p.cfg.Upstream[n] - uc.Init() - if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) - } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) - } - uc.SetCertPool(rootCertPool) - go uc.Ping() - } - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) + if !reload { + p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} + if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { + n := *mcr + if n == 0 { + p.sema = &noopSemaphore{} + } else { + p.sema = &chanSemaphore{ready: make(chan struct{}, n)} + } + } + setupUpstream(p.cfg) + p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) + } } // context for managing spawn goroutines. @@ -243,7 +268,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { defer cancelFunc() // Newer versions of android and iOS denies permission which breaks connectivity. - if !isMobile() { + if !isMobile() && !reload { wg.Add(1) go func() { defer wg.Done() @@ -255,22 +280,32 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() - go func(listenerNum string) { + if !reload { + go func(listenerNum string) { + listenerConfig := p.cfg.Listener[listenerNum] + upstreamConfig := p.cfg.Upstream[listenerNum] + if upstreamConfig == nil { + mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + } + addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) + if err := p.serveDNS(listenerNum); err != nil { + mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + } + }(listenerNum) + } + go func() { defer func() { cancelFunc() wg.Done() }() - listenerConfig := p.cfg.Listener[listenerNum] - upstreamConfig := p.cfg.Upstream[listenerNum] - if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + select { + case <-p.stopCh: + case <-ctx.Done(): + case <-reloadCh: } - addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) - } - }(listenerNum) + return + }() } if !reload {