From f9a3f4c045d3cfaaa5cd51bf3053eab184a8f393 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 23 Nov 2023 23:56:49 +0700 Subject: [PATCH] Implement new flow for LAN and private PTR resolution - Use client info table. - If no sufficient data, use gateway/os/defined local upstreams. - If no data is returned, use remote upstream --- cmd/cli/dns_proxy.go | 134 +++++++++++++++++++++++++---- cmd/cli/dns_proxy_test.go | 4 +- cmd/cli/prog.go | 49 +++++++---- config.go | 76 ++++++++-------- config_internal_test.go | 55 ++++++++++++ docs/config.md | 34 ++++---- internal/clientinfo/client_info.go | 79 +++++++++++++---- internal/clientinfo/dhcp.go | 17 ++++ internal/clientinfo/hostsfile.go | 19 ++++ internal/clientinfo/mdns.go | 18 ++++ internal/clientinfo/ptr_lookup.go | 18 ++++ 11 files changed, 396 insertions(+), 107 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 06d0702..1be818f 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/interfaces" "tailscale.com/net/netaddr" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" @@ -25,6 +26,7 @@ import ( const ( staleTTL = 60 * time.Second + localTTL = 3600 * time.Second // EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option. // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. @@ -81,7 +83,7 @@ func (p *prog) serveDNS(listenerNum string) error { if listenerConfig.Policy != nil { failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers } - answer = p.proxy(ctx, upstreams, failoverRcode, m, ci) + answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -251,7 +253,7 @@ macRules: return upstreams, matched } -func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg { +func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg { var staleAnswer *dns.Msg serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) @@ -259,11 +261,84 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} } - if isPrivatePtrLookup(msg) { - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup -> [%s]", upstreamOS) - upstreamConfigs = []*ctrld.UpstreamConfig{privateUpstreamConfig} - upstreams = []string{upstreamOS} + + // LAN/PTR lookup flow: + // + // 1. If there's matching rule, follow it. + // 2. Try from client info table. + // 3. Try private resolver. + // 4. Try remote upstream. + isLanOrPtrQuery := false + if !matched { + switch { + case isPrivatePtrLookup(msg): + isLanOrPtrQuery = true + ip := ipFromARPA(msg.Question[0].Name) + if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + answer.Answer = []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + }, + Ptr: dns.Fqdn(name), + }} + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: name, + }) + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams) + case isLanHostnameQuery(msg): + isLanOrPtrQuery = true + q := msg.Question[0] + hostname := strings.TrimSuffix(q.Name, ".") + if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + switch { + case ip.Is4(): + answer.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + A: ip.AsSlice(), + }} + case ip.Is6(): + answer.Answer = []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + AAAA: ip.AsSlice(), + }} + } + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: hostname, + }) + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams) + } } + // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { for _, upstream := range upstreams { @@ -285,12 +360,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) - if upstreamConfig.Type == ctrld.ResolverTypePrivate { - if r := p.ptrResolver; r != nil { - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints) - dnsResolver = r - } - } if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err @@ -344,6 +413,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } continue } + // We are doing LAN/PTR lookup using private resolver, so always process next one. + // Except for the last, we want to send response instead of saying all upstream failed. + if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { + continue + } if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") continue @@ -352,7 +426,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // set compression, as it is not set by default when unpacking answer.Compress = true - if p.cache != nil { + if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { ttl := ttlFromMsg(answer) now := time.Now() expired := now.Add(time.Duration(ttl) * time.Second) @@ -371,6 +445,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i return answer } +func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { + if len(p.localUpstreams) > 0 { + tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) + tmp = append(tmp, p.localUpstreams...) + tmp = append(tmp, upstreams...) + return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp) + } + return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...) +} + func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { @@ -705,14 +789,34 @@ func ipFromARPA(arpa string) net.IP { return nil } -// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN network. +// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network. func isPrivatePtrLookup(m *dns.Msg) bool { if m == nil || len(m.Question) == 0 { return false } q := m.Question[0] if ip := ipFromARPA(q.Name); ip != nil { - return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + if addr, ok := netip.AddrFromSlice(ip); ok { + return addr.IsPrivate() || + addr.IsLoopback() || + addr.IsLinkLocalUnicast() || + tsaddr.CGNATRange().Contains(addr) + } } return false } + +func isLanHostnameQuery(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + switch q.Qtype { + case dns.TypeA, dns.TypeAAAA: + default: + return false + } + return !strings.Contains(q.Name, ".") || + strings.HasSuffix(q.Name, ".domain") || + strings.HasSuffix(q.Name, ".lan") +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index d0e5c74..70197ad 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -153,8 +153,8 @@ func TestCache(t *testing.T) { answer2.SetRcode(msg, dns.RcodeRefused) prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute))) - got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil) - got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil) + got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false) + got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false) assert.NotSame(t, got1, got2) assert.Equal(t, answer1.Rcode, got1.Rcode) assert.Equal(t, answer2.Rcode, got2.Rcode) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index fb88c81..f828426 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -19,6 +19,7 @@ import ( "github.com/kardianos/service" "github.com/spf13/viper" "tailscale.com/net/interfaces" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" @@ -32,6 +33,7 @@ const ( ctrldControlUnixSock = "ctrld_control.sock" upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" + upstreamPrivate = upstreamPrefix + "private" ) var logf = func(format string, args ...any) { @@ -55,14 +57,15 @@ type prog struct { logConn net.Conn cs *controlServer - cfg *ctrld.Config - appCallback *AppCallback - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrResolver ctrld.Resolver + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router loopMu sync.Mutex loop map[string]bool @@ -160,7 +163,7 @@ func (p *prog) runWait() { // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. mainLog.Load().Debug().Msg("setup upstream with new config") - setupUpstream(newCfg) + p.setupUpstream(newCfg) p.mu.Lock() *p.cfg = *newCfg @@ -187,7 +190,9 @@ func (p *prog) preRun() { } } -func setupUpstream(cfg *ctrld.Config) { +func (p *prog) setupUpstream(cfg *ctrld.Config) { + localUpstreams := make([]string, 0, len(cfg.Upstream)) + ptrNameservers := make([]string, 0, len(cfg.Upstream)) for n := range cfg.Upstream { uc := cfg.Upstream[n] uc.Init() @@ -199,7 +204,16 @@ func setupUpstream(cfg *ctrld.Config) { } uc.SetCertPool(rootCertPool) go uc.Ping() + + if canBeLocalUpstream(uc.Domain) { + localUpstreams = append(localUpstreams, upstreamPrefix+n) + } + if uc.IsDiscoverable() { + ptrNameservers = append(ptrNameservers, uc.Endpoint) + } } + p.localUpstreams = localUpstreams + p.ptrNameservers = ptrNameservers } // run runs the ctrld main components. @@ -230,9 +244,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } - if r := p.cfg.Service.PtrResolver(); r != nil { - p.ptrResolver = r - } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -260,8 +271,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.sema = &chanSemaphore{ready: make(chan struct{}, n)} } } - setupUpstream(p.cfg) - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) + p.setupUpstream(p.cfg) + p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) @@ -613,3 +624,11 @@ func defaultRouteIP() string { mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") return ip } + +// canBeLocalUpstream reports whether the IP address can be used as a local upstream. +func canBeLocalUpstream(addr string) bool { + if ip, err := netip.ParseAddr(addr); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + return false +} diff --git a/config.go b/config.go index 8b37078..1bd9043 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/http" + "net/netip" "net/url" "os" "runtime" @@ -26,6 +27,7 @@ import ( "github.com/spf13/viper" "golang.org/x/sync/singleflight" "tailscale.com/logtail/backoff" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld/internal/dnsrcode" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -177,46 +179,22 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverPtrEndpoints []string `mapstructure:"discover_ptr_endpoints" toml:"discover_ptr_endpoints,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` -} - -// PtrResolver returns a Resolver used for PTR lookup, based on ServiceConfig.DiscoverPtrEndpoints value. -func (s ServiceConfig) PtrResolver() Resolver { - if len(s.DiscoverPtrEndpoints) > 0 { - nss := make([]string, 0, len(s.DiscoverPtrEndpoints)) - for _, ns := range s.DiscoverPtrEndpoints { - host, port := ns, "53" - if h, p, err := net.SplitHostPort(ns); err == nil { - host, port = h, p - } - // Only use valid ip:port pair. - if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { - nss = append(nss, net.JoinHostPort(host, port)) - } else { - ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for PTR resolver: %q", ns) - } - } - if len(nss) > 0 { - return NewResolverWithNameserver(nss) - } - } - return nil + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. @@ -238,6 +216,9 @@ type UpstreamConfig struct { // The caller should not access this field directly. // Use UpstreamSendClientInfo instead. SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` + // The caller should not access this field directly. + // Use IsDiscoverable instead. + Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group rebootstrap atomic.Bool @@ -364,6 +345,21 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { return false } +// IsDiscoverable reports whether the upstream can be used for PTR discovery. +// The caller must ensure uc.Init() was called before calling this. +func (uc *UpstreamConfig) IsDiscoverable() bool { + if uc.Discoverable != nil { + return *uc.Discoverable + } + switch uc.Type { + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: + if ip, err := netip.ParseAddr(uc.Domain); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + } + return false +} + // BootstrapIPs returns the bootstrap IPs list of upstreams. func (uc *UpstreamConfig) BootstrapIPs() []string { return uc.bootstrapIPs diff --git a/config_internal_test.go b/config_internal_test.go index 89cec19..96beddc 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) { } } +func TestUpstreamConfig_IsDiscoverable(t *testing.T) { + tests := []struct { + name string + uc *UpstreamConfig + discoverable bool + }{ + { + "loopback", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy}, + true, + }, + { + "rfc1918", + &UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy}, + true, + }, + { + "CGNAT", + &UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy}, + true, + }, + { + "Public IP", + &UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy}, + false, + }, + { + "override discoverable", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)}, + false, + }, + { + "override non-public", + &UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)}, + true, + }, + { + "non-legacy upstream", + &UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH}, + false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + if got := tc.uc.IsDiscoverable(); got != tc.discoverable { + t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) + } + }) + } +} + func ptrBool(b bool) *bool { return &b } diff --git a/docs/config.md b/docs/config.md index 57a794b..bc0ead2 100644 --- a/docs/config.md +++ b/docs/config.md @@ -193,22 +193,6 @@ Perform LAN client discovery using PTR queries. - Required: no - Default: true -### discover_ptr_endpoints -List of DNS nameservers used for PTR discovery. - -Each entry can be either "ip" (default port 53) or "ip:port" pair. Invalid entry will be ignored. - -- Type: array of string -- Required: no -- Default: [] - -Example: - -```toml -[service] -discover_ptr_endpoints = ["192.168.1.1", "192.168.2.1:5354"] -``` - ### discover_hosts Perform LAN client discovery using hosts file. @@ -335,6 +319,24 @@ If `ip_stack` is empty, or undefined: - Default value is `both` for non-Control D resolvers. - Default value is `split` for Control D resolvers. +### send_client_info +Specifying whether to include client info when sending query to upstream. + +- Type: boolean +- Required: no +- Default: + - `true` for ControlD upstreams. + - `false` for other upstreams. + +### discoverable +Specifying whether the upstream can be used for PTR discovery. + +- Type: boolean +- Required: no +- Default: + - `true` for loopback/RFC1918/CGNAT IP address. + - `false` for public IP address. + ## Network The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs. diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 0e60643..07e4cf0 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -3,7 +3,9 @@ package clientinfo import ( "context" "fmt" + "net" "net/netip" + "strconv" "strings" "sync" "time" @@ -69,25 +71,27 @@ type Table struct { refreshers []refresher initOnce sync.Once - dhcp *dhcp - merlin *merlinDiscover - arp *arpDiscover - ptr *ptrDiscover - mdns *mdns - hf *hostsFile - vni *virtualNetworkIface - svcCfg ctrld.ServiceConfig - quitCh chan struct{} - selfIP string - cdUID string + dhcp *dhcp + merlin *merlinDiscover + arp *arpDiscover + ptr *ptrDiscover + mdns *mdns + hf *hostsFile + vni *virtualNetworkIface + svcCfg ctrld.ServiceConfig + quitCh chan struct{} + selfIP string + cdUID string + ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ - svcCfg: cfg.Service, - quitCh: make(chan struct{}), - selfIP: selfIP, - cdUID: cdUID, + svcCfg: cfg.Service, + quitCh: make(chan struct{}), + selfIP: selfIP, + cdUID: cdUID, + ptrNameservers: ns, } } @@ -183,9 +187,25 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} - if r := t.svcCfg.PtrResolver(); r != nil { - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR discover", t.svcCfg.DiscoverPtrEndpoints) - t.ptr.resolver = r + if len(t.ptrNameservers) > 0 { + nss := make([]string, 0, len(t.ptrNameservers)) + for _, ns := range t.ptrNameservers { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + } + } + if len(nss) > 0 { + t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + } + } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { @@ -358,6 +378,27 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { t.vni.ip2name.Store(ci.IP, ci.Hostname) } +// ipFinder is the interface for retrieving IP address from hostname. +type ipFinder interface { + lookupIPByHostname(name string, v6 bool) string +} + +// LookupIPByHostname returns the ip address of given hostname. +// If v6 is true, return IPv6 instead of default IPv4. +func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr { + if t == nil { + return nil + } + for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} { + if addr := finder.lookupIPByHostname(hostname, v6); addr != "" { + if ip, err := netip.ParseAddr(addr); err == nil { + return &ip + } + } + } + return nil +} + func (t *Table) discoverDHCP() bool { if t.svcCfg.DiscoverDHCP == nil { return true diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 7c1b2cf..e036638 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -134,6 +134,23 @@ func (d *dhcp) List() []string { return ips } +func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { + if d == nil { + return "" + } + var ip string + d.ip2name.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + // AddLeaseFile adds given lease file for reading/watching clients info. func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error { if d.watcher == nil { diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index baf05fb..8c86987 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -1,6 +1,7 @@ package clientinfo import ( + "net/netip" "os" "sync" @@ -109,6 +110,24 @@ func (hf *hostsFile) String() string { return "hosts" } +func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string { + if hf == nil { + return "" + } + hf.mu.Lock() + defer hf.mu.Unlock() + for addr, names := range hf.m { + if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() { + for _, n := range names { + if n == name && ip.Is6() == v6 { + return ip.String() + } + } + } + } + return "" +} + // isLocalhostName reports whether the given hostname represents localhost. func isLocalhostName(hostname string) bool { switch hostname { diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 9a5fa85..59e6e9c 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "os" "sync" "syscall" @@ -59,6 +60,23 @@ func (m *mdns) List() []string { return ips } +func (m *mdns) lookupIPByHostname(name string, v6 bool) string { + if m == nil { + return "" + } + var ip string + m.name.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + func (m *mdns) init(quitCh chan struct{}) error { ifaces, err := multicastInterfaces() if err != nil { diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index fea79fb..b6204d5 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -2,6 +2,7 @@ package clientinfo import ( "context" + "net/netip" "sync" "sync/atomic" "time" @@ -94,6 +95,23 @@ func (p *ptrDiscover) lookupHostname(ip string) string { return "" } +func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { + if p == nil { + return "" + } + var ip string + p.hostname.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + // checkServer monitors if the resolver can reach its nameserver. When the nameserver // is reachable, set p.serverDown to false, so p.lookupHostname can continue working. func (p *ptrDiscover) checkServer() {