From 4ea1e6479526fa3d1e07bcb6a616dee9d3522f1b Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 31 Jan 2023 03:16:41 +0700 Subject: [PATCH] all: make cache scope to upstream --- cmd/ctrld/dns_proxy.go | 20 ++++++++++++-------- cmd/ctrld/dns_proxy_test.go | 37 +++++++++++++++++++++++++++++++++++++ internal/dnscache/cache.go | 11 ++++++----- 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index bd5176b..4cdfab0 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -143,11 +143,20 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg { var staleAnswer *dns.Msg serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale + upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) + if len(upstreamConfigs) == 0 { + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + upstreams = []string{"upstream.os"} + } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { - if cachedValue := p.cache.Get(dnscache.NewKey(msg)); cachedValue != nil { + for _, upstream := range upstreams { + cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream)) + if cachedValue == nil { + continue + } answer := cachedValue.Msg.Copy() - answer.SetReply(msg) + answer.SetRcode(msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, proxyLog.Debug(), "hit cached response") @@ -157,11 +166,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i staleAnswer = answer } } - upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - if len(upstreamConfigs) == 0 { - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{"upstream.os"} - } resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) @@ -206,7 +210,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i expired = now.Add(time.Duration(cachedTTL) * time.Second) } setCachedAnswerTTL(answer, now, expired) - p.cache.Add(dnscache.NewKey(msg), dnscache.NewValue(answer, expired)) + p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired)) ctrld.Log(ctx, proxyLog.Debug(), "add cached response") } return answer diff --git a/cmd/ctrld/dns_proxy_test.go b/cmd/ctrld/dns_proxy_test.go index 8435f7d..82c0c95 100644 --- a/cmd/ctrld/dns_proxy_test.go +++ b/cmd/ctrld/dns_proxy_test.go @@ -4,11 +4,14 @@ import ( "context" "net" "testing" + "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/dnscache" "github.com/Control-D-Inc/ctrld/testhelper" ) @@ -115,3 +118,37 @@ func Test_prog_upstreamFor(t *testing.T) { }) } } + +func TestCache(t *testing.T) { + cfg := testhelper.SampleConfig(t) + prog := &prog{cfg: cfg} + for _, nc := range prog.cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + cacher, err := dnscache.NewLRUCache(4096) + require.NoError(t, err) + prog.cache = cacher + + msg := new(dns.Msg) + msg.SetQuestion("example.com", dns.TypeA) + msg.MsgHdr.RecursionDesired = true + answer1 := new(dns.Msg) + answer1.SetRcode(msg, dns.RcodeSuccess) + + prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute))) + answer2 := new(dns.Msg) + answer2.SetRcode(msg, dns.RcodeRefused) + prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute))) + + got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg) + got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg) + assert.NotSame(t, got1, got2) + assert.Equal(t, answer1.Rcode, got1.Rcode) + assert.Equal(t, answer2.Rcode, got2.Rcode) +} diff --git a/internal/dnscache/cache.go b/internal/dnscache/cache.go index efbd8e3..4aa7f69 100644 --- a/internal/dnscache/cache.go +++ b/internal/dnscache/cache.go @@ -16,9 +16,10 @@ type Cacher interface { // Key is the caching key for DNS message. type Key struct { - Qtype uint16 - Qclass uint16 - Name string + Qtype uint16 + Qclass uint16 + Name string + Upstream string } type Value struct { @@ -49,9 +50,9 @@ func NewLRUCache(size int) (*LRUCache, error) { } // NewKey creates a new cache key for given DNS message. -func NewKey(msg *dns.Msg) Key { +func NewKey(msg *dns.Msg, upstream string) Key { q := msg.Question[0] - return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name)} + return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name), Upstream: upstream} } // NewValue creates a new cache value for given DNS message.