Implement new flow for LAN and private PTR resolution

- Use client info table.
 - If no sufficient data, use gateway/os/defined local upstreams.
 - If no data is returned, use remote upstream
This commit is contained in:
Cuong Manh Le
2023-11-23 23:56:49 +07:00
committed by Cuong Manh Le
parent a2cb895cdc
commit f9a3f4c045
11 changed files with 396 additions and 107 deletions

View File

@@ -17,6 +17,7 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/net/interfaces"
"tailscale.com/net/netaddr"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
@@ -25,6 +26,7 @@ import (
const (
staleTTL = 60 * time.Second
localTTL = 3600 * time.Second
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
@@ -81,7 +83,7 @@ func (p *prog) serveDNS(listenerNum string) error {
if listenerConfig.Policy != nil {
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
}
answer = p.proxy(ctx, upstreams, failoverRcode, m, ci)
answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched)
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
@@ -251,7 +253,7 @@ macRules:
return upstreams, matched
}
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg {
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg {
var staleAnswer *dns.Msg
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
@@ -259,11 +261,84 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{upstreamOS}
}
if isPrivatePtrLookup(msg) {
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup -> [%s]", upstreamOS)
upstreamConfigs = []*ctrld.UpstreamConfig{privateUpstreamConfig}
upstreams = []string{upstreamOS}
// LAN/PTR lookup flow:
//
// 1. If there's matching rule, follow it.
// 2. Try from client info table.
// 3. Try private resolver.
// 4. Try remote upstream.
isLanOrPtrQuery := false
if !matched {
switch {
case isPrivatePtrLookup(msg):
isLanOrPtrQuery = true
ip := ipFromARPA(msg.Question[0].Name)
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
answer.Answer = []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(name),
}}
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: name,
})
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
case isLanHostnameQuery(msg):
isLanOrPtrQuery = true
q := msg.Question[0]
hostname := strings.TrimSuffix(q.Name, ".")
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
switch {
case ip.Is4():
answer.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
A: ip.AsSlice(),
}}
case ip.Is6():
answer.Answer = []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
AAAA: ip.AsSlice(),
}}
}
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: hostname,
})
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
}
}
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
for _, upstream := range upstreams {
@@ -285,12 +360,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
if upstreamConfig.Type == ctrld.ResolverTypePrivate {
if r := p.ptrResolver; r != nil {
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints)
dnsResolver = r
}
}
if err != nil {
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
return nil, err
@@ -344,6 +413,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
}
continue
}
// We are doing LAN/PTR lookup using private resolver, so always process next one.
// Except for the last, we want to send response instead of saying all upstream failed.
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
continue
}
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
continue
@@ -352,7 +426,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
// set compression, as it is not set by default when unpacking
answer.Compress = true
if p.cache != nil {
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
ttl := ttlFromMsg(answer)
now := time.Now()
expired := now.Add(time.Duration(ttl) * time.Second)
@@ -371,6 +445,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
return answer
}
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
if len(p.localUpstreams) > 0 {
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
tmp = append(tmp, p.localUpstreams...)
tmp = append(tmp, upstreams...)
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
}
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
}
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
for _, upstream := range upstreams {
@@ -705,14 +789,34 @@ func ipFromARPA(arpa string) net.IP {
return nil
}
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN network.
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
func isPrivatePtrLookup(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
q := m.Question[0]
if ip := ipFromARPA(q.Name); ip != nil {
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
if addr, ok := netip.AddrFromSlice(ip); ok {
return addr.IsPrivate() ||
addr.IsLoopback() ||
addr.IsLinkLocalUnicast() ||
tsaddr.CGNATRange().Contains(addr)
}
}
return false
}
func isLanHostnameQuery(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
q := m.Question[0]
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
default:
return false
}
return !strings.Contains(q.Name, ".") ||
strings.HasSuffix(q.Name, ".domain") ||
strings.HasSuffix(q.Name, ".lan")
}

View File

@@ -153,8 +153,8 @@ func TestCache(t *testing.T) {
answer2.SetRcode(msg, dns.RcodeRefused)
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false)
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false)
assert.NotSame(t, got1, got2)
assert.Equal(t, answer1.Rcode, got1.Rcode)
assert.Equal(t, answer2.Rcode, got2.Rcode)

View File

@@ -19,6 +19,7 @@ import (
"github.com/kardianos/service"
"github.com/spf13/viper"
"tailscale.com/net/interfaces"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
@@ -32,6 +33,7 @@ const (
ctrldControlUnixSock = "ctrld_control.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
)
var logf = func(format string, args ...any) {
@@ -55,14 +57,15 @@ type prog struct {
logConn net.Conn
cs *controlServer
cfg *ctrld.Config
appCallback *AppCallback
cache dnscache.Cacher
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrResolver ctrld.Resolver
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
loopMu sync.Mutex
loop map[string]bool
@@ -160,7 +163,7 @@ func (p *prog) runWait() {
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
setupUpstream(newCfg)
p.setupUpstream(newCfg)
p.mu.Lock()
*p.cfg = *newCfg
@@ -187,7 +190,9 @@ func (p *prog) preRun() {
}
}
func setupUpstream(cfg *ctrld.Config) {
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
uc.Init()
@@ -199,7 +204,16 @@ func setupUpstream(cfg *ctrld.Config) {
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
if canBeLocalUpstream(uc.Domain) {
localUpstreams = append(localUpstreams, upstreamPrefix+n)
}
if uc.IsDiscoverable() {
ptrNameservers = append(ptrNameservers, uc.Endpoint)
}
}
p.localUpstreams = localUpstreams
p.ptrNameservers = ptrNameservers
}
// run runs the ctrld main components.
@@ -230,9 +244,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.cache = cacher
}
}
if r := p.cfg.Service.PtrResolver(); r != nil {
p.ptrResolver = r
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
@@ -260,8 +271,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
p.setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
@@ -613,3 +624,11 @@ func defaultRouteIP() string {
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
return ip
}
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
func canBeLocalUpstream(addr string) bool {
if ip, err := netip.ParseAddr(addr); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
return false
}