Avoid data race when initializing OS resolver

With new leaking queries features, the initialization of OS resolver can
now lead to data race if queries are resolving while re-initialization
happens.

To fix it, using an atomic pointer to store list of nameservers which
were initialized, making read/write to the list concurrently safe.
This commit is contained in:
Cuong Manh Le
2024-10-17 18:09:16 +07:00
committed by Cuong Manh Le
parent 30ea0c6499
commit f87220a908
2 changed files with 46 additions and 21 deletions

View File

@@ -9,6 +9,7 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"tailscale.com/net/netmon"
@@ -44,7 +45,7 @@ const (
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS.
var or = &osResolver{nameservers: defaultNameservers()}
var or = newResolverWithNameserver(defaultNameservers())
// defaultNameservers returns OS nameservers plus ControlD public DNS.
func defaultNameservers() []string {
@@ -58,7 +59,7 @@ func defaultNameservers() []string {
// It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function.
func InitializeOsResolver() []string {
or.nameservers = or.nameservers[:0]
var nss []string
// Ignore local addresses to prevent loop.
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
machineIPsMap := make(map[string]struct{}, len(regularIPs))
@@ -70,11 +71,12 @@ func InitializeOsResolver() []string {
continue
}
if testNameserver(ns) {
or.nameservers = append(or.nameservers, ns)
nss = append(nss, ns)
}
}
or.nameservers = append(or.nameservers, controldPublicDnsWithPort)
return or.nameservers
nss = append(nss, controldPublicDnsWithPort)
or.nameservers.Store(&nss)
return nss
}
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
@@ -121,7 +123,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
}
type osResolver struct {
nameservers []string
nameservers atomic.Pointer[[]string]
}
type osResolverResult struct {
@@ -134,7 +136,8 @@ type osResolverResult struct {
// Query is sent to all nameservers concurrently, and the first
// success response will be returned.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
numServers := len(o.nameservers)
nss := *o.nameservers.Load()
numServers := len(nss)
if numServers == 0 {
return nil, errors.New("no nameservers available")
}
@@ -144,12 +147,12 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
dnsClient := &dns.Client{Net: "udp"}
ch := make(chan *osResolverResult, numServers)
var wg sync.WaitGroup
wg.Add(len(o.nameservers))
wg.Add(len(nss))
go func() {
wg.Wait()
close(ch)
}()
for _, server := range o.nameservers {
for _, server := range nss {
go func(server string) {
defer wg.Done()
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
@@ -238,11 +241,12 @@ func LookupIP(domain string) []string {
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolver := &osResolver{nameservers: nameservers()}
nss := nameservers()
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
}
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
resolver := newResolverWithNameserver(nss)
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, nss)
timeoutMs := 2000
if timeout > 0 && timeout < timeoutMs {
timeoutMs = timeout
@@ -315,12 +319,12 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
// - Gateway IP address (depends on OS).
// - Input servers.
func NewBootstrapResolver(servers ...string) Resolver {
resolver := &osResolver{nameservers: nameservers()}
resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...)
nss := nameservers()
nss = append([]string{controldPublicDnsWithPort}, nss...)
for _, ns := range servers {
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
}
return resolver
return NewResolverWithNameserver(nss)
}
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
@@ -357,10 +361,10 @@ func NewPrivateResolver() Resolver {
}
}
nss = nss[:n]
return NewResolverWithNameserver(nss)
return newResolverWithNameserver(nss)
}
// NewResolverWithNameserver returns an OS resolver which uses the given nameservers
// NewResolverWithNameserver returns a Resolver which uses the given nameservers
// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned.
//
// Each nameserver must be form "host:port". It's the caller responsibility to ensure all
@@ -369,7 +373,13 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
if len(nameservers) == 0 {
return &dummyResolver{}
}
return &osResolver{nameservers: nameservers}
return newResolverWithNameserver(nameservers)
}
func newResolverWithNameserver(nameservers []string) *osResolver {
r := &osResolver{}
r.nameservers.Store(&nameservers)
return r
}
// Rfc1918Addresses returns the list of local interfaces private IP addresses

View File

@@ -16,7 +16,8 @@ func Test_osResolver_Resolve(t *testing.T) {
go func() {
defer cancel()
resolver := &osResolver{nameservers: []string{"127.0.0.127:5353"}}
resolver := &osResolver{}
resolver.nameservers.Store(&[]string{"127.0.0.127:5353"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
@@ -69,7 +70,8 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
server.Shutdown()
}
}()
resolver := &osResolver{nameservers: ns}
resolver := &osResolver{}
resolver.nameservers.Store(&ns)
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
@@ -81,6 +83,19 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
}
}
func Test_osResolver_InitializationRace(t *testing.T) {
var wg sync.WaitGroup
n := 10
wg.Add(n)
for range n {
go func() {
defer wg.Done()
InitializeOsResolver()
}()
}
wg.Wait()
}
func Test_upstreamTypeFromEndpoint(t *testing.T) {
tests := []struct {
name string