mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Implement new flow for LAN and private PTR resolution
- Use client info table. - If no sufficient data, use gateway/os/defined local upstreams. - If no data is returned, use remote upstream
This commit is contained in:
committed by
Cuong Manh Le
parent
a2cb895cdc
commit
f9a3f4c045
@@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
localTTL = 3600 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
@@ -81,7 +83,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
answer = p.proxy(ctx, upstreams, failoverRcode, m, ci)
|
||||
answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched)
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
}
|
||||
@@ -251,7 +253,7 @@ macRules:
|
||||
return upstreams, matched
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg {
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg {
|
||||
var staleAnswer *dns.Msg
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
@@ -259,11 +261,84 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
if isPrivatePtrLookup(msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup -> [%s]", upstreamOS)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{privateUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
|
||||
// LAN/PTR lookup flow:
|
||||
//
|
||||
// 1. If there's matching rule, follow it.
|
||||
// 2. Try from client info table.
|
||||
// 3. Try private resolver.
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if !matched {
|
||||
switch {
|
||||
case isPrivatePtrLookup(msg):
|
||||
isLanOrPtrQuery = true
|
||||
ip := ipFromARPA(msg.Question[0].Name)
|
||||
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
answer.Answer = []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Ptr: dns.Fqdn(name),
|
||||
}}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: name,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(msg):
|
||||
isLanOrPtrQuery = true
|
||||
q := msg.Question[0]
|
||||
hostname := strings.TrimSuffix(q.Name, ".")
|
||||
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
switch {
|
||||
case ip.Is4():
|
||||
answer.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}}
|
||||
case ip.Is6():
|
||||
answer.Answer = []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: hostname,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
@@ -285,12 +360,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
if upstreamConfig.Type == ctrld.ResolverTypePrivate {
|
||||
if r := p.ptrResolver; r != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints)
|
||||
dnsResolver = r
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
@@ -344,6 +413,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
continue
|
||||
}
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
||||
// Except for the last, we want to send response instead of saying all upstream failed.
|
||||
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
@@ -352,7 +426,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
@@ -371,6 +445,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
return answer
|
||||
}
|
||||
|
||||
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
tmp = append(tmp, p.localUpstreams...)
|
||||
tmp = append(tmp, upstreams...)
|
||||
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
||||
}
|
||||
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
||||
}
|
||||
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
@@ -705,14 +789,34 @@ func ipFromARPA(arpa string) net.IP {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN network.
|
||||
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
||||
func isPrivatePtrLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
if ip := ipFromARPA(q.Name); ip != nil {
|
||||
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
if addr, ok := netip.AddrFromSlice(ip); ok {
|
||||
return addr.IsPrivate() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsLinkLocalUnicast() ||
|
||||
tsaddr.CGNATRange().Contains(addr)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return !strings.Contains(q.Name, ".") ||
|
||||
strings.HasSuffix(q.Name, ".domain") ||
|
||||
strings.HasSuffix(q.Name, ".lan")
|
||||
}
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestCache(t *testing.T) {
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
@@ -32,6 +33,7 @@ const (
|
||||
ctrldControlUnixSock = "ctrld_control.sock"
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -55,14 +57,15 @@ type prog struct {
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
|
||||
cfg *ctrld.Config
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
ptrResolver ctrld.Resolver
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
ptrNameservers []string
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
@@ -160,7 +163,7 @@ func (p *prog) runWait() {
|
||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||
// upstream config because its initialization function have not been called yet.
|
||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||
setupUpstream(newCfg)
|
||||
p.setupUpstream(newCfg)
|
||||
|
||||
p.mu.Lock()
|
||||
*p.cfg = *newCfg
|
||||
@@ -187,7 +190,9 @@ func (p *prog) preRun() {
|
||||
}
|
||||
}
|
||||
|
||||
func setupUpstream(cfg *ctrld.Config) {
|
||||
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
uc.Init()
|
||||
@@ -199,7 +204,16 @@ func setupUpstream(cfg *ctrld.Config) {
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
|
||||
if canBeLocalUpstream(uc.Domain) {
|
||||
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
||||
}
|
||||
if uc.IsDiscoverable() {
|
||||
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
p.localUpstreams = localUpstreams
|
||||
p.ptrNameservers = ptrNameservers
|
||||
}
|
||||
|
||||
// run runs the ctrld main components.
|
||||
@@ -230,9 +244,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
p.cache = cacher
|
||||
}
|
||||
}
|
||||
if r := p.cfg.Service.PtrResolver(); r != nil {
|
||||
p.ptrResolver = r
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(p.cfg.Listener))
|
||||
@@ -260,8 +271,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
||||
}
|
||||
}
|
||||
setupUpstream(p.cfg)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
|
||||
p.setupUpstream(p.cfg)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
@@ -613,3 +624,11 @@ func defaultRouteIP() string {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
return ip
|
||||
}
|
||||
|
||||
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
|
||||
func canBeLocalUpstream(addr string) bool {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
76
config.go
76
config.go
@@ -11,6 +11,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
@@ -177,46 +179,22 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
|
||||
|
||||
// ServiceConfig specifies the general ctrld config.
|
||||
type ServiceConfig struct {
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverPtrEndpoints []string `mapstructure:"discover_ptr_endpoints" toml:"discover_ptr_endpoints,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// PtrResolver returns a Resolver used for PTR lookup, based on ServiceConfig.DiscoverPtrEndpoints value.
|
||||
func (s ServiceConfig) PtrResolver() Resolver {
|
||||
if len(s.DiscoverPtrEndpoints) > 0 {
|
||||
nss := make([]string, 0, len(s.DiscoverPtrEndpoints))
|
||||
for _, ns := range s.DiscoverPtrEndpoints {
|
||||
host, port := ns, "53"
|
||||
if h, p, err := net.SplitHostPort(ns); err == nil {
|
||||
host, port = h, p
|
||||
}
|
||||
// Only use valid ip:port pair.
|
||||
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
|
||||
nss = append(nss, net.JoinHostPort(host, port))
|
||||
} else {
|
||||
ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for PTR resolver: %q", ns)
|
||||
}
|
||||
}
|
||||
if len(nss) > 0 {
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
||||
@@ -238,6 +216,9 @@ type UpstreamConfig struct {
|
||||
// The caller should not access this field directly.
|
||||
// Use UpstreamSendClientInfo instead.
|
||||
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
|
||||
// The caller should not access this field directly.
|
||||
// Use IsDiscoverable instead.
|
||||
Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"`
|
||||
|
||||
g singleflight.Group
|
||||
rebootstrap atomic.Bool
|
||||
@@ -364,6 +345,21 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDiscoverable reports whether the upstream can be used for PTR discovery.
|
||||
// The caller must ensure uc.Init() was called before calling this.
|
||||
func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
if uc.Discoverable != nil {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BootstrapIPs returns the bootstrap IPs list of upstreams.
|
||||
func (uc *UpstreamConfig) BootstrapIPs() []string {
|
||||
return uc.bootstrapIPs
|
||||
|
||||
@@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
discoverable bool
|
||||
}{
|
||||
{
|
||||
"loopback",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"rfc1918",
|
||||
&UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"CGNAT",
|
||||
&UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Public IP",
|
||||
&UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override discoverable",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override non-public",
|
||||
&UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-legacy upstream",
|
||||
&UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
@@ -193,22 +193,6 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### discover_ptr_endpoints
|
||||
List of DNS nameservers used for PTR discovery.
|
||||
|
||||
Each entry can be either "ip" (default port 53) or "ip:port" pair. Invalid entry will be ignored.
|
||||
|
||||
- Type: array of string
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
Example:
|
||||
|
||||
```toml
|
||||
[service]
|
||||
discover_ptr_endpoints = ["192.168.1.1", "192.168.2.1:5354"]
|
||||
```
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
@@ -335,6 +319,24 @@ If `ip_stack` is empty, or undefined:
|
||||
- Default value is `both` for non-Control D resolvers.
|
||||
- Default value is `split` for Control D resolvers.
|
||||
|
||||
### send_client_info
|
||||
Specifying whether to include client info when sending query to upstream.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for ControlD upstreams.
|
||||
- `false` for other upstreams.
|
||||
|
||||
### discoverable
|
||||
Specifying whether the upstream can be used for PTR discovery.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for loopback/RFC1918/CGNAT IP address.
|
||||
- `false` for public IP address.
|
||||
|
||||
## Network
|
||||
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ package clientinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -69,25 +71,27 @@ type Table struct {
|
||||
refreshers []refresher
|
||||
initOnce sync.Once
|
||||
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table {
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
return &Table{
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,9 +187,25 @@ func (t *Table) init() {
|
||||
// PTR lookup.
|
||||
if t.discoverPTR() {
|
||||
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
|
||||
if r := t.svcCfg.PtrResolver(); r != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR discover", t.svcCfg.DiscoverPtrEndpoints)
|
||||
t.ptr.resolver = r
|
||||
if len(t.ptrNameservers) > 0 {
|
||||
nss := make([]string, 0, len(t.ptrNameservers))
|
||||
for _, ns := range t.ptrNameservers {
|
||||
host, port := ns, "53"
|
||||
if h, p, err := net.SplitHostPort(ns); err == nil {
|
||||
host, port = h, p
|
||||
}
|
||||
// Only use valid ip:port pair.
|
||||
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
|
||||
nss = append(nss, net.JoinHostPort(host, port))
|
||||
} else {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
|
||||
}
|
||||
}
|
||||
if len(nss) > 0 {
|
||||
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
|
||||
}
|
||||
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
|
||||
if err := t.ptr.refresh(); err != nil {
|
||||
@@ -358,6 +378,27 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
|
||||
t.vni.ip2name.Store(ci.IP, ci.Hostname)
|
||||
}
|
||||
|
||||
// ipFinder is the interface for retrieving IP address from hostname.
|
||||
type ipFinder interface {
|
||||
lookupIPByHostname(name string, v6 bool) string
|
||||
}
|
||||
|
||||
// LookupIPByHostname returns the ip address of given hostname.
|
||||
// If v6 is true, return IPv6 instead of default IPv4.
|
||||
func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} {
|
||||
if addr := finder.lookupIPByHostname(hostname, v6); addr != "" {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return &ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) discoverDHCP() bool {
|
||||
if t.svcCfg.DiscoverDHCP == nil {
|
||||
return true
|
||||
|
||||
@@ -134,6 +134,23 @@ func (d *dhcp) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (d *dhcp) lookupIPByHostname(name string, v6 bool) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
d.ip2name.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
// AddLeaseFile adds given lease file for reading/watching clients info.
|
||||
func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error {
|
||||
if d.watcher == nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
@@ -109,6 +110,24 @@ func (hf *hostsFile) String() string {
|
||||
return "hosts"
|
||||
}
|
||||
|
||||
func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string {
|
||||
if hf == nil {
|
||||
return ""
|
||||
}
|
||||
hf.mu.Lock()
|
||||
defer hf.mu.Unlock()
|
||||
for addr, names := range hf.m {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() {
|
||||
for _, n := range names {
|
||||
if n == name && ip.Is6() == v6 {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isLocalhostName reports whether the given hostname represents localhost.
|
||||
func isLocalhostName(hostname string) bool {
|
||||
switch hostname {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -59,6 +60,23 @@ func (m *mdns) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (m *mdns) lookupIPByHostname(name string, v6 bool) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
m.name.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
func (m *mdns) init(quitCh chan struct{}) error {
|
||||
ifaces, err := multicastInterfaces()
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -94,6 +95,23 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
p.hostname.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
|
||||
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
|
||||
func (p *ptrDiscover) checkServer() {
|
||||
|
||||
Reference in New Issue
Block a user