diff --git a/resolver.go b/resolver.go index a44ddb2..52515f9 100644 --- a/resolver.go +++ b/resolver.go @@ -9,12 +9,14 @@ import ( "net/netip" "runtime" "slices" + "strings" "sync" "sync/atomic" "time" "github.com/miekg/dns" "github.com/rs/zerolog" + "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { type osResolver struct { lanServers atomic.Pointer[[]string] publicServers atomic.Pointer[[]string] + group *singleflight.Group + cache *sync.Map } type osResolverResult struct { @@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired return dnsClient.ExchangeContext(ctx, msg, server) } +const hotCacheTTL = time.Second + // Resolve resolves DNS queries using pre-configured nameservers. -// Query is sent to all nameservers concurrently, and the first +// The Query is sent to all nameservers concurrently, and the first // success response will be returned. +// +// To guard against unexpected DoS to upstreams, multiple queries of +// the same Qtype to a domain will be shared, so there's only 1 qps +// for each upstream at any time. +// +// Further, a hot cache will be used, so repeated queries will be cached +// for a short period (currently 1 second), reducing unnecessary traffics +// sent to upstreams. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("no question found") + } + domain := strings.TrimSuffix(msg.Question[0].Name, ".") + qtype := msg.Question[0].Qtype + + // Unique key for the singleflight group. + key := fmt.Sprintf("%s:%d:", domain, qtype) + + // Checking the cache first. + if val, ok := o.cache.Load(key); ok { + if val, ok := val.(*dns.Msg); ok { + Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + res := val.Copy() + res.SetRcode(msg, val.Rcode) + return res, nil + } + } + + // Ensure only one DNS query is in flight for the key. + v, err, shared := o.group.Do(key, func() (interface{}, error) { + msg, err := o.resolve(ctx, msg) + if err != nil { + return nil, err + } + // If we got an answer, storing it to the hot cache for hotCacheTTL + // This prevents possible DoS to upstream, ensuring there's only 1 QPS. + o.cache.Store(key, msg) + // Depends on go runtime scheduling, the result may end up in hot cache longer + // than hotCacheTTL duration. However, this is fine since we only want to guard + // against DoS attack. The result will be cleaned from the cache eventually. + time.AfterFunc(hotCacheTTL, func() { + o.removeCache(key) + }) + return msg, nil + }) + if err != nil { + return nil, err + } + + sharedMsg, ok := v.(*dns.Msg) + if !ok { + return nil, fmt.Errorf("invalid answer for key: %s", key) + } + res := sharedMsg.Copy() + res.SetRcode(msg, sharedMsg.Rcode) + if shared { + Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + } + + return res, nil +} + +// resolve sends the query to current nameservers. +func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { publicServers := *o.publicServers.Load() var nss []string if p := o.lanServers.Load(); p != nil { @@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, errors.Join(errs...) } +func (o *osResolver) removeCache(key string) { + o.cache.Delete(key) +} + type legacyResolver struct { uc *UpstreamConfig } @@ -627,10 +700,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // newResolverWithNameserver returns an OS resolver from given nameservers list. // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers) - r := &osResolver{} + r := &osResolver{ + group: &singleflight.Group{}, + cache: &sync.Map{}, + } var publicNss []string var lanNss []string for _, ns := range slices.Sorted(slices.Values(nameservers)) { diff --git a/resolver_test.go b/resolver_test.go index fb6831b..a75e748 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "testing" "time" @@ -16,8 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) + resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -50,8 +50,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { t.Error("not a LAN query") return } - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + resolver := newResolverWithNameserver([]string{"76.76.2.0:53"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -107,11 +106,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { }() // We now create an osResolver which has both a LAN and public nameserver. - resolver := &osResolver{} - // Explicitly store the LAN nameserver. - resolver.lanServers.Store(&[]string{lanAddr}) - // And store the public nameservers. - resolver.publicServers.Store(&publicNS) + nss := []string{lanAddr} + nss = append(nss, publicNS...) + resolver := newResolverWithNameserver(nss) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) @@ -139,6 +136,102 @@ func Test_osResolver_InitializationRace(t *testing.T) { wg.Wait() } +func Test_osResolver_Singleflight(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + n := 10 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + _, err := or.Resolve(context.Background(), m) + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + + // All above queries should only make 1 call to server. + if call.Load() != 1 { + t.Fatalf("expected 1 result from singleflight lookup, got %d", call) + } +} + +func Test_osResolver_HotCache(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + // Make 2 repeated queries to server, should hit hot cache. + for i := 0; i < 2; i++ { + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + } + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + + timeoutChan := make(chan struct{}) + time.AfterFunc(5*time.Second, func() { + close(timeoutChan) + }) + + for { + select { + case <-timeoutChan: + t.Fatal("timed out waiting for cache cleaned") + default: + count := 0 + or.cache.Range(func(key, value interface{}) bool { + count++ + return true + }) + if count != 0 { + t.Logf("hot cache is not empty: %d elements", count) + continue + } + } + break + } + + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + if call.Load() != 2 { + t.Fatal("cache hit unexpectedly") + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -208,3 +301,12 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { w.WriteMsg(m) } } + +func countHandler(call *atomic.Int64) dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + call.Add(1) + } +}