From b65a5ac2839660676b66aa510978749f0fa1765c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 30 Mar 2023 18:25:36 +0700 Subject: [PATCH] all: fix bug that causes ctrld stop working if bootstrap failed The bootstrap process has two issues that can make ctrld stop resolving after restarting machine host. ctrld uses bootstrap DNS and os nameservers for resolving upstream. On unix, /etc/resolv.conf content is used to get available nameservers. This works well when installing ctrld. However, after being installed, ctrld may modify the content of /etc/resolv.conf itself, to make other apps use its listener as DNS resolver. So when ctrld starts after OS restart, it ends up using [bootstrap DNS + ctrld's listener], for resolving upstream. At this moment, if ctrld could not contact bootstrap DNS for any reason, upstream domain will not be resolved. For above reason, an upstream may not have bootstrap IPs after ctrld starts. When re-bootstrapping, if there's no bootstrap IPs, ctrld should call the setup bootstrap process again. Currently, it does not, causing all queries failed. This commit fixes above issue by adding mechanism for retrieving OS nameservers properly, by querying routing table information: - Parsing /proc/net subsystem on Linux. - For BSD variants, just fetching routing information base from OS. - On Windows, just include the gateway information when reading iface. The fixing for second issue is trivial, just kickoff a bootstrap process if there's no bootstrap IPs when re-boostrapping. While at it, also ensure that fetching resolver information from ControlD API is also used the same approach. Fixes #34 --- config.go | 16 ++++++- config_internal_test.go | 20 ++++++++ go.mod | 2 +- go.sum | 2 - internal/controld/config.go | 48 +++++++++++++++++++ nameservers_bsd.go | 58 +++++++++++++++++++++++ nameservers_linux.go | 88 +++++++++++++++++++++++++++++++++++ nameservers_test.go | 11 +++++ nameservers_unix.go | 8 +--- nameservers_windows.go | 91 ++++++++++++++++--------------------- 10 files changed, 280 insertions(+), 64 deletions(-) create mode 100644 config_internal_test.go create mode 100644 nameservers_bsd.go create mode 100644 nameservers_linux.go create mode 100644 nameservers_test.go diff --git a/config.go b/config.go index 47c2315..2cb0e38 100644 --- a/config.go +++ b/config.go @@ -155,6 +155,12 @@ func (uc *UpstreamConfig) Init() { // SetupBootstrapIP manually find all available IPs of the upstream. // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) SetupBootstrapIP() { + uc.setupBootstrapIP(true) +} + +// SetupBootstrapIP manually find all available IPs of the upstream. +// The first usable IP will be used as bootstrap IP of the upstream. +func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { bootstrapIP := func(record dns.RR) string { switch ar := record.(type) { case *dns.A: @@ -166,7 +172,9 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { } resolver := &osResolver{nameservers: availableNameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + if withBootstrapDNS { + resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + } ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers) timeoutMs := 2000 if uc.Timeout > 0 && uc.Timeout < timeoutMs { @@ -228,9 +236,13 @@ func (uc *UpstreamConfig) ReBootstrap() { default: return } - _, _, _ = uc.g.Do("rebootstrap", func() (any, error) { + _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") n := uint32(len(uc.bootstrapIPs)) + if n == 0 { + uc.SetupBootstrapIP() + uc.setupTransportWithoutPingUpstream() + } timeoutMs := 1000 if uc.Timeout > 0 && uc.Timeout < timeoutMs { diff --git a/config_internal_test.go b/config_internal_test.go new file mode 100644 index 0000000..608b0ec --- /dev/null +++ b/config_internal_test.go @@ -0,0 +1,20 @@ +package ctrld + +import ( + "testing" +) + +func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { + uc := &UpstreamConfig{ + Name: "test", + Type: ResolverTypeDOH, + Endpoint: "https://freedns.controld.com/p2", + Timeout: 5000, + } + uc.Init() + uc.setupBootstrapIP(false) + if uc.BootstrapIP == "" { + t.Fatal("could not bootstrap ip without bootstrap DNS") + } + t.Log(uc) +} diff --git a/go.mod b/go.mod index 25f5cda..da86f6d 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/spf13/cobra v1.4.0 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.8.1 + golang.org/x/net v0.7.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.5.0 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -68,7 +69,6 @@ require ( golang.org/x/crypto v0.4.0 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/net v0.7.0 // indirect golang.org/x/text v0.7.0 // indirect golang.org/x/tools v0.2.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 57a40c0..5b863e7 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27 github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353 h1:PFKlvMrKAUendoPEiJxSMkYeGG/G/5k7vu2ldGBnq3I= -github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/controld/config.go b/internal/controld/config.go index 3994837..f323f99 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -7,16 +7,26 @@ import ( "fmt" "net" "net/http" + "sync" "time" + "github.com/miekg/dns" + + "github.com/Control-D-Inc/ctrld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) const ( + apiDomain = "api.controld.com" resolverDataURL = "https://api.controld.com/utility" InvalidConfigCode = 40401 ) +var ( + resolveAPIDomainOnce sync.Once + apiDomainIP string +) + // ResolverConfig represents Control D resolver data. type ResolverConfig struct { DOH string `json:"doh"` @@ -64,6 +74,44 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) { if ctrldnet.SupportsIPv4() { proto = "tcp4" } + resolveAPIDomainOnce.Do(func() { + r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS}) + if err != nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + msg := new(dns.Msg) + dnsType := dns.TypeAAAA + if proto == "tcp4" { + dnsType = dns.TypeA + } + msg.SetQuestion(apiDomain+".", dnsType) + msg.RecursionDesired = true + answer, err := r.Resolve(ctx, msg) + if err != nil { + return + } + if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 { + return + } + for _, record := range answer.Answer { + switch ar := record.(type) { + case *dns.A: + apiDomainIP = ar.A.String() + return + case *dns.AAAA: + apiDomainIP = ar.AAAA.String() + return + } + } + }) + if apiDomainIP != "" { + if _, port, _ := net.SplitHostPort(addr); port != "" { + return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port)) + } + } return ctrldnet.Dialer.DialContext(ctx, proto, addr) } client := http.Client{ diff --git a/nameservers_bsd.go b/nameservers_bsd.go new file mode 100644 index 0000000..5ecc5e6 --- /dev/null +++ b/nameservers_bsd.go @@ -0,0 +1,58 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package ctrld + +import ( + "net" + "syscall" + + "golang.org/x/net/route" +) + +func osNameservers() []string { + var dns []string + seen := make(map[string]bool) + rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) + if err != nil { + return nil + } + messages, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return nil + } + for _, message := range messages { + message, ok := message.(*route.RouteMessage) + if !ok { + continue + } + addresses := message.Addrs + if len(addresses) < 2 { + continue + } + dst, gw := toNetIP(addresses[0]), toNetIP(addresses[1]) + if dst == nil || gw == nil { + continue + } + if gw.IsLoopback() || seen[gw.String()] { + continue + } + if dst.Equal(net.IPv4zero) || dst.Equal(net.IPv6zero) { + seen[gw.String()] = true + dns = append(dns, net.JoinHostPort(gw.String(), "53")) + } + } + return dns +} + +func toNetIP(addr route.Addr) net.IP { + switch t := addr.(type) { + case *route.Inet4Addr: + return net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) + case *route.Inet6Addr: + ip := make(net.IP, net.IPv6len) + copy(ip, t.IP[:]) + return ip + default: + return nil + } +} diff --git a/nameservers_linux.go b/nameservers_linux.go new file mode 100644 index 0000000..deeff7e --- /dev/null +++ b/nameservers_linux.go @@ -0,0 +1,88 @@ +package ctrld + +import ( + "bufio" + "bytes" + "encoding/hex" + "net" + "os" +) + +const ( + v4RouteFile = "/proc/net/route" + v6RouteFile = "/proc/net/ipv6_route" +) + +func osNameservers() []string { + ns4 := dns4() + ns6 := dns6() + ns := make([]string, len(ns4)+len(ns6)) + ns = append(ns, ns4...) + ns = append(ns, ns6...) + return ns +} + +func dns4() []string { + f, err := os.Open(v4RouteFile) + if err != nil { + return nil + } + defer f.Close() + + var dns []string + seen := make(map[string]bool) + s := bufio.NewScanner(f) + first := true + for s.Scan() { + if first { + first = false + continue + } + fields := bytes.Fields(s.Bytes()) + if len(fields) < 2 { + continue + } + + gw := make([]byte, net.IPv4len) + // Third fields is gateway. + if _, err := hex.Decode(gw, fields[2]); err != nil { + continue + } + ip := net.IPv4(gw[3], gw[2], gw[1], gw[0]) + if ip.Equal(net.IPv4zero) || seen[ip.String()] { + continue + } + seen[ip.String()] = true + dns = append(dns, net.JoinHostPort(ip.String(), "53")) + } + return dns +} + +func dns6() []string { + f, err := os.Open(v6RouteFile) + if err != nil { + return nil + } + defer f.Close() + + var dns []string + s := bufio.NewScanner(f) + for s.Scan() { + fields := bytes.Fields(s.Bytes()) + if len(fields) < 4 { + continue + } + + gw := make([]byte, net.IPv6len) + // Fifth fields is gateway. + if _, err := hex.Decode(gw, fields[4]); err != nil { + continue + } + ip := net.IP(gw) + if ip.Equal(net.IPv6zero) { + continue + } + dns = append(dns, net.JoinHostPort(ip.String(), "53")) + } + return dns +} diff --git a/nameservers_test.go b/nameservers_test.go new file mode 100644 index 0000000..166cced --- /dev/null +++ b/nameservers_test.go @@ -0,0 +1,11 @@ +package ctrld + +import "testing" + +func TestNameservers(t *testing.T) { + ns := nameservers() + if len(ns) == 0 { + t.Fatal("failed to get nameservers") + } + t.Log(ns) +} diff --git a/nameservers_unix.go b/nameservers_unix.go index 5c765d3..fd9ebfc 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -1,11 +1,7 @@ -//go:build !js && !windows +//go:build unix package ctrld -import ( - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" -) - func nameservers() []string { - return resolvconffile.NameServersWithPort() + return osNameservers() } diff --git a/nameservers_windows.go b/nameservers_windows.go index 7812f2a..1863a6e 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -2,70 +2,55 @@ package ctrld import ( "net" - "os" "syscall" - "unsafe" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.org/x/sys/windows" ) func nameservers() []string { - aas, err := adapterAddresses() + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) if err != nil { return nil } - ns := make([]string, 0, len(aas)) + ns := make([]string, 0, len(aas)*2) + seen := make(map[string]bool) + do := func(addr windows.SocketAddress) { + sa, err := addr.Sockaddr.Sockaddr() + if err != nil { + return + } + var ip net.IP + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) + case *syscall.SockaddrInet6: + ip = make(net.IP, net.IPv6len) + copy(ip, sa.Addr[:]) + if ip[0] == 0xfe && ip[1] == 0xc0 { + // Ignore these fec0/10 ones. Windows seems to + // populate them as defaults on its misc rando + // interfaces. + return + } + default: + return + + } + if ip.IsLoopback() || seen[ip.String()] { + return + } + seen[ip.String()] = true + ns = append(ns, net.JoinHostPort(ip.String(), "53")) + } for _, aa := range aas { - for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next { - sa, err := dns.Address.Sockaddr.Sockaddr() - if err != nil { - continue - } - var ip net.IP - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) - case *syscall.SockaddrInet6: - ip = make(net.IP, net.IPv6len) - copy(ip, sa.Addr[:]) - if ip[0] == 0xfe && ip[1] == 0xc0 { - // Ignore these fec0/10 ones. Windows seems to - // populate them as defaults on its misc rando - // interfaces. - continue - } - default: - // Unexpected type. - continue - } - ns = append(ns, net.JoinHostPort(ip.String(), "53")) + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { + do(dns.Address) + } + for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next { + do(gw.Address) } } return ns } - -func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { - var b []byte - l := uint32(15000) // recommended initial size - for { - b = make([]byte, l) - err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) - if err == nil { - if l == 0 { - return nil, nil - } - break - } - if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - if l <= uint32(len(b)) { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - } - var aas []*windows.IpAdapterAddresses - for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { - aas = append(aas, aa) - } - return aas, nil -}