perf(doq): implement connection pooling for improved performance

Implement QUIC connection pooling for DoQ resolver to match DoH3
performance. Previously, DoQ created a new QUIC connection for every
DNS query, incurring significant handshake overhead. Now connections are
reused across queries, eliminating this overhead for subsequent requests.

The implementation follows the same pattern as DoH3, using parallel dialing
and connection pooling to achieve comparable performance characteristics.
This commit is contained in:
Cuong Manh Le
2026-01-06 14:46:00 +07:00
committed by Cuong Manh Le
parent aacba92698
commit e4e655414c
3 changed files with 286 additions and 31 deletions
+30 -2
View File
@@ -282,6 +282,9 @@ type UpstreamConfig struct {
http3RoundTripper http.RoundTripper http3RoundTripper http.RoundTripper
http3RoundTripper4 http.RoundTripper http3RoundTripper4 http.RoundTripper
http3RoundTripper6 http.RoundTripper http3RoundTripper6 http.RoundTripper
doqConnPool *doqConnPool
doqConnPool4 *doqConnPool
doqConnPool6 *doqConnPool
certPool *x509.CertPool certPool *x509.CertPool
u *url.URL u *url.URL
fallbackOnce sync.Once fallbackOnce sync.Once
@@ -504,7 +507,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
// ReBootstrap re-setup the bootstrap IP and the transport. // ReBootstrap re-setup the bootstrap IP and the transport.
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
switch uc.Type { switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3: case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
default: default:
return return
} }
@@ -525,6 +528,27 @@ func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
uc.setupDOHTransport(ctx) uc.setupDOHTransport(ctx)
case ResolverTypeDOH3: case ResolverTypeDOH3:
uc.setupDOH3Transport(ctx) uc.setupDOH3Transport(ctx)
case ResolverTypeDOQ:
uc.setupDOQTransport(ctx)
}
}
func (uc *UpstreamConfig) setupDOQTransport(ctx context.Context) {
switch uc.IPStack {
case IpStackBoth, "":
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs)
case IpStackV4:
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs4)
case IpStackV6:
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs6)
case IpStackSplit:
uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4)
if HasIPv6(ctx) {
uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6)
} else {
uc.doqConnPool6 = uc.doqConnPool4
}
uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs)
} }
} }
@@ -612,7 +636,7 @@ func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error {
func (uc *UpstreamConfig) ping(ctx context.Context) error { func (uc *UpstreamConfig) ping(ctx context.Context) error {
switch uc.Type { switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3: case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
default: default:
return nil return nil
} }
@@ -646,6 +670,10 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error {
if err := ping(uc.doh3Transport(ctx, typ)); err != nil { if err := ping(uc.doh3Transport(ctx, typ)); err != nil {
return err return err
} }
case ResolverTypeDOQ:
// For DoQ, we just ensure transport is set up by calling doqTransport
// DoQ doesn't use HTTP, so we can't ping it the same way
_ = uc.doqTransport(ctx, typ)
} }
} }
+25
View File
@@ -92,6 +92,27 @@ func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) htt
return uc.http3RoundTripper return uc.http3RoundTripper
} }
func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool {
uc.transportOnce.Do(func() {
uc.SetupTransport(ctx)
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport(ctx)
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.doqConnPool
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.doqConnPool4
default:
return uc.doqConnPool6
}
}
return uc.doqConnPool
}
// Putting the code for quic parallel dialer here: // Putting the code for quic parallel dialer here:
// //
// - quic dialer is different with net.Dialer // - quic dialer is different with net.Dialer
@@ -159,3 +180,7 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
return nil, errors.Join(errs...) return nil, errors.Join(errs...)
} }
func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool {
return newDOQConnPool(ctx, uc, addrs)
}
+231 -29
View File
@@ -5,8 +5,11 @@ package ctrld
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"io" "io"
"net" "net"
"runtime"
"sync"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -21,22 +24,19 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
logger := LoggerFromCtx(ctx) logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoQ resolver query started") Log(ctx, logger.Debug(), "DoQ resolver query started")
endpoint := r.uc.Endpoint // Get the appropriate connection pool based on DNS type and IP stack
tlsConfig := &tls.Config{NextProtos: []string{"doq"}} dnsTyp := uint16(0)
ip := r.uc.BootstrapIP if msg != nil && len(msg.Question) > 0 {
if ip == "" { dnsTyp = msg.Question[0].Qtype
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp)
} }
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(ip, port)
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", endpoint) pool := r.uc.doqTransport(ctx, dnsTyp)
answer, err := resolve(ctx, msg, endpoint, tlsConfig) if pool == nil {
Log(ctx, logger.Error(), "DoQ connection pool is not available")
return nil, errors.New("DoQ connection pool is not available")
}
answer, err := pool.Resolve(ctx, msg)
if err != nil { if err != nil {
Log(ctx, logger.Error().Err(err), "DoQ request failed") Log(ctx, logger.Error().Err(err), "DoQ request failed")
} else { } else {
@@ -45,11 +45,59 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, err return answer, err
} }
func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { // doqConnPool manages a pool of QUIC connections for DoQ queries.
// DoQ quic-go server returns io.EOF error after running for a long time, type doqConnPool struct {
// even for a good stream. So retrying the query for 5 times before giving up. uc *UpstreamConfig
addrs []string
port string
tlsConfig *tls.Config
mu sync.RWMutex
conns map[string]*doqConn
closed bool
}
type doqConn struct {
conn *quic.Conn
lastUsed time.Time
refCount int
mu sync.Mutex
}
func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool {
_, port, _ := net.SplitHostPort(uc.Endpoint)
if port == "" {
port = "853"
}
tlsConfig := &tls.Config{
NextProtos: []string{"doq"},
RootCAs: uc.certPool,
ServerName: uc.Domain,
}
pool := &doqConnPool{
uc: uc,
addrs: addrs,
port: port,
tlsConfig: tlsConfig,
conns: make(map[string]*doqConn),
}
// Use SetFinalizer here because we need to call a method on the pool itself.
// AddCleanup would require passing the pool as arg (which panics) or capturing
// it in a closure (which prevents GC). SetFinalizer is appropriate for this case.
runtime.SetFinalizer(pool, func(p *doqConnPool) {
p.CloseIdleConnections()
})
return pool
}
// Resolve performs a DNS query using a pooled QUIC connection.
func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// Retry logic for io.EOF errors (as per original implementation)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
answer, err := doResolve(ctx, msg, endpoint, tlsConfig) answer, err := p.doResolve(ctx, msg)
if err == io.EOF { if err == io.EOF {
continue continue
} }
@@ -58,57 +106,72 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
} }
return answer, nil return answer, nil
} }
return nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(quic.InternalError), ErrorMessage: quic.InternalError.Message()} return nil, &quic.ApplicationError{
ErrorCode: quic.ApplicationErrorCode(quic.InternalError),
ErrorMessage: quic.InternalError.Message(),
}
} }
func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil) conn, addr, err := p.getConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer session.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
// Pack the DNS message
msgBytes, err := msg.Pack() msgBytes, err := msg.Pack()
if err != nil { if err != nil {
p.putConn(addr, conn, false)
return nil, err return nil, err
} }
stream, err := session.OpenStream() // Open a new stream for this query
stream, err := conn.OpenStream()
if err != nil { if err != nil {
p.putConn(addr, conn, false)
return nil, err return nil, err
} }
// Set deadline
deadline, ok := ctx.Deadline() deadline, ok := ctx.Deadline()
if !ok { if !ok {
deadline = time.Now().Add(5 * time.Second) deadline = time.Now().Add(5 * time.Second)
} }
_ = stream.SetDeadline(deadline) _ = stream.SetDeadline(deadline)
// Write message length (2 bytes) followed by message
var msgLen = uint16(len(msgBytes)) var msgLen = uint16(len(msgBytes))
var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)} var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)}
if _, err := stream.Write(msgLenBytes); err != nil { if _, err := stream.Write(msgLenBytes); err != nil {
stream.Close()
p.putConn(addr, conn, false)
return nil, err return nil, err
} }
if _, err := stream.Write(msgBytes); err != nil { if _, err := stream.Write(msgBytes); err != nil {
stream.Close()
p.putConn(addr, conn, false)
return nil, err return nil, err
} }
// Read response
buf, err := io.ReadAll(stream) buf, err := io.ReadAll(stream)
stream.Close()
// Return connection to pool (mark as potentially bad if error occurred)
isGood := err == nil && len(buf) > 0
p.putConn(addr, conn, isGood)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = stream.Close() // io.ReadAll hides io.EOF error, so check for empty buffer
// io.ReadAll hide the io.EOF error returned by quic-go server.
// Once we figure out why quic-go server sends io.EOF after running
// for a long time, we can have a better way to handle this. For now,
// make sure io.EOF error returned, so the caller can handle it cleanly.
if len(buf) == 0 { if len(buf) == 0 {
return nil, io.EOF return nil, io.EOF
} }
// Unpack DNS response (skip 2-byte length prefix)
answer := new(dns.Msg) answer := new(dns.Msg)
if err := answer.Unpack(buf[2:]); err != nil { if err := answer.Unpack(buf[2:]); err != nil {
return nil, err return nil, err
@@ -116,3 +179,142 @@ func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tl
answer.SetReply(msg) answer.SetReply(msg)
return answer, nil return answer, nil
} }
// getConn gets a QUIC connection from the pool or creates a new one.
func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, "", io.EOF
}
// Try to reuse an existing connection
for addr, doqConn := range p.conns {
doqConn.mu.Lock()
if doqConn.refCount == 0 && doqConn.conn != nil {
// Check if connection is still alive
select {
case <-doqConn.conn.Context().Done():
// Connection is closed, remove it
doqConn.mu.Unlock()
delete(p.conns, addr)
continue
default:
}
doqConn.refCount++
doqConn.lastUsed = time.Now()
conn := doqConn.conn
doqConn.mu.Unlock()
return conn, addr, nil
}
doqConn.mu.Unlock()
}
// No available connection, create a new one
addr, conn, err := p.dialConn(ctx)
if err != nil {
return nil, "", err
}
doqConn := &doqConn{
conn: conn,
lastUsed: time.Now(),
refCount: 1,
}
p.conns[addr] = doqConn
return conn, addr, nil
}
// putConn returns a connection to the pool.
func (p *doqConnPool) putConn(addr string, conn *quic.Conn, isGood bool) {
p.mu.Lock()
defer p.mu.Unlock()
doqConn, ok := p.conns[addr]
if !ok {
return
}
doqConn.mu.Lock()
defer doqConn.mu.Unlock()
doqConn.refCount--
if doqConn.refCount < 0 {
doqConn.refCount = 0
}
// If connection is bad or closed, remove it from pool
if !isGood || conn.Context().Err() != nil {
delete(p.conns, addr)
conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
return
}
doqConn.lastUsed = time.Now()
}
// dialConn creates a new QUIC connection using parallel dialing like DoH3.
func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) {
logger := LoggerFromCtx(ctx)
// If we have a bootstrap IP, use it directly
if p.uc.BootstrapIP != "" {
addr := net.JoinHostPort(p.uc.BootstrapIP, p.port)
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return "", nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
udpConn.Close()
return "", nil, err
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, nil)
if err != nil {
udpConn.Close()
return "", nil, err
}
return addr, conn, nil
}
// Use parallel dialing like DoH3
dialAddrs := make([]string, len(p.addrs))
for i := range p.addrs {
dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, nil)
if err != nil {
return "", nil, err
}
addr := conn.RemoteAddr().String()
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr)
return addr, conn, nil
}
// CloseIdleConnections closes all idle connections in the pool.
// When called during cleanup (e.g., from finalizer), it closes all connections
// regardless of refCount to prevent resource leaks.
func (p *doqConnPool) CloseIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
for addr, dc := range p.conns {
dc.mu.Lock()
if dc.conn != nil {
// Close all connections to ensure proper cleanup, even if in use
// This prevents resource leaks when the pool is being destroyed
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
dc.mu.Unlock()
delete(p.conns, addr)
}
}