From 595071b6089a2843279598fe42fb09a1d461d2aa Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 4 Feb 2025 17:58:05 +0700 Subject: [PATCH] all: update client info table on network changes So the client metadata will be updated correctly when the device roaming between networks. --- cmd/cli/dns_proxy.go | 19 ++++++++-- cmd/cli/prog.go | 59 ++++++++++++++++++++---------- internal/clientinfo/client_info.go | 40 ++++++++++++++++---- 3 files changed, 87 insertions(+), 31 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 853e77a..18ac373 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -84,9 +84,9 @@ type upstreamForResult struct { srcAddr string } -func (p *prog) serveDNS(listenerNum string) error { +func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { // Start network monitoring - if err := p.monitorNetworkChanges(); err != nil { + if err := p.monitorNetworkChanges(mainCtx); err != nil { mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") // Don't return here as we still want DNS service to run } @@ -1316,7 +1316,7 @@ func FlushDNSCache() error { } // monitorNetworkChanges starts monitoring for network interface changes -func (p *prog) monitorNetworkChanges() error { +func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1336,6 +1336,19 @@ func (p *prog) monitorNetworkChanges() error { oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) + // Client info discover only run on non-mobile platforms. + if !isMobile() { + // If this is major change, re-init client info table if its self IP changes. + if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) { + selfIP := defaultRouteIP() + if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" { + p.stopClientInfoDiscover() + p.setupClientInfoDiscover(selfIP) + p.runClientInfoDiscover(ctx) + } + } + } + // Check for changes in valid interfaces changed := false activeInterfaceExists := false diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index fd49764..4c9270c 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -72,6 +72,7 @@ var useSystemdResolved = false type prog struct { mu sync.Mutex + wg sync.WaitGroup waitCh chan struct{} stopCh chan struct{} reloadCh chan struct{} // For Windows. @@ -451,7 +452,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } var wg sync.WaitGroup - wg.Add(len(p.cfg.Listener)) + p.wg = wg + p.wg.Add(len(p.cfg.Listener)) for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { @@ -477,12 +479,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) - } + p.setupClientInfoDiscover(defaultRouteIP()) } // context for managing spawn goroutines. @@ -491,12 +488,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { // Newer versions of android and iOS denies permission which breaks connectivity. if !isMobile() && !reload { - wg.Add(1) - go func() { - defer wg.Done() - p.ciTable.Init() - p.ciTable.RefreshLoop(ctx) - }() + p.runClientInfoDiscover(ctx) go p.watchLinkState(ctx) } @@ -511,7 +503,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { + if err := p.serveDNS(ctx, listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } }(listenerNum) @@ -519,7 +511,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { go func() { defer func() { cancelFunc() - wg.Done() + p.wg.Done() }() select { case <-p.stopCh: @@ -540,19 +532,19 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { close(p.onStartedDone) - wg.Add(1) + p.wg.Add(1) go func() { - defer wg.Done() + defer p.wg.Done() // Check for possible DNS loop. p.checkDnsLoop() // Start check DNS loop ticker. p.checkDnsLoopTicker(ctx) }() - wg.Add(1) + p.wg.Add(1) // Prometheus exporter goroutine. go func() { - defer wg.Done() + defer p.wg.Done() p.runMetricsServer(ctx, reloadCh) }() @@ -567,7 +559,34 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.postRun() p.initInternalLogging(logWriters) } - wg.Wait() + p.wg.Wait() +} + +// setupClientInfoDiscover performs necessary works for running client info discover. +func (p *prog) setupClientInfoDiscover(selfIP string) { + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) + } +} + +// runClientInfoDiscover runs the client info discover in background. +func (p *prog) runClientInfoDiscover(ctx context.Context) { + p.wg.Add(1) + go func() { + defer p.wg.Done() + p.ciTable.Init() + p.ciTable.RefreshLoop(ctx) + }() +} + +// stopClientInfoDiscover stops the current client info discover goroutine. +// It blocks until the goroutine terminated. +func (p *prog) stopClientInfoDiscover() { + p.ciTable.Stop() + mainLog.Load().Debug().Msg("stopped client info discover") } // metricsEnabled reports whether prometheus exporter is enabled/disabled. diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 780334b..04ec4c3 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -77,6 +77,7 @@ type Table struct { hostnameResolvers []HostnameResolver refreshers []refresher initOnce sync.Once + stopOnce sync.Once refreshInterval int dhcp *dhcp @@ -90,6 +91,7 @@ type Table struct { vni *virtualNetworkIface svcCfg ctrld.ServiceConfig quitCh chan struct{} + stopCh chan struct{} selfIP string cdUID string ptrNameservers []string @@ -103,6 +105,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), + stopCh: make(chan struct{}), selfIP: selfIP, cdUID: cdUID, ptrNameservers: ns, @@ -120,24 +123,47 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { // RefreshLoop runs all the refresher to update new client info data. func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval)) - defer timer.Stop() + defer func() { + timer.Stop() + close(t.quitCh) + }() for { select { case <-timer.C: - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() + case <-t.stopCh: + return case <-ctx.Done(): - close(t.quitCh) return } } } +// Init initializes all client info discovers. func (t *Table) Init() { t.initOnce.Do(t.init) } +// Refresh forces all discovers to retrieve new data. +func (t *Table) Refresh() { + for _, r := range t.refreshers { + _ = r.refresh() + } +} + +// Stop stops all the discovers. +// It blocks until all the discovers done. +func (t *Table) Stop() { + t.stopOnce.Do(func() { + close(t.stopCh) + }) + <-t.quitCh +} + +func (t *Table) SelfIP() string { + return t.selfIP +} + func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { @@ -381,9 +407,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry { // ListClients returns list of clients discovered by ctrld. func (t *Table) ListClients() []*Client { - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() ipMap := make(map[string]*Client) il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni} for _, ir := range il {