diff --git a/config.go b/config.go index bef32e3..b06f09f 100644 --- a/config.go +++ b/config.go @@ -347,9 +347,7 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) case IpStackSplit: uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if ctrldnet.IPv6Available(ctx) { + if hasIPv6() { uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 @@ -419,7 +417,10 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { case dns.TypeA: return pick(uc.bootstrapIPs4) default: - return pick(uc.bootstrapIPs6) + if hasIPv6() { + return pick(uc.bootstrapIPs6) + } + return pick(uc.bootstrapIPs4) } } return pick(uc.bootstrapIPs) @@ -438,7 +439,10 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { case dns.TypeA: return "tcp4-tls", "udp4" default: - return "tcp6-tls", "udp6" + if hasIPv6() { + return "tcp6-tls", "udp6" + } + return "tcp4-tls", "udp4" } } return "tcp-tls", "udp" diff --git a/dot.go b/dot.go index 68cf2e1..1fef409 100644 --- a/dot.go +++ b/dot.go @@ -33,6 +33,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro endpoint := r.uc.Endpoint if r.uc.BootstrapIP != "" { dnsClient.TLSConfig.ServerName = r.uc.Domain + dnsClient.Net = "tcp-tls" _, port, _ := net.SplitHostPort(endpoint) endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) } diff --git a/net.go b/net.go new file mode 100644 index 0000000..110d67e --- /dev/null +++ b/net.go @@ -0,0 +1,46 @@ +package ctrld + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "tailscale.com/logtail/backoff" + + ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" +) + +var ( + hasIPv6Once sync.Once + ipv6Available atomic.Bool +) + +func hasIPv6() bool { + hasIPv6Once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + val := ctrldnet.IPv6Available(ctx) + ipv6Available.Store(val) + go probingIPv6(val) + }) + return ipv6Available.Load() +} + +// TODO(cuonglm): doing poll check natively for supported platforms. +func probingIPv6(old bool) { + b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second) + bCtx := context.Background() + for { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + cur := ctrldnet.IPv6Available(ctx) + if ipv6Available.CompareAndSwap(old, cur) { + old = cur + } + }() + b.BackOff(bCtx, errors.New("no change")) + } +} diff --git a/resolver.go b/resolver.go index befa298..391a4e8 100644 --- a/resolver.go +++ b/resolver.go @@ -125,7 +125,14 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e Net: udpNet, Dialer: dialer, } - answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.uc.Endpoint) + endpoint := r.uc.Endpoint + if r.uc.BootstrapIP != "" { + dnsClient.Net = "udp" + _, port, _ := net.SplitHostPort(endpoint) + endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) + } + + answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) return answer, err }