diff --git a/cmd/ctrld/os_linux.go b/cmd/ctrld/os_linux.go index 839d99d..307ee3a 100644 --- a/cmd/ctrld/os_linux.go +++ b/cmd/ctrld/os_linux.go @@ -112,7 +112,7 @@ func resetDNS(iface *net.Interface) (err error) { } // TODO(cuonglm): handle DHCPv6 properly. - if ctrldnet.SupportsIPv6() { + if ctrldnet.IPv6Available(ctx) { c := client6.NewClient() conversation, err := c.Exchange(iface.Name) if err != nil { diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 549d9ff..9aba028 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -74,7 +74,7 @@ func (p *prog) run() { uc.Init() if uc.BootstrapIP == "" { uc.SetupBootstrapIP() - mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n) + mainLog.Info().Msgf("Bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) } else { mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n) } diff --git a/config.go b/config.go index c77d73c..7e7bccb 100644 --- a/config.go +++ b/config.go @@ -205,6 +205,10 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { return false } +func (uc *UpstreamConfig) BootstrapIPs() []string { + return uc.bootstrapIPs +} + // SetCertPool sets the system cert pool used for TLS connections. func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) { uc.certPool = cp @@ -220,19 +224,6 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS) - for _, ip := range uc.bootstrapIPs { - if uc.BootstrapIP == "" { - // Remember what's the current IP in bootstrap IPs list, - // so we can select next one upon re-bootstrapping. - uc.nextBootstrapIP.Add(1) - - // If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip. - if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) { - continue - } - uc.BootstrapIP = ip - } - } ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs) } @@ -245,32 +236,7 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") - n := uint32(len(uc.bootstrapIPs)) - if n == 0 { - uc.SetupBootstrapIP() - uc.setupTransportWithoutPingUpstream() - } - - timeoutMs := 1000 - if uc.Timeout > 0 && uc.Timeout < timeoutMs { - timeoutMs = uc.Timeout - } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) - defer cancel() - - hasIPv6 := ctrldnet.IPv6Available(ctx) - // Only attempt n times, because if there's no usable ip, - // the bootstrap ip will be kept as-is. - for i := uint32(0); i < n; i++ { - // Select the next ip in bootstrap ip list. - next := uc.nextBootstrapIP.Add(1) - ip := uc.bootstrapIPs[(next-1)%n] - if !hasIPv6 && ctrldnet.IsIPv6(ip) { - continue - } - uc.BootstrapIP = ip - break - } + uc.BootstrapIP = "" uc.setupTransportWithoutPingUpstream() return true, nil }) @@ -312,18 +278,26 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { } dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: dialerTimeout, - KeepAlive: dialerTimeout, - } - // if we have a bootstrap ip set, use it to avoid DNS lookup + _, port, _ := net.SplitHostPort(addr) if uc.BootstrapIP != "" { - if _, port, _ := net.SplitHostPort(addr); port != "" { - addr = net.JoinHostPort(uc.BootstrapIP, port) - } + dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout} + addr := net.JoinHostPort(uc.BootstrapIP, port) + Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) + return dialer.DialContext(ctx, network, addr) } - Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) - return dialer.DialContext(ctx, network, addr) + pd := &ctrldnet.ParallelDialer{} + pd.Timeout = dialerTimeout + pd.KeepAlive = dialerTimeout + addrs := make([]string, len(uc.bootstrapIPs)) + for i := range uc.bootstrapIPs { + addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port) + } + conn, err := pd.DialContext(ctx, network, addrs) + if err != nil { + return nil, err + } + Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr()) + return conn, nil } } @@ -374,21 +348,6 @@ func defaultPortFor(typ string) string { return "53" } -func availableNameservers() []string { - nss := nameservers() - n := 0 - for _, ns := range nss { - ip, _, _ := net.SplitHostPort(ns) - // skipping invalid entry or ipv6 nameserver if ipv6 not available. - if ip == "" || (ctrldnet.IsIPv6(ip) && !ctrldnet.SupportsIPv6()) { - continue - } - nss[n] = ns - n++ - } - return nss[:n] -} - // ResolverTypeFromEndpoint tries guessing the resolver type with a given endpoint // using following rules: // diff --git a/config_internal_test.go b/config_internal_test.go index 4ec872a..4c67826 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -16,8 +16,8 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { } uc.Init() uc.setupBootstrapIP(false) - if uc.BootstrapIP == "" { - t.Log(availableNameservers()) + if len(uc.bootstrapIPs) == 0 { + t.Log(nameservers()) t.Fatal("could not bootstrap ip without bootstrap DNS") } t.Log(uc) diff --git a/config_quic.go b/config_quic.go index c9b641f..4b9f4c9 100644 --- a/config_quic.go +++ b/config_quic.go @@ -5,7 +5,9 @@ package ctrld import ( "context" "crypto/tls" + "errors" "net" + "sync" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -20,26 +22,91 @@ func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() { rt := &http3.RoundTripper{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - host := addr - ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS) + domain := addr + _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { - if _, port, _ := net.SplitHostPort(addr); port != "" { - addr = net.JoinHostPort(uc.BootstrapIP, port) - } + addr = net.JoinHostPort(uc.BootstrapIP, port) ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr) + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + remoteAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg) } - remoteAddr, err := net.ResolveUDPAddr("udp", addr) + addrs := make([]string, len(uc.bootstrapIPs)) + for i := range uc.bootstrapIPs { + addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port) + } + pd := &quicParallelDialer{} + conn, err := pd.Dial(ctx, domain, addrs, tlsCfg, cfg) if err != nil { return nil, err } - - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) + ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) + return conn, err } uc.http3RoundTripper = rt } + +// Putting the code for quic parallel dialer here: +// +// - quic dialer is different with net.Dialer +// - simplification for quic free version +type parallelDialerResult struct { + conn quic.EarlyConnection + err error +} + +type quicParallelDialer struct{} + +func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + if len(addrs) == 0 { + return nil, errors.New("empty addresses") + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan *parallelDialerResult, len(addrs)) + var wg sync.WaitGroup + wg.Add(len(addrs)) + go func() { + wg.Wait() + close(ch) + }() + + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + + for _, addr := range addrs { + go func(addr string) { + defer wg.Done() + remoteAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + ch <- ¶llelDialerResult{conn: nil, err: err} + return + } + + conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg) + ch <- ¶llelDialerResult{conn: conn, err: err} + }(addr) + } + + errs := make([]error, 0, len(addrs)) + for res := range ch { + if res.err == nil { + cancel() + return res.conn, res.err + } + errs = append(errs, res.err) + } + + return nil, errors.Join(errs...) +} diff --git a/internal/net/net.go b/internal/net/net.go index e64a908..1c43bbb 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -13,7 +13,6 @@ import ( const ( controldIPv6Test = "ipv6.controld.io" - controldIPv4Test = "ipv4.controld.io" bootstrapDNS = "76.76.2.0:53" ) @@ -38,7 +37,6 @@ var probeStackDialer = &net.Dialer{ var ( stackOnce atomic.Pointer[sync.Once] - ipv6Enabled bool canListenIPv6Local bool hasNetworkUp bool ) @@ -47,13 +45,8 @@ func init() { stackOnce.Store(new(sync.Once)) } -func supportIPv4() bool { - _, err := probeStackDialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80")) - return err == nil -} - func supportIPv6(ctx context.Context) bool { - _, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "80")) + _, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443")) return err == nil } @@ -75,7 +68,6 @@ func probeStack() { b.BackOff(context.Background(), err) } } - ipv6Enabled = supportIPv6(context.Background()) canListenIPv6Local = supportListenIPv6Local() } @@ -84,11 +76,6 @@ func Up() bool { return hasNetworkUp } -func SupportsIPv6() bool { - stackOnce.Load().Do(probeStack) - return ipv6Enabled -} - func SupportsIPv6ListenLocal() bool { stackOnce.Load().Do(probeStack) return canListenIPv6Local diff --git a/resolver.go b/resolver.go index a1b8efa..084281c 100644 --- a/resolver.go +++ b/resolver.go @@ -131,7 +131,7 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - resolver := &osResolver{nameservers: availableNameservers()} + resolver := &osResolver{nameservers: nameservers()} if withBootstrapDNS { resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) }