diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 4cc4f29..445ae70 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -50,6 +50,7 @@ func (p *prog) serveDNS(listenerNum string) error { handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() + go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) reqId := requestID() @@ -287,6 +288,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i if upstreamConfig == nil { continue } + if p.isLoop(upstreamConfig) { + mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint) + continue + } if p.um.isDown(upstreams[n]) { ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) continue diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go new file mode 100644 index 0000000..87dabf8 --- /dev/null +++ b/cmd/cli/loop.go @@ -0,0 +1,100 @@ +package cli + +import ( + "context" + "strings" + "time" + + "github.com/miekg/dns" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + loopTestDomain = ".test" + loopTestQtype = dns.TypeTXT +) + +// isLoop reports whether the given upstream config is detected as having DNS loop. +func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { + p.loopMu.Lock() + defer p.loopMu.Unlock() + return p.loop[uc.UID()] +} + +// detectLoop checks if the given DNS message is initialized sent by ctrld. +// If yes, marking the corresponding upstream as loop, prevent infinite DNS +// forwarding loop. +// +// See p.checkDnsLoop for more details how it works. +func (p *prog) detectLoop(msg *dns.Msg) { + if len(msg.Question) != 1 { + return + } + q := msg.Question[0] + if q.Qtype != loopTestQtype { + return + } + unFQDNname := strings.TrimSuffix(q.Name, ".") + uid := strings.TrimSuffix(unFQDNname, loopTestDomain) + p.loopMu.Lock() + if _, loop := p.loop[uid]; loop { + p.loop[uid] = loop + } + p.loopMu.Unlock() +} + +// checkDnsLoop sends a message to check if there's any DNS forwarding loop +// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect. +// +// - Generating a TXT test query and sending it to all upstream. +// - The test query is formed by upstream UID and test domain: .test +// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop). +// +// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html +func (p *prog) checkDnsLoop() { + mainLog.Load().Debug().Msg("start checking DNS loop") + upstream := make(map[string]*ctrld.UpstreamConfig) + p.loopMu.Lock() + for _, uc := range p.cfg.Upstream { + uid := uc.UID() + p.loop[uid] = false + upstream[uid] = uc + } + p.loopMu.Unlock() + + for uid := range p.loop { + msg := loopTestMsg(uid) + uc := upstream[uid] + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + continue + } + if _, err := resolver.Resolve(context.Background(), msg); err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + } + } + mainLog.Load().Debug().Msg("end checking DNS loop") +} + +// checkDnsLoopTicker performs p.checkDnsLoop every minute. +func (p *prog) checkDnsLoopTicker() { + timer := time.NewTicker(time.Minute) + defer timer.Stop() + for { + select { + case <-p.stopCh: + return + case <-timer.C: + p.checkDnsLoop() + } + } +} + +// loopTestMsg generates DNS message for checking loop. +func loopTestMsg(uid string) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype) + return msg +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 47e2304..e30a03d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -59,6 +59,9 @@ type prog struct { um *upstreamMonitor router router.Router + loopMu sync.Mutex + loop map[string]bool + started chan struct{} onStartedDone chan struct{} onStarted []func() @@ -91,6 +94,7 @@ func (p *prog) run() { numListeners := len(p.cfg.Listener) p.started = make(chan struct{}, numListeners) p.onStartedDone = make(chan struct{}) + p.loop = make(map[string]bool) if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -174,8 +178,13 @@ func (p *prog) run() { for _, f := range p.onStarted { f() } + // Check for possible DNS loop. + p.checkDnsLoop() close(p.onStartedDone) + // Start check DNS loop ticker. + go p.checkDnsLoopTicker() + // Stop writing log to unix socket. consoleWriter.Out = os.Stdout initLoggingWithBackup(false) diff --git a/config.go b/config.go index 50fb76b..21d636c 100644 --- a/config.go +++ b/config.go @@ -2,8 +2,10 @@ package ctrld import ( "context" + crand "crypto/rand" "crypto/tls" "crypto/x509" + "encoding/hex" "errors" "io" "math/rand" @@ -217,6 +219,7 @@ type UpstreamConfig struct { http3RoundTripper6 http.RoundTripper certPool *x509.CertPool u *url.URL + uid string } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -261,6 +264,7 @@ type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. func (uc *UpstreamConfig) Init() { + uc.uid = upstreamUID() if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Host switch uc.Type { @@ -341,6 +345,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.setupBootstrapIP(true) } +// UID returns the unique identifier of the upstream. +func (uc *UpstreamConfig) UID() string { + return uc.uid +} + // SetupBootstrapIP manually find all available IPs of the upstream. // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { @@ -680,3 +689,15 @@ func ResolverTypeFromEndpoint(endpoint string) string { func pick(s []string) string { return s[rand.Intn(len(s))] } + +// upstreamUID generates an unique identifier for an upstream. +func upstreamUID() string { + b := make([]byte, 4) + for { + if _, err := crand.Read(b); err != nil { + ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...") + continue + } + return hex.EncodeToString(b) + } +} diff --git a/config_internal_test.go b/config_internal_test.go index 6fc1844..89cec19 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -185,6 +185,7 @@ func TestUpstreamConfig_Init(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tc.uc.Init() + tc.uc.uid = "" // we don't care about the uid. assert.Equal(t, tc.expected, tc.uc) }) }