Implementing new initializing OS resolver logic

Since the nameservers that we got during startup are the good ones that
work, saving it for later usage if we could not find available ones.
This commit is contained in:
Cuong Manh Le
2024-12-05 22:37:06 +07:00
committed by Cuong Manh Le
parent 09426dcd36
commit ed39269c80
2 changed files with 67 additions and 82 deletions

View File

@@ -85,60 +85,41 @@ func availableNameservers() []string {
func InitializeOsResolver() []string {
return initializeOsResolver(availableNameservers())
}
// initializeOsResolver performs logic for choosing OS resolver nameserver.
// The logic:
//
// - First available LAN servers are saved and store.
// - Later calls, if no LAN servers available, the saved servers above will be used.
func initializeOsResolver(servers []string) []string {
var (
nss []string
lanNss []string
publicNss []string
)
var (
lastLanServer netip.Addr
curLanServer netip.Addr
curLanServerAvailable bool
)
if p := or.currentLanServer.Load(); p != nil {
curLanServer = *p
or.currentLanServer.Store(nil)
}
if p := or.lastLanServer.Load(); p != nil {
lastLanServer = *p
or.lastLanServer.Store(nil)
}
for _, ns := range servers {
addr, err := netip.ParseAddr(ns)
if err != nil {
continue
}
server := net.JoinHostPort(ns, "53")
// Always use new public nameserver.
if !isLanAddr(addr) {
publicNss = append(publicNss, server)
nss = append(nss, server)
continue
}
// For LAN server, storing only current and last LAN server if any.
if addr.Compare(curLanServer) == 0 {
curLanServerAvailable = true
if isLanAddr(addr) {
lanNss = append(lanNss, server)
} else {
if addr.Compare(lastLanServer) == 0 {
or.lastLanServer.Store(&addr)
} else {
if or.currentLanServer.CompareAndSwap(nil, &addr) {
nss = append(nss, server)
}
}
publicNss = append(publicNss, server)
}
}
// Store current LAN server as last one only if it's still available.
if curLanServerAvailable && curLanServer.IsValid() {
or.lastLanServer.Store(&curLanServer)
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
if len(lanNss) > 0 {
// Saved first initialized LAN servers.
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
}
if len(publicNss) == 0 {
publicNss = append(publicNss, controldPublicDnsWithPort)
nss = append(nss, controldPublicDnsWithPort)
if len(lanNss) == 0 {
or.lanServers.Store(or.initializedLanServers.Load())
} else {
or.lanServers.Store(&lanNss)
}
or.publicServer.Store(&publicNss)
return nss
or.publicServers.Store(&publicNss)
return slices.Concat(lanNss, publicNss)
}
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
@@ -185,9 +166,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
}
type osResolver struct {
currentLanServer atomic.Pointer[netip.Addr]
lastLanServer atomic.Pointer[netip.Addr]
publicServer atomic.Pointer[[]string]
initializedLanServers atomic.Pointer[[]string]
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
}
type osResolverResult struct {
@@ -201,13 +182,10 @@ type osResolverResult struct {
// Query is sent to all nameservers concurrently, and the first
// success response will be returned.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
publicServers := *o.publicServer.Load()
nss := make([]string, 0, 2)
if p := o.currentLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
}
if p := o.lastLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
publicServers := *o.publicServers.Load()
var nss []string
if p := o.lanServers.Load(); p != nil {
nss = append(nss, (*p)...)
}
numServers := len(nss) + len(publicServers)
if numServers == 0 {
@@ -467,17 +445,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver {
r := &osResolver{}
nss := slices.Sorted(slices.Values(nameservers))
for i, ns := range nss {
var publicNss []string
var lanNss []string
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
ip, _, _ := net.SplitHostPort(ns)
addr, _ := netip.ParseAddr(ip)
if isLanAddr(addr) {
r.currentLanServer.Store(&addr)
nss = slices.Delete(nss, i, i+1)
break
lanNss = append(lanNss, ns)
} else {
publicNss = append(publicNss, ns)
}
}
r.publicServer.Store(&nss)
r.lanServers.Store(&lanNss)
r.publicServers.Store(&publicNss)
return r
}

View File

@@ -20,7 +20,7 @@ func Test_osResolver_Resolve(t *testing.T) {
go func() {
defer cancel()
resolver := &osResolver{}
resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
@@ -74,7 +74,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
}
}()
resolver := &osResolver{}
resolver.publicServer.Store(&ns)
resolver.publicServers.Store(&ns)
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
@@ -156,38 +156,43 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
func Test_initializeOsResolver(t *testing.T) {
lanServer1 := "192.168.1.1"
lanServer2 := "10.0.10.69"
lanServer3 := "192.168.40.1"
wanServer := "1.1.1.1"
lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")}
publicServers := []string{net.JoinHostPort(wanServer, "53")}
// First initialization.
or = newResolverWithNameserver(defaultNameservers())
// First initialization, initialized servers are saved.
initializeOsResolver([]string{lanServer1, lanServer2, wanServer})
p := or.initializedLanServers.Load()
assert.NotNil(t, p)
t.Logf("%v - %v", *p, lanServers)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// No new LAN servers, but lanServer2 gone, initialized servers not changed.
initializeOsResolver([]string{lanServer1, wanServer})
p := or.currentLanServer.Load()
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")}))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// No new LAN server, current LAN server -> last LAN server.
initializeOsResolver([]string{lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.Nil(t, p)
p = or.lastLanServer.Load()
// New LAN servers, they are used, initialized servers not changed.
initializeOsResolver([]string{lanServer3, wanServer})
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")}))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
// New LAN server detected.
initializeOsResolver([]string{lanServer2, lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer2, p.String())
p = or.lastLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// No LAN server available.
// No LAN server available, initialized servers will be used.
initializeOsResolver([]string{wanServer})
assert.Nil(t, or.currentLanServer.Load())
assert.Nil(t, or.lastLanServer.Load())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
p = or.initializedLanServers.Load()
assert.NotNil(t, p)
assert.True(t, slices.Equal(*p, lanServers))
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
}