diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index c052f44..96e264b 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -1278,8 +1278,9 @@ func initUpgradeCmd() *cobra.Command { dlUrl := upgradeUrl(baseUrl) mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) - resp, err := getWithRetry(dlUrl) + resp, err := getWithRetry(dlUrl, downloadServerIp) if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to download binary") } defer resp.Body.Close() diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index eecfd6d..6a214e5 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -519,13 +519,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err } - resolveCtx, cancel := context.WithCancel(ctx) + resolveCtx, cancel := upstreamConfig.Context(ctx) defer cancel() - if upstreamConfig.Timeout > 0 { - timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout)) - defer cancel() - resolveCtx = timeoutCtx - } return dnsResolver.Resolve(resolveCtx, msg) } resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { diff --git a/cmd/cli/library.go b/cmd/cli/library.go index a5ba389..3c1db1b 100644 --- a/cmd/cli/library.go +++ b/cmd/cli/library.go @@ -28,6 +28,7 @@ type AppConfig struct { const ( defaultHTTPTimeout = 30 * time.Second defaultMaxRetries = 3 + downloadServerIp = "23.171.240.151" ) // httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback @@ -46,10 +47,15 @@ func httpClientWithFallback(timeout time.Duration) *http.Client { } // doWithRetry performs an HTTP request with retries -func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) { +func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) { var lastErr error client := httpClientWithFallback(defaultHTTPTimeout) - + var ipReq *http.Request + if ip != "" { + ipReq = req.Clone(req.Context()) + ipReq.Host = ip + ipReq.URL.Host = ip + } for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff @@ -59,6 +65,15 @@ func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) { if err == nil { return resp, nil } + if ipReq != nil { + mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host) + mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip) + resp, err = client.Do(ipReq) + if err == nil { + return resp, nil + } + } + lastErr = err mainLog.Load().Debug().Err(err). Str("method", req.Method). @@ -69,10 +84,10 @@ func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) { } // Helper for making GET requests with retries -func getWithRetry(url string) (*http.Response, error) { +func getWithRetry(url string, ip string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } - return doWithRetry(req, defaultMaxRetries) + return doWithRetry(req, defaultMaxRetries, ip) } diff --git a/config.go b/config.go index 4ace9f1..f208f0d 100644 --- a/config.go +++ b/config.go @@ -53,10 +53,27 @@ const ( FreeDnsDomain = "freedns.controld.com" // FreeDNSBoostrapIP is the IP address of freedns.controld.com. FreeDNSBoostrapIP = "76.76.2.11" + // FreeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com. + FreeDNSBoostrapIPv6 = "2606:1a40::11" // PremiumDnsDomain is the domain name of premium ControlD service. PremiumDnsDomain = "dns.controld.com" // PremiumDNSBoostrapIP is the IP address of dns.controld.com. PremiumDNSBoostrapIP = "76.76.2.22" + // PremiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.com. + PremiumDNSBoostrapIPv6 = "2606:1a40::22" + + // freeDnsDomainDev is the domain name of free ControlD service on dev env. + freeDnsDomainDev = "freedns.controld.dev" + // freeDNSBoostrapIP is the IP address of freedns.controld.dev. + freeDNSBoostrapIP = "176.125.239.11" + // freeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com. + freeDNSBoostrapIPv6 = "2606:1a40:f000::11" + // premiumDnsDomainDev is the domain name of premium ControlD service on dev env. + premiumDnsDomainDev = "dns.controld.dev" + // premiumDNSBoostrapIP is the IP address of dns.controld.dev. + premiumDNSBoostrapIP = "176.125.239.22" + // premiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.dev. + premiumDNSBoostrapIPv6 = "2606:1a40:f000::22" controlDComDomain = "controld.com" controlDNetDomain = "controld.net" @@ -261,6 +278,7 @@ type UpstreamConfig struct { http3RoundTripper6 http.RoundTripper certPool *x509.CertPool u *url.URL + fallbackOnce sync.Once uid string } @@ -426,6 +444,10 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { } } uc.bootstrapIPs = uc.bootstrapIPs[:n] + if len(uc.bootstrapIPs) == 0 { + uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain) + ProxyLogger.Load().Warn().Msgf("no bootstrap IPs found for %q, fallback to direct IPs", uc.Domain) + } } if len(uc.bootstrapIPs) > 0 { break @@ -538,7 +560,10 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { // Ping warms up the connection to DoH/DoH3 upstream. func (uc *UpstreamConfig) Ping() { - _ = uc.ping() + if err := uc.ping(); err != nil { + ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) + _ = uc.FallbackToDirectIP() + } } // ErrorPing is like Ping, but return an error if any. @@ -575,7 +600,6 @@ func (uc *UpstreamConfig) ping() error { for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { switch uc.Type { case ResolverTypeDOH: - if err := ping(uc.dohTransport(typ)); err != nil { return err } @@ -743,6 +767,41 @@ func (uc *UpstreamConfig) initDnsStamps() error { return nil } +// Context returns a new context with timeout set from upstream config. +func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context.CancelFunc) { + if uc.Timeout > 0 { + return context.WithTimeout(ctx, time.Millisecond*time.Duration(uc.Timeout)) + } + return context.WithCancel(ctx) +} + +// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain. +func (uc *UpstreamConfig) FallbackToDirectIP() bool { + if !uc.IsControlD() { + return false + } + if uc.u == nil || uc.Domain == "" { + return false + } + + done := false + uc.fallbackOnce.Do(func() { + var ip string + switch { + case dns.IsSubDomain(PremiumDnsDomain, uc.Domain): + ip = PremiumDNSBoostrapIP + case dns.IsSubDomain(FreeDnsDomain, uc.Domain): + ip = FreeDNSBoostrapIP + default: + return + } + ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) + uc.u.Host = ip + done = true + }) + return done +} + // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { if lc.Policy != nil { @@ -889,3 +948,18 @@ func (uc *UpstreamConfig) String() string { return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}", uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack) } + +// bootstrapIPsFromControlDDomain returns bootstrap IPs for ControlD domain. +func bootstrapIPsFromControlDDomain(domain string) []string { + switch domain { + case PremiumDnsDomain: + return []string{PremiumDNSBoostrapIP, PremiumDNSBoostrapIPv6} + case FreeDnsDomain: + return []string{FreeDNSBoostrapIP, FreeDNSBoostrapIPv6} + case premiumDnsDomainDev: + return []string{premiumDNSBoostrapIP, premiumDNSBoostrapIPv6} + case freeDnsDomainDev: + return []string{freeDNSBoostrapIP, freeDNSBoostrapIPv6} + } + return nil +} diff --git a/doh.go b/doh.go index d702995..73b2764 100644 --- a/doh.go +++ b/doh.go @@ -113,6 +113,12 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro c.Transport = transport } resp, err := c.Do(req) + if err != nil && r.uc.FallbackToDirectIP() { + retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx)) + defer cancel() + Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip") + resp, err = c.Do(req.Clone(retryCtx)) + } if err != nil { if r.isDoH3 { if closer, ok := c.Transport.(io.Closer); ok { diff --git a/internal/controld/config.go b/internal/controld/config.go index 5e65fdb..23542c7 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -24,7 +24,10 @@ import ( const ( apiDomainCom = "api.controld.com" + apiDomainComIPv4 = "147.185.34.1" + apiDomainComIPv6 = "2606:1a40:3::1" apiDomainDev = "api.controld.dev" + apiDomainDevIPv4 = "23.171.240.84" apiURLCom = "https://api.controld.com" apiURLDev = "https://api.controld.dev" resolverDataURLCom = apiURLCom + "/utility" @@ -136,11 +139,11 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") transport := apiTransport(cdDev) - client := http.Client{ + client := &http.Client{ Timeout: defaultTimeout, Transport: transport, } - resp, err := client.Do(req) + resp, err := doWithFallback(client, req, apiServerIP(cdDev)) if err != nil { return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) } @@ -177,11 +180,11 @@ func SendLogs(lr *LogsRequest, cdDev bool) error { req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/x-www-form-urlencoded") transport := apiTransport(cdDev) - client := http.Client{ + client := &http.Client{ Timeout: sendLogTimeout, Transport: transport, } - resp, err := client.Do(req) + resp, err := doWithFallback(client, req, apiServerIP(cdDev)) if err != nil { return fmt.Errorf("SendLogs client.Do: %w", err) } @@ -213,20 +216,20 @@ func apiTransport(cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom + apiIpsV4 := []string{apiDomainComIPv4} + apiIpsV6 := []string{apiDomainComIPv6} + apiIPs := []string{apiDomainComIPv4, apiDomainComIPv6} if cdDev { apiDomain = apiDomainDev - } - - // First try IPv4 - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + apiIpsV4 = []string{apiDomainDevIPv4} + apiIpsV6 = []string{} + apiIPs = []string{apiDomainDevIPv4} } ips := ctrld.LookupIP(apiDomain) if len(ips) == 0 { - ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, falling back to direct connection to %s", apiDomain, addr) - return dialer.DialContext(ctx, network, addr) + ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) + ips = apiIPs } // Separate IPv4 and IPv6 addresses @@ -239,35 +242,62 @@ func apiTransport(cdDev bool) *http.Transport { } } + dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) { + d := &ctrldnet.ParallelDialer{} + return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load()) + } _, port, _ := net.SplitHostPort(addr) // Try IPv4 first if len(ipv4s) > 0 { - addrs := make([]string, len(ipv4s)) - for i, ip := range ipv4s { - addrs[i] = net.JoinHostPort(ip, port) - } - d := &ctrldnet.ParallelDialer{} - if conn, err := d.DialContext(ctx, "tcp4", addrs, ctrld.ProxyLogger.Load()); err == nil { + if conn, err := dial(ctx, "tcp4", addrsFromPort(ipv4s, port)); err == nil { return conn, nil } } - - // Fall back to IPv6 if available - if len(ipv6s) > 0 { - addrs := make([]string, len(ipv6s)) - for i, ip := range ipv6s { - addrs[i] = net.JoinHostPort(ip, port) - } - d := &ctrldnet.ParallelDialer{} - return d.DialContext(ctx, "tcp6", addrs, ctrld.ProxyLogger.Load()) + // Fallback to direct IPv4 + if conn, err := dial(ctx, "tcp4", addrsFromPort(apiIpsV4, port)); err == nil { + return conn, nil } - // Final fallback to direct connection - return dialer.DialContext(ctx, network, addr) + // Fallback to IPv6 if available + if len(ipv6s) > 0 { + if conn, err := dial(ctx, "tcp6", addrsFromPort(ipv6s, port)); err == nil { + return conn, nil + } + } + // Fallback to direct IPv6 + return dial(ctx, "tcp6", addrsFromPort(apiIpsV6, port)) } if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } return transport } + +func addrsFromPort(ips []string, port string) []string { + addrs := make([]string, len(ips)) + for i, ip := range ips { + addrs[i] = net.JoinHostPort(ip, port) + } + return addrs +} + +func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { + resp, err := client.Do(req) + if err != nil { + ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) + ipReq := req.Clone(req.Context()) + ipReq.Host = apiIp + ipReq.URL.Host = apiIp + resp, err = client.Do(ipReq) + } + return resp, err +} + +// apiServerIP returns the direct IP to connect to API server. +func apiServerIP(cdDev bool) string { + if cdDev { + return apiDomainDevIPv4 + } + return apiDomainComIPv4 +}