all: optimizing multiple queries to upstreams

To guard ctrld from possible DoS to remote upstreams, this commit
implements following things:

 - Optimizing multiple queries with the same domain and qtype to use
   singleflight group, so there's only 1 query to remote upstreams at
   any time.
 - Adding a hot cache with 1 second TTL, so repeated queries will re-use
   the result from cache if existed, preventing unnecessary requests to
   remote upstreams.
This commit is contained in:
Cuong Manh Le
2025-05-21 19:33:54 +07:00
committed by Cuong Manh Le
parent 62f73bcaa2
commit a983dfaee2
2 changed files with 189 additions and 14 deletions

View File

@@ -9,12 +9,14 @@ import (
"net/netip"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/rs/zerolog"
"golang.org/x/sync/singleflight"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
)
@@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
type osResolver struct {
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
group *singleflight.Group
cache *sync.Map
}
type osResolverResult struct {
@@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired
return dnsClient.ExchangeContext(ctx, msg, server)
}
const hotCacheTTL = time.Second
// Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first
// The Query is sent to all nameservers concurrently, and the first
// success response will be returned.
//
// To guard against unexpected DoS to upstreams, multiple queries of
// the same Qtype to a domain will be shared, so there's only 1 qps
// for each upstream at any time.
//
// Further, a hot cache will be used, so repeated queries will be cached
// for a short period (currently 1 second), reducing unnecessary traffics
// sent to upstreams.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("no question found")
}
domain := strings.TrimSuffix(msg.Question[0].Name, ".")
qtype := msg.Question[0].Qtype
// Unique key for the singleflight group.
key := fmt.Sprintf("%s:%d:", domain, qtype)
// Checking the cache first.
if val, ok := o.cache.Load(key); ok {
if val, ok := val.(*dns.Msg); ok {
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
res := val.Copy()
res.SetRcode(msg, val.Rcode)
return res, nil
}
}
// Ensure only one DNS query is in flight for the key.
v, err, shared := o.group.Do(key, func() (interface{}, error) {
msg, err := o.resolve(ctx, msg)
if err != nil {
return nil, err
}
// If we got an answer, storing it to the hot cache for hotCacheTTL
// This prevents possible DoS to upstream, ensuring there's only 1 QPS.
o.cache.Store(key, msg)
// Depends on go runtime scheduling, the result may end up in hot cache longer
// than hotCacheTTL duration. However, this is fine since we only want to guard
// against DoS attack. The result will be cleaned from the cache eventually.
time.AfterFunc(hotCacheTTL, func() {
o.removeCache(key)
})
return msg, nil
})
if err != nil {
return nil, err
}
sharedMsg, ok := v.(*dns.Msg)
if !ok {
return nil, fmt.Errorf("invalid answer for key: %s", key)
}
res := sharedMsg.Copy()
res.SetRcode(msg, sharedMsg.Rcode)
if shared {
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
}
return res, nil
}
// resolve sends the query to current nameservers.
func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
publicServers := *o.publicServers.Load()
var nss []string
if p := o.lanServers.Load(); p != nil {
@@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
return nil, errors.Join(errs...)
}
func (o *osResolver) removeCache(key string) {
o.cache.Delete(key)
}
type legacyResolver struct {
uc *UpstreamConfig
}
@@ -627,10 +700,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
// newResolverWithNameserver returns an OS resolver from given nameservers list.
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver {
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
r := &osResolver{}
r := &osResolver{
group: &singleflight.Group{},
cache: &sync.Map{},
}
var publicNss []string
var lanNss []string
for _, ns := range slices.Sorted(slices.Values(nameservers)) {