diff --git a/resolver.go b/resolver.go index 226ad1f..1e896c1 100644 --- a/resolver.go +++ b/resolver.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "tailscale.com/net/netmon" @@ -44,7 +45,7 @@ const ( var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. -var or = &osResolver{nameservers: defaultNameservers()} +var or = newResolverWithNameserver(defaultNameservers()) // defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { @@ -58,7 +59,7 @@ func defaultNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - or.nameservers = or.nameservers[:0] + var nss []string // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) @@ -70,11 +71,12 @@ func InitializeOsResolver() []string { continue } if testNameserver(ns) { - or.nameservers = append(or.nameservers, ns) + nss = append(nss, ns) } } - or.nameservers = append(or.nameservers, controldPublicDnsWithPort) - return or.nameservers + nss = append(nss, controldPublicDnsWithPort) + or.nameservers.Store(&nss) + return nss } // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. @@ -121,7 +123,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - nameservers []string + nameservers atomic.Pointer[[]string] } type osResolverResult struct { @@ -134,7 +136,8 @@ type osResolverResult struct { // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - numServers := len(o.nameservers) + nss := *o.nameservers.Load() + numServers := len(nss) if numServers == 0 { return nil, errors.New("no nameservers available") } @@ -144,12 +147,12 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error dnsClient := &dns.Client{Net: "udp"} ch := make(chan *osResolverResult, numServers) var wg sync.WaitGroup - wg.Add(len(o.nameservers)) + wg.Add(len(nss)) go func() { wg.Wait() close(ch) }() - for _, server := range o.nameservers { + for _, server := range nss { go func(server string) { defer wg.Done() answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) @@ -238,11 +241,12 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - resolver := &osResolver{nameservers: nameservers()} + nss := nameservers() if withBootstrapDNS { - resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...) + nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) } - ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers) + resolver := newResolverWithNameserver(nss) + ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, nss) timeoutMs := 2000 if timeout > 0 && timeout < timeoutMs { timeoutMs = timeout @@ -315,12 +319,12 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Gateway IP address (depends on OS). // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { - resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...) + nss := nameservers() + nss = append([]string{controldPublicDnsWithPort}, nss...) for _, ns := range servers { - resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) + nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) } - return resolver + return NewResolverWithNameserver(nss) } // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, @@ -357,10 +361,10 @@ func NewPrivateResolver() Resolver { } } nss = nss[:n] - return NewResolverWithNameserver(nss) + return newResolverWithNameserver(nss) } -// NewResolverWithNameserver returns an OS resolver which uses the given nameservers +// NewResolverWithNameserver returns a Resolver which uses the given nameservers // for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned. // // Each nameserver must be form "host:port". It's the caller responsibility to ensure all @@ -369,7 +373,13 @@ func NewResolverWithNameserver(nameservers []string) Resolver { if len(nameservers) == 0 { return &dummyResolver{} } - return &osResolver{nameservers: nameservers} + return newResolverWithNameserver(nameservers) +} + +func newResolverWithNameserver(nameservers []string) *osResolver { + r := &osResolver{} + r.nameservers.Store(&nameservers) + return r } // Rfc1918Addresses returns the list of local interfaces private IP addresses diff --git a/resolver_test.go b/resolver_test.go index 23c27ae..9d1cb34 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -16,7 +16,8 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() - resolver := &osResolver{nameservers: []string{"127.0.0.127:5353"}} + resolver := &osResolver{} + resolver.nameservers.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -69,7 +70,8 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { server.Shutdown() } }() - resolver := &osResolver{nameservers: ns} + resolver := &osResolver{} + resolver.nameservers.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) @@ -81,6 +83,19 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } } +func Test_osResolver_InitializationRace(t *testing.T) { + var wg sync.WaitGroup + n := 10 + wg.Add(n) + for range n { + go func() { + defer wg.Done() + InitializeOsResolver() + }() + } + wg.Wait() +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string