Query all possible nameservers for os resolver

So we don't have to worry about network stack changes causes an upstream
to be broken. Just send requests to all nameservers concurrently, and
get the first success response.
This commit is contained in:
Cuong Manh Le
2023-03-09 23:48:53 +07:00
committed by Cuong Manh Le
parent 4f6c2032a1
commit e6800fbc82
2 changed files with 70 additions and 8 deletions

43
errors.go Normal file
View File

@@ -0,0 +1,43 @@
package ctrld
// TODO(cuonglm): use stdlib once we bump minimum version to 1.20
func joinErrors(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}
type joinError struct {
errs []error
}
func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}
func (e *joinError) Unwrap() []error {
return e.errs
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"sync/atomic"
"github.com/miekg/dns"
)
@@ -51,22 +50,42 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
type osResolver struct {
nameservers []string
next atomic.Uint32
}
type osResolverResult struct {
answer *dns.Msg
err error
}
// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from
// available nameservers with a roundrobin algorithm.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
numServers := uint32(len(o.nameservers))
numServers := len(o.nameservers)
if numServers == 0 {
return nil, errors.New("no nameservers available")
}
next := o.next.Add(1)
server := o.nameservers[(next-1)%numServers]
dnsClient := &dns.Client{Net: "udp"}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, server)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
return answer, err
dnsClient := &dns.Client{Net: "udp"}
ch := make(chan *osResolverResult, numServers)
for _, server := range o.nameservers {
go func(server string) {
answer, _, err := dnsClient.ExchangeContext(ctx, msg, server)
ch <- &osResolverResult{answer: answer, err: err}
}(server)
}
errs := make([]error, 0, numServers)
for res := range ch {
if res.err == nil {
cancel()
return res.answer, res.err
}
errs = append(errs, res.err)
}
return nil, joinErrors(errs...)
}
func newDialer(dnsAddress string) *net.Dialer {