Prefer LAN server answer over public one

While at it, also implementing new OS resolver chosing logic, keeping
only 2 LAN servers at any time, 1 for current one, and 1 for last used
one.
This commit is contained in:
Cuong Manh Le
2024-10-18 01:31:40 +07:00
committed by Cuong Manh Le
parent f87220a908
commit 0cdff0d368
4 changed files with 124 additions and 40 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
uc.Init() uc.Init()
uc.setupBootstrapIP(false) uc.setupBootstrapIP(false)
if len(uc.bootstrapIPs) == 0 { if len(uc.bootstrapIPs) == 0 {
t.Log(nameservers()) t.Log(defaultNameservers())
t.Fatal("could not bootstrap ip without bootstrap DNS") t.Fatal("could not bootstrap ip without bootstrap DNS")
} }
t.Log(uc) t.Log(uc)
+2 -3
View File
@@ -1,9 +1,8 @@
package ctrld package ctrld
import "net"
type dnsFn func() []string type dnsFn func() []string
// nameservers returns DNS nameservers from system settings.
func nameservers() []string { func nameservers() []string {
var dns []string var dns []string
seen := make(map[string]bool) seen := make(map[string]bool)
@@ -21,7 +20,7 @@ func nameservers() []string {
continue continue
} }
seen[ns] = true seen[ns] = true
dns = append(dns, net.JoinHostPort(ns, "53")) dns = append(dns, ns)
} }
} }
+119 -34
View File
@@ -12,9 +12,9 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"tailscale.com/net/netmon"
"github.com/miekg/dns" "github.com/miekg/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
) )
const ( const (
@@ -47,10 +47,34 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS. // or is the Resolver used for ResolverTypeOS.
var or = newResolverWithNameserver(defaultNameservers()) var or = newResolverWithNameserver(defaultNameservers())
// defaultNameservers returns OS nameservers plus ControlD public DNS. // defaultNameservers is like nameservers with each element formed "ip:53".
func defaultNameservers() []string { func defaultNameservers() []string {
ns := nameservers() ns := nameservers()
return ns nss := make([]string, len(ns))
for i := range ns {
nss[i] = net.JoinHostPort(ns[i], "53")
}
return nss
}
// availableNameservers returns list of current available DNS servers of the system.
func availableNameservers() []string {
var nss []string
// Ignore local addresses to prevent loop.
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
machineIPsMap := make(map[string]struct{}, len(regularIPs))
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
machineIPsMap[v.String()] = struct{}{}
}
for _, ns := range nameservers() {
if _, ok := machineIPsMap[ns]; ok {
continue
}
if testNameserver(ns) {
nss = append(nss, ns)
}
}
return nss
} }
// InitializeOsResolver initializes OS resolver using the current system DNS settings. // InitializeOsResolver initializes OS resolver using the current system DNS settings.
@@ -59,23 +83,39 @@ func defaultNameservers() []string {
// It's the caller's responsibility to ensure the system DNS is in a clean state before // It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function. // calling this function.
func InitializeOsResolver() []string { func InitializeOsResolver() []string {
var nss []string var (
// Ignore local addresses to prevent loop. nss []string
regularIPs, loopbackIPs, _ := netmon.LocalAddresses() publicNss []string
machineIPsMap := make(map[string]struct{}, len(regularIPs)) )
for _, v := range slices.Concat(regularIPs, loopbackIPs) { var curLanServer netip.Addr
machineIPsMap[net.JoinHostPort(v.String(), "53")] = struct{}{} if p := or.currentLanServer.Load(); p != nil {
curLanServer = *p
or.currentLanServer.Store(nil)
} }
for _, ns := range defaultNameservers() { for _, ns := range availableNameservers() {
if _, ok := machineIPsMap[ns]; ok { addr, err := netip.ParseAddr(ns)
if err != nil {
continue continue
} }
if testNameserver(ns) { server := net.JoinHostPort(ns, "53")
nss = append(nss, ns) if isLanAddr(addr) {
if addr.Compare(curLanServer) != 0 && or.currentLanServer.CompareAndSwap(nil, &addr) {
nss = append(nss, server)
}
} else {
publicNss = append(publicNss, server)
nss = append(nss, server)
} }
} }
nss = append(nss, controldPublicDnsWithPort) if curLanServer.IsValid() {
or.nameservers.Store(&nss) or.lastLanServer.Store(&curLanServer)
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
}
if len(publicNss) == 0 {
publicNss = append(publicNss, controldPublicDnsWithPort)
nss = append(nss, controldPublicDnsWithPort)
}
or.publicServer.Store(&publicNss)
return nss return nss
} }
@@ -86,7 +126,7 @@ func testNameserver(addr string) bool {
client := new(dns.Client) client := new(dns.Client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
_, _, err := client.ExchangeContext(ctx, msg, addr) _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
if err != nil { if err != nil {
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr) ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
} }
@@ -123,21 +163,31 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
} }
type osResolver struct { type osResolver struct {
nameservers atomic.Pointer[[]string] currentLanServer atomic.Pointer[netip.Addr]
lastLanServer atomic.Pointer[netip.Addr]
publicServer atomic.Pointer[[]string]
} }
type osResolverResult struct { type osResolverResult struct {
answer *dns.Msg answer *dns.Msg
err error err error
server string server string
lan bool
} }
// Resolve resolves DNS queries using pre-configured nameservers. // Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first // Query is sent to all nameservers concurrently, and the first
// success response will be returned. // success response will be returned.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
nss := *o.nameservers.Load() publicServers := *o.publicServer.Load()
numServers := len(nss) nss := make([]string, 0, 2)
if p := o.currentLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
}
if p := o.lastLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
}
numServers := len(nss) + len(publicServers)
if numServers == 0 { if numServers == 0 {
return nil, errors.New("no nameservers available") return nil, errors.New("no nameservers available")
} }
@@ -146,19 +196,24 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
dnsClient := &dns.Client{Net: "udp"} dnsClient := &dns.Client{Net: "udp"}
ch := make(chan *osResolverResult, numServers) ch := make(chan *osResolverResult, numServers)
var wg sync.WaitGroup wg := &sync.WaitGroup{}
wg.Add(len(nss)) wg.Add(numServers)
go func() { go func() {
wg.Wait() wg.Wait()
close(ch) close(ch)
}() }()
for _, server := range nss {
go func(server string) { do := func(servers []string, isLan bool) {
defer wg.Done() for _, server := range servers {
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) go func(server string) {
ch <- &osResolverResult{answer: answer, err: err, server: server} defer wg.Done()
}(server) answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
}(server)
}
} }
do(nss, true)
do(publicServers, false)
logAnswer := func(server string) { logAnswer := func(server string) {
if before, _, found := strings.Cut(server, ":"); found { if before, _, found := strings.Cut(server, ":"); found {
@@ -170,14 +225,20 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
nonSuccessAnswer *dns.Msg nonSuccessAnswer *dns.Msg
nonSuccessServer string nonSuccessServer string
controldSuccessAnswer *dns.Msg controldSuccessAnswer *dns.Msg
publicServerAnswer *dns.Msg
publicServer string
) )
errs := make([]error, 0, numServers) errs := make([]error, 0, numServers)
for res := range ch { for res := range ch {
switch { switch {
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
if res.server == controldPublicDnsWithPort { switch {
case res.server == controldPublicDnsWithPort:
controldSuccessAnswer = res.answer // only use ControlD answer as last one. controldSuccessAnswer = res.answer // only use ControlD answer as last one.
} else { case !res.lan && publicServerAnswer == nil:
publicServerAnswer = res.answer // use public DNS answer after LAN server..
publicServer = res.server
default:
cancel() cancel()
logAnswer(res.server) logAnswer(res.server)
return res.answer, nil return res.answer, nil
@@ -188,6 +249,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
} }
errs = append(errs, res.err) errs = append(errs, res.err)
} }
if publicServerAnswer != nil {
logAnswer(publicServer)
return publicServerAnswer, nil
}
if controldSuccessAnswer != nil { if controldSuccessAnswer != nil {
logAnswer(controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort)
return controldSuccessAnswer, nil return controldSuccessAnswer, nil
@@ -241,7 +306,7 @@ func LookupIP(domain string) []string {
} }
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
nss := nameservers() nss := defaultNameservers()
if withBootstrapDNS { if withBootstrapDNS {
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
} }
@@ -319,7 +384,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
// - Gateway IP address (depends on OS). // - Gateway IP address (depends on OS).
// - Input servers. // - Input servers.
func NewBootstrapResolver(servers ...string) Resolver { func NewBootstrapResolver(servers ...string) Resolver {
nss := nameservers() nss := defaultNameservers()
nss = append([]string{controldPublicDnsWithPort}, nss...) nss = append([]string{controldPublicDnsWithPort}, nss...)
for _, ns := range servers { for _, ns := range servers {
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
@@ -335,7 +400,7 @@ func NewBootstrapResolver(servers ...string) Resolver {
// //
// This is useful for doing PTR lookup in LAN network. // This is useful for doing PTR lookup in LAN network.
func NewPrivateResolver() Resolver { func NewPrivateResolver() Resolver {
nss := nameservers() nss := defaultNameservers()
resolveConfNss := nameserversFromResolvconf() resolveConfNss := nameserversFromResolvconf()
localRfc1918Addrs := Rfc1918Addresses() localRfc1918Addrs := Rfc1918Addresses()
n := 0 n := 0
@@ -376,9 +441,21 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
return newResolverWithNameserver(nameservers) return newResolverWithNameserver(nameservers)
} }
// newResolverWithNameserver returns an OS resolver from given nameservers list.
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver { func newResolverWithNameserver(nameservers []string) *osResolver {
r := &osResolver{} r := &osResolver{}
r.nameservers.Store(&nameservers) nss := slices.Sorted(slices.Values(nameservers))
for i, ns := range nss {
ip, _, _ := net.SplitHostPort(ns)
addr, _ := netip.ParseAddr(ip)
if isLanAddr(addr) {
r.currentLanServer.Store(&addr)
nss = slices.Delete(nss, i, i+1)
break
}
}
r.publicServer.Store(&nss)
return r return r
} }
@@ -409,3 +486,11 @@ func newDialer(dnsAddress string) *net.Dialer {
}, },
} }
} }
// isLanAddr reports whether addr is considered a LAN ip address.
func isLanAddr(addr netip.Addr) bool {
return addr.IsPrivate() ||
addr.IsLoopback() ||
addr.IsLinkLocalUnicast() ||
tsaddr.CGNATRange().Contains(addr)
}
+2 -2
View File
@@ -17,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) {
go func() { go func() {
defer cancel() defer cancel()
resolver := &osResolver{} resolver := &osResolver{}
resolver.nameservers.Store(&[]string{"127.0.0.127:5353"}) resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA) m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true m.RecursionDesired = true
@@ -71,7 +71,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
} }
}() }()
resolver := &osResolver{} resolver := &osResolver{}
resolver.nameservers.Store(&ns) resolver.publicServer.Store(&ns)
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS) msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg) answer, err := resolver.Resolve(context.Background(), msg)