diff --git a/dot.go b/dot.go index 1fef409..c0fe102 100644 --- a/dot.go +++ b/dot.go @@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro // dns.controld.dev first. By using a dialer with custom resolver, // we ensure that we can always resolve the bootstrap domain // regardless of the machine DNS status. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype diff --git a/resolver.go b/resolver.go index 49ac652..1e5a371 100644 --- a/resolver.go +++ b/resolver.go @@ -30,18 +30,18 @@ const ( ResolverTypePrivate = "private" ) -const bootstrapDNS = "76.76.2.22" +const ( + controldBootstrapDns = "76.76.2.22" + controldPublicDns = "76.76.2.0" +) // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} -// defaultNameservers returns nameservers used by the OS. -// If no nameservers can be found, ctrld bootstrap nameserver will be used. +// defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { ns := nameservers() - if len(ns) == 0 { - ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) - } + ns = append(ns, net.JoinHostPort(controldPublicDns, "53")) return ns } @@ -120,15 +120,21 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }(server) } + var nonSuccessAnswer *dns.Msg errs := make([]error, 0, numServers) for res := range ch { - if res.err == nil { - cancel() - return res.answer, res.err + if res.answer != nil { + if res.answer.Rcode == dns.RcodeSuccess { + cancel() + return res.answer, nil + } + nonSuccessAnswer = res.answer } errs = append(errs, res.err) } - + if nonSuccessAnswer != nil { + return nonSuccessAnswer, nil + } return nil, errors.Join(errs...) } @@ -138,7 +144,7 @@ type legacyResolver struct { func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // See comment in (*dotResolver).resolve method. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -176,7 +182,7 @@ func LookupIP(domain string) []string { func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { resolver := &osResolver{nameservers: nameservers()} if withBootstrapDNS { - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...) } ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers) timeoutMs := 2000 @@ -252,7 +258,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...) for _, ns := range servers { resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) } diff --git a/resolver_test.go b/resolver_test.go index 531570b..23c27ae 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "net" + "sync" "testing" "time" @@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { + ns := make([]string, 0, 2) + servers := make([]*dns.Server, 0, 2) + successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + }) + nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc { + return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, rcode) + w.WriteMsg(m) + }) + } + + handlers := []dns.Handler{ + nonSuccessHandlerWithRcode(dns.RcodeRefused), + nonSuccessHandlerWithRcode(dns.RcodeNameError), + successHandler, + } + for i := range handlers { + pc, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i]) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ns = append(ns, addr) + servers = append(servers, s) + } + defer func() { + for _, server := range servers { + server.Shutdown() + } + }() + resolver := &osResolver{nameservers: ns} + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + answer, err := resolver.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + if answer.Rcode != dns.RcodeSuccess { + t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode]) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) { }) } } + +func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) { + t.Helper() + + server := &dns.Server{ + PacketConn: pc, + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, + Handler: handler, + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + for _, opt := range opts { + opt(server) + } + + addr, closer := pc.LocalAddr().String(), pc + go func() { + if err := server.ActivateAndServe(); err != nil { + t.Error(err) + } + closer.Close() + }() + + waitLock.Lock() + return server, addr, nil +}