From ed39269c8061a989550856038343f96726b625af Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 5 Dec 2024 22:37:06 +0700 Subject: [PATCH] Implementing new initializing OS resolver logic Since the nameservers that we got during startup are the good ones that work, saving it for later usage if we could not find available ones. --- resolver.go | 90 +++++++++++++++++++----------------------------- resolver_test.go | 59 ++++++++++++++++--------------- 2 files changed, 67 insertions(+), 82 deletions(-) diff --git a/resolver.go b/resolver.go index f54edfb..fd97e48 100644 --- a/resolver.go +++ b/resolver.go @@ -85,60 +85,41 @@ func availableNameservers() []string { func InitializeOsResolver() []string { return initializeOsResolver(availableNameservers()) } + +// initializeOsResolver performs logic for choosing OS resolver nameserver. +// The logic: +// +// - First available LAN servers are saved and store. +// - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { var ( - nss []string + lanNss []string publicNss []string ) - var ( - lastLanServer netip.Addr - curLanServer netip.Addr - curLanServerAvailable bool - ) - if p := or.currentLanServer.Load(); p != nil { - curLanServer = *p - or.currentLanServer.Store(nil) - } - if p := or.lastLanServer.Load(); p != nil { - lastLanServer = *p - or.lastLanServer.Store(nil) - } + for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { continue } server := net.JoinHostPort(ns, "53") - // Always use new public nameserver. - if !isLanAddr(addr) { - publicNss = append(publicNss, server) - nss = append(nss, server) - continue - } - // For LAN server, storing only current and last LAN server if any. - if addr.Compare(curLanServer) == 0 { - curLanServerAvailable = true + if isLanAddr(addr) { + lanNss = append(lanNss, server) } else { - if addr.Compare(lastLanServer) == 0 { - or.lastLanServer.Store(&addr) - } else { - if or.currentLanServer.CompareAndSwap(nil, &addr) { - nss = append(nss, server) - } - } + publicNss = append(publicNss, server) } } - // Store current LAN server as last one only if it's still available. - if curLanServerAvailable && curLanServer.IsValid() { - or.lastLanServer.Store(&curLanServer) - nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) + if len(lanNss) > 0 { + // Saved first initialized LAN servers. + or.initializedLanServers.CompareAndSwap(nil, &lanNss) } - if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) - nss = append(nss, controldPublicDnsWithPort) + if len(lanNss) == 0 { + or.lanServers.Store(or.initializedLanServers.Load()) + } else { + or.lanServers.Store(&lanNss) } - or.publicServer.Store(&publicNss) - return nss + or.publicServers.Store(&publicNss) + return slices.Concat(lanNss, publicNss) } // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. @@ -185,9 +166,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - currentLanServer atomic.Pointer[netip.Addr] - lastLanServer atomic.Pointer[netip.Addr] - publicServer atomic.Pointer[[]string] + initializedLanServers atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -201,13 +182,10 @@ type osResolverResult struct { // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - publicServers := *o.publicServer.Load() - nss := make([]string, 0, 2) - if p := o.currentLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) - } - if p := o.lastLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) + publicServers := *o.publicServers.Load() + var nss []string + if p := o.lanServers.Load(); p != nil { + nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) if numServers == 0 { @@ -467,17 +445,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { r := &osResolver{} - nss := slices.Sorted(slices.Values(nameservers)) - for i, ns := range nss { + var publicNss []string + var lanNss []string + for _, ns := range slices.Sorted(slices.Values(nameservers)) { ip, _, _ := net.SplitHostPort(ns) addr, _ := netip.ParseAddr(ip) if isLanAddr(addr) { - r.currentLanServer.Store(&addr) - nss = slices.Delete(nss, i, i+1) - break + lanNss = append(lanNss, ns) + } else { + publicNss = append(publicNss, ns) } } - r.publicServer.Store(&nss) + r.lanServers.Store(&lanNss) + r.publicServers.Store(&publicNss) return r } diff --git a/resolver_test.go b/resolver_test.go index 7b1a49d..0db05f6 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -20,7 +20,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() resolver := &osResolver{} - resolver.publicServer.Store(&[]string{"127.0.0.127:5353"}) + resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -74,7 +74,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } }() resolver := &osResolver{} - resolver.publicServer.Store(&ns) + resolver.publicServers.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) @@ -156,38 +156,43 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H func Test_initializeOsResolver(t *testing.T) { lanServer1 := "192.168.1.1" lanServer2 := "10.0.10.69" + lanServer3 := "192.168.40.1" wanServer := "1.1.1.1" + lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")} publicServers := []string{net.JoinHostPort(wanServer, "53")} - // First initialization. + or = newResolverWithNameserver(defaultNameservers()) + + // First initialization, initialized servers are saved. + initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) + p := or.initializedLanServers.Load() + assert.NotNil(t, p) + t.Logf("%v - %v", *p, lanServers) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) + + // No new LAN servers, but lanServer2 gone, initialized servers not changed. initializeOsResolver([]string{lanServer1, wanServer}) - p := or.currentLanServer.Load() + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // No new LAN server, current LAN server -> last LAN server. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.Nil(t, p) - p = or.lastLanServer.Load() + // New LAN servers, they are used, initialized servers not changed. + initializeOsResolver([]string{lanServer3, wanServer}) + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // New LAN server detected. - initializeOsResolver([]string{lanServer2, lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer2, p.String()) - p = or.lastLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // No LAN server available. + // No LAN server available, initialized servers will be used. initializeOsResolver([]string{wanServer}) - assert.Nil(t, or.currentLanServer.Load()) - assert.Nil(t, or.lastLanServer.Load()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + p = or.initializedLanServers.Load() + assert.NotNil(t, p) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) }