Add ControlD public DNS to OS resolver

Since the OS resolver only returns response with NOERROR first, it's
safe to use ControlD public DNS in parallel with system DNS. Local
domains would resolve only though local resolvers, because public ones
will return NXDOMAIN response.
This commit is contained in:
Cuong Manh Le
2024-07-12 17:35:34 +07:00
committed by Cuong Manh Le
parent dc48c908b8
commit 56f9c72569
3 changed files with 103 additions and 14 deletions

2
dot.go
View File

@@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype

View File

@@ -30,18 +30,18 @@ const (
ResolverTypePrivate = "private"
)
const bootstrapDNS = "76.76.2.22"
const (
controldBootstrapDns = "76.76.2.22"
controldPublicDns = "76.76.2.0"
)
// or is the Resolver used for ResolverTypeOS.
var or = &osResolver{nameservers: defaultNameservers()}
// defaultNameservers returns nameservers used by the OS.
// If no nameservers can be found, ctrld bootstrap nameserver will be used.
// defaultNameservers returns OS nameservers plus ControlD public DNS.
func defaultNameservers() []string {
ns := nameservers()
if len(ns) == 0 {
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
}
ns = append(ns, net.JoinHostPort(controldPublicDns, "53"))
return ns
}
@@ -120,15 +120,21 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
}(server)
}
var nonSuccessAnswer *dns.Msg
errs := make([]error, 0, numServers)
for res := range ch {
if res.err == nil {
cancel()
return res.answer, res.err
if res.answer != nil {
if res.answer.Rcode == dns.RcodeSuccess {
cancel()
return res.answer, nil
}
nonSuccessAnswer = res.answer
}
errs = append(errs, res.err)
}
if nonSuccessAnswer != nil {
return nonSuccessAnswer, nil
}
return nil, errors.Join(errs...)
}
@@ -138,7 +144,7 @@ type legacyResolver struct {
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// See comment in (*dotResolver).resolve method.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
@@ -176,7 +182,7 @@ func LookupIP(domain string) []string {
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolver := &osResolver{nameservers: nameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
}
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
timeoutMs := 2000
@@ -252,7 +258,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
// - Input servers.
func NewBootstrapResolver(servers ...string) Resolver {
resolver := &osResolver{nameservers: nameservers()}
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...)
for _, ns := range servers {
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
}

View File

@@ -2,6 +2,8 @@ package ctrld
import (
"context"
"net"
"sync"
"testing"
"time"
@@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) {
}
}
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
ns := make([]string, 0, 2)
servers := make([]*dns.Server, 0, 2)
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, dns.RcodeSuccess)
w.WriteMsg(m)
})
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, rcode)
w.WriteMsg(m)
})
}
handlers := []dns.Handler{
nonSuccessHandlerWithRcode(dns.RcodeRefused),
nonSuccessHandlerWithRcode(dns.RcodeNameError),
successHandler,
}
for i := range handlers {
pc, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ns = append(ns, addr)
servers = append(servers, s)
}
defer func() {
for _, server := range servers {
server.Shutdown()
}
}()
resolver := &osResolver{nameservers: ns}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
if err != nil {
t.Fatal(err)
}
if answer.Rcode != dns.RcodeSuccess {
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
}
}
func Test_upstreamTypeFromEndpoint(t *testing.T) {
tests := []struct {
name string
@@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) {
})
}
}
func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) {
t.Helper()
server := &dns.Server{
PacketConn: pc,
ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
Handler: handler,
}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
for _, opt := range opts {
opt(server)
}
addr, closer := pc.LocalAddr().String(), pc
go func() {
if err := server.ActivateAndServe(); err != nil {
t.Error(err)
}
closer.Close()
}()
waitLock.Lock()
return server, addr, nil
}