mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-13 10:26:06 +00:00
Fix dnsFromResolvConf not filtering loopback IPs
The continue statement only broke out of the inner loop, so loopback/local IPs (e.g. 127.0.0.1) were never filtered. This caused ctrld to use itself as bootstrap DNS when already installed as the system resolver — a self-referential loop. Use the same isLocal flag pattern as getDNSFromScutil() and getAllDHCPNameservers().
This commit is contained in:
committed by
Cuong Manh Le
parent
f44169c8b2
commit
c4cf4331a7
@@ -5,12 +5,38 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
// localNameservers filters a list of nameserver strings, returning only those
|
||||
// that are not loopback or local machine IP addresses.
|
||||
func localNameservers(nss []string, regularIPs, loopbackIPs []netip.Addr) []string {
|
||||
var result []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback and local IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
if ip.String() == v.String() {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
result = append(result, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
|
||||
// A nameserver is usable if it's not one of current machine's IP addresses
|
||||
// and loopback IP addresses.
|
||||
@@ -29,24 +55,7 @@ func dnsFromResolvConf(_ context.Context) []string {
|
||||
}
|
||||
|
||||
nss := CurrentNameserversFromResolvconf()
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
localDNS := localNameservers(nss, regularIPs, loopbackIPs)
|
||||
|
||||
// If we successfully read the file and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
|
||||
105
nameservers_unix_test.go
Normal file
105
nameservers_unix_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build unix
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_localNameservers(t *testing.T) {
|
||||
loopbackIPs := []netip.Addr{
|
||||
netip.MustParseAddr("127.0.0.1"),
|
||||
netip.MustParseAddr("::1"),
|
||||
}
|
||||
regularIPs := []netip.Addr{
|
||||
netip.MustParseAddr("192.168.1.100"),
|
||||
netip.MustParseAddr("10.0.0.5"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nss []string
|
||||
regularIPs []netip.Addr
|
||||
loopbackIPs []netip.Addr
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "filters loopback IPv4",
|
||||
nss: []string{"127.0.0.1", "8.8.8.8"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "filters loopback IPv6",
|
||||
nss: []string{"::1", "1.1.1.1"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"1.1.1.1"},
|
||||
},
|
||||
{
|
||||
name: "filters local machine IPs",
|
||||
nss: []string{"192.168.1.100", "8.8.4.4"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: nil,
|
||||
want: []string{"8.8.4.4"},
|
||||
},
|
||||
{
|
||||
name: "filters both loopback and local IPs",
|
||||
nss: []string{"127.0.0.1", "192.168.1.100", "8.8.8.8"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "deduplicates results",
|
||||
nss: []string{"8.8.8.8", "8.8.8.8", "1.1.1.1"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8", "1.1.1.1"},
|
||||
},
|
||||
{
|
||||
name: "all filtered returns nil",
|
||||
nss: []string{"127.0.0.1", "::1", "192.168.1.100"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "empty input returns nil",
|
||||
nss: nil,
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "skips unparseable entries",
|
||||
nss: []string{"not-an-ip", "8.8.8.8"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "no local IPs filters nothing",
|
||||
nss: []string{"8.8.8.8", "1.1.1.1"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: nil,
|
||||
want: []string{"8.8.8.8", "1.1.1.1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := localNameservers(tt.nss, tt.regularIPs, tt.loopbackIPs)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("localNameservers() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("localNameservers()[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user