From c4cf4331a78546f9cff39d535eacd3108d4ab871 Mon Sep 17 00:00:00 2001 From: Codescribe Date: Tue, 3 Mar 2026 13:25:36 -0500 Subject: [PATCH] Fix dnsFromResolvConf not filtering loopback IPs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The continue statement only broke out of the inner loop, so loopback/local IPs (e.g. 127.0.0.1) were never filtered. This caused ctrld to use itself as bootstrap DNS when already installed as the system resolver — a self-referential loop. Use the same isLocal flag pattern as getDNSFromScutil() and getAllDHCPNameservers(). --- nameservers_unix.go | 45 ++++++++++------- nameservers_unix_test.go | 105 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 18 deletions(-) create mode 100644 nameservers_unix_test.go diff --git a/nameservers_unix.go b/nameservers_unix.go index 6022f7a..d813bf4 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -5,12 +5,38 @@ package ctrld import ( "context" "net" + "net/netip" "slices" "time" "tailscale.com/net/netmon" ) +// localNameservers filters a list of nameserver strings, returning only those +// that are not loopback or local machine IP addresses. +func localNameservers(nss []string, regularIPs, loopbackIPs []netip.Addr) []string { + var result []string + seen := make(map[string]bool) + + for _, ns := range nss { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback and local IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + if ip.String() == v.String() { + isLocal = true + break + } + } + if !isLocal && !seen[ip.String()] { + seen[ip.String()] = true + result = append(result, ip.String()) + } + } + } + return result +} + // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. // A nameserver is usable if it's not one of current machine's IP addresses // and loopback IP addresses. @@ -29,24 +55,7 @@ func dnsFromResolvConf(_ context.Context) []string { } nss := CurrentNameserversFromResolvconf() - var localDNS []string - seen := make(map[string]bool) - - for _, ns := range nss { - if ip := net.ParseIP(ns); ip != nil { - // skip loopback IPs - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - ipStr := v.String() - if ip.String() == ipStr { - continue - } - } - if !seen[ip.String()] { - seen[ip.String()] = true - localDNS = append(localDNS, ip.String()) - } - } - } + localDNS := localNameservers(nss, regularIPs, loopbackIPs) // If we successfully read the file and found nameservers, return them if len(localDNS) > 0 { diff --git a/nameservers_unix_test.go b/nameservers_unix_test.go new file mode 100644 index 0000000..a771dc1 --- /dev/null +++ b/nameservers_unix_test.go @@ -0,0 +1,105 @@ +//go:build unix + +package ctrld + +import ( + "net/netip" + "testing" +) + +func Test_localNameservers(t *testing.T) { + loopbackIPs := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("::1"), + } + regularIPs := []netip.Addr{ + netip.MustParseAddr("192.168.1.100"), + netip.MustParseAddr("10.0.0.5"), + } + + tests := []struct { + name string + nss []string + regularIPs []netip.Addr + loopbackIPs []netip.Addr + want []string + }{ + { + name: "filters loopback IPv4", + nss: []string{"127.0.0.1", "8.8.8.8"}, + regularIPs: nil, + loopbackIPs: loopbackIPs, + want: []string{"8.8.8.8"}, + }, + { + name: "filters loopback IPv6", + nss: []string{"::1", "1.1.1.1"}, + regularIPs: nil, + loopbackIPs: loopbackIPs, + want: []string{"1.1.1.1"}, + }, + { + name: "filters local machine IPs", + nss: []string{"192.168.1.100", "8.8.4.4"}, + regularIPs: regularIPs, + loopbackIPs: nil, + want: []string{"8.8.4.4"}, + }, + { + name: "filters both loopback and local IPs", + nss: []string{"127.0.0.1", "192.168.1.100", "8.8.8.8"}, + regularIPs: regularIPs, + loopbackIPs: loopbackIPs, + want: []string{"8.8.8.8"}, + }, + { + name: "deduplicates results", + nss: []string{"8.8.8.8", "8.8.8.8", "1.1.1.1"}, + regularIPs: regularIPs, + loopbackIPs: loopbackIPs, + want: []string{"8.8.8.8", "1.1.1.1"}, + }, + { + name: "all filtered returns nil", + nss: []string{"127.0.0.1", "::1", "192.168.1.100"}, + regularIPs: regularIPs, + loopbackIPs: loopbackIPs, + want: nil, + }, + { + name: "empty input returns nil", + nss: nil, + regularIPs: regularIPs, + loopbackIPs: loopbackIPs, + want: nil, + }, + { + name: "skips unparseable entries", + nss: []string{"not-an-ip", "8.8.8.8"}, + regularIPs: regularIPs, + loopbackIPs: loopbackIPs, + want: []string{"8.8.8.8"}, + }, + { + name: "no local IPs filters nothing", + nss: []string{"8.8.8.8", "1.1.1.1"}, + regularIPs: nil, + loopbackIPs: nil, + want: []string{"8.8.8.8", "1.1.1.1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := localNameservers(tt.nss, tt.regularIPs, tt.loopbackIPs) + if len(got) != len(tt.want) { + t.Fatalf("localNameservers() = %v, want %v", got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("localNameservers()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +}