diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..8b47c6c --- /dev/null +++ b/errors.go @@ -0,0 +1,43 @@ +package ctrld + +// TODO(cuonglm): use stdlib once we bump minimum version to 1.20 + +func joinErrors(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/resolver.go b/resolver.go index 5c04f37..a12c700 100644 --- a/resolver.go +++ b/resolver.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "sync/atomic" "github.com/miekg/dns" ) @@ -51,22 +50,42 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { type osResolver struct { nameservers []string - next atomic.Uint32 +} + +type osResolverResult struct { + answer *dns.Msg + err error } // Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from // available nameservers with a roundrobin algorithm. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - numServers := uint32(len(o.nameservers)) + numServers := len(o.nameservers) if numServers == 0 { return nil, errors.New("no nameservers available") } - next := o.next.Add(1) - server := o.nameservers[(next-1)%numServers] - dnsClient := &dns.Client{Net: "udp"} - answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - return answer, err + dnsClient := &dns.Client{Net: "udp"} + ch := make(chan *osResolverResult, numServers) + for _, server := range o.nameservers { + go func(server string) { + answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) + ch <- &osResolverResult{answer: answer, err: err} + }(server) + } + + errs := make([]error, 0, numServers) + for res := range ch { + if res.err == nil { + cancel() + return res.answer, res.err + } + errs = append(errs, res.err) + } + + return nil, joinErrors(errs...) } func newDialer(dnsAddress string) *net.Dialer {