mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Avoid data race when initializing OS resolver
With new leaking queries features, the initialization of OS resolver can now lead to data race if queries are resolving while re-initialization happens. To fix it, using an atomic pointer to store list of nameservers which were initialized, making read/write to the list concurrently safe.
This commit is contained in:
committed by
Cuong Manh Le
parent
30ea0c6499
commit
f87220a908
48
resolver.go
48
resolver.go
@@ -9,6 +9,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
@@ -44,7 +45,7 @@ const (
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = &osResolver{nameservers: defaultNameservers()}
|
||||
var or = newResolverWithNameserver(defaultNameservers())
|
||||
|
||||
// defaultNameservers returns OS nameservers plus ControlD public DNS.
|
||||
func defaultNameservers() []string {
|
||||
@@ -58,7 +59,7 @@ func defaultNameservers() []string {
|
||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||
// calling this function.
|
||||
func InitializeOsResolver() []string {
|
||||
or.nameservers = or.nameservers[:0]
|
||||
var nss []string
|
||||
// Ignore local addresses to prevent loop.
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
@@ -70,11 +71,12 @@ func InitializeOsResolver() []string {
|
||||
continue
|
||||
}
|
||||
if testNameserver(ns) {
|
||||
or.nameservers = append(or.nameservers, ns)
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
}
|
||||
or.nameservers = append(or.nameservers, controldPublicDnsWithPort)
|
||||
return or.nameservers
|
||||
nss = append(nss, controldPublicDnsWithPort)
|
||||
or.nameservers.Store(&nss)
|
||||
return nss
|
||||
}
|
||||
|
||||
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
@@ -121,7 +123,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
}
|
||||
|
||||
type osResolver struct {
|
||||
nameservers []string
|
||||
nameservers atomic.Pointer[[]string]
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -134,7 +136,8 @@ type osResolverResult struct {
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
numServers := len(o.nameservers)
|
||||
nss := *o.nameservers.Load()
|
||||
numServers := len(nss)
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available")
|
||||
}
|
||||
@@ -144,12 +147,12 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
ch := make(chan *osResolverResult, numServers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(o.nameservers))
|
||||
wg.Add(len(nss))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
for _, server := range o.nameservers {
|
||||
for _, server := range nss {
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
@@ -238,11 +241,12 @@ func LookupIP(domain string) []string {
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
nss := nameservers()
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
|
||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||
resolver := newResolverWithNameserver(nss)
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, nss)
|
||||
timeoutMs := 2000
|
||||
if timeout > 0 && timeout < timeoutMs {
|
||||
timeoutMs = timeout
|
||||
@@ -315,12 +319,12 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
||||
// - Gateway IP address (depends on OS).
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...)
|
||||
nss := nameservers()
|
||||
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
||||
for _, ns := range servers {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
||||
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
|
||||
}
|
||||
return resolver
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
|
||||
@@ -357,10 +361,10 @@ func NewPrivateResolver() Resolver {
|
||||
}
|
||||
}
|
||||
nss = nss[:n]
|
||||
return NewResolverWithNameserver(nss)
|
||||
return newResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// NewResolverWithNameserver returns an OS resolver which uses the given nameservers
|
||||
// NewResolverWithNameserver returns a Resolver which uses the given nameservers
|
||||
// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned.
|
||||
//
|
||||
// Each nameserver must be form "host:port". It's the caller responsibility to ensure all
|
||||
@@ -369,7 +373,13 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
if len(nameservers) == 0 {
|
||||
return &dummyResolver{}
|
||||
}
|
||||
return &osResolver{nameservers: nameservers}
|
||||
return newResolverWithNameserver(nameservers)
|
||||
}
|
||||
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
r := &osResolver{}
|
||||
r.nameservers.Store(&nameservers)
|
||||
return r
|
||||
}
|
||||
|
||||
// Rfc1918Addresses returns the list of local interfaces private IP addresses
|
||||
|
||||
@@ -16,7 +16,8 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{nameservers: []string{"127.0.0.127:5353"}}
|
||||
resolver := &osResolver{}
|
||||
resolver.nameservers.Store(&[]string{"127.0.0.127:5353"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -69,7 +70,8 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
server.Shutdown()
|
||||
}
|
||||
}()
|
||||
resolver := &osResolver{nameservers: ns}
|
||||
resolver := &osResolver{}
|
||||
resolver.nameservers.Store(&ns)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
@@ -81,6 +83,19 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_InitializationRace(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
n := 10
|
||||
wg.Add(n)
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
InitializeOsResolver()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user