diff --git a/nameservers_bsd.go b/nameservers_bsd.go index b835060..09c9516 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -10,7 +10,7 @@ import ( ) func dnsFns() []dnsFn { - return []dnsFn{dnsFromRIB} + return []dnsFn{dnsFromResolvConf, dnsFromRIB} } func dnsFromRIB() []string { diff --git a/nameservers_darwin.go b/nameservers_darwin.go index b6b1543..1bf4574 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -16,58 +16,12 @@ import ( "time" "tailscale.com/net/netmon" - - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} } -// dnsFromResolvConf reads nameservers from /etc/resolv.conf -func dnsFromResolvConf() []string { - const ( - maxRetries = 10 - retryInterval = 100 * time.Millisecond - ) - - regularIPs, loopbackIPs, _ := netmon.LocalAddresses() - - var dns []string - for attempt := 0; attempt < maxRetries; attempt++ { - if attempt > 0 { - time.Sleep(retryInterval) - } - - nss := resolvconffile.NameServers("") - var localDNS []string - seen := make(map[string]bool) - - for _, ns := range nss { - if ip := net.ParseIP(ns); ip != nil { - // skip loopback IPs - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - ipStr := v.String() - if ip.String() == ipStr { - continue - } - } - if !seen[ip.String()] { - seen[ip.String()] = true - localDNS = append(localDNS, ip.String()) - } - } - } - - // If we successfully read the file and found nameservers, return them - if len(localDNS) > 0 { - return localDNS - } - } - - return dns -} - func getDNSFromScutil() []string { logger := *ProxyLogger.Load() diff --git a/nameservers_linux.go b/nameservers_linux.go index 1fad95b..13a5507 100644 --- a/nameservers_linux.go +++ b/nameservers_linux.go @@ -17,7 +17,7 @@ const ( ) func dnsFns() []dnsFn { - return []dnsFn{dns4, dns6, dnsFromSystemdResolver} + return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver} } func dns4() []string { diff --git a/nameservers_unix.go b/nameservers_unix.go index 39cc971..d7af521 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -2,8 +2,63 @@ package ctrld -import "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +import ( + "net" + "slices" + "time" -func nameserversFromResolvconf() []string { + "tailscale.com/net/netmon" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +) + +// currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file. +func currentNameserversFromResolvconf() []string { return resolvconffile.NameServers("") } + +// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. +// A nameserver is usable if it's not one of current machine's IP addresses +// and loopback IP addresses. +func dnsFromResolvConf() []string { + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var dns []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + nss := resolvconffile.NameServers("") + var localDNS []string + seen := make(map[string]bool) + + for _, ns := range nss { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + continue + } + } + if !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + + // If we successfully read the file and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return dns +} diff --git a/nameservers_windows.go b/nameservers_windows.go index 0c47e58..eb4f2b5 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -158,7 +158,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { 0, // DomainGuid - not needed 0, // SiteName - not needed uintptr(flags), // Flags - uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output if ret != 0 { switch ret { @@ -330,7 +330,8 @@ func getDNSServers(ctx context.Context) ([]string, error) { return ns, nil } -func nameserversFromResolvconf() []string { +// currentNameserversFromResolvconf returns a nil slice of strings. +func currentNameserversFromResolvconf() []string { return nil } diff --git a/resolver.go b/resolver.go index 401a7f9..a44ddb2 100644 --- a/resolver.go +++ b/resolver.go @@ -584,7 +584,7 @@ func NewPrivateResolver() Resolver { } nss := *or.lanServers.Load() resolverMutex.Unlock() - resolveConfNss := nameserversFromResolvconf() + resolveConfNss := currentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 for _, ns := range nss {