mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-07-04 01:07:49 +02:00
Prefer LAN server answer over public one
While at it, also implementing new OS resolver chosing logic, keeping only 2 LAN servers at any time, 1 for current one, and 1 for last used one.
This commit is contained in:
committed by
Cuong Manh Le
parent
f87220a908
commit
0cdff0d368
@@ -17,7 +17,7 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
|||||||
uc.Init()
|
uc.Init()
|
||||||
uc.setupBootstrapIP(false)
|
uc.setupBootstrapIP(false)
|
||||||
if len(uc.bootstrapIPs) == 0 {
|
if len(uc.bootstrapIPs) == 0 {
|
||||||
t.Log(nameservers())
|
t.Log(defaultNameservers())
|
||||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||||
}
|
}
|
||||||
t.Log(uc)
|
t.Log(uc)
|
||||||
|
|||||||
+2
-3
@@ -1,9 +1,8 @@
|
|||||||
package ctrld
|
package ctrld
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
type dnsFn func() []string
|
type dnsFn func() []string
|
||||||
|
|
||||||
|
// nameservers returns DNS nameservers from system settings.
|
||||||
func nameservers() []string {
|
func nameservers() []string {
|
||||||
var dns []string
|
var dns []string
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
@@ -21,7 +20,7 @@ func nameservers() []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[ns] = true
|
seen[ns] = true
|
||||||
dns = append(dns, net.JoinHostPort(ns, "53"))
|
dns = append(dns, ns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+119
-34
@@ -12,9 +12,9 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/net/netmon"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"tailscale.com/net/netmon"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -47,10 +47,34 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
|||||||
// or is the Resolver used for ResolverTypeOS.
|
// or is the Resolver used for ResolverTypeOS.
|
||||||
var or = newResolverWithNameserver(defaultNameservers())
|
var or = newResolverWithNameserver(defaultNameservers())
|
||||||
|
|
||||||
// defaultNameservers returns OS nameservers plus ControlD public DNS.
|
// defaultNameservers is like nameservers with each element formed "ip:53".
|
||||||
func defaultNameservers() []string {
|
func defaultNameservers() []string {
|
||||||
ns := nameservers()
|
ns := nameservers()
|
||||||
return ns
|
nss := make([]string, len(ns))
|
||||||
|
for i := range ns {
|
||||||
|
nss[i] = net.JoinHostPort(ns[i], "53")
|
||||||
|
}
|
||||||
|
return nss
|
||||||
|
}
|
||||||
|
|
||||||
|
// availableNameservers returns list of current available DNS servers of the system.
|
||||||
|
func availableNameservers() []string {
|
||||||
|
var nss []string
|
||||||
|
// Ignore local addresses to prevent loop.
|
||||||
|
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||||
|
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||||
|
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||||
|
machineIPsMap[v.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, ns := range nameservers() {
|
||||||
|
if _, ok := machineIPsMap[ns]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if testNameserver(ns) {
|
||||||
|
nss = append(nss, ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nss
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitializeOsResolver initializes OS resolver using the current system DNS settings.
|
// InitializeOsResolver initializes OS resolver using the current system DNS settings.
|
||||||
@@ -59,23 +83,39 @@ func defaultNameservers() []string {
|
|||||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||||
// calling this function.
|
// calling this function.
|
||||||
func InitializeOsResolver() []string {
|
func InitializeOsResolver() []string {
|
||||||
var nss []string
|
var (
|
||||||
// Ignore local addresses to prevent loop.
|
nss []string
|
||||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
publicNss []string
|
||||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
)
|
||||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
var curLanServer netip.Addr
|
||||||
machineIPsMap[net.JoinHostPort(v.String(), "53")] = struct{}{}
|
if p := or.currentLanServer.Load(); p != nil {
|
||||||
|
curLanServer = *p
|
||||||
|
or.currentLanServer.Store(nil)
|
||||||
}
|
}
|
||||||
for _, ns := range defaultNameservers() {
|
for _, ns := range availableNameservers() {
|
||||||
if _, ok := machineIPsMap[ns]; ok {
|
addr, err := netip.ParseAddr(ns)
|
||||||
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if testNameserver(ns) {
|
server := net.JoinHostPort(ns, "53")
|
||||||
nss = append(nss, ns)
|
if isLanAddr(addr) {
|
||||||
|
if addr.Compare(curLanServer) != 0 && or.currentLanServer.CompareAndSwap(nil, &addr) {
|
||||||
|
nss = append(nss, server)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
publicNss = append(publicNss, server)
|
||||||
|
nss = append(nss, server)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
nss = append(nss, controldPublicDnsWithPort)
|
if curLanServer.IsValid() {
|
||||||
or.nameservers.Store(&nss)
|
or.lastLanServer.Store(&curLanServer)
|
||||||
|
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
|
||||||
|
}
|
||||||
|
if len(publicNss) == 0 {
|
||||||
|
publicNss = append(publicNss, controldPublicDnsWithPort)
|
||||||
|
nss = append(nss, controldPublicDnsWithPort)
|
||||||
|
}
|
||||||
|
or.publicServer.Store(&publicNss)
|
||||||
return nss
|
return nss
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +126,7 @@ func testNameserver(addr string) bool {
|
|||||||
client := new(dns.Client)
|
client := new(dns.Client)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, _, err := client.ExchangeContext(ctx, msg, addr)
|
_, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
|
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
|
||||||
}
|
}
|
||||||
@@ -123,21 +163,31 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type osResolver struct {
|
type osResolver struct {
|
||||||
nameservers atomic.Pointer[[]string]
|
currentLanServer atomic.Pointer[netip.Addr]
|
||||||
|
lastLanServer atomic.Pointer[netip.Addr]
|
||||||
|
publicServer atomic.Pointer[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
type osResolverResult struct {
|
type osResolverResult struct {
|
||||||
answer *dns.Msg
|
answer *dns.Msg
|
||||||
err error
|
err error
|
||||||
server string
|
server string
|
||||||
|
lan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||||
// Query is sent to all nameservers concurrently, and the first
|
// Query is sent to all nameservers concurrently, and the first
|
||||||
// success response will be returned.
|
// success response will be returned.
|
||||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||||
nss := *o.nameservers.Load()
|
publicServers := *o.publicServer.Load()
|
||||||
numServers := len(nss)
|
nss := make([]string, 0, 2)
|
||||||
|
if p := o.currentLanServer.Load(); p != nil {
|
||||||
|
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||||
|
}
|
||||||
|
if p := o.lastLanServer.Load(); p != nil {
|
||||||
|
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||||
|
}
|
||||||
|
numServers := len(nss) + len(publicServers)
|
||||||
if numServers == 0 {
|
if numServers == 0 {
|
||||||
return nil, errors.New("no nameservers available")
|
return nil, errors.New("no nameservers available")
|
||||||
}
|
}
|
||||||
@@ -146,19 +196,24 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
|||||||
|
|
||||||
dnsClient := &dns.Client{Net: "udp"}
|
dnsClient := &dns.Client{Net: "udp"}
|
||||||
ch := make(chan *osResolverResult, numServers)
|
ch := make(chan *osResolverResult, numServers)
|
||||||
var wg sync.WaitGroup
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(len(nss))
|
wg.Add(numServers)
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(ch)
|
close(ch)
|
||||||
}()
|
}()
|
||||||
for _, server := range nss {
|
|
||||||
go func(server string) {
|
do := func(servers []string, isLan bool) {
|
||||||
defer wg.Done()
|
for _, server := range servers {
|
||||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
go func(server string) {
|
||||||
ch <- &osResolverResult{answer: answer, err: err, server: server}
|
defer wg.Done()
|
||||||
}(server)
|
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||||
|
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
|
||||||
|
}(server)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
do(nss, true)
|
||||||
|
do(publicServers, false)
|
||||||
|
|
||||||
logAnswer := func(server string) {
|
logAnswer := func(server string) {
|
||||||
if before, _, found := strings.Cut(server, ":"); found {
|
if before, _, found := strings.Cut(server, ":"); found {
|
||||||
@@ -170,14 +225,20 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
|||||||
nonSuccessAnswer *dns.Msg
|
nonSuccessAnswer *dns.Msg
|
||||||
nonSuccessServer string
|
nonSuccessServer string
|
||||||
controldSuccessAnswer *dns.Msg
|
controldSuccessAnswer *dns.Msg
|
||||||
|
publicServerAnswer *dns.Msg
|
||||||
|
publicServer string
|
||||||
)
|
)
|
||||||
errs := make([]error, 0, numServers)
|
errs := make([]error, 0, numServers)
|
||||||
for res := range ch {
|
for res := range ch {
|
||||||
switch {
|
switch {
|
||||||
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||||
if res.server == controldPublicDnsWithPort {
|
switch {
|
||||||
|
case res.server == controldPublicDnsWithPort:
|
||||||
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
||||||
} else {
|
case !res.lan && publicServerAnswer == nil:
|
||||||
|
publicServerAnswer = res.answer // use public DNS answer after LAN server..
|
||||||
|
publicServer = res.server
|
||||||
|
default:
|
||||||
cancel()
|
cancel()
|
||||||
logAnswer(res.server)
|
logAnswer(res.server)
|
||||||
return res.answer, nil
|
return res.answer, nil
|
||||||
@@ -188,6 +249,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
|||||||
}
|
}
|
||||||
errs = append(errs, res.err)
|
errs = append(errs, res.err)
|
||||||
}
|
}
|
||||||
|
if publicServerAnswer != nil {
|
||||||
|
logAnswer(publicServer)
|
||||||
|
return publicServerAnswer, nil
|
||||||
|
}
|
||||||
if controldSuccessAnswer != nil {
|
if controldSuccessAnswer != nil {
|
||||||
logAnswer(controldPublicDnsWithPort)
|
logAnswer(controldPublicDnsWithPort)
|
||||||
return controldSuccessAnswer, nil
|
return controldSuccessAnswer, nil
|
||||||
@@ -241,7 +306,7 @@ func LookupIP(domain string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||||
nss := nameservers()
|
nss := defaultNameservers()
|
||||||
if withBootstrapDNS {
|
if withBootstrapDNS {
|
||||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||||
}
|
}
|
||||||
@@ -319,7 +384,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
|||||||
// - Gateway IP address (depends on OS).
|
// - Gateway IP address (depends on OS).
|
||||||
// - Input servers.
|
// - Input servers.
|
||||||
func NewBootstrapResolver(servers ...string) Resolver {
|
func NewBootstrapResolver(servers ...string) Resolver {
|
||||||
nss := nameservers()
|
nss := defaultNameservers()
|
||||||
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
||||||
for _, ns := range servers {
|
for _, ns := range servers {
|
||||||
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
|
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
|
||||||
@@ -335,7 +400,7 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
|||||||
//
|
//
|
||||||
// This is useful for doing PTR lookup in LAN network.
|
// This is useful for doing PTR lookup in LAN network.
|
||||||
func NewPrivateResolver() Resolver {
|
func NewPrivateResolver() Resolver {
|
||||||
nss := nameservers()
|
nss := defaultNameservers()
|
||||||
resolveConfNss := nameserversFromResolvconf()
|
resolveConfNss := nameserversFromResolvconf()
|
||||||
localRfc1918Addrs := Rfc1918Addresses()
|
localRfc1918Addrs := Rfc1918Addresses()
|
||||||
n := 0
|
n := 0
|
||||||
@@ -376,9 +441,21 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
|||||||
return newResolverWithNameserver(nameservers)
|
return newResolverWithNameserver(nameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newResolverWithNameserver returns an OS resolver from given nameservers list.
|
||||||
|
// The caller must ensure each server in list is formed "ip:53".
|
||||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||||
r := &osResolver{}
|
r := &osResolver{}
|
||||||
r.nameservers.Store(&nameservers)
|
nss := slices.Sorted(slices.Values(nameservers))
|
||||||
|
for i, ns := range nss {
|
||||||
|
ip, _, _ := net.SplitHostPort(ns)
|
||||||
|
addr, _ := netip.ParseAddr(ip)
|
||||||
|
if isLanAddr(addr) {
|
||||||
|
r.currentLanServer.Store(&addr)
|
||||||
|
nss = slices.Delete(nss, i, i+1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publicServer.Store(&nss)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,3 +486,11 @@ func newDialer(dnsAddress string) *net.Dialer {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLanAddr reports whether addr is considered a LAN ip address.
|
||||||
|
func isLanAddr(addr netip.Addr) bool {
|
||||||
|
return addr.IsPrivate() ||
|
||||||
|
addr.IsLoopback() ||
|
||||||
|
addr.IsLinkLocalUnicast() ||
|
||||||
|
tsaddr.CGNATRange().Contains(addr)
|
||||||
|
}
|
||||||
|
|||||||
+2
-2
@@ -17,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
resolver := &osResolver{}
|
resolver := &osResolver{}
|
||||||
resolver.nameservers.Store(&[]string{"127.0.0.127:5353"})
|
resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("controld.com.", dns.TypeA)
|
m.SetQuestion("controld.com.", dns.TypeA)
|
||||||
m.RecursionDesired = true
|
m.RecursionDesired = true
|
||||||
@@ -71,7 +71,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
resolver := &osResolver{}
|
resolver := &osResolver{}
|
||||||
resolver.nameservers.Store(&ns)
|
resolver.publicServer.Store(&ns)
|
||||||
msg := new(dns.Msg)
|
msg := new(dns.Msg)
|
||||||
msg.SetQuestion(".", dns.TypeNS)
|
msg.SetQuestion(".", dns.TypeNS)
|
||||||
answer, err := resolver.Resolve(context.Background(), msg)
|
answer, err := resolver.Resolve(context.Background(), msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user