diff --git a/resolver.go b/resolver.go index 1e5a371..d8b7f8d 100644 --- a/resolver.go +++ b/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "slices" "sync" "time" @@ -35,13 +36,14 @@ const ( controldPublicDns = "76.76.2.0" ) +var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") + // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} // defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { ns := nameservers() - ns = append(ns, net.JoinHostPort(controldPublicDns, "53")) return ns } @@ -51,10 +53,27 @@ func defaultNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - or.nameservers = defaultNameservers() + or.nameservers = or.nameservers[:0] + for _, ns := range defaultNameservers() { + if testNameserver(ns) { + or.nameservers = append(or.nameservers, ns) + } + } + or.nameservers = append(or.nameservers, controldPublicDnsWithPort) return or.nameservers } +// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. +func testNameserver(addr string) bool { + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + client := new(dns.Client) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, _, err := client.ExchangeContext(ctx, msg, addr) + return err == nil +} + // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error. @@ -89,8 +108,9 @@ type osResolver struct { } type osResolverResult struct { - answer *dns.Msg - err error + answer *dns.Msg + err error + isControlDPublicDNS bool } // Resolve resolves DNS queries using pre-configured nameservers. @@ -116,24 +136,33 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error go func(server string) { defer wg.Done() answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) - ch <- &osResolverResult{answer: answer, err: err} + ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort} }(server) } - var nonSuccessAnswer *dns.Msg + var ( + nonSuccessAnswer *dns.Msg + controldSuccessAnswer *dns.Msg + ) errs := make([]error, 0, numServers) for res := range ch { - if res.answer != nil { - if res.answer.Rcode == dns.RcodeSuccess { + switch { + case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: + if res.isControlDPublicDNS { + controldSuccessAnswer = res.answer // only use ControlD answer as last one. + } else { cancel() return res.answer, nil } + case res.answer != nil: nonSuccessAnswer = res.answer } errs = append(errs, res.err) } - if nonSuccessAnswer != nil { - return nonSuccessAnswer, nil + for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} { + if answer != nil { + return answer, nil + } } return nil, errors.Join(errs...) } @@ -258,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...) for _, ns := range servers { resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) } @@ -285,11 +314,11 @@ func NewPrivateResolver() Resolver { // - Direct listener that has ctrld as an upstream (e.g: dnsmasq). // // causing the query always succeed. - if sliceContains(resolveConfNss, host) { + if slices.Contains(resolveConfNss, host) { continue } // Ignoring local RFC 1918 addresses. - if sliceContains(localRfc1918Addrs, host) { + if slices.Contains(localRfc1918Addrs, host) { continue } ip := net.ParseIP(host) @@ -341,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer { }, } } - -// TODO(cuonglm): use slices.Contains once upgrading to go1.21 -// sliceContains reports whether v is present in s. -func sliceContains[S ~[]E, E comparable](s S, v E) bool { - return sliceIndex(s, v) >= 0 -} - -// sliceIndex returns the index of the first occurrence of v in s, -// or -1 if not present. -func sliceIndex[S ~[]E, E comparable](s S, v E) int { - for i := range s { - if v == s[i] { - return i - } - } - return -1 -}