all: another rework on discovering bootstrap IPs

Instead of re-query DNS record for upstream when re-bootstrapping, just
query all records on startup, then selecting the next bootstrap ip
depends on the current network stack.
This commit is contained in:
Cuong Manh Le
2023-03-08 11:38:46 +07:00
committed by Cuong Manh Le
parent 018f6651c1
commit fa50cd4df4
3 changed files with 108 additions and 76 deletions

View File

@@ -64,10 +64,12 @@ func (p *prog) run() {
for n := range p.cfg.Upstream {
uc := p.cfg.Upstream[n]
uc.Init()
if err := uc.SetupBootstrapIP(); err != nil {
mainLog.Fatal().Err(err).Msgf("failed to setup bootstrap IP for upstream.%s", n)
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
} else {
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n)
}
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
uc.SetupTransport()
}

136
config.go
View File

@@ -2,13 +2,12 @@ package ctrld
import (
"context"
"errors"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
@@ -106,9 +105,9 @@ type UpstreamConfig struct {
transport *http.Transport `mapstructure:"-" toml:"-"`
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
g singleflight.Group
// guard BootstrapIP
mu sync.Mutex
g singleflight.Group
bootstrapIPs []string
nextBootstrapIP atomic.Uint32
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
@@ -153,19 +152,85 @@ func (uc *UpstreamConfig) Init() {
}
}
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) SetupBootstrapIP() {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(uc.Timeout)*time.Millisecond)
defer cancel()
c := new(dns.Client)
bootstrapIP := func(record dns.RR) string {
switch ar := record.(type) {
case *dns.A:
return ar.A.String()
case *dns.AAAA:
return ar.AAAA.String()
}
return ""
}
// Find all A, AAAA records of the upstream.
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
m := new(dns.Msg)
m.SetQuestion(uc.Domain+".", dnsType)
m.RecursionDesired = true
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(bootstrapDNS, "53"))
if err != nil {
ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain)
continue
}
if r.Rcode != dns.RcodeSuccess {
ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode)
continue
}
if len(r.Answer) == 0 {
ProxyLog.Error().Msg("no answer from bootstrap DNS server")
continue
}
for _, a := range r.Answer {
ip := bootstrapIP(a)
if ip == "" {
continue
}
// Storing the ip to uc.bootstrapIPs list, so it can be selected later
// when retrying failed request due to network stack changed.
uc.bootstrapIPs = append(uc.bootstrapIPs, ip)
if uc.BootstrapIP == "" {
// Remember what's the current IP in bootstrap IPs list,
// so we can select next one upon re-bootstrapping.
uc.nextBootstrapIP.Add(1)
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
}
}
}
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
}
// ReBootstrap re-setup the bootstrap IP and the transport.
func (uc *UpstreamConfig) ReBootstrap() {
_, _, _ = uc.g.Do("rebootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
ctrldnet.Reset()
err := uc.SetupBootstrapIP()
if err != nil {
ProxyLog.Error().Err(err).Msg("re-bootstrapping failed")
} else {
ProxyLog.Debug().Msgf("bootstrap ip set to: %s", uc.BootstrapIP)
n := uint32(len(uc.bootstrapIPs))
// Only attempt n times, because if there's no usable ip,
// the bootstrap ip will be kept as-is.
for i := uint32(0); i < n; i++ {
// Select the next ip in bootstrap ip list.
next := uc.nextBootstrapIP.Add(1)
ip := uc.bootstrapIPs[(next-1)%n]
if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
break
}
uc.SetupTransport()
return err == nil, err
return true, nil
})
}
@@ -180,53 +245,6 @@ func (uc *UpstreamConfig) SetupTransport() {
}
}
// SetupBootstrapIP manually find all available IPs of the upstream.
func (uc *UpstreamConfig) SetupBootstrapIP() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(uc.Timeout)*time.Millisecond)
defer cancel()
uc.mu.Lock()
defer uc.mu.Unlock()
c := new(dns.Client)
m := new(dns.Msg)
dnsType := dns.TypeA
if ctrldnet.SupportsIPv6() {
dnsType = dns.TypeAAAA
}
m.SetQuestion(uc.Domain+".", dnsType)
m.RecursionDesired = true
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(bootstrapDNS, "53"))
if err != nil {
ProxyLog.Error().Err(err).Msgf("could not resolve domain %s for upstream", uc.Domain)
return err
}
if r.Rcode != dns.RcodeSuccess {
ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode)
return errors.New(dns.RcodeToString[r.Rcode])
}
if len(r.Answer) == 0 {
return errors.New("no answer from bootstrap DNS server")
}
bootstrapIP := func(record dns.RR) string {
switch ar := record.(type) {
case *dns.A:
return ar.A.String()
case *dns.AAAA:
return ar.AAAA.String()
}
return ""
}
for _, a := range r.Answer {
if ip := bootstrapIP(a); ip != "" {
uc.BootstrapIP = ip
break
}
}
return nil
}
func (uc *UpstreamConfig) setupDOHTransport() {
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
uc.transport.IdleConnTimeout = 5 * time.Second

View File

@@ -40,6 +40,24 @@ func init() {
stackOnce.Store(new(sync.Once))
}
func supportIPv4() bool {
_, err := Dialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80"))
return err == nil
}
func supportIPv6() bool {
_, err := Dialer.Dial("tcp6", net.JoinHostPort(controldIPv6Test, "80"))
return err == nil
}
func supportListenIPv6Local() bool {
if ln, err := net.Listen("tcp6", "[::1]:0"); err == nil {
ln.Close()
return true
}
return false
}
func probeStack() {
b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, time.Minute)
for {
@@ -50,20 +68,9 @@ func probeStack() {
b.BackOff(context.Background(), err)
}
}
if _, err := Dialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80")); err == nil {
ipv4Enabled = true
}
if _, err := Dialer.Dial("tcp6", net.JoinHostPort(controldIPv6Test, "80")); err == nil {
ipv6Enabled = true
}
if ln, err := net.Listen("tcp6", "[::1]:53"); err == nil {
ln.Close()
canListenIPv6Local = true
}
}
func Reset() {
stackOnce.Store(new(sync.Once))
ipv4Enabled = supportIPv4()
ipv6Enabled = supportIPv6()
canListenIPv6Local = supportListenIPv6Local()
}
func Up() bool {
@@ -86,6 +93,11 @@ func SupportsIPv6ListenLocal() bool {
return canListenIPv6Local
}
// IPv6Available is like SupportsIPv6, but always do the check without caching.
func IPv6Available() bool {
return supportIPv6()
}
// IsIPv6 checks if the provided IP is v6.
//
//lint:ignore U1000 use in os_windows.go