From 0cdff0d368b005e415f64da246bf0edf18365e99 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 18 Oct 2024 01:31:40 +0700 Subject: [PATCH] Prefer LAN server answer over public one While at it, also implementing new OS resolver chosing logic, keeping only 2 LAN servers at any time, 1 for current one, and 1 for last used one. --- config_internal_test.go | 2 +- nameservers.go | 5 +- resolver.go | 153 +++++++++++++++++++++++++++++++--------- resolver_test.go | 4 +- 4 files changed, 124 insertions(+), 40 deletions(-) diff --git a/config_internal_test.go b/config_internal_test.go index 7b09da3..6823686 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -17,7 +17,7 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { uc.Init() uc.setupBootstrapIP(false) if len(uc.bootstrapIPs) == 0 { - t.Log(nameservers()) + t.Log(defaultNameservers()) t.Fatal("could not bootstrap ip without bootstrap DNS") } t.Log(uc) diff --git a/nameservers.go b/nameservers.go index ce99a3b..0aebf9e 100644 --- a/nameservers.go +++ b/nameservers.go @@ -1,9 +1,8 @@ package ctrld -import "net" - type dnsFn func() []string +// nameservers returns DNS nameservers from system settings. func nameservers() []string { var dns []string seen := make(map[string]bool) @@ -21,7 +20,7 @@ func nameservers() []string { continue } seen[ns] = true - dns = append(dns, net.JoinHostPort(ns, "53")) + dns = append(dns, ns) } } diff --git a/resolver.go b/resolver.go index 1e896c1..b38504c 100644 --- a/resolver.go +++ b/resolver.go @@ -12,9 +12,9 @@ import ( "sync/atomic" "time" - "tailscale.com/net/netmon" - "github.com/miekg/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/tsaddr" ) const ( @@ -47,10 +47,34 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. var or = newResolverWithNameserver(defaultNameservers()) -// defaultNameservers returns OS nameservers plus ControlD public DNS. +// defaultNameservers is like nameservers with each element formed "ip:53". func defaultNameservers() []string { ns := nameservers() - return ns + nss := make([]string, len(ns)) + for i := range ns { + nss[i] = net.JoinHostPort(ns[i], "53") + } + return nss +} + +// availableNameservers returns list of current available DNS servers of the system. +func availableNameservers() []string { + var nss []string + // Ignore local addresses to prevent loop. + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + machineIPsMap := make(map[string]struct{}, len(regularIPs)) + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + machineIPsMap[v.String()] = struct{}{} + } + for _, ns := range nameservers() { + if _, ok := machineIPsMap[ns]; ok { + continue + } + if testNameserver(ns) { + nss = append(nss, ns) + } + } + return nss } // InitializeOsResolver initializes OS resolver using the current system DNS settings. @@ -59,23 +83,39 @@ func defaultNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - var nss []string - // Ignore local addresses to prevent loop. - regularIPs, loopbackIPs, _ := netmon.LocalAddresses() - machineIPsMap := make(map[string]struct{}, len(regularIPs)) - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - machineIPsMap[net.JoinHostPort(v.String(), "53")] = struct{}{} + var ( + nss []string + publicNss []string + ) + var curLanServer netip.Addr + if p := or.currentLanServer.Load(); p != nil { + curLanServer = *p + or.currentLanServer.Store(nil) } - for _, ns := range defaultNameservers() { - if _, ok := machineIPsMap[ns]; ok { + for _, ns := range availableNameservers() { + addr, err := netip.ParseAddr(ns) + if err != nil { continue } - if testNameserver(ns) { - nss = append(nss, ns) + server := net.JoinHostPort(ns, "53") + if isLanAddr(addr) { + if addr.Compare(curLanServer) != 0 && or.currentLanServer.CompareAndSwap(nil, &addr) { + nss = append(nss, server) + } + } else { + publicNss = append(publicNss, server) + nss = append(nss, server) } } - nss = append(nss, controldPublicDnsWithPort) - or.nameservers.Store(&nss) + if curLanServer.IsValid() { + or.lastLanServer.Store(&curLanServer) + nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) + } + if len(publicNss) == 0 { + publicNss = append(publicNss, controldPublicDnsWithPort) + nss = append(nss, controldPublicDnsWithPort) + } + or.publicServer.Store(&publicNss) return nss } @@ -86,7 +126,7 @@ func testNameserver(addr string) bool { client := new(dns.Client) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, _, err := client.ExchangeContext(ctx, msg, addr) + _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53")) if err != nil { ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr) } @@ -123,21 +163,31 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - nameservers atomic.Pointer[[]string] + currentLanServer atomic.Pointer[netip.Addr] + lastLanServer atomic.Pointer[netip.Addr] + publicServer atomic.Pointer[[]string] } type osResolverResult struct { answer *dns.Msg err error server string + lan bool } // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - nss := *o.nameservers.Load() - numServers := len(nss) + publicServers := *o.publicServer.Load() + nss := make([]string, 0, 2) + if p := o.currentLanServer.Load(); p != nil { + nss = append(nss, net.JoinHostPort(p.String(), "53")) + } + if p := o.lastLanServer.Load(); p != nil { + nss = append(nss, net.JoinHostPort(p.String(), "53")) + } + numServers := len(nss) + len(publicServers) if numServers == 0 { return nil, errors.New("no nameservers available") } @@ -146,19 +196,24 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error dnsClient := &dns.Client{Net: "udp"} ch := make(chan *osResolverResult, numServers) - var wg sync.WaitGroup - wg.Add(len(nss)) + wg := &sync.WaitGroup{} + wg.Add(numServers) go func() { wg.Wait() close(ch) }() - for _, server := range nss { - go func(server string) { - defer wg.Done() - answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) - ch <- &osResolverResult{answer: answer, err: err, server: server} - }(server) + + do := func(servers []string, isLan bool) { + for _, server := range servers { + go func(server string) { + defer wg.Done() + answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) + ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan} + }(server) + } } + do(nss, true) + do(publicServers, false) logAnswer := func(server string) { if before, _, found := strings.Cut(server, ":"); found { @@ -170,14 +225,20 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error nonSuccessAnswer *dns.Msg nonSuccessServer string controldSuccessAnswer *dns.Msg + publicServerAnswer *dns.Msg + publicServer string ) errs := make([]error, 0, numServers) for res := range ch { switch { case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: - if res.server == controldPublicDnsWithPort { + switch { + case res.server == controldPublicDnsWithPort: controldSuccessAnswer = res.answer // only use ControlD answer as last one. - } else { + case !res.lan && publicServerAnswer == nil: + publicServerAnswer = res.answer // use public DNS answer after LAN server.. + publicServer = res.server + default: cancel() logAnswer(res.server) return res.answer, nil @@ -188,6 +249,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } errs = append(errs, res.err) } + if publicServerAnswer != nil { + logAnswer(publicServer) + return publicServerAnswer, nil + } if controldSuccessAnswer != nil { logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil @@ -241,7 +306,7 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - nss := nameservers() + nss := defaultNameservers() if withBootstrapDNS { nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) } @@ -319,7 +384,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Gateway IP address (depends on OS). // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { - nss := nameservers() + nss := defaultNameservers() nss = append([]string{controldPublicDnsWithPort}, nss...) for _, ns := range servers { nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) @@ -335,7 +400,7 @@ func NewBootstrapResolver(servers ...string) Resolver { // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { - nss := nameservers() + nss := defaultNameservers() resolveConfNss := nameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 @@ -376,9 +441,21 @@ func NewResolverWithNameserver(nameservers []string) Resolver { return newResolverWithNameserver(nameservers) } +// newResolverWithNameserver returns an OS resolver from given nameservers list. +// The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { r := &osResolver{} - r.nameservers.Store(&nameservers) + nss := slices.Sorted(slices.Values(nameservers)) + for i, ns := range nss { + ip, _, _ := net.SplitHostPort(ns) + addr, _ := netip.ParseAddr(ip) + if isLanAddr(addr) { + r.currentLanServer.Store(&addr) + nss = slices.Delete(nss, i, i+1) + break + } + } + r.publicServer.Store(&nss) return r } @@ -409,3 +486,11 @@ func newDialer(dnsAddress string) *net.Dialer { }, } } + +// isLanAddr reports whether addr is considered a LAN ip address. +func isLanAddr(addr netip.Addr) bool { + return addr.IsPrivate() || + addr.IsLoopback() || + addr.IsLinkLocalUnicast() || + tsaddr.CGNATRange().Contains(addr) +} diff --git a/resolver_test.go b/resolver_test.go index 9d1cb34..44b170a 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -17,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() resolver := &osResolver{} - resolver.nameservers.Store(&[]string{"127.0.0.127:5353"}) + resolver.publicServer.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -71,7 +71,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } }() resolver := &osResolver{} - resolver.nameservers.Store(&ns) + resolver.publicServer.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg)