refactor(config): consolidate transport setup and eliminate duplication

Consolidate DoH/DoH3/DoQ transport initialization into a single
SetupTransport method and introduce generic helper functions to eliminate
duplicated IP stack selection logic across transport getters.

This reduces code duplication by ~77 lines while maintaining the same
functionality.
This commit is contained in:
Cuong Manh Le
2026-01-06 18:50:13 +07:00
committed by Cuong Manh Le
parent e8d1a4604e
commit 1f4c47318e
3 changed files with 54 additions and 114 deletions

119
config.go
View File

@@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/netip"
@@ -509,54 +508,49 @@ func (uc *UpstreamConfig) ReBootstrap() {
// For now, only DoH upstream is supported.
func (uc *UpstreamConfig) SetupTransport() {
switch uc.Type {
case ResolverTypeDOH:
uc.setupDOHTransport()
case ResolverTypeDOH3:
uc.setupDOH3Transport()
case ResolverTypeDOQ:
uc.setupDOQTransport()
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
default:
return
}
}
func (uc *UpstreamConfig) setupDOQTransport() {
ips := uc.bootstrapIPs
switch uc.IPStack {
case IpStackBoth, "":
uc.doqConnPool = uc.newDOQConnPool(uc.bootstrapIPs)
case IpStackV4:
uc.doqConnPool = uc.newDOQConnPool(uc.bootstrapIPs4)
ips = uc.bootstrapIPs4
case IpStackV6:
uc.doqConnPool = uc.newDOQConnPool(uc.bootstrapIPs6)
case IpStackSplit:
ips = uc.bootstrapIPs6
}
uc.transport = uc.newDOHTransport(ips)
uc.http3RoundTripper = uc.newDOH3Transport(ips)
uc.doqConnPool = uc.newDOQConnPool(ips)
if uc.IPStack == IpStackSplit {
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
uc.doqConnPool4 = uc.newDOQConnPool(uc.bootstrapIPs4)
if HasIPv6() {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
uc.doqConnPool6 = uc.newDOQConnPool(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
uc.http3RoundTripper6 = uc.http3RoundTripper4
uc.doqConnPool6 = uc.doqConnPool4
}
uc.doqConnPool = uc.newDOQConnPool(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) setupDOHTransport() {
switch uc.IPStack {
case IpStackBoth, "":
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
case IpStackV4:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
case IpStackV6:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
case IpStackSplit:
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
if HasIPv6() {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
}
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
func (uc *UpstreamConfig) ensureSetupTransport() {
uc.transportOnce.Do(func() {
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport()
}
}
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
if uc.Type != ResolverTypeDOH {
return nil
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.MaxIdleConnsPerHost = 100
transport.TLSClientConfig = &tls.Config{
@@ -690,46 +684,8 @@ func (uc *UpstreamConfig) isNextDNS() bool {
}
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.transport
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.transport4
default:
return uc.transport6
}
}
return uc.transport
}
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
switch uc.IPStack {
case IpStackBoth:
return pick(uc.bootstrapIPs)
case IpStackV4:
return pick(uc.bootstrapIPs4)
case IpStackV6:
return pick(uc.bootstrapIPs6)
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return pick(uc.bootstrapIPs4)
default:
if HasIPv6() {
return pick(uc.bootstrapIPs6)
}
return pick(uc.bootstrapIPs4)
}
}
return pick(uc.bootstrapIPs)
uc.ensureSetupTransport()
return transportByIpStack(uc.IPStack, dnsType, uc.transport, uc.transport4, uc.transport6)
}
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
@@ -974,10 +930,6 @@ func ResolverTypeFromEndpoint(endpoint string) string {
return ResolverTypeDOT
}
func pick(s []string) string {
return s[rand.Intn(len(s))]
}
// upstreamUID generates an unique identifier for an upstream.
func upstreamUID() string {
b := make([]byte, 4)
@@ -1013,3 +965,18 @@ func bootstrapIPsFromControlDDomain(domain string) []string {
}
return nil
}
func transportByIpStack[T any](ipStack string, dnsType uint16, transport, transport4, transport6 T) T {
switch ipStack {
case IpStackBoth, IpStackV4, IpStackV6:
return transport
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return transport4
default:
return transport6
}
}
return transport
}

View File

@@ -9,7 +9,6 @@ import (
"runtime"
"sync"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
@@ -34,6 +33,9 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
}
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
if uc.Type != ResolverTypeDOH3 {
return nil
}
rt := &http3.Transport{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
@@ -71,45 +73,13 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
}
return uc.http3RoundTripper
uc.ensureSetupTransport()
return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6)
}
func (uc *UpstreamConfig) doqTransport(dnsType uint16) *doqConnPool {
uc.transportOnce.Do(func() {
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.doqConnPool
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.doqConnPool4
default:
return uc.doqConnPool6
}
}
return uc.doqConnPool
uc.ensureSetupTransport()
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
}
// Putting the code for quic parallel dialer here:
@@ -181,5 +151,8 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
}
func (uc *UpstreamConfig) newDOQConnPool(addrs []string) *doqConnPool {
if uc.Type != ResolverTypeDOQ {
return nil
}
return newDOQConnPool(uc, addrs)
}

2
doq.go
View File

@@ -86,7 +86,7 @@ func newDOQConnPool(uc *UpstreamConfig, addrs []string) *doqConnPool {
// Resolve performs a DNS query using a pooled QUIC connection.
func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// Retry logic for io.EOF errors (as per original implementation)
for i := 0; i < 5; i++ {
for range 5 {
answer, err := p.doResolve(ctx, msg)
if err == io.EOF {
continue