mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Add ControlD public DNS to OS resolver
Since the OS resolver only returns response with NOERROR first, it's safe to use ControlD public DNS in parallel with system DNS. Local domains would resolve only though local resolvers, because public ones will return NXDOMAIN response.
This commit is contained in:
committed by
Cuong Manh Le
parent
dc48c908b8
commit
56f9c72569
2
dot.go
2
dot.go
@@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
|
||||
32
resolver.go
32
resolver.go
@@ -30,18 +30,18 @@ const (
|
||||
ResolverTypePrivate = "private"
|
||||
)
|
||||
|
||||
const bootstrapDNS = "76.76.2.22"
|
||||
const (
|
||||
controldBootstrapDns = "76.76.2.22"
|
||||
controldPublicDns = "76.76.2.0"
|
||||
)
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = &osResolver{nameservers: defaultNameservers()}
|
||||
|
||||
// defaultNameservers returns nameservers used by the OS.
|
||||
// If no nameservers can be found, ctrld bootstrap nameserver will be used.
|
||||
// defaultNameservers returns OS nameservers plus ControlD public DNS.
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
if len(ns) == 0 {
|
||||
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
|
||||
}
|
||||
ns = append(ns, net.JoinHostPort(controldPublicDns, "53"))
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -120,15 +120,21 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
}(server)
|
||||
}
|
||||
|
||||
var nonSuccessAnswer *dns.Msg
|
||||
errs := make([]error, 0, numServers)
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
return res.answer, res.err
|
||||
if res.answer != nil {
|
||||
if res.answer.Rcode == dns.RcodeSuccess {
|
||||
cancel()
|
||||
return res.answer, nil
|
||||
}
|
||||
nonSuccessAnswer = res.answer
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
if nonSuccessAnswer != nil {
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -138,7 +144,7 @@ type legacyResolver struct {
|
||||
|
||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// See comment in (*dotResolver).resolve method.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -176,7 +182,7 @@ func LookupIP(domain string) []string {
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
@@ -252,7 +258,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(controldPublicDns, "53")}, resolver.nameservers...)
|
||||
for _, ns := range servers {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
ns := make([]string, 0, 2)
|
||||
servers := make([]*dns.Server, 0, 2)
|
||||
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
|
||||
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, rcode)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
}
|
||||
|
||||
handlers := []dns.Handler{
|
||||
nonSuccessHandlerWithRcode(dns.RcodeRefused),
|
||||
nonSuccessHandlerWithRcode(dns.RcodeNameError),
|
||||
successHandler,
|
||||
}
|
||||
for i := range handlers {
|
||||
pc, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ns = append(ns, addr)
|
||||
servers = append(servers, s)
|
||||
}
|
||||
defer func() {
|
||||
for _, server := range servers {
|
||||
server.Shutdown()
|
||||
}
|
||||
}()
|
||||
resolver := &osResolver{nameservers: ns}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
|
||||
}
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
server := &dns.Server{
|
||||
PacketConn: pc,
|
||||
ReadTimeout: time.Hour,
|
||||
WriteTimeout: time.Hour,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
server.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
addr, closer := pc.LocalAddr().String(), pc
|
||||
go func() {
|
||||
if err := server.ActivateAndServe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
closer.Close()
|
||||
}()
|
||||
|
||||
waitLock.Lock()
|
||||
return server, addr, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user