mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-13 10:26:06 +00:00
refactor(config): consolidate transport setup and eliminate duplication
Consolidate DoH/DoH3/DoQ transport initialization into a single SetupTransport method and introduce generic helper functions to eliminate duplicated IP stack selection logic across transport getters. This reduces code duplication by ~77 lines while maintaining the same functionality.
This commit is contained in:
committed by
Cuong Manh Le
parent
f4a938c873
commit
366193514b
123
config.go
123
config.go
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -520,58 +519,53 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// SetupTransport initializes the network transport used to connect to upstream server.
|
||||
// For now, only DoH upstream is supported.
|
||||
// SetupTransport initializes the network transport used to connect to upstream servers.
|
||||
// For now, DoH/DoH3/DoQ upstreams are supported.
|
||||
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
uc.setupDOHTransport(ctx)
|
||||
case ResolverTypeDOH3:
|
||||
uc.setupDOH3Transport(ctx)
|
||||
case ResolverTypeDOQ:
|
||||
uc.setupDOQTransport(ctx)
|
||||
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOQTransport(ctx context.Context) {
|
||||
ips := uc.bootstrapIPs
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs4)
|
||||
ips = uc.bootstrapIPs4
|
||||
case IpStackV6:
|
||||
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
ips = uc.bootstrapIPs6
|
||||
}
|
||||
uc.transport = uc.newDOHTransport(ctx, ips)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, ips)
|
||||
uc.doqConnPool = uc.newDOQConnPool(ctx, ips)
|
||||
if uc.IPStack == IpStackSplit {
|
||||
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
uc.doqConnPool6 = uc.doqConnPool4
|
||||
}
|
||||
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
}
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
func (uc *UpstreamConfig) ensureSetupTransport(ctx context.Context) {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport {
|
||||
if uc.Type != ResolverTypeDOH {
|
||||
return nil
|
||||
}
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConnsPerHost = 100
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
@@ -707,46 +701,8 @@ func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return uc.transport
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return uc.transport4
|
||||
default:
|
||||
return uc.transport6
|
||||
}
|
||||
}
|
||||
return uc.transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return pick(uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
return pick(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if HasIPv6(ctx) {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
}
|
||||
}
|
||||
return pick(uc.bootstrapIPs)
|
||||
uc.ensureSetupTransport(ctx)
|
||||
return transportByIpStack(uc.IPStack, dnsType, uc.transport, uc.transport4, uc.transport6)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) {
|
||||
@@ -998,10 +954,6 @@ func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
return ResolverTypeDOT
|
||||
}
|
||||
|
||||
func pick(s []string) string {
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID(ctx context.Context) string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
@@ -1038,3 +990,18 @@ func bootstrapIPsFromControlDDomain(domain string) []string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func transportByIpStack[T any](ipStack string, dnsType uint16, transport, transport4, transport6 T) T {
|
||||
switch ipStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return transport
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return transport4
|
||||
default:
|
||||
return transport6
|
||||
}
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
@@ -9,31 +9,14 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
||||
if uc.Type != ResolverTypeDOH3 {
|
||||
return nil
|
||||
}
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
logger := LoggerFromCtx(ctx)
|
||||
@@ -72,45 +55,13 @@ func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return uc.http3RoundTripper
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return uc.http3RoundTripper4
|
||||
default:
|
||||
return uc.http3RoundTripper6
|
||||
}
|
||||
}
|
||||
return uc.http3RoundTripper
|
||||
uc.ensureSetupTransport(ctx)
|
||||
return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return uc.doqConnPool
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return uc.doqConnPool4
|
||||
default:
|
||||
return uc.doqConnPool6
|
||||
}
|
||||
}
|
||||
return uc.doqConnPool
|
||||
uc.ensureSetupTransport(ctx)
|
||||
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
|
||||
}
|
||||
|
||||
// Putting the code for quic parallel dialer here:
|
||||
@@ -182,5 +133,8 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool {
|
||||
if uc.Type != ResolverTypeDOQ {
|
||||
return nil
|
||||
}
|
||||
return newDOQConnPool(ctx, uc, addrs)
|
||||
}
|
||||
|
||||
4
doq.go
4
doq.go
@@ -63,7 +63,7 @@ type doqConn struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool {
|
||||
func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool {
|
||||
_, port, _ := net.SplitHostPort(uc.Endpoint)
|
||||
if port == "" {
|
||||
port = "853"
|
||||
@@ -96,7 +96,7 @@ func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *do
|
||||
// Resolve performs a DNS query using a pooled QUIC connection.
|
||||
func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// Retry logic for io.EOF errors (as per original implementation)
|
||||
for i := 0; i < 5; i++ {
|
||||
for range 5 {
|
||||
answer, err := p.doResolve(ctx, msg)
|
||||
if err == io.EOF {
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user