From 56f9c725691f1635cdee09ba0f3cddf830ff3996 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 12 Jul 2024 17:35:34 +0700 Subject: [PATCH] Add ControlD public DNS to OS resolver Since the OS resolver only returns response with NOERROR first, it's safe to use ControlD public DNS in parallel with system DNS. Local domains would resolve only though local resolvers, because public ones will return NXDOMAIN response. --- dot.go | 2 +- resolver.go | 32 +++++++++++-------- resolver_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 14 deletions(-) diff --git a/dot.go b/dot.go index 1fef409..c0fe102 100644 --- a/dot.go +++ b/dot.go @@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro // dns.controld.dev first. By using a dialer with custom resolver, // we ensure that we can always resolve the bootstrap domain // regardless of the machine DNS status. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype diff --git a/resolver.go b/resolver.go index 49ac652..1e5a371 100644 --- a/resolver.go +++ b/resolver.go @@ -30,18 +30,18 @@ const ( ResolverTypePrivate = "private" ) -const bootstrapDNS = "76.76.2.22" +const ( + controldBootstrapDns = "76.76.2.22" + controldPublicDns = "76.76.2.0" +) // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} -// defaultNameservers returns nameservers used by the OS. -// If no nameservers can be found, ctrld bootstrap nameserver will be used. +// defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { ns := nameservers() - if len(ns) == 0 { - ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) - } + ns = append(ns, net.JoinHostPort(controldPublicDns, "53")) return ns } @@ -120,15 +120,21 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }(server) } + var nonSuccessAnswer *dns.Msg errs := make([]error, 0, numServers) for res := range ch { - if res.err == nil { - cancel() - return res.answer, res.err + if res.answer != nil { + if res.answer.Rcode == dns.RcodeSuccess { + cancel() + return res.answer, nil + } + nonSuccessAnswer = res.answer } errs = append(errs, res.err) } - + if nonSuccessAnswer != nil { + return nonSuccessAnswer, nil + } return nil, errors.Join(errs...) } @@ -138,7 +144,7 @@ type legacyResolver struct { func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // See comment in (*dotResolver).resolve method. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -176,7 +182,7 @@ func LookupIP(domain string) []string { func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { resolver := &osResolver{nameservers: nameservers()} if withBootstrapDNS { - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...) } ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers) timeoutMs := 2000 @@ -252,7 +258,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...) for _, ns := range servers { resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) } diff --git a/resolver_test.go b/resolver_test.go index 531570b..23c27ae 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "net" + "sync" "testing" "time" @@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { + ns := make([]string, 0, 2) + servers := make([]*dns.Server, 0, 2) + successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + }) + nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc { + return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, rcode) + w.WriteMsg(m) + }) + } + + handlers := []dns.Handler{ + nonSuccessHandlerWithRcode(dns.RcodeRefused), + nonSuccessHandlerWithRcode(dns.RcodeNameError), + successHandler, + } + for i := range handlers { + pc, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i]) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ns = append(ns, addr) + servers = append(servers, s) + } + defer func() { + for _, server := range servers { + server.Shutdown() + } + }() + resolver := &osResolver{nameservers: ns} + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + answer, err := resolver.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + if answer.Rcode != dns.RcodeSuccess { + t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode]) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) { }) } } + +func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) { + t.Helper() + + server := &dns.Server{ + PacketConn: pc, + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, + Handler: handler, + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + for _, opt := range opts { + opt(server) + } + + addr, closer := pc.LocalAddr().String(), pc + go func() { + if err := server.ActivateAndServe(); err != nil { + t.Error(err) + } + closer.Close() + }() + + waitLock.Lock() + return server, addr, nil +}