diff --git a/config.go b/config.go index 4a3c113..63c7f6a 100644 --- a/config.go +++ b/config.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "math/rand" "net" "net/http" "net/netip" @@ -520,58 +519,53 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { }) } -// SetupTransport initializes the network transport used to connect to upstream server. -// For now, only DoH upstream is supported. +// SetupTransport initializes the network transport used to connect to upstream servers. +// For now, DoH/DoH3/DoQ upstreams are supported. func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { switch uc.Type { - case ResolverTypeDOH: - uc.setupDOHTransport(ctx) - case ResolverTypeDOH3: - uc.setupDOH3Transport(ctx) - case ResolverTypeDOQ: - uc.setupDOQTransport(ctx) + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: + default: + return } -} - -func (uc *UpstreamConfig) setupDOQTransport(ctx context.Context) { + ips := uc.bootstrapIPs switch uc.IPStack { - case IpStackBoth, "": - uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) case IpStackV4: - uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + ips = uc.bootstrapIPs4 case IpStackV6: - uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) - case IpStackSplit: + ips = uc.bootstrapIPs6 + } + uc.transport = uc.newDOHTransport(ctx, ips) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, ips) + uc.doqConnPool = uc.newDOQConnPool(ctx, ips) + if uc.IPStack == IpStackSplit { + uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) + uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) if HasIPv6(ctx) { + uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) + uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) } else { + uc.transport6 = uc.transport4 + uc.http3RoundTripper6 = uc.http3RoundTripper4 uc.doqConnPool6 = uc.doqConnPool4 } - uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) { - switch uc.IPStack { - case IpStackBoth, "": - uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) - case IpStackV4: - uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4) - case IpStackV6: - uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6) - case IpStackSplit: - uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) - if HasIPv6(ctx) { - uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) - } else { - uc.transport6 = uc.transport4 - } - uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) +func (uc *UpstreamConfig) ensureSetupTransport(ctx context.Context) { + uc.transportOnce.Do(func() { + uc.SetupTransport(ctx) + }) + if uc.rebootstrap.CompareAndSwap(true, false) { + uc.SetupTransport(ctx) } } func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport { + if uc.Type != ResolverTypeDOH { + return nil + } transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxIdleConnsPerHost = 100 transport.TLSClientConfig = &tls.Config{ @@ -707,46 +701,8 @@ func (uc *UpstreamConfig) isNextDNS() bool { } func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper { - uc.transportOnce.Do(func() { - uc.SetupTransport(ctx) - }) - if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport(ctx) - } - switch uc.IPStack { - case IpStackBoth, IpStackV4, IpStackV6: - return uc.transport - case IpStackSplit: - switch dnsType { - case dns.TypeA: - return uc.transport4 - default: - return uc.transport6 - } - } - return uc.transport -} - -func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string { - switch uc.IPStack { - case IpStackBoth: - return pick(uc.bootstrapIPs) - case IpStackV4: - return pick(uc.bootstrapIPs4) - case IpStackV6: - return pick(uc.bootstrapIPs6) - case IpStackSplit: - switch dnsType { - case dns.TypeA: - return pick(uc.bootstrapIPs4) - default: - if HasIPv6(ctx) { - return pick(uc.bootstrapIPs6) - } - return pick(uc.bootstrapIPs4) - } - } - return pick(uc.bootstrapIPs) + uc.ensureSetupTransport(ctx) + return transportByIpStack(uc.IPStack, dnsType, uc.transport, uc.transport4, uc.transport6) } func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) { @@ -998,10 +954,6 @@ func ResolverTypeFromEndpoint(endpoint string) string { return ResolverTypeDOT } -func pick(s []string) string { - return s[rand.Intn(len(s))] -} - // upstreamUID generates an unique identifier for an upstream. func upstreamUID(ctx context.Context) string { logger := LoggerFromCtx(ctx) @@ -1038,3 +990,18 @@ func bootstrapIPsFromControlDDomain(domain string) []string { } return nil } + +func transportByIpStack[T any](ipStack string, dnsType uint16, transport, transport4, transport6 T) T { + switch ipStack { + case IpStackBoth, IpStackV4, IpStackV6: + return transport + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return transport4 + default: + return transport6 + } + } + return transport +} diff --git a/config_quic.go b/config_quic.go index 6172ba2..df9f22b 100644 --- a/config_quic.go +++ b/config_quic.go @@ -9,31 +9,14 @@ import ( "runtime" "sync" - "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" ) -func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) { - switch uc.IPStack { - case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) - case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) - case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) - case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) - if HasIPv6(ctx) { - uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) - } else { - uc.http3RoundTripper6 = uc.http3RoundTripper4 - } - uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) - } -} - func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper { + if uc.Type != ResolverTypeDOH3 { + return nil + } rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} logger := LoggerFromCtx(ctx) @@ -72,45 +55,13 @@ func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) } func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper { - uc.transportOnce.Do(func() { - uc.SetupTransport(ctx) - }) - if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport(ctx) - } - switch uc.IPStack { - case IpStackBoth, IpStackV4, IpStackV6: - return uc.http3RoundTripper - case IpStackSplit: - switch dnsType { - case dns.TypeA: - return uc.http3RoundTripper4 - default: - return uc.http3RoundTripper6 - } - } - return uc.http3RoundTripper + uc.ensureSetupTransport(ctx) + return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6) } func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool { - uc.transportOnce.Do(func() { - uc.SetupTransport(ctx) - }) - if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport(ctx) - } - switch uc.IPStack { - case IpStackBoth, IpStackV4, IpStackV6: - return uc.doqConnPool - case IpStackSplit: - switch dnsType { - case dns.TypeA: - return uc.doqConnPool4 - default: - return uc.doqConnPool6 - } - } - return uc.doqConnPool + uc.ensureSetupTransport(ctx) + return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6) } // Putting the code for quic parallel dialer here: @@ -182,5 +133,8 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t } func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool { + if uc.Type != ResolverTypeDOQ { + return nil + } return newDOQConnPool(ctx, uc, addrs) } diff --git a/doq.go b/doq.go index d309e45..6556eb3 100644 --- a/doq.go +++ b/doq.go @@ -63,7 +63,7 @@ type doqConn struct { mu sync.Mutex } -func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool { +func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool { _, port, _ := net.SplitHostPort(uc.Endpoint) if port == "" { port = "853" @@ -96,7 +96,7 @@ func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *do // Resolve performs a DNS query using a pooled QUIC connection. func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // Retry logic for io.EOF errors (as per original implementation) - for i := 0; i < 5; i++ { + for range 5 { answer, err := p.doResolve(ctx, msg) if err == io.EOF { continue