From e6800fbc82deacc39d9aa69fb9194c40376b9499 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 9 Mar 2023 23:48:53 +0700 Subject: [PATCH] Query all possible nameservers for os resolver So we don't have to worry about network stack changes causes an upstream to be broken. Just send requests to all nameservers concurrently, and get the first success response. --- errors.go | 43 +++++++++++++++++++++++++++++++++++++++++++ resolver.go | 35 +++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 8 deletions(-) create mode 100644 errors.go diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..8b47c6c --- /dev/null +++ b/errors.go @@ -0,0 +1,43 @@ +package ctrld + +// TODO(cuonglm): use stdlib once we bump minimum version to 1.20 + +func joinErrors(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/resolver.go b/resolver.go index 5c04f37..a12c700 100644 --- a/resolver.go +++ b/resolver.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "sync/atomic" "github.com/miekg/dns" ) @@ -51,22 +50,42 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { type osResolver struct { nameservers []string - next atomic.Uint32 +} + +type osResolverResult struct { + answer *dns.Msg + err error } // Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from // available nameservers with a roundrobin algorithm. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - numServers := uint32(len(o.nameservers)) + numServers := len(o.nameservers) if numServers == 0 { return nil, errors.New("no nameservers available") } - next := o.next.Add(1) - server := o.nameservers[(next-1)%numServers] - dnsClient := &dns.Client{Net: "udp"} - answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - return answer, err + dnsClient := &dns.Client{Net: "udp"} + ch := make(chan *osResolverResult, numServers) + for _, server := range o.nameservers { + go func(server string) { + answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) + ch <- &osResolverResult{answer: answer, err: err} + }(server) + } + + errs := make([]error, 0, numServers) + for res := range ch { + if res.err == nil { + cancel() + return res.answer, res.err + } + errs = append(errs, res.err) + } + + return nil, joinErrors(errs...) } func newDialer(dnsAddress string) *net.Dialer {