diff --git a/config_internal_test.go b/config_internal_test.go index 608b0ec..0a457d3 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -14,6 +14,7 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { uc.Init() uc.setupBootstrapIP(false) if uc.BootstrapIP == "" { + t.Log(availableNameservers()) t.Fatal("could not bootstrap ip without bootstrap DNS") } t.Log(uc) diff --git a/nameservers.go b/nameservers.go new file mode 100644 index 0000000..ce99a3b --- /dev/null +++ b/nameservers.go @@ -0,0 +1,29 @@ +package ctrld + +import "net" + +type dnsFn func() []string + +func nameservers() []string { + var dns []string + seen := make(map[string]bool) + ch := make(chan []string) + fns := dnsFns() + + for _, fn := range fns { + go func(fn dnsFn) { + ch <- fn() + }(fn) + } + for range fns { + for _, ns := range <-ch { + if seen[ns] { + continue + } + seen[ns] = true + dns = append(dns, net.JoinHostPort(ns, "53")) + } + } + + return dns +} diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 5ecc5e6..2beebd0 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -4,14 +4,20 @@ package ctrld import ( "net" + "os/exec" + "runtime" + "strings" "syscall" "golang.org/x/net/route" ) -func osNameservers() []string { +func dnsFns() []dnsFn { + return []dnsFn{dnsFromRIB, dnsFromIPConfig} +} + +func dnsFromRIB() []string { var dns []string - seen := make(map[string]bool) rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { return nil @@ -33,17 +39,28 @@ func osNameservers() []string { if dst == nil || gw == nil { continue } - if gw.IsLoopback() || seen[gw.String()] { + if gw.IsLoopback() { continue } if dst.Equal(net.IPv4zero) || dst.Equal(net.IPv6zero) { - seen[gw.String()] = true - dns = append(dns, net.JoinHostPort(gw.String(), "53")) + dns = append(dns, gw.String()) } } return dns } +func dnsFromIPConfig() []string { + if runtime.GOOS != "darwin" { + return nil + } + cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server") + out, _ := cmd.Output() + if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil { + return []string{ip.String()} + } + return nil +} + func toNetIP(addr route.Addr) net.IP { switch t := addr.(type) { case *route.Inet4Addr: diff --git a/nameservers_linux.go b/nameservers_linux.go index deeff7e..8859ea5 100644 --- a/nameservers_linux.go +++ b/nameservers_linux.go @@ -6,6 +6,8 @@ import ( "encoding/hex" "net" "os" + + "github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile" ) const ( @@ -13,13 +15,8 @@ const ( v6RouteFile = "/proc/net/ipv6_route" ) -func osNameservers() []string { - ns4 := dns4() - ns6 := dns6() - ns := make([]string, len(ns4)+len(ns6)) - ns = append(ns, ns4...) - ns = append(ns, ns6...) - return ns +func dnsFns() []dnsFn { + return []dnsFn{dns4, dns6, dnsFromSystemdResolver} } func dns4() []string { @@ -53,7 +50,7 @@ func dns4() []string { continue } seen[ip.String()] = true - dns = append(dns, net.JoinHostPort(ip.String(), "53")) + dns = append(dns, ip.String()) } return dns } @@ -82,7 +79,19 @@ func dns6() []string { if ip.Equal(net.IPv6zero) { continue } - dns = append(dns, net.JoinHostPort(ip.String(), "53")) + dns = append(dns, ip.String()) } return dns } + +func dnsFromSystemdResolver() []string { + c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf") + if err != nil { + return nil + } + ns := make([]string, 0, len(c.Nameservers)) + for _, nameserver := range c.Nameservers { + ns = append(ns, nameserver.String()) + } + return ns +} diff --git a/nameservers_unix.go b/nameservers_unix.go deleted file mode 100644 index fd9ebfc..0000000 --- a/nameservers_unix.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build unix - -package ctrld - -func nameservers() []string { - return osNameservers() -} diff --git a/nameservers_windows.go b/nameservers_windows.go index 1863a6e..5cd7811 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -9,7 +9,11 @@ import ( "golang.org/x/sys/windows" ) -func nameservers() []string { +func dnsFns() []dnsFn { + return []dnsFn{dnsFromAdapter} +} + +func dnsFromAdapter() []string { aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) if err != nil { return nil @@ -42,7 +46,7 @@ func nameservers() []string { return } seen[ip.String()] = true - ns = append(ns, net.JoinHostPort(ip.String(), "53")) + ns = append(ns, ip.String()) } for _, aa := range aas { for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {