From b50cccac85da927e42780fc3a0ee804ebded3f56 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 14 Mar 2024 20:02:44 +0700 Subject: [PATCH] all: add flush cache domains config --- cmd/cli/cli.go | 5 +++++ cmd/cli/dns_proxy.go | 4 ++++ cmd/cli/prog.go | 28 ++++++++++++++++---------- config.go | 41 +++++++++++++++++++------------------- config_test.go | 11 ++++++++++ docs/config.md | 9 ++++++++- internal/dnscache/cache.go | 8 ++++++++ 7 files changed, 74 insertions(+), 32 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 093192a..f9e8c68 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1831,6 +1831,11 @@ func fieldErrorMsg(fe validator.FieldError) string { return fmt.Sprintf("must define at least %s element", fe.Param()) } return fmt.Sprintf("minimum value: %q", fe.Param()) + case "max": + if fe.Kind() == reflect.Map || fe.Kind() == reflect.Slice { + return fmt.Sprintf("exceeded maximum number of elements: %s", fe.Param()) + } + return fmt.Sprintf("maximum value: %q", fe.Param()) case "len": if fe.Kind() == reflect.Slice { return fmt.Sprintf("must have at least %s element", fe.Param()) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index fb5f903..4cd4641 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -101,6 +101,10 @@ func (p *prog) serveDNS(listenerNum string) error { go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { + p.cache.Purge() + ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) + } remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) ci.ClientIDPref = p.cfg.Service.ClientIDPref diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index afe297c..b3f3abf 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -70,17 +70,18 @@ type prog struct { logConn net.Conn cs *controlServer - cfg *ctrld.Config - localUpstreams []string - ptrNameservers []string - appCallback *AppCallback - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrLoopGuard *loopGuard - lanLoopGuard *loopGuard + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + cacheFlushDomainsMap map[string]struct{} + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard loopMu sync.Mutex loop map[string]bool @@ -253,12 +254,17 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.loop = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() + p.cacheFlushDomainsMap = nil if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled") } else { p.cache = cacher + p.cacheFlushDomainsMap = make(map[string]struct{}, 256) + for _, domain := range p.cfg.Service.CacheFlushDomains { + p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{} + } } } diff --git a/config.go b/config.go index 56bb68d..582069c 100644 --- a/config.go +++ b/config.go @@ -179,26 +179,27 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` - ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` - MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` - MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` + ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` + MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` + MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. diff --git a/config_test.go b/config_test.go index 55a19f3..83a1e13 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package ctrld_test import ( + "fmt" "os" "strings" "testing" @@ -103,6 +104,7 @@ func TestConfigValidation(t *testing.T) { {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, {"invalid client id pref", configWithInvalidClientIDPref(t), true}, {"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false}, + {"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true}, } for _, tc := range tests { @@ -275,3 +277,12 @@ func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config { cfg.Service.ClientIDPref = "foo" return cfg } + +func configWithInvalidFlushCacheDomain(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.CacheFlushDomains = make([]string, 257) + for i := range cfg.Service.CacheFlushDomains { + cfg.Service.CacheFlushDomains[i] = fmt.Sprintf("%d.com", i) + } + return cfg +} diff --git a/docs/config.md b/docs/config.md index 1862316..d9c1dae 100644 --- a/docs/config.md +++ b/docs/config.md @@ -157,6 +157,13 @@ stale cached records (regardless of their TTLs) until upstream comes online. - Required: no - Default: false +### cache_flush_domains +When `ctrld` receives query with domain name in `cache_flush_domains`, the local cache will be discarded +before serving the query. + +- Type: array of strings +- Required: no + ### max_concurrent_requests The number of concurrent requests that will be handled, must be a non-negative integer. Tweaking this value depends on the capacity of your system. @@ -531,7 +538,7 @@ And within each policy, the rules are processed from top to bottom. ### failover_rcodes For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. -- Type: array of string +- Type: array of strings - Required: no - Default: [] - diff --git a/internal/dnscache/cache.go b/internal/dnscache/cache.go index 4aa7f69..af8883e 100644 --- a/internal/dnscache/cache.go +++ b/internal/dnscache/cache.go @@ -12,6 +12,7 @@ import ( type Cacher interface { Get(Key) *Value Add(Key, *Value) + Purge() } // Key is the caching key for DNS message. @@ -34,15 +35,22 @@ type LRUCache struct { cacher *lru.ARCCache[Key, *Value] } +// Get looks up key's value from cache. func (l *LRUCache) Get(key Key) *Value { v, _ := l.cacher.Get(key) return v } +// Add adds a value to cache. func (l *LRUCache) Add(key Key, value *Value) { l.cacher.Add(key, value) } +// Purge clears the cache. +func (l *LRUCache) Purge() { + l.cacher.Purge() +} + // NewLRUCache creates a new LRUCache instance with given size. func NewLRUCache(size int) (*LRUCache, error) { cacher, err := lru.NewARC[Key, *Value](size)