diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 4e67df6..bbe84f1 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "net" "runtime" @@ -170,12 +169,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i staleAnswer = answer } } - resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { + resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) if err != nil { ctrld.Log(ctx, mainLog.Error().Err(err), "failed to create resolver") - return nil + return nil, err } resolveCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -184,17 +183,17 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i defer cancel() resolveCtx = timeoutCtx } - answer, err := dnsResolver.Resolve(resolveCtx, msg) - if errors.Is(err, ctrld.ErrUpstreamFailed) { - ctrldnet.Reset() - if err := upstreamConfig.SetupBootstrapIP(); err != nil { - mainLog.Error().Err(err).Msg("could not re-initialize bootstrap IP") - } else { - mainLog.Debug().Msg("re-initialize bootstrap IP done") - } - return nil - } + return dnsResolver.Resolve(resolveCtx, msg) + } + resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { + answer, err := resolve1(n, upstreamConfig, msg) if err != nil { + // If any error occurred, re-bootstrap transport/ip, retry the request. + upstreamConfig.ReBootstrap() + answer, err = resolve1(n, upstreamConfig, msg) + if err == nil { + return answer + } ctrld.Log(ctx, mainLog.Error().Err(err), "failed to resolve query") return nil } diff --git a/config.go b/config.go index 9082e01..507c914 100644 --- a/config.go +++ b/config.go @@ -14,14 +14,12 @@ import ( "github.com/go-playground/validator/v10" "github.com/miekg/dns" "github.com/spf13/viper" + "golang.org/x/sync/singleflight" "github.com/Control-D-Inc/ctrld/internal/dnsrcode" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) -// ErrUpstreamFailed indicates that ctrld failed to connect to an upstream. -var ErrUpstreamFailed = errors.New("could not connect to upstream") - // SetConfigName set the config name that ctrld will look for. func SetConfigName(v *viper.Viper, name string) { v.SetConfigName(name) @@ -108,6 +106,7 @@ type UpstreamConfig struct { transport *http.Transport `mapstructure:"-" toml:"-"` http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"` + g singleflight.Group // guard BootstrapIP mu sync.Mutex } @@ -154,6 +153,22 @@ func (uc *UpstreamConfig) Init() { } } +// ReBootstrap re-setup the bootstrap IP and the transport. +func (uc *UpstreamConfig) ReBootstrap() { + _, _, _ = uc.g.Do("rebootstrap", func() (any, error) { + ProxyLog.Debug().Msg("re-bootstrapping upstream ip") + ctrldnet.Reset() + err := uc.SetupBootstrapIP() + if err != nil { + ProxyLog.Error().Err(err).Msg("re-bootstrapping failed") + } else { + ProxyLog.Debug().Msgf("bootstrap ip set to: %s", uc.BootstrapIP) + } + uc.SetupTransport() + return err == nil, err + }) +} + // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. func (uc *UpstreamConfig) SetupTransport() { @@ -208,8 +223,8 @@ func (uc *UpstreamConfig) setupDOHTransport() { uc.transport.IdleConnTimeout = 5 * time.Second uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 5 * time.Second, + Timeout: 2 * time.Second, + KeepAlive: 2 * time.Second, } Log(ctx, ProxyLog.Debug(), "debug dial context %s - %s - %s", addr, network, bootstrapDNS) // if we have a bootstrap ip set, use it to avoid DNS lookup @@ -219,12 +234,7 @@ func (uc *UpstreamConfig) setupDOHTransport() { } } Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - Log(ctx, ProxyLog.Debug().Err(err), "could not dial to upstream") - return nil, ErrUpstreamFailed - } - return conn, nil + return dialer.DialContext(ctx, network, addr) } uc.pingUpstream() diff --git a/config_quic.go b/config_quic.go index fb00655..72ce351 100644 --- a/config_quic.go +++ b/config_quic.go @@ -32,12 +32,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { if err != nil { return nil, err } - conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) - if err != nil { - Log(ctx, ProxyLog.Debug().Err(err), "could not dial early to upstream") - return nil, ErrUpstreamFailed - } - return conn, nil + return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) } uc.http3RoundTripper = rt diff --git a/doq.go b/doq.go index ab3fbb6..20919e3 100644 --- a/doq.go +++ b/doq.go @@ -47,7 +47,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { session, err := quic.DialAddr(endpoint, tlsConfig, nil) if err != nil { - return nil, ErrUpstreamFailed + return nil, err } defer session.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") diff --git a/go.mod b/go.mod index 1b09384..7193372 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/cobra v1.4.0 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.8.1 + golang.org/x/sync v0.1.0 golang.org/x/sys v0.5.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.34.1 @@ -67,7 +68,6 @@ require ( golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.6.0 // indirect golang.org/x/net v0.7.0 // indirect - golang.org/x/sync v0.1.0 // indirect golang.org/x/text v0.7.0 // indirect golang.org/x/tools v0.2.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/resolver.go b/resolver.go index bb23627..5c04f37 100644 --- a/resolver.go +++ b/resolver.go @@ -93,8 +93,5 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e Dialer: dialer, } answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint) - if _, ok := err.(*net.OpError); ok { - return answer, ErrUpstreamFailed - } return answer, err }