perf(dot): implement connection pooling for improved performance

Implement TCP/TLS connection pooling for DoT resolver to match DoQ
performance. Previously, DoT created a new TCP/TLS connection for every
DNS query, incurring significant TLS handshake overhead. Now connections are
reused across queries, eliminating this overhead for subsequent requests.

The implementation follows the same pattern as DoQ, using parallel dialing
and connection pooling to achieve comparable performance characteristics.
This commit is contained in:
Cuong Manh Le
2026-01-08 20:00:15 +07:00
committed by Cuong Manh Le
parent 8dd90cb354
commit f859c52916
6 changed files with 339 additions and 24 deletions

View File

@@ -288,6 +288,9 @@ type UpstreamConfig struct {
doqConnPool *doqConnPool
doqConnPool4 *doqConnPool
doqConnPool6 *doqConnPool
dotClientPool *dotConnPool
dotClientPool4 *dotConnPool
dotClientPool6 *dotConnPool
certPool *x509.CertPool
u *url.URL
fallbackOnce sync.Once
@@ -510,7 +513,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
// ReBootstrap re-setup the bootstrap IP and the transport.
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT:
default:
return
}
@@ -524,10 +527,10 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
}
// SetupTransport initializes the network transport used to connect to upstream servers.
// For now, DoH/DoH3/DoQ upstreams are supported.
// For now, DoH/DoH3/DoQ/DoT upstreams are supported.
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT:
default:
return
}
@@ -541,18 +544,22 @@ func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
uc.transport = uc.newDOHTransport(ctx, ips)
uc.http3RoundTripper = uc.newDOH3Transport(ctx, ips)
uc.doqConnPool = uc.newDOQConnPool(ctx, ips)
uc.dotClientPool = uc.newDOTClientPool(ctx, ips)
if uc.IPStack == IpStackSplit {
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4)
uc.dotClientPool4 = uc.newDOTClientPool(ctx, uc.bootstrapIPs4)
if HasIPv6(ctx) {
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6)
uc.dotClientPool6 = uc.newDOTClientPool(ctx, uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
uc.http3RoundTripper6 = uc.http3RoundTripper4
uc.doqConnPool6 = uc.doqConnPool4
uc.dotClientPool6 = uc.dotClientPool4
}
}
}
@@ -674,6 +681,10 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error {
// For DoQ, we just ensure transport is set up by calling doqTransport
// DoQ doesn't use HTTP, so we can't ping it the same way
_ = uc.doqTransport(ctx, typ)
case ResolverTypeDOT:
// For DoT, we just ensure transport is set up by calling dotTransport
// DoT doesn't use HTTP, so we can't ping it the same way
_ = uc.dotTransport(ctx, typ)
}
}

View File

@@ -64,6 +64,11 @@ func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doq
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
}
func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool {
uc.ensureSetupTransport(ctx)
return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6)
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
@@ -138,3 +143,10 @@ func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *d
}
return newDOQConnPool(ctx, uc, addrs)
}
func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool {
if uc.Type != ResolverTypeDOT {
return nil
}
return newDOTClientPool(ctx, uc, addrs)
}

3
doh.go
View File

@@ -88,6 +88,9 @@ type dohResolver struct {
// Resolve performs DNS query with given DNS message using DOH protocol.
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoH resolver query started")

3
doq.go
View File

@@ -21,6 +21,9 @@ type doqResolver struct {
}
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoQ resolver query started")

309
dot.go
View File

@@ -3,7 +3,12 @@ package ctrld
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
)
@@ -13,39 +18,301 @@ type dotResolver struct {
}
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoT resolver query started")
// The dialer is used to prevent bootstrapping cycle.
// If r.endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp)
dnsClient := &dns.Client{
Net: tcpNet,
Dialer: dialer,
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
}
endpoint := r.uc.Endpoint
if r.uc.BootstrapIP != "" {
dnsClient.TLSConfig.ServerName = r.uc.Domain
dnsClient.Net = "tcp-tls"
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
pool := r.uc.dotTransport(ctx, dnsTyp)
if pool == nil {
Log(ctx, logger.Error(), "DoT client pool is not available")
return nil, errors.New("DoT client pool is not available")
}
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
answer, err := pool.Resolve(ctx, msg)
if err != nil {
Log(ctx, logger.Error().Err(err), "DoT request failed")
} else {
Log(ctx, logger.Debug(), "DoT resolver query successful")
}
return answer, wrapCertificateVerificationError(err)
return answer, err
}
// dotConnPool manages a pool of TCP/TLS connections for DoT queries.
type dotConnPool struct {
uc *UpstreamConfig
addrs []string
port string
tlsConfig *tls.Config
dialer *net.Dialer
mu sync.RWMutex
conns map[string]*dotConn
closed bool
}
type dotConn struct {
conn net.Conn
lastUsed time.Time
refCount int
mu sync.Mutex
}
func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool {
_, port, _ := net.SplitHostPort(uc.Endpoint)
if port == "" {
port = "853"
}
// The dialer is used to prevent bootstrapping cycle.
// If endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
tlsConfig := &tls.Config{
RootCAs: uc.certPool,
}
if uc.BootstrapIP != "" {
tlsConfig.ServerName = uc.Domain
}
pool := &dotConnPool{
uc: uc,
addrs: addrs,
port: port,
tlsConfig: tlsConfig,
dialer: dialer,
conns: make(map[string]*dotConn),
}
// Use SetFinalizer here because we need to call a method on the pool itself.
// AddCleanup would require passing the pool as arg (which panics) or capturing
// it in a closure (which prevents GC). SetFinalizer is appropriate for this case.
runtime.SetFinalizer(pool, func(p *dotConnPool) {
p.CloseIdleConnections()
})
return pool
}
// Resolve performs a DNS query using a pooled TCP/TLS connection.
func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if msg == nil {
return nil, errors.New("nil DNS message")
}
conn, addr, err := p.getConn(ctx)
if err != nil {
return nil, wrapCertificateVerificationError(err)
}
// Set deadline
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(5 * time.Second)
}
_ = conn.SetDeadline(deadline)
client := dns.Client{Net: "tcp-tls"}
answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn})
isGood := err == nil
p.putConn(addr, conn, isGood)
if err != nil {
return nil, wrapCertificateVerificationError(err)
}
return answer, nil
}
// getConn gets a TCP/TLS connection from the pool or creates a new one.
func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, "", io.EOF
}
// Try to reuse an existing connection
for addr, dotConn := range p.conns {
dotConn.mu.Lock()
if dotConn.refCount == 0 && dotConn.conn != nil {
dotConn.refCount++
dotConn.lastUsed = time.Now()
conn := dotConn.conn
dotConn.mu.Unlock()
return conn, addr, nil
}
dotConn.mu.Unlock()
}
// No available connection, create a new one
addr, conn, err := p.dialConn(ctx)
if err != nil {
return nil, "", err
}
dotConn := &dotConn{
conn: conn,
lastUsed: time.Now(),
refCount: 1,
}
p.conns[addr] = dotConn
return conn, addr, nil
}
// putConn returns a connection to the pool.
func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) {
p.mu.Lock()
defer p.mu.Unlock()
dotConn, ok := p.conns[addr]
if !ok {
return
}
dotConn.mu.Lock()
defer dotConn.mu.Unlock()
dotConn.refCount--
if dotConn.refCount < 0 {
dotConn.refCount = 0
}
// If connection is bad, remove it from pool
if !isGood {
delete(p.conns, addr)
if conn != nil {
conn.Close()
}
return
}
dotConn.lastUsed = time.Now()
}
// dialConn creates a new TCP/TLS connection.
func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) {
logger := LoggerFromCtx(ctx)
var endpoint string
if p.uc.BootstrapIP != "" {
endpoint = net.JoinHostPort(p.uc.BootstrapIP, p.port)
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
if err != nil {
return "", nil, err
}
tlsConn := tls.Client(conn, p.tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
conn.Close()
return "", nil, err
}
return endpoint, tlsConn, nil
}
// Try bootstrap IPs in parallel
if len(p.addrs) > 0 {
type result struct {
conn net.Conn
addr string
err error
}
ch := make(chan result, len(p.addrs))
done := make(chan struct{})
defer close(done)
for _, addr := range p.addrs {
go func(addr string) {
endpoint := net.JoinHostPort(addr, p.port)
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
if err != nil {
select {
case ch <- result{conn: nil, addr: endpoint, err: err}:
case <-done:
}
return
}
tlsConfig := p.tlsConfig.Clone()
tlsConfig.ServerName = p.uc.Domain
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
conn.Close()
select {
case ch <- result{conn: nil, addr: endpoint, err: err}:
case <-done:
}
return
}
select {
case ch <- result{conn: tlsConn, addr: endpoint, err: nil}:
case <-done:
if conn != nil {
conn.Close()
}
}
}(addr)
}
errs := make([]error, 0, len(p.addrs))
for range len(p.addrs) {
select {
case res := <-ch:
if res.err == nil && res.conn != nil {
Log(ctx, logger.Debug(), "Sending DoT request to: %s", res.addr)
return res.addr, res.conn, nil
}
if res.err != nil {
errs = append(errs, res.err)
}
case <-ctx.Done():
return "", nil, ctx.Err()
}
}
return "", nil, errors.Join(errs...)
}
// Fallback to endpoint resolution
endpoint = p.uc.Endpoint
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
if err != nil {
return "", nil, err
}
tlsConn := tls.Client(conn, p.tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
conn.Close()
return "", nil, err
}
return endpoint, tlsConn, nil
}
// CloseIdleConnections closes all connections in the pool.
func (p *dotConnPool) CloseIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
for addr, dotConn := range p.conns {
dotConn.mu.Lock()
if dotConn.conn != nil {
dotConn.conn.Close()
}
dotConn.mu.Unlock()
delete(p.conns, addr)
}
}

View File

@@ -267,6 +267,9 @@ const hotCacheTTL = time.Second
// for a short period (currently 1 second), reducing unnecessary traffics
// sent to upstreams.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
if len(msg.Question) == 0 {
return nil, errors.New("no question found")
}
@@ -492,6 +495,9 @@ type legacyResolver struct {
}
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "Legacy resolver query started")
@@ -526,6 +532,9 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
type dummyResolver struct{}
func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if err := validateMsg(msg); err != nil {
return nil, err
}
ans := new(dns.Msg)
ans.SetReply(msg)
return ans, nil
@@ -749,3 +758,13 @@ func isLanAddr(addr netip.Addr) bool {
addr.IsLinkLocalUnicast() ||
tsaddr.CGNATRange().Contains(addr)
}
func validateMsg(msg *dns.Msg) error {
if msg == nil {
return errors.New("nil DNS message")
}
if len(msg.Question) == 0 {
return errors.New("no question found")
}
return nil
}