mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Implementing new initializing OS resolver logic
Since the nameservers that we got during startup are the good ones that work, saving it for later usage if we could not find available ones.
This commit is contained in:
committed by
Cuong Manh Le
parent
09426dcd36
commit
ed39269c80
90
resolver.go
90
resolver.go
@@ -85,60 +85,41 @@ func availableNameservers() []string {
|
||||
func InitializeOsResolver() []string {
|
||||
return initializeOsResolver(availableNameservers())
|
||||
}
|
||||
|
||||
// initializeOsResolver performs logic for choosing OS resolver nameserver.
|
||||
// The logic:
|
||||
//
|
||||
// - First available LAN servers are saved and store.
|
||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||
func initializeOsResolver(servers []string) []string {
|
||||
var (
|
||||
nss []string
|
||||
lanNss []string
|
||||
publicNss []string
|
||||
)
|
||||
var (
|
||||
lastLanServer netip.Addr
|
||||
curLanServer netip.Addr
|
||||
curLanServerAvailable bool
|
||||
)
|
||||
if p := or.currentLanServer.Load(); p != nil {
|
||||
curLanServer = *p
|
||||
or.currentLanServer.Store(nil)
|
||||
}
|
||||
if p := or.lastLanServer.Load(); p != nil {
|
||||
lastLanServer = *p
|
||||
or.lastLanServer.Store(nil)
|
||||
}
|
||||
|
||||
for _, ns := range servers {
|
||||
addr, err := netip.ParseAddr(ns)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
server := net.JoinHostPort(ns, "53")
|
||||
// Always use new public nameserver.
|
||||
if !isLanAddr(addr) {
|
||||
publicNss = append(publicNss, server)
|
||||
nss = append(nss, server)
|
||||
continue
|
||||
}
|
||||
// For LAN server, storing only current and last LAN server if any.
|
||||
if addr.Compare(curLanServer) == 0 {
|
||||
curLanServerAvailable = true
|
||||
if isLanAddr(addr) {
|
||||
lanNss = append(lanNss, server)
|
||||
} else {
|
||||
if addr.Compare(lastLanServer) == 0 {
|
||||
or.lastLanServer.Store(&addr)
|
||||
} else {
|
||||
if or.currentLanServer.CompareAndSwap(nil, &addr) {
|
||||
nss = append(nss, server)
|
||||
}
|
||||
}
|
||||
publicNss = append(publicNss, server)
|
||||
}
|
||||
}
|
||||
// Store current LAN server as last one only if it's still available.
|
||||
if curLanServerAvailable && curLanServer.IsValid() {
|
||||
or.lastLanServer.Store(&curLanServer)
|
||||
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
|
||||
if len(lanNss) > 0 {
|
||||
// Saved first initialized LAN servers.
|
||||
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
|
||||
}
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = append(publicNss, controldPublicDnsWithPort)
|
||||
nss = append(nss, controldPublicDnsWithPort)
|
||||
if len(lanNss) == 0 {
|
||||
or.lanServers.Store(or.initializedLanServers.Load())
|
||||
} else {
|
||||
or.lanServers.Store(&lanNss)
|
||||
}
|
||||
or.publicServer.Store(&publicNss)
|
||||
return nss
|
||||
or.publicServers.Store(&publicNss)
|
||||
return slices.Concat(lanNss, publicNss)
|
||||
}
|
||||
|
||||
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
@@ -185,9 +166,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
}
|
||||
|
||||
type osResolver struct {
|
||||
currentLanServer atomic.Pointer[netip.Addr]
|
||||
lastLanServer atomic.Pointer[netip.Addr]
|
||||
publicServer atomic.Pointer[[]string]
|
||||
initializedLanServers atomic.Pointer[[]string]
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -201,13 +182,10 @@ type osResolverResult struct {
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServer.Load()
|
||||
nss := make([]string, 0, 2)
|
||||
if p := o.currentLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
}
|
||||
if p := o.lastLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
publicServers := *o.publicServers.Load()
|
||||
var nss []string
|
||||
if p := o.lanServers.Load(); p != nil {
|
||||
nss = append(nss, (*p)...)
|
||||
}
|
||||
numServers := len(nss) + len(publicServers)
|
||||
if numServers == 0 {
|
||||
@@ -467,17 +445,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
r := &osResolver{}
|
||||
nss := slices.Sorted(slices.Values(nameservers))
|
||||
for i, ns := range nss {
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
|
||||
ip, _, _ := net.SplitHostPort(ns)
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
if isLanAddr(addr) {
|
||||
r.currentLanServer.Store(&addr)
|
||||
nss = slices.Delete(nss, i, i+1)
|
||||
break
|
||||
lanNss = append(lanNss, ns)
|
||||
} else {
|
||||
publicNss = append(publicNss, ns)
|
||||
}
|
||||
}
|
||||
r.publicServer.Store(&nss)
|
||||
r.lanServers.Store(&lanNss)
|
||||
r.publicServers.Store(&publicNss)
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
|
||||
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -74,7 +74,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServer.Store(&ns)
|
||||
resolver.publicServers.Store(&ns)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
@@ -156,38 +156,43 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
|
||||
func Test_initializeOsResolver(t *testing.T) {
|
||||
lanServer1 := "192.168.1.1"
|
||||
lanServer2 := "10.0.10.69"
|
||||
lanServer3 := "192.168.40.1"
|
||||
wanServer := "1.1.1.1"
|
||||
lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")}
|
||||
publicServers := []string{net.JoinHostPort(wanServer, "53")}
|
||||
|
||||
// First initialization.
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
|
||||
// First initialization, initialized servers are saved.
|
||||
initializeOsResolver([]string{lanServer1, lanServer2, wanServer})
|
||||
p := or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
t.Logf("%v - %v", *p, lanServers)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// No new LAN servers, but lanServer2 gone, initialized servers not changed.
|
||||
initializeOsResolver([]string{lanServer1, wanServer})
|
||||
p := or.currentLanServer.Load()
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")}))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// No new LAN server, current LAN server -> last LAN server.
|
||||
initializeOsResolver([]string{lanServer1, wanServer})
|
||||
p = or.currentLanServer.Load()
|
||||
assert.Nil(t, p)
|
||||
p = or.lastLanServer.Load()
|
||||
// New LAN servers, they are used, initialized servers not changed.
|
||||
initializeOsResolver([]string{lanServer3, wanServer})
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")}))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
|
||||
// New LAN server detected.
|
||||
initializeOsResolver([]string{lanServer2, lanServer1, wanServer})
|
||||
p = or.currentLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer2, p.String())
|
||||
p = or.lastLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
|
||||
// No LAN server available.
|
||||
// No LAN server available, initialized servers will be used.
|
||||
initializeOsResolver([]string{wanServer})
|
||||
assert.Nil(t, or.currentLanServer.Load())
|
||||
assert.Nil(t, or.lastLanServer.Load())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
p = or.initializedLanServers.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, slices.Equal(*p, lanServers))
|
||||
assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers))
|
||||
assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user