From 30fefe7ab93bd17f2ed578a3eec6eb945891dc47 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 21 Dec 2022 00:29:04 +0700 Subject: [PATCH] all: add local caching This commit adds config params to enable local DNS response caching and control its behavior, allow tweaking the cache size, ttl override and serving stale response. --- cmd/ctrld/dns_proxy.go | 65 ++++++++++++++++++++++++++++++++ cmd/ctrld/prog.go | 12 +++++- config.go | 19 +++++----- docs/config.md | 27 ++++++++++++++ go.mod | 1 + go.sum | 2 + internal/dnscache/cache.go | 76 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 192 insertions(+), 10 deletions(-) create mode 100644 internal/dnscache/cache.go diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 9d14c66..b1df992 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -13,8 +13,11 @@ import ( "github.com/miekg/dns" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/dnscache" ) +const staleTTL = 60 * time.Second + func (p *prog) serveUDP(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated @@ -123,6 +126,22 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c } func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg { + var staleAnswer *dns.Msg + serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale + // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 + if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { + if cachedValue := p.cache.Get(dnscache.NewKey(msg)); cachedValue != nil { + answer := cachedValue.Msg.Copy() + answer.SetReply(msg) + now := time.Now() + if cachedValue.Expire.After(now) { + ctrld.Log(ctx, proxyLog.Debug(), "hit cached response") + setCachedAnswerTTL(answer, now, cachedValue.Expire) + return answer + } + staleAnswer = answer + } + } upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) @@ -148,12 +167,29 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i for n, upstreamConfig := range upstreamConfigs { answer := resolve(n, upstreamConfig, msg) if answer == nil { + if serveStaleCache && staleAnswer != nil { + ctrld.Log(ctx, proxyLog.Debug(), "serving stale cached response") + now := time.Now() + setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) + return staleAnswer + } continue } if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { ctrld.Log(ctx, proxyLog.Debug(), "failover rcode matched, process to next upstream") continue } + if p.cache != nil { + ttl := ttlFromMsg(answer) + now := time.Now() + expired := now.Add(time.Duration(ttl) * time.Second) + if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 { + expired = now.Add(time.Duration(cachedTTL) * time.Second) + } + setCachedAnswerTTL(answer, now, expired) + p.cache.Add(dnscache.NewKey(msg), dnscache.NewValue(answer, expired)) + ctrld.Log(ctx, proxyLog.Debug(), "add cached response") + } return answer } ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed") @@ -229,6 +265,35 @@ func containRcode(rcodes []int, rcode int) bool { return false } +func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) { + ttl := uint32(expiredTime.Sub(now).Seconds()) + if ttl < 0 { + return + } + + for _, rr := range answer.Answer { + rr.Header().Ttl = ttl + } + for _, rr := range answer.Ns { + rr.Header().Ttl = ttl + } + for _, rr := range answer.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + rr.Header().Ttl = ttl + } + } +} + +func ttlFromMsg(msg *dns.Msg) uint32 { + for _, rr := range msg.Answer { + return rr.Header().Ttl + } + for _, rr := range msg.Ns { + return rr.Header().Ttl + } + return 0 +} + var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: "os", diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 49b6c0e..9a889be 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -12,6 +12,7 @@ import ( "github.com/miekg/dns" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/dnscache" ) var errWindowsAddrInUse = syscall.Errno(0x2740) @@ -22,7 +23,8 @@ var svcConfig = &service.Config{ } type prog struct { - cfg *ctrld.Config + cfg *ctrld.Config + cache dnscache.Cacher } func (p *prog) Start(s service.Service) error { @@ -32,6 +34,14 @@ func (p *prog) Start(s service.Service) error { } func (p *prog) run() { + if p.cfg.Service.CacheEnable { + cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) + if err != nil { + mainLog.Error().Err(err).Msg("failed to create cacher, caching is disabled") + } else { + p.cache = cacher + } + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) diff --git a/config.go b/config.go index c8ba10d..650b371 100644 --- a/config.go +++ b/config.go @@ -64,10 +64,14 @@ type Config struct { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level"` - LogPath string `mapstructure:"log_path" toml:"log_path"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level"` + LogPath string `mapstructure:"log_path" toml:"log_path"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. @@ -179,11 +183,8 @@ func (uc *UpstreamConfig) setupDOH3Transport() { if err != nil { return nil, err } - localAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} - if strings.Index(uc.BootstrapIP, ":") > -1 { - localAddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0} - } - udpConn, err := net.ListenUDP("udp", localAddr) + + udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err } diff --git a/docs/config.md b/docs/config.md index 44c125e..483f0d1 100644 --- a/docs/config.md +++ b/docs/config.md @@ -33,6 +33,8 @@ If no configuration files found, a default `config.toml` file will be created in [service] log_level = "info" log_path = "" + cache_enable = true + cache_size = 4096 [network.0] cidrs = ["0.0.0.0/0"] @@ -109,6 +111,31 @@ Relative or absolute path of the log file. - Type: string - Required: no +### cache_enable +When `cache_enable = true`, all resolved DNS query responses will be cached for duration of the upstream record TTLs. + +- Type: boolean +- Required: no + +### cache_size +The number of cached records, must be a positive integer. Tweaking this value with care depends on your available RAM. +A minimum value `4096` should be enough for most use cases. + +An invalid `cache_size` value will disable the cache, regardless of `cache_enable` value. + +- Type: int +- Required: no + +### cache_ttl_override +When `cache_ttl_override` is set to a positive value (in seconds), TTLs are overridden to this value and cached for this long. + +- Type: int +- Required: no + +### cache_serve_stale +When `cache_serve_stale = true`, in cases of upstream failures (upstreams not reachable), `ctrld` will keep serving +stale cached records (regardless of their TTLs) until upstream comes online. + The above config will look like this at query time. ``` diff --git a/go.mod b/go.mod index 67ec79a..c3b636b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/go-playground/validator/v10 v10.11.1 + github.com/hashicorp/golang-lru/v2 v2.0.1 github.com/kardianos/service v1.2.1 github.com/lucas-clemente/quic-go v0.29.1 github.com/miekg/dns v1.1.50 diff --git a/go.sum b/go.sum index d971225..e123e72 100644 --- a/go.sum +++ b/go.sum @@ -112,6 +112,8 @@ github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4= +github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= diff --git a/internal/dnscache/cache.go b/internal/dnscache/cache.go new file mode 100644 index 0000000..efbd8e3 --- /dev/null +++ b/internal/dnscache/cache.go @@ -0,0 +1,76 @@ +package dnscache + +import ( + "strings" + "time" + + lru "github.com/hashicorp/golang-lru/v2" + "github.com/miekg/dns" +) + +// Cacher is the interface for caching DNS response. +type Cacher interface { + Get(Key) *Value + Add(Key, *Value) +} + +// Key is the caching key for DNS message. +type Key struct { + Qtype uint16 + Qclass uint16 + Name string +} + +type Value struct { + Expire time.Time + Msg *dns.Msg +} + +var _ Cacher = (*LRUCache)(nil) + +// LRUCache implements Cacher interface. +type LRUCache struct { + cacher *lru.ARCCache[Key, *Value] +} + +func (l *LRUCache) Get(key Key) *Value { + v, _ := l.cacher.Get(key) + return v +} + +func (l *LRUCache) Add(key Key, value *Value) { + l.cacher.Add(key, value) +} + +// NewLRUCache creates a new LRUCache instance with given size. +func NewLRUCache(size int) (*LRUCache, error) { + cacher, err := lru.NewARC[Key, *Value](size) + return &LRUCache{cacher: cacher}, err +} + +// NewKey creates a new cache key for given DNS message. +func NewKey(msg *dns.Msg) Key { + q := msg.Question[0] + return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name)} +} + +// NewValue creates a new cache value for given DNS message. +func NewValue(msg *dns.Msg, expire time.Time) *Value { + return &Value{ + Expire: expire, + Msg: msg, + } +} + +func normalizeQname(name string) string { + var b strings.Builder + b.Grow(len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + b.WriteByte(c) + } + return b.String() +}