diff --git a/config.go b/config.go index 47c2315..2cb0e38 100644 --- a/config.go +++ b/config.go @@ -155,6 +155,12 @@ func (uc *UpstreamConfig) Init() { // SetupBootstrapIP manually find all available IPs of the upstream. // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) SetupBootstrapIP() { + uc.setupBootstrapIP(true) +} + +// SetupBootstrapIP manually find all available IPs of the upstream. +// The first usable IP will be used as bootstrap IP of the upstream. +func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { bootstrapIP := func(record dns.RR) string { switch ar := record.(type) { case *dns.A: @@ -166,7 +172,9 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { } resolver := &osResolver{nameservers: availableNameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + if withBootstrapDNS { + resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + } ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers) timeoutMs := 2000 if uc.Timeout > 0 && uc.Timeout < timeoutMs { @@ -228,9 +236,13 @@ func (uc *UpstreamConfig) ReBootstrap() { default: return } - _, _, _ = uc.g.Do("rebootstrap", func() (any, error) { + _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") n := uint32(len(uc.bootstrapIPs)) + if n == 0 { + uc.SetupBootstrapIP() + uc.setupTransportWithoutPingUpstream() + } timeoutMs := 1000 if uc.Timeout > 0 && uc.Timeout < timeoutMs { diff --git a/config_internal_test.go b/config_internal_test.go new file mode 100644 index 0000000..608b0ec --- /dev/null +++ b/config_internal_test.go @@ -0,0 +1,20 @@ +package ctrld + +import ( + "testing" +) + +func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { + uc := &UpstreamConfig{ + Name: "test", + Type: ResolverTypeDOH, + Endpoint: "https://freedns.controld.com/p2", + Timeout: 5000, + } + uc.Init() + uc.setupBootstrapIP(false) + if uc.BootstrapIP == "" { + t.Fatal("could not bootstrap ip without bootstrap DNS") + } + t.Log(uc) +} diff --git a/go.mod b/go.mod index 25f5cda..da86f6d 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/spf13/cobra v1.4.0 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.8.1 + golang.org/x/net v0.7.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.5.0 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -68,7 +69,6 @@ require ( golang.org/x/crypto v0.4.0 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/net v0.7.0 // indirect golang.org/x/text v0.7.0 // indirect golang.org/x/tools v0.2.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 57a40c0..5b863e7 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27 github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353 h1:PFKlvMrKAUendoPEiJxSMkYeGG/G/5k7vu2ldGBnq3I= -github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/controld/config.go b/internal/controld/config.go index 3994837..f323f99 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -7,16 +7,26 @@ import ( "fmt" "net" "net/http" + "sync" "time" + "github.com/miekg/dns" + + "github.com/Control-D-Inc/ctrld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) const ( + apiDomain = "api.controld.com" resolverDataURL = "https://api.controld.com/utility" InvalidConfigCode = 40401 ) +var ( + resolveAPIDomainOnce sync.Once + apiDomainIP string +) + // ResolverConfig represents Control D resolver data. type ResolverConfig struct { DOH string `json:"doh"` @@ -64,6 +74,44 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) { if ctrldnet.SupportsIPv4() { proto = "tcp4" } + resolveAPIDomainOnce.Do(func() { + r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS}) + if err != nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + msg := new(dns.Msg) + dnsType := dns.TypeAAAA + if proto == "tcp4" { + dnsType = dns.TypeA + } + msg.SetQuestion(apiDomain+".", dnsType) + msg.RecursionDesired = true + answer, err := r.Resolve(ctx, msg) + if err != nil { + return + } + if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 { + return + } + for _, record := range answer.Answer { + switch ar := record.(type) { + case *dns.A: + apiDomainIP = ar.A.String() + return + case *dns.AAAA: + apiDomainIP = ar.AAAA.String() + return + } + } + }) + if apiDomainIP != "" { + if _, port, _ := net.SplitHostPort(addr); port != "" { + return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port)) + } + } return ctrldnet.Dialer.DialContext(ctx, proto, addr) } client := http.Client{ diff --git a/nameservers_bsd.go b/nameservers_bsd.go new file mode 100644 index 0000000..5ecc5e6 --- /dev/null +++ b/nameservers_bsd.go @@ -0,0 +1,58 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package ctrld + +import ( + "net" + "syscall" + + "golang.org/x/net/route" +) + +func osNameservers() []string { + var dns []string + seen := make(map[string]bool) + rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) + if err != nil { + return nil + } + messages, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return nil + } + for _, message := range messages { + message, ok := message.(*route.RouteMessage) + if !ok { + continue + } + addresses := message.Addrs + if len(addresses) < 2 { + continue + } + dst, gw := toNetIP(addresses[0]), toNetIP(addresses[1]) + if dst == nil || gw == nil { + continue + } + if gw.IsLoopback() || seen[gw.String()] { + continue + } + if dst.Equal(net.IPv4zero) || dst.Equal(net.IPv6zero) { + seen[gw.String()] = true + dns = append(dns, net.JoinHostPort(gw.String(), "53")) + } + } + return dns +} + +func toNetIP(addr route.Addr) net.IP { + switch t := addr.(type) { + case *route.Inet4Addr: + return net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) + case *route.Inet6Addr: + ip := make(net.IP, net.IPv6len) + copy(ip, t.IP[:]) + return ip + default: + return nil + } +} diff --git a/nameservers_linux.go b/nameservers_linux.go new file mode 100644 index 0000000..deeff7e --- /dev/null +++ b/nameservers_linux.go @@ -0,0 +1,88 @@ +package ctrld + +import ( + "bufio" + "bytes" + "encoding/hex" + "net" + "os" +) + +const ( + v4RouteFile = "/proc/net/route" + v6RouteFile = "/proc/net/ipv6_route" +) + +func osNameservers() []string { + ns4 := dns4() + ns6 := dns6() + ns := make([]string, len(ns4)+len(ns6)) + ns = append(ns, ns4...) + ns = append(ns, ns6...) + return ns +} + +func dns4() []string { + f, err := os.Open(v4RouteFile) + if err != nil { + return nil + } + defer f.Close() + + var dns []string + seen := make(map[string]bool) + s := bufio.NewScanner(f) + first := true + for s.Scan() { + if first { + first = false + continue + } + fields := bytes.Fields(s.Bytes()) + if len(fields) < 2 { + continue + } + + gw := make([]byte, net.IPv4len) + // Third fields is gateway. + if _, err := hex.Decode(gw, fields[2]); err != nil { + continue + } + ip := net.IPv4(gw[3], gw[2], gw[1], gw[0]) + if ip.Equal(net.IPv4zero) || seen[ip.String()] { + continue + } + seen[ip.String()] = true + dns = append(dns, net.JoinHostPort(ip.String(), "53")) + } + return dns +} + +func dns6() []string { + f, err := os.Open(v6RouteFile) + if err != nil { + return nil + } + defer f.Close() + + var dns []string + s := bufio.NewScanner(f) + for s.Scan() { + fields := bytes.Fields(s.Bytes()) + if len(fields) < 4 { + continue + } + + gw := make([]byte, net.IPv6len) + // Fifth fields is gateway. + if _, err := hex.Decode(gw, fields[4]); err != nil { + continue + } + ip := net.IP(gw) + if ip.Equal(net.IPv6zero) { + continue + } + dns = append(dns, net.JoinHostPort(ip.String(), "53")) + } + return dns +} diff --git a/nameservers_test.go b/nameservers_test.go new file mode 100644 index 0000000..166cced --- /dev/null +++ b/nameservers_test.go @@ -0,0 +1,11 @@ +package ctrld + +import "testing" + +func TestNameservers(t *testing.T) { + ns := nameservers() + if len(ns) == 0 { + t.Fatal("failed to get nameservers") + } + t.Log(ns) +} diff --git a/nameservers_unix.go b/nameservers_unix.go index 5c765d3..fd9ebfc 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -1,11 +1,7 @@ -//go:build !js && !windows +//go:build unix package ctrld -import ( - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" -) - func nameservers() []string { - return resolvconffile.NameServersWithPort() + return osNameservers() } diff --git a/nameservers_windows.go b/nameservers_windows.go index 7812f2a..1863a6e 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -2,70 +2,55 @@ package ctrld import ( "net" - "os" "syscall" - "unsafe" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.org/x/sys/windows" ) func nameservers() []string { - aas, err := adapterAddresses() + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) if err != nil { return nil } - ns := make([]string, 0, len(aas)) + ns := make([]string, 0, len(aas)*2) + seen := make(map[string]bool) + do := func(addr windows.SocketAddress) { + sa, err := addr.Sockaddr.Sockaddr() + if err != nil { + return + } + var ip net.IP + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) + case *syscall.SockaddrInet6: + ip = make(net.IP, net.IPv6len) + copy(ip, sa.Addr[:]) + if ip[0] == 0xfe && ip[1] == 0xc0 { + // Ignore these fec0/10 ones. Windows seems to + // populate them as defaults on its misc rando + // interfaces. + return + } + default: + return + + } + if ip.IsLoopback() || seen[ip.String()] { + return + } + seen[ip.String()] = true + ns = append(ns, net.JoinHostPort(ip.String(), "53")) + } for _, aa := range aas { - for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next { - sa, err := dns.Address.Sockaddr.Sockaddr() - if err != nil { - continue - } - var ip net.IP - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) - case *syscall.SockaddrInet6: - ip = make(net.IP, net.IPv6len) - copy(ip, sa.Addr[:]) - if ip[0] == 0xfe && ip[1] == 0xc0 { - // Ignore these fec0/10 ones. Windows seems to - // populate them as defaults on its misc rando - // interfaces. - continue - } - default: - // Unexpected type. - continue - } - ns = append(ns, net.JoinHostPort(ip.String(), "53")) + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { + do(dns.Address) + } + for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next { + do(gw.Address) } } return ns } - -func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { - var b []byte - l := uint32(15000) // recommended initial size - for { - b = make([]byte, l) - err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) - if err == nil { - if l == 0 { - return nil, nil - } - break - } - if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - if l <= uint32(len(b)) { - return nil, os.NewSyscallError("getadaptersaddresses", err) - } - } - var aas []*windows.IpAdapterAddresses - for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { - aas = append(aas, aa) - } - return aas, nil -}