From 79476add1267c172a1d9a5030f9f936df8537759 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 8 Aug 2024 01:03:30 +0700 Subject: [PATCH] Testing nameserver when initializing OS resolver There are several issues with OS resolver right now: - The list of nameservers are obtained un-conditionally from all running interfaces. - ControlD public DNS query is always be used if response ok. This could lead to slow query time, and also incorrect result if a domain is resolved differently between internal DNS and ControlD public DNS. To fix these problems: - While initializing OS resolver, sending a test query to the nameserver to ensure it will response. Unreachable nameserver will not be used. - Only use ControlD public DNS success response as last one, preferring ok response from internal DNS servers. While at it, also using standard package slices, since ctrld now requires go1.21 as the minimum version. --- resolver.go | 72 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/resolver.go b/resolver.go index 1e5a371..d8b7f8d 100644 --- a/resolver.go +++ b/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "slices" "sync" "time" @@ -35,13 +36,14 @@ const ( controldPublicDns = "76.76.2.0" ) +var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") + // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} // defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { ns := nameservers() - ns = append(ns, net.JoinHostPort(controldPublicDns, "53")) return ns } @@ -51,10 +53,27 @@ func defaultNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - or.nameservers = defaultNameservers() + or.nameservers = or.nameservers[:0] + for _, ns := range defaultNameservers() { + if testNameserver(ns) { + or.nameservers = append(or.nameservers, ns) + } + } + or.nameservers = append(or.nameservers, controldPublicDnsWithPort) return or.nameservers } +// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. +func testNameserver(addr string) bool { + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + client := new(dns.Client) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, _, err := client.ExchangeContext(ctx, msg, addr) + return err == nil +} + // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error. @@ -89,8 +108,9 @@ type osResolver struct { } type osResolverResult struct { - answer *dns.Msg - err error + answer *dns.Msg + err error + isControlDPublicDNS bool } // Resolve resolves DNS queries using pre-configured nameservers. @@ -116,24 +136,33 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error go func(server string) { defer wg.Done() answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) - ch <- &osResolverResult{answer: answer, err: err} + ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort} }(server) } - var nonSuccessAnswer *dns.Msg + var ( + nonSuccessAnswer *dns.Msg + controldSuccessAnswer *dns.Msg + ) errs := make([]error, 0, numServers) for res := range ch { - if res.answer != nil { - if res.answer.Rcode == dns.RcodeSuccess { + switch { + case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: + if res.isControlDPublicDNS { + controldSuccessAnswer = res.answer // only use ControlD answer as last one. + } else { cancel() return res.answer, nil } + case res.answer != nil: nonSuccessAnswer = res.answer } errs = append(errs, res.err) } - if nonSuccessAnswer != nil { - return nonSuccessAnswer, nil + for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} { + if answer != nil { + return answer, nil + } } return nil, errors.Join(errs...) } @@ -258,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...) for _, ns := range servers { resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) } @@ -285,11 +314,11 @@ func NewPrivateResolver() Resolver { // - Direct listener that has ctrld as an upstream (e.g: dnsmasq). // // causing the query always succeed. - if sliceContains(resolveConfNss, host) { + if slices.Contains(resolveConfNss, host) { continue } // Ignoring local RFC 1918 addresses. - if sliceContains(localRfc1918Addrs, host) { + if slices.Contains(localRfc1918Addrs, host) { continue } ip := net.ParseIP(host) @@ -341,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer { }, } } - -// TODO(cuonglm): use slices.Contains once upgrading to go1.21 -// sliceContains reports whether v is present in s. -func sliceContains[S ~[]E, E comparable](s S, v E) bool { - return sliceIndex(s, v) >= 0 -} - -// sliceIndex returns the index of the first occurrence of v in s, -// or -1 if not present. -func sliceIndex[S ~[]E, E comparable](s S, v E) int { - for i := range s { - if v == s[i] { - return i - } - } - return -1 -}