diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 1c7cbd5..5f18933 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -78,7 +78,7 @@ var rootCmd = &cobra.Command{ } func curVersion() string { - if version != "dev" { + if version != "dev" && !strings.HasPrefix(version, "v") { version = "v" + version } if len(commit) > 7 { diff --git a/config.go b/config.go index b30d8a4..7afaac9 100644 --- a/config.go +++ b/config.go @@ -177,71 +177,20 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { // SetupBootstrapIP manually find all available IPs of the upstream. // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { - bootstrapIP := func(record dns.RR) string { - switch ar := record.(type) { - case *dns.A: - return ar.A.String() - case *dns.AAAA: - return ar.AAAA.String() - } - return "" - } + uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS) + for _, ip := range uc.bootstrapIPs { + if uc.BootstrapIP == "" { + // Remember what's the current IP in bootstrap IPs list, + // so we can select next one upon re-bootstrapping. + uc.nextBootstrapIP.Add(1) - resolver := &osResolver{nameservers: availableNameservers()} - if withBootstrapDNS { - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) - } - ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers) - timeoutMs := 2000 - if uc.Timeout > 0 && uc.Timeout < timeoutMs { - timeoutMs = uc.Timeout - } - do := func(dnsType uint16) { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) - defer cancel() - m := new(dns.Msg) - m.SetQuestion(uc.Domain+".", dnsType) - m.RecursionDesired = true - - r, err := resolver.Resolve(ctx, m) - if err != nil { - ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain) - return - } - if r.Rcode != dns.RcodeSuccess { - ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", uc.Domain, dns.RcodeToString[r.Rcode]) - return - } - if len(r.Answer) == 0 { - ProxyLog.Error().Msg("no answer from bootstrap DNS server") - return - } - for _, a := range r.Answer { - ip := bootstrapIP(a) - if ip == "" { + // If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip. + if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) { continue } - - // Storing the ip to uc.bootstrapIPs list, so it can be selected later - // when retrying failed request due to network stack changed. - uc.bootstrapIPs = append(uc.bootstrapIPs, ip) - if uc.BootstrapIP == "" { - // Remember what's the current IP in bootstrap IPs list, - // so we can select next one upon re-bootstrapping. - uc.nextBootstrapIP.Add(1) - - // If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip. - if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) { - continue - } - uc.BootstrapIP = ip - } + uc.BootstrapIP = ip } } - // Find all A, AAAA records of the upstream. - for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} { - do(dnsType) - } ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs) } diff --git a/internal/controld/config.go b/internal/controld/config.go index 7092d51..48efef9 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -8,11 +8,8 @@ import ( "fmt" "net" "net/http" - "sync" "time" - "github.com/miekg/dns" - "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/certs" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -25,11 +22,6 @@ const ( InvalidConfigCode = 40401 ) -var ( - resolveAPIDomainOnce sync.Once - apiDomainIP string -) - // ResolverConfig represents Control D resolver data. type ResolverConfig struct { DOH string `json:"doh"` @@ -71,51 +63,19 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) { req.Header.Add("Content-Type", "application/json") transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - // We experiment hanging in TLS handshake when connecting to ControlD API - // with ipv6. So prefer ipv4 if available. - proto := "tcp6" - if ctrldnet.SupportsIPv4() { - proto = "tcp4" + ips := ctrld.LookupIP(apiDomain) + if len(ips) == 0 { + ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr) + return ctrldnet.Dialer.DialContext(ctx, network, addr) } - resolveAPIDomainOnce.Do(func() { - r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS}) - if err != nil { - return - } - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - msg := new(dns.Msg) - dnsType := dns.TypeAAAA - if proto == "tcp4" { - dnsType = dns.TypeA - } - msg.SetQuestion(apiDomain+".", dnsType) - msg.RecursionDesired = true - answer, err := r.Resolve(ctx, msg) - if err != nil { - return - } - if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 { - return - } - for _, record := range answer.Answer { - switch ar := record.(type) { - case *dns.A: - apiDomainIP = ar.A.String() - return - case *dns.AAAA: - apiDomainIP = ar.AAAA.String() - return - } - } - }) - if apiDomainIP != "" { - if _, port, _ := net.SplitHostPort(addr); port != "" { - return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port)) - } + ctrld.ProxyLog.Debug().Msgf("API IPs: %v", ips) + _, port, _ := net.SplitHostPort(addr) + addrs := make([]string, len(ips)) + for i := range ips { + addrs[i] = net.JoinHostPort(ips[i], port) } - return ctrldnet.Dialer.DialContext(ctx, proto, addr) + d := &ctrldnet.ParallelDialer{} + return d.DialContext(ctx, network, addrs) } if router.Name() == router.DDWrt { diff --git a/internal/controld/config_test.go b/internal/controld/config_test.go index cd6ea06..13d937a 100644 --- a/internal/controld/config_test.go +++ b/internal/controld/config_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/require" ) -const utilityURL = "https://api.controld.com/utility" - func TestFetchResolverConfig(t *testing.T) { tests := []struct { name string diff --git a/internal/net/net.go b/internal/net/net.go index 4e71206..e64a908 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -2,6 +2,7 @@ package net import ( "context" + "errors" "net" "sync" "sync/atomic" @@ -37,7 +38,6 @@ var probeStackDialer = &net.Dialer{ var ( stackOnce atomic.Pointer[sync.Once] - ipv4Enabled bool ipv6Enabled bool canListenIPv6Local bool hasNetworkUp bool @@ -75,7 +75,6 @@ func probeStack() { b.BackOff(context.Background(), err) } } - ipv4Enabled = supportIPv4() ipv6Enabled = supportIPv6(context.Background()) canListenIPv6Local = supportListenIPv6Local() } @@ -85,11 +84,6 @@ func Up() bool { return hasNetworkUp } -func SupportsIPv4() bool { - stackOnce.Load().Do(probeStack) - return ipv4Enabled -} - func SupportsIPv6() bool { stackOnce.Load().Do(probeStack) return ipv6Enabled @@ -112,3 +106,47 @@ func IsIPv6(ip string) bool { parsedIP := net.ParseIP(ip) return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil } + +type parallelDialerResult struct { + conn net.Conn + err error +} + +type ParallelDialer struct { + net.Dialer +} + +func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) { + if len(addrs) == 0 { + return nil, errors.New("empty addresses") + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan *parallelDialerResult, len(addrs)) + var wg sync.WaitGroup + wg.Add(len(addrs)) + go func() { + wg.Wait() + close(ch) + }() + + for _, addr := range addrs { + go func(addr string) { + defer wg.Done() + conn, err := d.Dialer.DialContext(ctx, network, addr) + ch <- ¶llelDialerResult{conn: conn, err: err} + }(addr) + } + + errs := make([]error, 0, len(addrs)) + for res := range ch { + if res.err == nil { + cancel() + return res.conn, res.err + } + errs = append(errs, res.err) + } + + return nil, errors.Join(errs...) +} diff --git a/resolver.go b/resolver.go index 77517e6..a1b8efa 100644 --- a/resolver.go +++ b/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "time" "github.com/miekg/dns" ) @@ -79,7 +80,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error for _, server := range o.nameservers { go func(server string) { defer wg.Done() - answer, _, err := dnsClient.ExchangeContext(ctx, msg, server) + answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) ch <- &osResolverResult{answer: answer, err: err} }(server) } @@ -122,3 +123,62 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint) return answer, err } + +// LookupIP looks up host using OS resolver. +// It returns a slice of that host's IPv4 and IPv6 addresses. +func LookupIP(domain string) []string { + return lookupIP(domain, -1, true) +} + +func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { + resolver := &osResolver{nameservers: availableNameservers()} + if withBootstrapDNS { + resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + } + ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", domain, resolver.nameservers) + timeoutMs := 2000 + if timeout > 0 && timeout < timeoutMs { + timeoutMs = timeoutMs + } + ipFromRecord := func(record dns.RR) string { + switch ar := record.(type) { + case *dns.A: + return ar.A.String() + case *dns.AAAA: + return ar.AAAA.String() + } + return "" + } + + lookup := func(dnsType uint16) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) + defer cancel() + m := new(dns.Msg) + m.SetQuestion(domain+".", dnsType) + m.RecursionDesired = true + + r, err := resolver.Resolve(ctx, m) + if err != nil { + ProxyLog.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) + return + } + if r.Rcode != dns.RcodeSuccess { + ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) + return + } + if len(r.Answer) == 0 { + ProxyLog.Error().Msg("no answer from OS resolver") + return + } + for _, a := range r.Answer { + if ip := ipFromRecord(a); ip != "" { + ips = append(ips, ip) + } + } + } + // Find all A, AAAA records of the domain. + for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} { + lookup(dnsType) + } + return ips +}