diff --git a/resolver.go b/resolver.go index f54edfb..fd97e48 100644 --- a/resolver.go +++ b/resolver.go @@ -85,60 +85,41 @@ func availableNameservers() []string { func InitializeOsResolver() []string { return initializeOsResolver(availableNameservers()) } + +// initializeOsResolver performs logic for choosing OS resolver nameserver. +// The logic: +// +// - First available LAN servers are saved and store. +// - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { var ( - nss []string + lanNss []string publicNss []string ) - var ( - lastLanServer netip.Addr - curLanServer netip.Addr - curLanServerAvailable bool - ) - if p := or.currentLanServer.Load(); p != nil { - curLanServer = *p - or.currentLanServer.Store(nil) - } - if p := or.lastLanServer.Load(); p != nil { - lastLanServer = *p - or.lastLanServer.Store(nil) - } + for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { continue } server := net.JoinHostPort(ns, "53") - // Always use new public nameserver. - if !isLanAddr(addr) { - publicNss = append(publicNss, server) - nss = append(nss, server) - continue - } - // For LAN server, storing only current and last LAN server if any. - if addr.Compare(curLanServer) == 0 { - curLanServerAvailable = true + if isLanAddr(addr) { + lanNss = append(lanNss, server) } else { - if addr.Compare(lastLanServer) == 0 { - or.lastLanServer.Store(&addr) - } else { - if or.currentLanServer.CompareAndSwap(nil, &addr) { - nss = append(nss, server) - } - } + publicNss = append(publicNss, server) } } - // Store current LAN server as last one only if it's still available. - if curLanServerAvailable && curLanServer.IsValid() { - or.lastLanServer.Store(&curLanServer) - nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) + if len(lanNss) > 0 { + // Saved first initialized LAN servers. + or.initializedLanServers.CompareAndSwap(nil, &lanNss) } - if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) - nss = append(nss, controldPublicDnsWithPort) + if len(lanNss) == 0 { + or.lanServers.Store(or.initializedLanServers.Load()) + } else { + or.lanServers.Store(&lanNss) } - or.publicServer.Store(&publicNss) - return nss + or.publicServers.Store(&publicNss) + return slices.Concat(lanNss, publicNss) } // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. @@ -185,9 +166,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - currentLanServer atomic.Pointer[netip.Addr] - lastLanServer atomic.Pointer[netip.Addr] - publicServer atomic.Pointer[[]string] + initializedLanServers atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -201,13 +182,10 @@ type osResolverResult struct { // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - publicServers := *o.publicServer.Load() - nss := make([]string, 0, 2) - if p := o.currentLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) - } - if p := o.lastLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) + publicServers := *o.publicServers.Load() + var nss []string + if p := o.lanServers.Load(); p != nil { + nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) if numServers == 0 { @@ -467,17 +445,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { r := &osResolver{} - nss := slices.Sorted(slices.Values(nameservers)) - for i, ns := range nss { + var publicNss []string + var lanNss []string + for _, ns := range slices.Sorted(slices.Values(nameservers)) { ip, _, _ := net.SplitHostPort(ns) addr, _ := netip.ParseAddr(ip) if isLanAddr(addr) { - r.currentLanServer.Store(&addr) - nss = slices.Delete(nss, i, i+1) - break + lanNss = append(lanNss, ns) + } else { + publicNss = append(publicNss, ns) } } - r.publicServer.Store(&nss) + r.lanServers.Store(&lanNss) + r.publicServers.Store(&publicNss) return r } diff --git a/resolver_test.go b/resolver_test.go index 7b1a49d..0db05f6 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -20,7 +20,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() resolver := &osResolver{} - resolver.publicServer.Store(&[]string{"127.0.0.127:5353"}) + resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -74,7 +74,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } }() resolver := &osResolver{} - resolver.publicServer.Store(&ns) + resolver.publicServers.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) @@ -156,38 +156,43 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H func Test_initializeOsResolver(t *testing.T) { lanServer1 := "192.168.1.1" lanServer2 := "10.0.10.69" + lanServer3 := "192.168.40.1" wanServer := "1.1.1.1" + lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")} publicServers := []string{net.JoinHostPort(wanServer, "53")} - // First initialization. + or = newResolverWithNameserver(defaultNameservers()) + + // First initialization, initialized servers are saved. + initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) + p := or.initializedLanServers.Load() + assert.NotNil(t, p) + t.Logf("%v - %v", *p, lanServers) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) + + // No new LAN servers, but lanServer2 gone, initialized servers not changed. initializeOsResolver([]string{lanServer1, wanServer}) - p := or.currentLanServer.Load() + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // No new LAN server, current LAN server -> last LAN server. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.Nil(t, p) - p = or.lastLanServer.Load() + // New LAN servers, they are used, initialized servers not changed. + initializeOsResolver([]string{lanServer3, wanServer}) + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // New LAN server detected. - initializeOsResolver([]string{lanServer2, lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer2, p.String()) - p = or.lastLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // No LAN server available. + // No LAN server available, initialized servers will be used. initializeOsResolver([]string{wanServer}) - assert.Nil(t, or.currentLanServer.Load()) - assert.Nil(t, or.lastLanServer.Load()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + p = or.initializedLanServers.Load() + assert.NotNil(t, p) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) }