all: guarding against DNS forwarding loop

Based on how dnsmasq "--dns-loop-detect" mechanism.

See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
This commit is contained in:
Cuong Manh Le
2023-09-21 06:10:21 +00:00
committed by Cuong Manh Le
parent 511c4e696f
commit a9959a6f3d
5 changed files with 136 additions and 0 deletions

View File

@@ -50,6 +50,7 @@ func (p *prog) serveDNS(listenerNum string) error {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
p.sema.acquire()
defer p.sema.release()
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
reqId := requestID()
@@ -287,6 +288,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
if upstreamConfig == nil {
continue
}
if p.isLoop(upstreamConfig) {
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
continue
}
if p.um.isDown(upstreams[n]) {
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
continue

100
cmd/cli/loop.go Normal file
View File

@@ -0,0 +1,100 @@
package cli
import (
"context"
"strings"
"time"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
)
const (
loopTestDomain = ".test"
loopTestQtype = dns.TypeTXT
)
// isLoop reports whether the given upstream config is detected as having DNS loop.
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
p.loopMu.Lock()
defer p.loopMu.Unlock()
return p.loop[uc.UID()]
}
// detectLoop checks if the given DNS message is initialized sent by ctrld.
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
// forwarding loop.
//
// See p.checkDnsLoop for more details how it works.
func (p *prog) detectLoop(msg *dns.Msg) {
if len(msg.Question) != 1 {
return
}
q := msg.Question[0]
if q.Qtype != loopTestQtype {
return
}
unFQDNname := strings.TrimSuffix(q.Name, ".")
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
p.loopMu.Lock()
if _, loop := p.loop[uid]; loop {
p.loop[uid] = loop
}
p.loopMu.Unlock()
}
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
//
// - Generating a TXT test query and sending it to all upstream.
// - The test query is formed by upstream UID and test domain: <uid>.test
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
//
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
func (p *prog) checkDnsLoop() {
mainLog.Load().Debug().Msg("start checking DNS loop")
upstream := make(map[string]*ctrld.UpstreamConfig)
p.loopMu.Lock()
for _, uc := range p.cfg.Upstream {
uid := uc.UID()
p.loop[uid] = false
upstream[uid] = uc
}
p.loopMu.Unlock()
for uid := range p.loop {
msg := loopTestMsg(uid)
uc := upstream[uid]
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
continue
}
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
}
}
mainLog.Load().Debug().Msg("end checking DNS loop")
}
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
func (p *prog) checkDnsLoopTicker() {
timer := time.NewTicker(time.Minute)
defer timer.Stop()
for {
select {
case <-p.stopCh:
return
case <-timer.C:
p.checkDnsLoop()
}
}
}
// loopTestMsg generates DNS message for checking loop.
func loopTestMsg(uid string) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
return msg
}

View File

@@ -59,6 +59,9 @@ type prog struct {
um *upstreamMonitor
router router.Router
loopMu sync.Mutex
loop map[string]bool
started chan struct{}
onStartedDone chan struct{}
onStarted []func()
@@ -91,6 +94,7 @@ func (p *prog) run() {
numListeners := len(p.cfg.Listener)
p.started = make(chan struct{}, numListeners)
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
@@ -174,8 +178,13 @@ func (p *prog) run() {
for _, f := range p.onStarted {
f()
}
// Check for possible DNS loop.
p.checkDnsLoop()
close(p.onStartedDone)
// Start check DNS loop ticker.
go p.checkDnsLoopTicker()
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)

View File

@@ -2,8 +2,10 @@ package ctrld
import (
"context"
crand "crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"io"
"math/rand"
@@ -217,6 +219,7 @@ type UpstreamConfig struct {
http3RoundTripper6 http.RoundTripper
certPool *x509.CertPool
u *url.URL
uid string
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
@@ -261,6 +264,7 @@ type Rule map[string][]string
// Init initialized necessary values for an UpstreamConfig.
func (uc *UpstreamConfig) Init() {
uc.uid = upstreamUID()
if u, err := url.Parse(uc.Endpoint); err == nil {
uc.Domain = u.Host
switch uc.Type {
@@ -341,6 +345,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
uc.setupBootstrapIP(true)
}
// UID returns the unique identifier of the upstream.
func (uc *UpstreamConfig) UID() string {
return uc.uid
}
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
@@ -680,3 +689,15 @@ func ResolverTypeFromEndpoint(endpoint string) string {
func pick(s []string) string {
return s[rand.Intn(len(s))]
}
// upstreamUID generates an unique identifier for an upstream.
func upstreamUID() string {
b := make([]byte, 4)
for {
if _, err := crand.Read(b); err != nil {
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
continue
}
return hex.EncodeToString(b)
}
}

View File

@@ -185,6 +185,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
tc.uc.uid = "" // we don't care about the uid.
assert.Equal(t, tc.expected, tc.uc)
})
}