all: make cache scope to upstream

This commit is contained in:
Cuong Manh Le
2023-01-31 03:16:41 +07:00
committed by Cuong Manh Le
parent 06372031b5
commit 4ea1e64795
3 changed files with 55 additions and 13 deletions

View File

@@ -143,11 +143,20 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg {
var staleAnswer *dns.Msg
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{"upstream.os"}
}
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
if cachedValue := p.cache.Get(dnscache.NewKey(msg)); cachedValue != nil {
for _, upstream := range upstreams {
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
if cachedValue == nil {
continue
}
answer := cachedValue.Msg.Copy()
answer.SetReply(msg)
answer.SetRcode(msg, answer.Rcode)
now := time.Now()
if cachedValue.Expire.After(now) {
ctrld.Log(ctx, proxyLog.Debug(), "hit cached response")
@@ -157,11 +166,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
staleAnswer = answer
}
}
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{"upstream.os"}
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
@@ -206,7 +210,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
expired = now.Add(time.Duration(cachedTTL) * time.Second)
}
setCachedAnswerTTL(answer, now, expired)
p.cache.Add(dnscache.NewKey(msg), dnscache.NewValue(answer, expired))
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
ctrld.Log(ctx, proxyLog.Debug(), "add cached response")
}
return answer

View File

@@ -4,11 +4,14 @@ import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/testhelper"
)
@@ -115,3 +118,37 @@ func Test_prog_upstreamFor(t *testing.T) {
})
}
}
func TestCache(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
for _, nc := range prog.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatal(err)
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
cacher, err := dnscache.NewLRUCache(4096)
require.NoError(t, err)
prog.cache = cacher
msg := new(dns.Msg)
msg.SetQuestion("example.com", dns.TypeA)
msg.MsgHdr.RecursionDesired = true
answer1 := new(dns.Msg)
answer1.SetRcode(msg, dns.RcodeSuccess)
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
answer2 := new(dns.Msg)
answer2.SetRcode(msg, dns.RcodeRefused)
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg)
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg)
assert.NotSame(t, got1, got2)
assert.Equal(t, answer1.Rcode, got1.Rcode)
assert.Equal(t, answer2.Rcode, got2.Rcode)
}

View File

@@ -16,9 +16,10 @@ type Cacher interface {
// Key is the caching key for DNS message.
type Key struct {
Qtype uint16
Qclass uint16
Name string
Qtype uint16
Qclass uint16
Name string
Upstream string
}
type Value struct {
@@ -49,9 +50,9 @@ func NewLRUCache(size int) (*LRUCache, error) {
}
// NewKey creates a new cache key for given DNS message.
func NewKey(msg *dns.Msg) Key {
func NewKey(msg *dns.Msg, upstream string) Key {
q := msg.Question[0]
return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name)}
return Key{Qtype: q.Qtype, Qclass: q.Qclass, Name: normalizeQname(q.Name), Upstream: upstream}
}
// NewValue creates a new cache value for given DNS message.