diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 8ece9c1..4a428a4 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net" "runtime" @@ -182,6 +183,15 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i resolveCtx = timeoutCtx } answer, err := dnsResolver.Resolve(resolveCtx, msg) + if errors.Is(err, ctrld.ErrUpstreamFailed) { + ctrldnet.Reset() + if err := upstreamConfig.SetupBootstrapIP(); err != nil { + mainLog.Error().Err(err).Msg("could not re-initialize bootstrap IP") + } else { + mainLog.Debug().Msg("re-initialize bootstrap IP done") + } + return nil + } if err != nil { ctrld.Log(ctx, mainLog.Error().Err(err), "failed to resolve query") return nil diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index edb43bf..8a1dd6e 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -11,11 +11,9 @@ import ( "syscall" "github.com/kardianos/service" - "github.com/miekg/dns" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" - ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) var logf = func(format string, args ...any) { @@ -37,7 +35,6 @@ type prog struct { func (p *prog) Start(s service.Service) error { p.cfg = &cfg go p.run() - mainLog.Info().Msg("Service started") return nil } @@ -67,45 +64,10 @@ func (p *prog) run() { for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() - if uc.BootstrapIP == "" { - // resolve it manually and set the bootstrap ip - c := new(dns.Client) - for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} { - if !ctrldnet.SupportsIPv6() && dnsType == dns.TypeAAAA { - continue - } - m := new(dns.Msg) - m.SetQuestion(uc.Domain+".", dnsType) - m.RecursionDesired = true - r, _, err := c.Exchange(m, net.JoinHostPort(bootstrapDNS, "53")) - if err != nil { - mainLog.Error().Err(err).Msgf("could not resolve domain %s for upstream.%s", uc.Domain, n) - continue - } - if r.Rcode != dns.RcodeSuccess { - mainLog.Error().Msgf("could not resolve domain return code: %d, upstream.%s", r.Rcode, n) - continue - } - if len(r.Answer) == 0 { - continue - } - for _, a := range r.Answer { - switch ar := a.(type) { - case *dns.A: - uc.BootstrapIP = ar.A.String() - case *dns.AAAA: - uc.BootstrapIP = ar.AAAA.String() - default: - continue - } - mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n) - // Stop if we reached here, because we got the bootstrap IP from r.Answer. - break - } - // If we reached here, uc.BootstrapIP was set, nothing to do anymore. - break - } + if err := uc.SetupBootstrapIP(); err != nil { + mainLog.Fatal().Err(err).Msgf("failed to setup bootstrap IP for upstream.%s", n) } + mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n) uc.SetupTransport() } diff --git a/config.go b/config.go index f6a132c..9082e01 100644 --- a/config.go +++ b/config.go @@ -2,19 +2,26 @@ package ctrld import ( "context" + "errors" "net" "net/http" "net/url" "os" "strings" + "sync" "time" - "github.com/Control-D-Inc/ctrld/internal/dnsrcode" "github.com/go-playground/validator/v10" "github.com/miekg/dns" "github.com/spf13/viper" + + "github.com/Control-D-Inc/ctrld/internal/dnsrcode" + ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +// ErrUpstreamFailed indicates that ctrld failed to connect to an upstream. +var ErrUpstreamFailed = errors.New("could not connect to upstream") + // SetConfigName set the config name that ctrld will look for. func SetConfigName(v *viper.Viper, name string) { v.SetConfigName(name) @@ -100,6 +107,9 @@ type UpstreamConfig struct { Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"` transport *http.Transport `mapstructure:"-" toml:"-"` http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"` + + // guard BootstrapIP + mu sync.Mutex } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -155,13 +165,51 @@ func (uc *UpstreamConfig) SetupTransport() { } } +// SetupBootstrapIP manually find all available IPs of the upstream. +func (uc *UpstreamConfig) SetupBootstrapIP() error { + uc.mu.Lock() + defer uc.mu.Unlock() + + c := new(dns.Client) + m := new(dns.Msg) + dnsType := dns.TypeA + if ctrldnet.SupportsIPv6() { + dnsType = dns.TypeAAAA + } + m.SetQuestion(uc.Domain+".", dnsType) + m.RecursionDesired = true + r, _, err := c.Exchange(m, net.JoinHostPort(bootstrapDNS, "53")) + if err != nil { + ProxyLog.Error().Err(err).Msgf("could not resolve domain %s for upstream", uc.Domain) + return err + } + if r.Rcode != dns.RcodeSuccess { + ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode) + return errors.New(dns.RcodeToString[r.Rcode]) + } + if len(r.Answer) == 0 { + return errors.New("no answer from bootstrap DNS server") + } + for _, a := range r.Answer { + switch ar := a.(type) { + case *dns.A: + uc.BootstrapIP = ar.A.String() + break + case *dns.AAAA: + uc.BootstrapIP = ar.AAAA.String() + break + } + } + return nil +} + func (uc *UpstreamConfig) setupDOHTransport() { uc.transport = http.DefaultTransport.(*http.Transport).Clone() uc.transport.IdleConnTimeout = 5 * time.Second uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 10 * time.Second, + Timeout: 5 * time.Second, + KeepAlive: 5 * time.Second, } Log(ctx, ProxyLog.Debug(), "debug dial context %s - %s - %s", addr, network, bootstrapDNS) // if we have a bootstrap ip set, use it to avoid DNS lookup @@ -169,9 +217,14 @@ func (uc *UpstreamConfig) setupDOHTransport() { if _, port, _ := net.SplitHostPort(addr); port != "" { addr = net.JoinHostPort(uc.BootstrapIP, port) } - Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) } - return dialer.DialContext(ctx, network, addr) + Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + Log(ctx, ProxyLog.Debug().Err(err), "could not dial to upstream") + return nil, ErrUpstreamFailed + } + return conn, nil } uc.pingUpstream() diff --git a/config_quic.go b/config_quic.go index 72ce351..fb00655 100644 --- a/config_quic.go +++ b/config_quic.go @@ -32,7 +32,12 @@ func (uc *UpstreamConfig) setupDOH3Transport() { if err != nil { return nil, err } - return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) + conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) + if err != nil { + Log(ctx, ProxyLog.Debug().Err(err), "could not dial early to upstream") + return nil, ErrUpstreamFailed + } + return conn, nil } uc.http3RoundTripper = rt diff --git a/doq.go b/doq.go index 20919e3..ab3fbb6 100644 --- a/doq.go +++ b/doq.go @@ -47,7 +47,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { session, err := quic.DialAddr(endpoint, tlsConfig, nil) if err != nil { - return nil, err + return nil, ErrUpstreamFailed } defer session.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") diff --git a/internal/net/net.go b/internal/net/net.go index 4360fc2..1488da9 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "time" "tailscale.com/logtail/backoff" @@ -28,13 +29,17 @@ var Dialer = &net.Dialer{ } var ( - stackOnce sync.Once + stackOnce atomic.Pointer[sync.Once] ipv4Enabled bool ipv6Enabled bool canListenIPv6Local bool hasNetworkUp bool ) +func init() { + stackOnce.Store(new(sync.Once)) +} + func probeStack() { b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, time.Minute) for { @@ -57,23 +62,27 @@ func probeStack() { } } +func Reset() { + stackOnce.Store(new(sync.Once)) +} + func Up() bool { - stackOnce.Do(probeStack) + stackOnce.Load().Do(probeStack) return hasNetworkUp } func SupportsIPv4() bool { - stackOnce.Do(probeStack) + stackOnce.Load().Do(probeStack) return ipv4Enabled } func SupportsIPv6() bool { - stackOnce.Do(probeStack) + stackOnce.Load().Do(probeStack) return ipv6Enabled } func SupportsIPv6ListenLocal() bool { - stackOnce.Do(probeStack) + stackOnce.Load().Do(probeStack) return canListenIPv6Local } diff --git a/resolver.go b/resolver.go index 5c04f37..bb23627 100644 --- a/resolver.go +++ b/resolver.go @@ -93,5 +93,8 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e Dialer: dialer, } answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint) + if _, ok := err.(*net.OpError); ok { + return answer, ErrUpstreamFailed + } return answer, err }