diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9513e45..080bebc 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -275,7 +275,13 @@ macRules: } func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { - ip := ipFromARPA(msg.Question[0].Name) + cDomainName := msg.Question[0].Name + locked := p.ptrLoopGuard.TryLock(cDomainName) + defer p.ptrLoopGuard.Unlock(cDomainName) + if !locked { + return nil + } + ip := ipFromARPA(cDomainName) if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { answer := new(dns.Msg) answer.SetReply(msg) @@ -302,6 +308,11 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { q := msg.Question[0] hostname := strings.TrimSuffix(q.Name, ".") + locked := p.lanLoopGuard.TryLock(hostname) + defer p.lanLoopGuard.Unlock(hostname) + if !locked { + return nil + } if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { answer := new(dns.Msg) answer.SetReply(msg) diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 281d59c..82c4f63 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -69,6 +69,8 @@ func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) p := &prog{cfg: cfg} p.um = newUpstreamMonitor(p.cfg) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index ec25840..06a7e03 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -3,6 +3,7 @@ package cli import ( "context" "strings" + "sync" "time" "github.com/miekg/dns" @@ -15,6 +16,36 @@ const ( loopTestQtype = dns.TypeTXT ) +// newLoopGuard returns new loopGuard. +func newLoopGuard() *loopGuard { + return &loopGuard{inflight: make(map[string]struct{})} +} + +// loopGuard guards against DNS loop, ensuring only one query +// for a given domain is processed at a time. +type loopGuard struct { + mu sync.Mutex + inflight map[string]struct{} +} + +// TryLock marks the domain as being processed. +func (lg *loopGuard) TryLock(domain string) bool { + lg.mu.Lock() + defer lg.mu.Unlock() + if _, inflight := lg.inflight[domain]; !inflight { + lg.inflight[domain] = struct{}{} + return true + } + return false +} + +// Unlock marks the domain as being done. +func (lg *loopGuard) Unlock(domain string) { + lg.mu.Lock() + defer lg.mu.Unlock() + delete(lg.inflight, domain) +} + // isLoop reports whether the given upstream config is detected as having DNS loop. func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { p.loopMu.Lock() diff --git a/cmd/cli/loop_test.go b/cmd/cli/loop_test.go new file mode 100644 index 0000000..e8cfb2a --- /dev/null +++ b/cmd/cli/loop_test.go @@ -0,0 +1,38 @@ +package cli + +import ( + "sync" + "testing" +) + +func Test_loopGuard(t *testing.T) { + lg := newLoopGuard() + key := "foo" + + var mu sync.Mutex + i := 0 + n := 1000 + do := func() { + locked := lg.TryLock(key) + defer lg.Unlock(key) + if locked { + mu.Lock() + i++ + mu.Unlock() + } + } + + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + do() + }() + } + wg.Wait() + + if i == n { + t.Fatalf("i must not be increased %d times", n) + } +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 9fcb42f..55dfafc 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -66,6 +66,8 @@ type prog struct { ciTable *clientinfo.Table um *upstreamMonitor router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard loopMu sync.Mutex loop map[string]bool @@ -236,6 +238,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil {