diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 855e5d3..06d0702 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -285,6 +285,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) + if upstreamConfig.Type == ctrld.ResolverTypePrivate { + if r := p.ptrResolver; r != nil { + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints) + dnsResolver = r + } + } if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 867c08a..fb88c81 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -62,6 +62,7 @@ type prog struct { ciTable *clientinfo.Table um *upstreamMonitor router router.Router + ptrResolver ctrld.Resolver loopMu sync.Mutex loop map[string]bool @@ -229,6 +230,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } + if r := p.cfg.Service.PtrResolver(); r != nil { + p.ptrResolver = r + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) diff --git a/config.go b/config.go index 051fb80..8b37078 100644 --- a/config.go +++ b/config.go @@ -196,6 +196,29 @@ type ServiceConfig struct { AllocateIP bool `mapstructure:"-" toml:"-"` } +// PtrResolver returns a Resolver used for PTR lookup, based on ServiceConfig.DiscoverPtrEndpoints value. +func (s ServiceConfig) PtrResolver() Resolver { + if len(s.DiscoverPtrEndpoints) > 0 { + nss := make([]string, 0, len(s.DiscoverPtrEndpoints)) + for _, ns := range s.DiscoverPtrEndpoints { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for PTR resolver: %q", ns) + } + } + if len(nss) > 0 { + return NewResolverWithNameserver(nss) + } + } + return nil +} + // NetworkConfig specifies configuration for networks where ctrld will handle requests. type NetworkConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index ee1a14f..0e60643 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -3,9 +3,7 @@ package clientinfo import ( "context" "fmt" - "net" "net/netip" - "strconv" "strings" "sync" "time" @@ -185,24 +183,9 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} - if len(t.svcCfg.DiscoverPtrEndpoints) > 0 { - nss := make([]string, 0, len(t.svcCfg.DiscoverPtrEndpoints)) - for _, ns := range t.svcCfg.DiscoverPtrEndpoints { - host, port := ns, "53" - if h, p, err := net.SplitHostPort(ns); err == nil { - host, port = h, p - } - // Only use valid ip:port pair. - if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { - nss = append(nss, net.JoinHostPort(host, port)) - } else { - ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) - } - } - if len(nss) > 0 { - t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) - } + if r := t.svcCfg.PtrResolver(); r != nil { + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR discover", t.svcCfg.DiscoverPtrEndpoints) + t.ptr.resolver = r } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil {