Only store last LAN server if available

Otherwise, queries may still be forwarded to this un-available LAN
server, causing slow query time.
This commit is contained in:
Cuong Manh Le
2024-10-22 00:47:03 +07:00
committed by Cuong Manh Le
parent 0cdff0d368
commit 65de7edcde
2 changed files with 72 additions and 8 deletions

View File

@@ -83,31 +83,53 @@ func availableNameservers() []string {
// It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function.
func InitializeOsResolver() []string {
return initializeOsResolver(availableNameservers())
}
func initializeOsResolver(servers []string) []string {
var (
nss []string
publicNss []string
)
var curLanServer netip.Addr
var (
lastLanServer netip.Addr
curLanServer netip.Addr
curLanServerAvailable bool
)
if p := or.currentLanServer.Load(); p != nil {
curLanServer = *p
or.currentLanServer.Store(nil)
}
for _, ns := range availableNameservers() {
if p := or.lastLanServer.Load(); p != nil {
lastLanServer = *p
or.lastLanServer.Store(nil)
}
for _, ns := range servers {
addr, err := netip.ParseAddr(ns)
if err != nil {
continue
}
server := net.JoinHostPort(ns, "53")
if isLanAddr(addr) {
if addr.Compare(curLanServer) != 0 && or.currentLanServer.CompareAndSwap(nil, &addr) {
nss = append(nss, server)
}
} else {
// Always use new public nameserver.
if !isLanAddr(addr) {
publicNss = append(publicNss, server)
nss = append(nss, server)
continue
}
// For LAN server, storing only current and last LAN server if any.
if addr.Compare(curLanServer) == 0 {
curLanServerAvailable = true
} else {
if addr.Compare(lastLanServer) == 0 {
or.lastLanServer.Store(&addr)
} else {
if or.currentLanServer.CompareAndSwap(nil, &addr) {
nss = append(nss, server)
}
}
}
}
if curLanServer.IsValid() {
// Store current LAN server as last one only if it's still available.
if curLanServerAvailable && curLanServer.IsValid() {
or.lastLanServer.Store(&curLanServer)
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
}

View File

@@ -3,10 +3,13 @@ package ctrld
import (
"context"
"net"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/miekg/dns"
)
@@ -149,3 +152,42 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
waitLock.Lock()
return server, addr, nil
}
func Test_initializeOsResolver(t *testing.T) {
lanServer1 := "192.168.1.1"
lanServer2 := "10.0.10.69"
wanServer := "1.1.1.1"
publicServers := []string{net.JoinHostPort(wanServer, "53")}
// First initialization.
initializeOsResolver([]string{lanServer1, wanServer})
p := or.currentLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// No new LAN server, current LAN server -> last LAN server.
initializeOsResolver([]string{lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.Nil(t, p)
p = or.lastLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// New LAN server detected.
initializeOsResolver([]string{lanServer2, lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer2, p.String())
p = or.lastLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// No LAN server available.
initializeOsResolver([]string{wanServer})
assert.Nil(t, or.currentLanServer.Load())
assert.Nil(t, or.lastLanServer.Load())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
}