mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
all: optimizing multiple queries to upstreams
To guard ctrld from possible DoS to remote upstreams, this commit implements following things: - Optimizing multiple queries with the same domain and qtype to use singleflight group, so there's only 1 query to remote upstreams at any time. - Adding a hot cache with 1 second TTL, so repeated queries will re-use the result from cache if existed, preventing unnecessary requests to remote upstreams.
This commit is contained in:
committed by
Cuong Manh Le
parent
62f73bcaa2
commit
a983dfaee2
83
resolver.go
83
resolver.go
@@ -9,12 +9,14 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
@@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
type osResolver struct {
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
group *singleflight.Group
|
||||
cache *sync.Map
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired
|
||||
return dnsClient.ExchangeContext(ctx, msg, server)
|
||||
}
|
||||
|
||||
const hotCacheTTL = time.Second
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// The Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
//
|
||||
// To guard against unexpected DoS to upstreams, multiple queries of
|
||||
// the same Qtype to a domain will be shared, so there's only 1 qps
|
||||
// for each upstream at any time.
|
||||
//
|
||||
// Further, a hot cache will be used, so repeated queries will be cached
|
||||
// for a short period (currently 1 second), reducing unnecessary traffics
|
||||
// sent to upstreams.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if len(msg.Question) == 0 {
|
||||
return nil, errors.New("no question found")
|
||||
}
|
||||
domain := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
qtype := msg.Question[0].Qtype
|
||||
|
||||
// Unique key for the singleflight group.
|
||||
key := fmt.Sprintf("%s:%d:", domain, qtype)
|
||||
|
||||
// Checking the cache first.
|
||||
if val, ok := o.cache.Load(key); ok {
|
||||
if val, ok := val.(*dns.Msg); ok {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
res := val.Copy()
|
||||
res.SetRcode(msg, val.Rcode)
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure only one DNS query is in flight for the key.
|
||||
v, err, shared := o.group.Do(key, func() (interface{}, error) {
|
||||
msg, err := o.resolve(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If we got an answer, storing it to the hot cache for hotCacheTTL
|
||||
// This prevents possible DoS to upstream, ensuring there's only 1 QPS.
|
||||
o.cache.Store(key, msg)
|
||||
// Depends on go runtime scheduling, the result may end up in hot cache longer
|
||||
// than hotCacheTTL duration. However, this is fine since we only want to guard
|
||||
// against DoS attack. The result will be cleaned from the cache eventually.
|
||||
time.AfterFunc(hotCacheTTL, func() {
|
||||
o.removeCache(key)
|
||||
})
|
||||
return msg, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedMsg, ok := v.(*dns.Msg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid answer for key: %s", key)
|
||||
}
|
||||
res := sharedMsg.Copy()
|
||||
res.SetRcode(msg, sharedMsg.Rcode)
|
||||
if shared {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// resolve sends the query to current nameservers.
|
||||
func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServers.Load()
|
||||
var nss []string
|
||||
if p := o.lanServers.Load(); p != nil {
|
||||
@@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (o *osResolver) removeCache(key string) {
|
||||
o.cache.Delete(key)
|
||||
}
|
||||
|
||||
type legacyResolver struct {
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
@@ -627,10 +700,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// newResolverWithNameserver returns an OS resolver from given nameservers list.
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
|
||||
r := &osResolver{}
|
||||
r := &osResolver{
|
||||
group: &singleflight.Group{},
|
||||
cache: &sync.Map{},
|
||||
}
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
|
||||
|
||||
120
resolver_test.go
120
resolver_test.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
|
||||
resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -50,8 +50,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) {
|
||||
t.Error("not a LAN query")
|
||||
return
|
||||
}
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
|
||||
resolver := newResolverWithNameserver([]string{"76.76.2.0:53"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -107,11 +106,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
}()
|
||||
|
||||
// We now create an osResolver which has both a LAN and public nameserver.
|
||||
resolver := &osResolver{}
|
||||
// Explicitly store the LAN nameserver.
|
||||
resolver.lanServers.Store(&[]string{lanAddr})
|
||||
// And store the public nameservers.
|
||||
resolver.publicServers.Store(&publicNS)
|
||||
nss := []string{lanAddr}
|
||||
nss = append(nss, publicNS...)
|
||||
resolver := newResolverWithNameserver(nss)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
@@ -139,6 +136,102 @@ func Test_osResolver_InitializationRace(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func Test_osResolver_Singleflight(t *testing.T) {
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
call := &atomic.Int64{}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
or := newResolverWithNameserver([]string{lanAddr})
|
||||
domain := "controld.com"
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
_, err := or.Resolve(context.Background(), m)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// All above queries should only make 1 call to server.
|
||||
if call.Load() != 1 {
|
||||
t.Fatalf("expected 1 result from singleflight lookup, got %d", call)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_HotCache(t *testing.T) {
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
call := &atomic.Int64{}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
or := newResolverWithNameserver([]string{lanAddr})
|
||||
domain := "controld.com"
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
|
||||
// Make 2 repeated queries to server, should hit hot cache.
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if call.Load() != 1 {
|
||||
t.Fatalf("cache not hit, server was called: %d", call.Load())
|
||||
}
|
||||
|
||||
timeoutChan := make(chan struct{})
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
close(timeoutChan)
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("timed out waiting for cache cleaned")
|
||||
default:
|
||||
count := 0
|
||||
or.cache.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count != 0 {
|
||||
t.Logf("hot cache is not empty: %d elements", count)
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if call.Load() != 2 {
|
||||
t.Fatal("cache hit unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -208,3 +301,12 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
func countHandler(call *atomic.Int64) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
w.WriteMsg(m)
|
||||
call.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user