From 65de7edcdee8c500c1d8a56b316759e2770a24a8 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 22 Oct 2024 00:47:03 +0700 Subject: [PATCH] Only store last LAN server if available Otherwise, queries may still be forwarded to this un-available LAN server, causing slow query time. --- resolver.go | 38 ++++++++++++++++++++++++++++++-------- resolver_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/resolver.go b/resolver.go index b38504c..f54edfb 100644 --- a/resolver.go +++ b/resolver.go @@ -83,31 +83,53 @@ func availableNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { + return initializeOsResolver(availableNameservers()) +} +func initializeOsResolver(servers []string) []string { var ( nss []string publicNss []string ) - var curLanServer netip.Addr + var ( + lastLanServer netip.Addr + curLanServer netip.Addr + curLanServerAvailable bool + ) if p := or.currentLanServer.Load(); p != nil { curLanServer = *p or.currentLanServer.Store(nil) } - for _, ns := range availableNameservers() { + if p := or.lastLanServer.Load(); p != nil { + lastLanServer = *p + or.lastLanServer.Store(nil) + } + for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { continue } server := net.JoinHostPort(ns, "53") - if isLanAddr(addr) { - if addr.Compare(curLanServer) != 0 && or.currentLanServer.CompareAndSwap(nil, &addr) { - nss = append(nss, server) - } - } else { + // Always use new public nameserver. + if !isLanAddr(addr) { publicNss = append(publicNss, server) nss = append(nss, server) + continue + } + // For LAN server, storing only current and last LAN server if any. + if addr.Compare(curLanServer) == 0 { + curLanServerAvailable = true + } else { + if addr.Compare(lastLanServer) == 0 { + or.lastLanServer.Store(&addr) + } else { + if or.currentLanServer.CompareAndSwap(nil, &addr) { + nss = append(nss, server) + } + } } } - if curLanServer.IsValid() { + // Store current LAN server as last one only if it's still available. + if curLanServerAvailable && curLanServer.IsValid() { or.lastLanServer.Store(&curLanServer) nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) } diff --git a/resolver_test.go b/resolver_test.go index 44b170a..7b1a49d 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -3,10 +3,13 @@ package ctrld import ( "context" "net" + "slices" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/miekg/dns" ) @@ -149,3 +152,42 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H waitLock.Lock() return server, addr, nil } + +func Test_initializeOsResolver(t *testing.T) { + lanServer1 := "192.168.1.1" + lanServer2 := "10.0.10.69" + wanServer := "1.1.1.1" + publicServers := []string{net.JoinHostPort(wanServer, "53")} + + // First initialization. + initializeOsResolver([]string{lanServer1, wanServer}) + p := or.currentLanServer.Load() + assert.NotNil(t, p) + assert.Equal(t, lanServer1, p.String()) + assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + + // No new LAN server, current LAN server -> last LAN server. + initializeOsResolver([]string{lanServer1, wanServer}) + p = or.currentLanServer.Load() + assert.Nil(t, p) + p = or.lastLanServer.Load() + assert.NotNil(t, p) + assert.Equal(t, lanServer1, p.String()) + assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + + // New LAN server detected. + initializeOsResolver([]string{lanServer2, lanServer1, wanServer}) + p = or.currentLanServer.Load() + assert.NotNil(t, p) + assert.Equal(t, lanServer2, p.String()) + p = or.lastLanServer.Load() + assert.NotNil(t, p) + assert.Equal(t, lanServer1, p.String()) + assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + + // No LAN server available. + initializeOsResolver([]string{wanServer}) + assert.Nil(t, or.currentLanServer.Load()) + assert.Nil(t, or.lastLanServer.Load()) + assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) +}