all: add pull API config based on special DNS query

For query domain that matches "uid.verify.controld.com" in cd mode, and
the uid has the same value with "--cd" flag, ctrld will fetch uid config
from ControlD API, using this config if valid.

This is useful for force syncing API without waiting until the API
reload ticker fire.
This commit is contained in:
Cuong Manh Le
2024-09-19 21:43:09 +07:00
committed by Cuong Manh Le
parent ede354166b
commit e6f256d640
4 changed files with 90 additions and 48 deletions
+37
View File
@@ -151,6 +151,7 @@ func (p *prog) serveDNS(listenerNum string) error {
ufr: ur,
})
go p.doSelfUninstall(pr.answer)
answer = pr.answer
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
@@ -168,6 +169,7 @@ func (p *prog) serveDNS(listenerNum string) error {
go func() {
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
p.forceFetchingAPI(domain)
}()
if err := w.WriteMsg(answer); err != nil {
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
@@ -926,6 +928,41 @@ func (p *prog) performCaptivePortalDetection() {
mainLog.Load().Warn().Msg("captive portal login finished, stop leaking query")
}
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
// and the domain == "cdUID.verify.controld.com"
func (p *prog) forceFetchingAPI(domain string) {
if cdUID == "" {
return
}
resolverID, parent, _ := strings.Cut(domain, ".")
if resolverID != cdUID {
return
}
switch {
case cdDev && parent == "verify.controld.dev":
// match ControlD dev
case parent == "verify.controld.com":
// match ControlD
default:
return
}
_ = p.apiForceReloadGroup.DoChan("force_sync_api", func() (interface{}, error) {
p.apiForceReloadCh <- struct{}{}
// Wait here to prevent abusing API if we are flooded.
time.Sleep(timeDurationOrDefault(p.cfg.Service.ForceRefetchWaitTime, 30) * time.Second)
return nil, nil
})
}
// timeDurationOrDefault returns time duration value from n if not nil.
// Otherwise, it returns time duration value defaultN.
func timeDurationOrDefault(n *int, defaultN int) time.Duration {
if n != nil && *n > 0 {
return time.Duration(*n)
}
return time.Duration(defaultN)
}
// queryFromSelf reports whether the input IP is from device running ctrld.
func queryFromSelf(ip string) bool {
netIP := netip.MustParseAddr(ip)