all: use parallel dialer for bootstrapping ip

So we don't have to depend on network probing for checking ipv4/ipv6
enabled, making ctrld working more stably.
This commit is contained in:
Cuong Manh Le
2023-04-25 01:36:51 +07:00
committed by Cuong Manh Le
parent f73cbde7a5
commit 0af7f64bca
7 changed files with 108 additions and 95 deletions

View File

@@ -112,7 +112,7 @@ func resetDNS(iface *net.Interface) (err error) {
}
// TODO(cuonglm): handle DHCPv6 properly.
if ctrldnet.SupportsIPv6() {
if ctrldnet.IPv6Available(ctx) {
c := client6.NewClient()
conversation, err := c.Exchange(iface.Name)
if err != nil {

View File

@@ -74,7 +74,7 @@ func (p *prog) run() {
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
mainLog.Info().Msgf("Bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n)
}

View File

@@ -205,6 +205,10 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
return false
}
func (uc *UpstreamConfig) BootstrapIPs() []string {
return uc.bootstrapIPs
}
// SetCertPool sets the system cert pool used for TLS connections.
func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
uc.certPool = cp
@@ -220,19 +224,6 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
for _, ip := range uc.bootstrapIPs {
if uc.BootstrapIP == "" {
// Remember what's the current IP in bootstrap IPs list,
// so we can select next one upon re-bootstrapping.
uc.nextBootstrapIP.Add(1)
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
}
}
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
}
@@ -245,32 +236,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
n := uint32(len(uc.bootstrapIPs))
if n == 0 {
uc.SetupBootstrapIP()
uc.setupTransportWithoutPingUpstream()
}
timeoutMs := 1000
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
timeoutMs = uc.Timeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
hasIPv6 := ctrldnet.IPv6Available(ctx)
// Only attempt n times, because if there's no usable ip,
// the bootstrap ip will be kept as-is.
for i := uint32(0); i < n; i++ {
// Select the next ip in bootstrap ip list.
next := uc.nextBootstrapIP.Add(1)
ip := uc.bootstrapIPs[(next-1)%n]
if !hasIPv6 && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
break
}
uc.BootstrapIP = ""
uc.setupTransportWithoutPingUpstream()
return true, nil
})
@@ -312,18 +278,26 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
}
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: dialerTimeout,
KeepAlive: dialerTimeout,
}
// if we have a bootstrap ip set, use it to avoid DNS lookup
_, port, _ := net.SplitHostPort(addr)
if uc.BootstrapIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
}
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
addr := net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
return dialer.DialContext(ctx, network, addr)
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
return dialer.DialContext(ctx, network, addr)
pd := &ctrldnet.ParallelDialer{}
pd.Timeout = dialerTimeout
pd.KeepAlive = dialerTimeout
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
}
conn, err := pd.DialContext(ctx, network, addrs)
if err != nil {
return nil, err
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr())
return conn, nil
}
}
@@ -374,21 +348,6 @@ func defaultPortFor(typ string) string {
return "53"
}
func availableNameservers() []string {
nss := nameservers()
n := 0
for _, ns := range nss {
ip, _, _ := net.SplitHostPort(ns)
// skipping invalid entry or ipv6 nameserver if ipv6 not available.
if ip == "" || (ctrldnet.IsIPv6(ip) && !ctrldnet.SupportsIPv6()) {
continue
}
nss[n] = ns
n++
}
return nss[:n]
}
// ResolverTypeFromEndpoint tries guessing the resolver type with a given endpoint
// using following rules:
//

View File

@@ -16,8 +16,8 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
}
uc.Init()
uc.setupBootstrapIP(false)
if uc.BootstrapIP == "" {
t.Log(availableNameservers())
if len(uc.bootstrapIPs) == 0 {
t.Log(nameservers())
t.Fatal("could not bootstrap ip without bootstrap DNS")
}
t.Log(uc)

View File

@@ -5,7 +5,9 @@ package ctrld
import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
@@ -20,26 +22,91 @@ func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
rt := &http3.RoundTripper{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
host := addr
ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS)
domain := addr
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
}
addr = net.JoinHostPort(uc.BootstrapIP, port)
ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, domain, addrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg)
ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
uc.http3RoundTripper = rt
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
// - simplification for quic free version
type parallelDialerResult struct {
conn quic.EarlyConnection
err error
}
type quicParallelDialer struct{}
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
ch <- &parallelDialerResult{conn: conn, err: err}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}

View File

@@ -13,7 +13,6 @@ import (
const (
controldIPv6Test = "ipv6.controld.io"
controldIPv4Test = "ipv4.controld.io"
bootstrapDNS = "76.76.2.0:53"
)
@@ -38,7 +37,6 @@ var probeStackDialer = &net.Dialer{
var (
stackOnce atomic.Pointer[sync.Once]
ipv6Enabled bool
canListenIPv6Local bool
hasNetworkUp bool
)
@@ -47,13 +45,8 @@ func init() {
stackOnce.Store(new(sync.Once))
}
func supportIPv4() bool {
_, err := probeStackDialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80"))
return err == nil
}
func supportIPv6(ctx context.Context) bool {
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "80"))
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
return err == nil
}
@@ -75,7 +68,6 @@ func probeStack() {
b.BackOff(context.Background(), err)
}
}
ipv6Enabled = supportIPv6(context.Background())
canListenIPv6Local = supportListenIPv6Local()
}
@@ -84,11 +76,6 @@ func Up() bool {
return hasNetworkUp
}
func SupportsIPv6() bool {
stackOnce.Load().Do(probeStack)
return ipv6Enabled
}
func SupportsIPv6ListenLocal() bool {
stackOnce.Load().Do(probeStack)
return canListenIPv6Local

View File

@@ -131,7 +131,7 @@ func LookupIP(domain string) []string {
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolver := &osResolver{nameservers: availableNameservers()}
resolver := &osResolver{nameservers: nameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
}