refactor(doq): simplify DoQ connection pool implementation

Replace the map-based pool and refCount bookkeeping with a channel-based
pool. Drop the closed state, per-connection address tracking, and extra
mutexes so the pool relies on the channel for concurrency and lifecycle,
matching the approach used in the DoT pool.
This commit is contained in:
Cuong Manh Le
2026-01-28 23:50:53 +07:00
committed by Cuong Manh Le
parent fbc6468ee3
commit 4640a9f20a

147
doq.go
View File

@@ -9,7 +9,6 @@ import (
"io"
"net"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
@@ -48,22 +47,19 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, err
}
// doqConnPool manages a pool of QUIC connections for DoQ queries.
const doqPoolSize = 16
// doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel.
type doqConnPool struct {
uc *UpstreamConfig
addrs []string
port string
tlsConfig *tls.Config
mu sync.RWMutex
conns map[string]*doqConn
closed bool
conns chan *doqConn
}
type doqConn struct {
conn *quic.Conn
lastUsed time.Time
refCount int
mu sync.Mutex
conn *quic.Conn
}
func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool {
@@ -83,7 +79,7 @@ func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqC
addrs: addrs,
port: port,
tlsConfig: tlsConfig,
conns: make(map[string]*doqConn),
conns: make(chan *doqConn, doqPoolSize),
}
// Use SetFinalizer here because we need to call a method on the pool itself.
@@ -116,7 +112,7 @@ func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
conn, addr, err := p.getConn(ctx)
conn, err := p.getConn(ctx)
if err != nil {
return nil, err
}
@@ -124,14 +120,14 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
// Pack the DNS message
msgBytes, err := msg.Pack()
if err != nil {
p.putConn(addr, conn, false)
p.putConn(conn, false)
return nil, err
}
// Open a new stream for this query
stream, err := conn.OpenStream()
if err != nil {
p.putConn(addr, conn, false)
p.putConn(conn, false)
return nil, err
}
@@ -147,13 +143,13 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)}
if _, err := stream.Write(msgLenBytes); err != nil {
stream.Close()
p.putConn(addr, conn, false)
p.putConn(conn, false)
return nil, err
}
if _, err := stream.Write(msgBytes); err != nil {
stream.Close()
p.putConn(addr, conn, false)
p.putConn(conn, false)
return nil, err
}
@@ -163,7 +159,7 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
// Return connection to pool (mark as potentially bad if error occurred)
isGood := err == nil && len(buf) > 0
p.putConn(addr, conn, isGood)
p.putConn(conn, isGood)
if err != nil {
return nil, err
@@ -184,79 +180,42 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
}
// getConn gets a QUIC connection from the pool or creates a new one.
func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, "", io.EOF
}
// Try to reuse an existing connection
for addr, doqConn := range p.conns {
doqConn.mu.Lock()
if doqConn.refCount == 0 && doqConn.conn != nil {
// Check if connection is still alive
select {
case <-doqConn.conn.Context().Done():
// Connection is closed, remove it
doqConn.mu.Unlock()
delete(p.conns, addr)
continue
default:
// A connection is taken from the channel while in use; putConn returns it.
func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, error) {
for {
select {
case dc := <-p.conns:
if dc.conn != nil && dc.conn.Context().Err() == nil {
return dc.conn, nil
}
doqConn.refCount++
doqConn.lastUsed = time.Now()
conn := doqConn.conn
doqConn.mu.Unlock()
return conn, addr, nil
if dc.conn != nil {
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
default:
_, conn, err := p.dialConn(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
doqConn.mu.Unlock()
}
// No available connection, create a new one
addr, conn, err := p.dialConn(ctx)
if err != nil {
return nil, "", err
}
doqConn := &doqConn{
conn: conn,
lastUsed: time.Now(),
refCount: 1,
}
p.conns[addr] = doqConn
return conn, addr, nil
}
// putConn returns a connection to the pool.
func (p *doqConnPool) putConn(addr string, conn *quic.Conn, isGood bool) {
p.mu.Lock()
defer p.mu.Unlock()
doqConn, ok := p.conns[addr]
if !ok {
// putConn returns a connection to the pool for reuse by other goroutines.
func (p *doqConnPool) putConn(conn *quic.Conn, isGood bool) {
if !isGood || conn == nil || conn.Context().Err() != nil {
if conn != nil {
conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
return
}
doqConn.mu.Lock()
defer doqConn.mu.Unlock()
doqConn.refCount--
if doqConn.refCount < 0 {
doqConn.refCount = 0
dc := &doqConn{conn: conn}
select {
case p.conns <- dc:
default:
// Channel full, close the connection
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
// If connection is bad or closed, remove it from pool
if !isGood || conn.Context().Err() != nil {
delete(p.conns, addr)
conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
return
}
doqConn.lastUsed = time.Now()
}
// dialConn creates a new QUIC connection using parallel dialing like DoH3.
@@ -301,23 +260,17 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error)
return addr, conn, nil
}
// CloseIdleConnections closes all idle connections in the pool.
// When called during cleanup (e.g., from finalizer), it closes all connections
// regardless of refCount to prevent resource leaks.
// CloseIdleConnections closes all connections in the pool.
// Connections currently checked out (in use) are not closed.
func (p *doqConnPool) CloseIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
for addr, dc := range p.conns {
dc.mu.Lock()
if dc.conn != nil {
// Close all connections to ensure proper cleanup, even if in use
// This prevents resource leaks when the pool is being destroyed
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
for {
select {
case dc := <-p.conns:
if dc.conn != nil {
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
default:
return
}
dc.mu.Unlock()
delete(p.conns, addr)
}
}