all: add local caching

This commit adds config params to enable local DNS response caching and
control its behavior, allow tweaking the cache size, ttl override and
serving stale response.
This commit is contained in:
Cuong Manh Le
2022-12-21 00:29:04 +07:00
committed by Cuong Manh Le
parent fa3c3e8a29
commit 30fefe7ab9
7 changed files with 192 additions and 10 deletions

View File

@@ -13,8 +13,11 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
) )
const staleTTL = 60 * time.Second
func (p *prog) serveUDP(listenerNum string) error { func (p *prog) serveUDP(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum] listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated // make sure ip is allocated
@@ -123,6 +126,22 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
} }
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg { func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg {
var staleAnswer *dns.Msg
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
if cachedValue := p.cache.Get(dnscache.NewKey(msg)); cachedValue != nil {
answer := cachedValue.Msg.Copy()
answer.SetReply(msg)
now := time.Now()
if cachedValue.Expire.After(now) {
ctrld.Log(ctx, proxyLog.Debug(), "hit cached response")
setCachedAnswerTTL(answer, now, cachedValue.Expire)
return answer
}
staleAnswer = answer
}
}
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
@@ -148,12 +167,29 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
for n, upstreamConfig := range upstreamConfigs { for n, upstreamConfig := range upstreamConfigs {
answer := resolve(n, upstreamConfig, msg) answer := resolve(n, upstreamConfig, msg)
if answer == nil { if answer == nil {
if serveStaleCache && staleAnswer != nil {
ctrld.Log(ctx, proxyLog.Debug(), "serving stale cached response")
now := time.Now()
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
return staleAnswer
}
continue continue
} }
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
ctrld.Log(ctx, proxyLog.Debug(), "failover rcode matched, process to next upstream") ctrld.Log(ctx, proxyLog.Debug(), "failover rcode matched, process to next upstream")
continue continue
} }
if p.cache != nil {
ttl := ttlFromMsg(answer)
now := time.Now()
expired := now.Add(time.Duration(ttl) * time.Second)
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
expired = now.Add(time.Duration(cachedTTL) * time.Second)
}
setCachedAnswerTTL(answer, now, expired)
p.cache.Add(dnscache.NewKey(msg), dnscache.NewValue(answer, expired))
ctrld.Log(ctx, proxyLog.Debug(), "add cached response")
}
return answer return answer
} }
ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed") ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed")
@@ -229,6 +265,35 @@ func containRcode(rcodes []int, rcode int) bool {
return false return false
} }
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
ttl := uint32(expiredTime.Sub(now).Seconds())
if ttl < 0 {
return
}
for _, rr := range answer.Answer {
rr.Header().Ttl = ttl
}
for _, rr := range answer.Ns {
rr.Header().Ttl = ttl
}
for _, rr := range answer.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
rr.Header().Ttl = ttl
}
}
}
func ttlFromMsg(msg *dns.Msg) uint32 {
for _, rr := range msg.Answer {
return rr.Header().Ttl
}
for _, rr := range msg.Ns {
return rr.Header().Ttl
}
return 0
}
var osUpstreamConfig = &ctrld.UpstreamConfig{ var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver", Name: "OS resolver",
Type: "os", Type: "os",

View File

@@ -12,6 +12,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
) )
var errWindowsAddrInUse = syscall.Errno(0x2740) var errWindowsAddrInUse = syscall.Errno(0x2740)
@@ -22,7 +23,8 @@ var svcConfig = &service.Config{
} }
type prog struct { type prog struct {
cfg *ctrld.Config cfg *ctrld.Config
cache dnscache.Cacher
} }
func (p *prog) Start(s service.Service) error { func (p *prog) Start(s service.Service) error {
@@ -32,6 +34,14 @@ func (p *prog) Start(s service.Service) error {
} }
func (p *prog) run() { func (p *prog) run() {
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
mainLog.Error().Err(err).Msg("failed to create cacher, caching is disabled")
} else {
p.cache = cacher
}
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener)) wg.Add(len(p.cfg.Listener))

View File

@@ -64,10 +64,14 @@ type Config struct {
// ServiceConfig specifies the general ctrld config. // ServiceConfig specifies the general ctrld config.
type ServiceConfig struct { type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level"` LogLevel string `mapstructure:"log_level" toml:"log_level"`
LogPath string `mapstructure:"log_path" toml:"log_path"` LogPath string `mapstructure:"log_path" toml:"log_path"`
Daemon bool `mapstructure:"-" toml:"-"` CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable"`
AllocateIP bool `mapstructure:"-" toml:"-"` CacheSize int `mapstructure:"cache_size" toml:"cache_size"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
} }
// NetworkConfig specifies configuration for networks where ctrld will handle requests. // NetworkConfig specifies configuration for networks where ctrld will handle requests.
@@ -179,11 +183,8 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
if err != nil { if err != nil {
return nil, err return nil, err
} }
localAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
if strings.Index(uc.BootstrapIP, ":") > -1 { udpConn, err := net.ListenUDP("udp", nil)
localAddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -33,6 +33,8 @@ If no configuration files found, a default `config.toml` file will be created in
[service] [service]
log_level = "info" log_level = "info"
log_path = "" log_path = ""
cache_enable = true
cache_size = 4096
[network.0] [network.0]
cidrs = ["0.0.0.0/0"] cidrs = ["0.0.0.0/0"]
@@ -109,6 +111,31 @@ Relative or absolute path of the log file.
- Type: string - Type: string
- Required: no - Required: no
### cache_enable
When `cache_enable = true`, all resolved DNS query responses will be cached for duration of the upstream record TTLs.
- Type: boolean
- Required: no
### cache_size
The number of cached records, must be a positive integer. Tweaking this value with care depends on your available RAM.
A minimum value `4096` should be enough for most use cases.
An invalid `cache_size` value will disable the cache, regardless of `cache_enable` value.
- Type: int
- Required: no
### cache_ttl_override
When `cache_ttl_override` is set to a positive value (in seconds), TTLs are overridden to this value and cached for this long.
- Type: int
- Required: no
### cache_serve_stale
When `cache_serve_stale = true`, in cases of upstream failures (upstreams not reachable), `ctrld` will keep serving
stale cached records (regardless of their TTLs) until upstream comes online.
The above config will look like this at query time. The above config will look like this at query time.
``` ```

1
go.mod
View File

@@ -4,6 +4,7 @@ go 1.19
require ( require (
github.com/go-playground/validator/v10 v10.11.1 github.com/go-playground/validator/v10 v10.11.1
github.com/hashicorp/golang-lru/v2 v2.0.1
github.com/kardianos/service v1.2.1 github.com/kardianos/service v1.2.1
github.com/lucas-clemente/quic-go v0.29.1 github.com/lucas-clemente/quic-go v0.29.1
github.com/miekg/dns v1.1.50 github.com/miekg/dns v1.1.50

2
go.sum
View File

@@ -112,6 +112,8 @@ github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b
github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4=
github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64=

View File

@@ -0,0 +1,76 @@
package dnscache
import (
"strings"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/miekg/dns"
)
// Cacher is the interface for caching DNS response.
type Cacher interface {
Get(Key) *Value
Add(Key, *Value)
}
// Key is the caching key for DNS message.
type Key struct {
Qtype uint16
Qclass uint16
Name string
}
type Value struct {
Expire time.Time
Msg *dns.Msg
}
var _ Cacher = (*LRUCache)(nil)
// LRUCache implements Cacher interface.
type LRUCache struct {
cacher *lru.ARCCache[Key, *Value]
}
func (l *LRUCache) Get(key Key) *Value {
v, _ := l.cacher.Get(key)
return v
}
func (l *LRUCache) Add(key Key, value *Value) {
l.cacher.Add(key, value)
}
// NewLRUCache creates a new LRUCache instance with given size.
func NewLRUCache(size int) (*LRUCache, error) {
cacher, err := lru.NewARC[Key, *Value](size)
return &LRUCache{cacher: cacher}, err
}
// NewKey creates a new cache key for given DNS message.
func NewKey(msg *dns.Msg) Key {
q := msg.Question[0]
return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name)}
}
// NewValue creates a new cache value for given DNS message.
func NewValue(msg *dns.Msg, expire time.Time) *Value {
return &Value{
Expire: expire,
Msg: msg,
}
}
func normalizeQname(name string) string {
var b strings.Builder
b.Grow(len(name))
for i := 0; i < len(name); i++ {
c := name[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
b.WriteByte(c)
}
return b.String()
}