From fa50cd4df4bb2b83227b1f0cde3780962c6c7c35 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 8 Mar 2023 11:38:46 +0700 Subject: [PATCH] all: another rework on discovering bootstrap IPs Instead of re-query DNS record for upstream when re-bootstrapping, just query all records on startup, then selecting the next bootstrap ip depends on the current network stack. --- cmd/ctrld/prog.go | 8 ++- config.go | 136 +++++++++++++++++++++++++------------------- internal/net/net.go | 40 ++++++++----- 3 files changed, 108 insertions(+), 76 deletions(-) diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 8a1dd6e..6b58116 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -64,10 +64,12 @@ func (p *prog) run() { for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() - if err := uc.SetupBootstrapIP(); err != nil { - mainLog.Fatal().Err(err).Msgf("failed to setup bootstrap IP for upstream.%s", n) + if uc.BootstrapIP == "" { + uc.SetupBootstrapIP() + mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n) + } else { + mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n) } - mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n) uc.SetupTransport() } diff --git a/config.go b/config.go index f35a7ff..84e2f40 100644 --- a/config.go +++ b/config.go @@ -2,13 +2,12 @@ package ctrld import ( "context" - "errors" "net" "net/http" "net/url" "os" "strings" - "sync" + "sync/atomic" "time" "github.com/go-playground/validator/v10" @@ -106,9 +105,9 @@ type UpstreamConfig struct { transport *http.Transport `mapstructure:"-" toml:"-"` http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"` - g singleflight.Group - // guard BootstrapIP - mu sync.Mutex + g singleflight.Group + bootstrapIPs []string + nextBootstrapIP atomic.Uint32 } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -153,19 +152,85 @@ func (uc *UpstreamConfig) Init() { } } +// SetupBootstrapIP manually find all available IPs of the upstream. +// The first usable IP will be used as bootstrap IP of the upstream. +func (uc *UpstreamConfig) SetupBootstrapIP() { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(uc.Timeout)*time.Millisecond) + defer cancel() + + c := new(dns.Client) + bootstrapIP := func(record dns.RR) string { + switch ar := record.(type) { + case *dns.A: + return ar.A.String() + case *dns.AAAA: + return ar.AAAA.String() + } + return "" + } + + // Find all A, AAAA records of the upstream. + for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} { + m := new(dns.Msg) + m.SetQuestion(uc.Domain+".", dnsType) + m.RecursionDesired = true + r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(bootstrapDNS, "53")) + if err != nil { + ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain) + continue + } + if r.Rcode != dns.RcodeSuccess { + ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode) + continue + } + if len(r.Answer) == 0 { + ProxyLog.Error().Msg("no answer from bootstrap DNS server") + continue + } + for _, a := range r.Answer { + ip := bootstrapIP(a) + if ip == "" { + continue + } + + // Storing the ip to uc.bootstrapIPs list, so it can be selected later + // when retrying failed request due to network stack changed. + uc.bootstrapIPs = append(uc.bootstrapIPs, ip) + if uc.BootstrapIP == "" { + // Remember what's the current IP in bootstrap IPs list, + // so we can select next one upon re-bootstrapping. + uc.nextBootstrapIP.Add(1) + + // If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip. + if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) { + continue + } + uc.BootstrapIP = ip + } + } + } + ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs) +} + // ReBootstrap re-setup the bootstrap IP and the transport. func (uc *UpstreamConfig) ReBootstrap() { _, _, _ = uc.g.Do("rebootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") - ctrldnet.Reset() - err := uc.SetupBootstrapIP() - if err != nil { - ProxyLog.Error().Err(err).Msg("re-bootstrapping failed") - } else { - ProxyLog.Debug().Msgf("bootstrap ip set to: %s", uc.BootstrapIP) + n := uint32(len(uc.bootstrapIPs)) + // Only attempt n times, because if there's no usable ip, + // the bootstrap ip will be kept as-is. + for i := uint32(0); i < n; i++ { + // Select the next ip in bootstrap ip list. + next := uc.nextBootstrapIP.Add(1) + ip := uc.bootstrapIPs[(next-1)%n] + if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) { + continue + } + uc.BootstrapIP = ip + break } uc.SetupTransport() - return err == nil, err + return true, nil }) } @@ -180,53 +245,6 @@ func (uc *UpstreamConfig) SetupTransport() { } } -// SetupBootstrapIP manually find all available IPs of the upstream. -func (uc *UpstreamConfig) SetupBootstrapIP() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(uc.Timeout)*time.Millisecond) - defer cancel() - - uc.mu.Lock() - defer uc.mu.Unlock() - - c := new(dns.Client) - m := new(dns.Msg) - dnsType := dns.TypeA - if ctrldnet.SupportsIPv6() { - dnsType = dns.TypeAAAA - } - m.SetQuestion(uc.Domain+".", dnsType) - m.RecursionDesired = true - r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(bootstrapDNS, "53")) - if err != nil { - ProxyLog.Error().Err(err).Msgf("could not resolve domain %s for upstream", uc.Domain) - return err - } - if r.Rcode != dns.RcodeSuccess { - ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode) - return errors.New(dns.RcodeToString[r.Rcode]) - } - if len(r.Answer) == 0 { - return errors.New("no answer from bootstrap DNS server") - } - - bootstrapIP := func(record dns.RR) string { - switch ar := record.(type) { - case *dns.A: - return ar.A.String() - case *dns.AAAA: - return ar.AAAA.String() - } - return "" - } - for _, a := range r.Answer { - if ip := bootstrapIP(a); ip != "" { - uc.BootstrapIP = ip - break - } - } - return nil -} - func (uc *UpstreamConfig) setupDOHTransport() { uc.transport = http.DefaultTransport.(*http.Transport).Clone() uc.transport.IdleConnTimeout = 5 * time.Second diff --git a/internal/net/net.go b/internal/net/net.go index 1488da9..6c78586 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -40,6 +40,24 @@ func init() { stackOnce.Store(new(sync.Once)) } +func supportIPv4() bool { + _, err := Dialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80")) + return err == nil +} + +func supportIPv6() bool { + _, err := Dialer.Dial("tcp6", net.JoinHostPort(controldIPv6Test, "80")) + return err == nil +} + +func supportListenIPv6Local() bool { + if ln, err := net.Listen("tcp6", "[::1]:0"); err == nil { + ln.Close() + return true + } + return false +} + func probeStack() { b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, time.Minute) for { @@ -50,20 +68,9 @@ func probeStack() { b.BackOff(context.Background(), err) } } - if _, err := Dialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80")); err == nil { - ipv4Enabled = true - } - if _, err := Dialer.Dial("tcp6", net.JoinHostPort(controldIPv6Test, "80")); err == nil { - ipv6Enabled = true - } - if ln, err := net.Listen("tcp6", "[::1]:53"); err == nil { - ln.Close() - canListenIPv6Local = true - } -} - -func Reset() { - stackOnce.Store(new(sync.Once)) + ipv4Enabled = supportIPv4() + ipv6Enabled = supportIPv6() + canListenIPv6Local = supportListenIPv6Local() } func Up() bool { @@ -86,6 +93,11 @@ func SupportsIPv6ListenLocal() bool { return canListenIPv6Local } +// IPv6Available is like SupportsIPv6, but always do the check without caching. +func IPv6Available() bool { + return supportIPv6() +} + // IsIPv6 checks if the provided IP is v6. // //lint:ignore U1000 use in os_windows.go