package cli import ( "context" "strings" "sync" "time" "github.com/miekg/dns" "github.com/Control-D-Inc/ctrld" ) const ( loopTestDomain = ".test" loopTestQtype = dns.TypeTXT ) // newLoopGuard returns new loopGuard. func newLoopGuard() *loopGuard { return &loopGuard{inflight: make(map[string]struct{})} } // loopGuard guards against DNS loop, ensuring only one query // for a given domain is processed at a time. type loopGuard struct { mu sync.Mutex inflight map[string]struct{} } // TryLock marks the domain as being processed. func (lg *loopGuard) TryLock(domain string) bool { lg.mu.Lock() defer lg.mu.Unlock() if _, inflight := lg.inflight[domain]; !inflight { lg.inflight[domain] = struct{}{} return true } return false } // Unlock marks the domain as being done. func (lg *loopGuard) Unlock(domain string) { lg.mu.Lock() defer lg.mu.Unlock() delete(lg.inflight, domain) } // isLoop reports whether the given upstream config is detected as having DNS loop. func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { p.loopMu.Lock() defer p.loopMu.Unlock() return p.loop[uc.UID()] } // detectLoop checks if the given DNS message is initialized sent by ctrld. // If yes, marking the corresponding upstream as loop, prevent infinite DNS // forwarding loop. // // See p.checkDnsLoop for more details how it works. func (p *prog) detectLoop(msg *dns.Msg) { if len(msg.Question) != 1 { return } q := msg.Question[0] if q.Qtype != loopTestQtype { return } unFQDNname := strings.TrimSuffix(q.Name, ".") uid := strings.TrimSuffix(unFQDNname, loopTestDomain) p.loopMu.Lock() if _, loop := p.loop[uid]; loop { p.loop[uid] = loop } p.loopMu.Unlock() } // checkDnsLoop sends a message to check if there's any DNS forwarding loop // with all the upstreams. The way it works based on dnsmasq --dns-loop-detect. // // - Generating a TXT test query and sending it to all upstream. // - The test query is formed by upstream UID and test domain: .test // - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop). // // See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html func (p *prog) checkDnsLoop() { p.Debug().Msg("Start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() for n, uc := range p.cfg.Upstream { if p.um.isDown("upstream." + n) { continue } // Do not send test query to external upstream. if !canBeLocalUpstream(uc.Domain) { p.Debug().Msgf("Skipping external: upstream.%s", n) continue } uid := uc.UID() p.loop[uid] = false upstream[uid] = uc } p.loopMu.Unlock() loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] // Skipping upstream which is being marked as down. if uc == nil { continue } resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue } if _, err := resolver.Resolve(context.Background(), msg); err != nil { p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) } } p.Debug().Msg("End checking DNS loop") } // checkDnsLoopTicker performs p.checkDnsLoop every minute. func (p *prog) checkDnsLoopTicker(ctx context.Context) { timer := time.NewTicker(time.Minute) defer timer.Stop() for { select { case <-p.stopCh: return case <-ctx.Done(): return case <-timer.C: p.checkDnsLoop() } } } // loopTestMsg creates a DNS test message for loop detection func loopTestMsg(uid string) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype) return msg }