all: update client info table on network changes

So the client metadata will be updated correctly when the device roaming
between networks.
This commit is contained in:
Cuong Manh Le
2025-02-04 17:58:05 +07:00
committed by Cuong Manh Le
parent 57ef717080
commit 595071b608
3 changed files with 87 additions and 31 deletions

View File

@@ -84,9 +84,9 @@ type upstreamForResult struct {
srcAddr string srcAddr string
} }
func (p *prog) serveDNS(listenerNum string) error { func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
// Start network monitoring // Start network monitoring
if err := p.monitorNetworkChanges(); err != nil { if err := p.monitorNetworkChanges(mainCtx); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
// Don't return here as we still want DNS service to run // Don't return here as we still want DNS service to run
} }
@@ -1316,7 +1316,7 @@ func FlushDNSCache() error {
} }
// monitorNetworkChanges starts monitoring for network interface changes // monitorNetworkChanges starts monitoring for network interface changes
func (p *prog) monitorNetworkChanges() error { func (p *prog) monitorNetworkChanges(ctx context.Context) error {
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
if err != nil { if err != nil {
return fmt.Errorf("creating network monitor: %w", err) return fmt.Errorf("creating network monitor: %w", err)
@@ -1336,6 +1336,19 @@ func (p *prog) monitorNetworkChanges() error {
oldIfs := parseInterfaceState(delta.Old) oldIfs := parseInterfaceState(delta.Old)
newIfs := parseInterfaceState(delta.New) newIfs := parseInterfaceState(delta.New)
// Client info discover only run on non-mobile platforms.
if !isMobile() {
// If this is major change, re-init client info table if its self IP changes.
if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) {
selfIP := defaultRouteIP()
if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" {
p.stopClientInfoDiscover()
p.setupClientInfoDiscover(selfIP)
p.runClientInfoDiscover(ctx)
}
}
}
// Check for changes in valid interfaces // Check for changes in valid interfaces
changed := false changed := false
activeInterfaceExists := false activeInterfaceExists := false

View File

@@ -72,6 +72,7 @@ var useSystemdResolved = false
type prog struct { type prog struct {
mu sync.Mutex mu sync.Mutex
wg sync.WaitGroup
waitCh chan struct{} waitCh chan struct{}
stopCh chan struct{} stopCh chan struct{}
reloadCh chan struct{} // For Windows. reloadCh chan struct{} // For Windows.
@@ -451,7 +452,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener)) p.wg = wg
p.wg.Add(len(p.cfg.Listener))
for _, nc := range p.cfg.Network { for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs { for _, cidr := range nc.Cidrs {
@@ -477,12 +479,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
} }
} }
p.setupUpstream(p.cfg) p.setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) p.setupClientInfoDiscover(defaultRouteIP())
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
} }
// context for managing spawn goroutines. // context for managing spawn goroutines.
@@ -491,12 +488,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
// Newer versions of android and iOS denies permission which breaks connectivity. // Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() && !reload { if !isMobile() && !reload {
wg.Add(1) p.runClientInfoDiscover(ctx)
go func() {
defer wg.Done()
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
}()
go p.watchLinkState(ctx) go p.watchLinkState(ctx)
} }
@@ -511,7 +503,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
} }
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil { if err := p.serveDNS(ctx, listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
} }
}(listenerNum) }(listenerNum)
@@ -519,7 +511,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
go func() { go func() {
defer func() { defer func() {
cancelFunc() cancelFunc()
wg.Done() p.wg.Done()
}() }()
select { select {
case <-p.stopCh: case <-p.stopCh:
@@ -540,19 +532,19 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
close(p.onStartedDone) close(p.onStartedDone)
wg.Add(1) p.wg.Add(1)
go func() { go func() {
defer wg.Done() defer p.wg.Done()
// Check for possible DNS loop. // Check for possible DNS loop.
p.checkDnsLoop() p.checkDnsLoop()
// Start check DNS loop ticker. // Start check DNS loop ticker.
p.checkDnsLoopTicker(ctx) p.checkDnsLoopTicker(ctx)
}() }()
wg.Add(1) p.wg.Add(1)
// Prometheus exporter goroutine. // Prometheus exporter goroutine.
go func() { go func() {
defer wg.Done() defer p.wg.Done()
p.runMetricsServer(ctx, reloadCh) p.runMetricsServer(ctx, reloadCh)
}() }()
@@ -567,7 +559,34 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.postRun() p.postRun()
p.initInternalLogging(logWriters) p.initInternalLogging(logWriters)
} }
wg.Wait() p.wg.Wait()
}
// setupClientInfoDiscover performs necessary works for running client info discover.
func (p *prog) setupClientInfoDiscover(selfIP string) {
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
}
// runClientInfoDiscover runs the client info discover in background.
func (p *prog) runClientInfoDiscover(ctx context.Context) {
p.wg.Add(1)
go func() {
defer p.wg.Done()
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
}()
}
// stopClientInfoDiscover stops the current client info discover goroutine.
// It blocks until the goroutine terminated.
func (p *prog) stopClientInfoDiscover() {
p.ciTable.Stop()
mainLog.Load().Debug().Msg("stopped client info discover")
} }
// metricsEnabled reports whether prometheus exporter is enabled/disabled. // metricsEnabled reports whether prometheus exporter is enabled/disabled.

View File

@@ -77,6 +77,7 @@ type Table struct {
hostnameResolvers []HostnameResolver hostnameResolvers []HostnameResolver
refreshers []refresher refreshers []refresher
initOnce sync.Once initOnce sync.Once
stopOnce sync.Once
refreshInterval int refreshInterval int
dhcp *dhcp dhcp *dhcp
@@ -90,6 +91,7 @@ type Table struct {
vni *virtualNetworkIface vni *virtualNetworkIface
svcCfg ctrld.ServiceConfig svcCfg ctrld.ServiceConfig
quitCh chan struct{} quitCh chan struct{}
stopCh chan struct{}
selfIP string selfIP string
cdUID string cdUID string
ptrNameservers []string ptrNameservers []string
@@ -103,6 +105,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
return &Table{ return &Table{
svcCfg: cfg.Service, svcCfg: cfg.Service,
quitCh: make(chan struct{}), quitCh: make(chan struct{}),
stopCh: make(chan struct{}),
selfIP: selfIP, selfIP: selfIP,
cdUID: cdUID, cdUID: cdUID,
ptrNameservers: ns, ptrNameservers: ns,
@@ -120,24 +123,47 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
// RefreshLoop runs all the refresher to update new client info data. // RefreshLoop runs all the refresher to update new client info data.
func (t *Table) RefreshLoop(ctx context.Context) { func (t *Table) RefreshLoop(ctx context.Context) {
timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval)) timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval))
defer timer.Stop() defer func() {
timer.Stop()
close(t.quitCh)
}()
for { for {
select { select {
case <-timer.C: case <-timer.C:
for _, r := range t.refreshers { t.Refresh()
_ = r.refresh() case <-t.stopCh:
} return
case <-ctx.Done(): case <-ctx.Done():
close(t.quitCh)
return return
} }
} }
} }
// Init initializes all client info discovers.
func (t *Table) Init() { func (t *Table) Init() {
t.initOnce.Do(t.init) t.initOnce.Do(t.init)
} }
// Refresh forces all discovers to retrieve new data.
func (t *Table) Refresh() {
for _, r := range t.refreshers {
_ = r.refresh()
}
}
// Stop stops all the discovers.
// It blocks until all the discovers done.
func (t *Table) Stop() {
t.stopOnce.Do(func() {
close(t.stopCh)
})
<-t.quitCh
}
func (t *Table) SelfIP() string {
return t.selfIP
}
func (t *Table) init() { func (t *Table) init() {
// Custom client ID presents, use it as the only source. // Custom client ID presents, use it as the only source.
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
@@ -381,9 +407,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry {
// ListClients returns list of clients discovered by ctrld. // ListClients returns list of clients discovered by ctrld.
func (t *Table) ListClients() []*Client { func (t *Table) ListClients() []*Client {
for _, r := range t.refreshers { t.Refresh()
_ = r.refresh()
}
ipMap := make(map[string]*Client) ipMap := make(map[string]*Client)
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni} il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
for _, ir := range il { for _, ir := range il {