refactor(dot): simplify DoT connection pool implementation

Replace the map-based pool and refCount bookkeeping with a channel-based
pool. Drop the closed state, per-connection address tracking, and
extra mutexes so the pool relies on the channel for concurrency and
lifecycle.
This commit is contained in:
Cuong Manh Le
2026-01-28 23:50:43 +07:00
committed by Cuong Manh Le
parent e45e56c021
commit 60dd366cc4
+44 -80
View File
@@ -7,7 +7,6 @@ import (
"io" "io"
"net" "net"
"runtime" "runtime"
"sync"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -44,23 +43,20 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, err return answer, err
} }
// dotConnPool manages a pool of TCP/TLS connections for DoT queries. const dotPoolSize = 16
// dotConnPool manages a pool of TCP/TLS connections for DoT queries using a buffered channel.
type dotConnPool struct { type dotConnPool struct {
uc *UpstreamConfig uc *UpstreamConfig
addrs []string addrs []string
port string port string
tlsConfig *tls.Config tlsConfig *tls.Config
dialer *net.Dialer dialer *net.Dialer
mu sync.RWMutex conns chan *dotConn
conns map[string]*dotConn
closed bool
} }
type dotConn struct { type dotConn struct {
conn *tls.Conn conn *tls.Conn
lastUsed time.Time
refCount int
mu sync.Mutex
} }
func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool { func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool {
@@ -90,7 +86,7 @@ func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *do
port: port, port: port,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
dialer: dialer, dialer: dialer,
conns: make(map[string]*dotConn), conns: make(chan *dotConn, dotPoolSize),
} }
// Use SetFinalizer here because we need to call a method on the pool itself. // Use SetFinalizer here because we need to call a method on the pool itself.
@@ -109,7 +105,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return nil, errors.New("nil DNS message") return nil, errors.New("nil DNS message")
} }
conn, addr, err := p.getConn(ctx) conn, err := p.getConn(ctx)
if err != nil { if err != nil {
return nil, wrapCertificateVerificationError(err) return nil, wrapCertificateVerificationError(err)
} }
@@ -117,7 +113,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
client := dns.Client{Net: "tcp-tls"} client := dns.Client{Net: "tcp-tls"}
answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn})
isGood := err == nil isGood := err == nil
p.putConn(addr, conn, isGood) p.putConn(conn, isGood)
if err != nil { if err != nil {
return nil, wrapCertificateVerificationError(err) return nil, wrapCertificateVerificationError(err)
@@ -127,71 +123,42 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
} }
// getConn gets a TCP/TLS connection from the pool or creates a new one. // getConn gets a TCP/TLS connection from the pool or creates a new one.
func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { // A connection is taken from the channel while in use; putConn returns it.
p.mu.Lock() func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, error) {
defer p.mu.Unlock() for {
select {
if p.closed { case dc := <-p.conns:
return nil, "", io.EOF if dc.conn != nil && isAlive(dc.conn) {
} return dc.conn, nil
}
// Try to reuse an existing connection if dc.conn != nil {
for addr, dotConn := range p.conns { dc.conn.Close()
dotConn.mu.Lock() }
if dotConn.refCount == 0 && dotConn.conn != nil && isAlive(dotConn.conn) { default:
dotConn.refCount++ _, conn, err := p.dialConn(ctx)
dotConn.lastUsed = time.Now() if err != nil {
conn := dotConn.conn return nil, err
dotConn.mu.Unlock() }
return conn, addr, nil return conn, nil
} }
dotConn.mu.Unlock()
} }
// No available connection, create a new one
addr, conn, err := p.dialConn(ctx)
if err != nil {
return nil, "", err
}
dotConn := &dotConn{
conn: conn,
lastUsed: time.Now(),
refCount: 1,
}
p.conns[addr] = dotConn
return conn, addr, nil
} }
// putConn returns a connection to the pool. // putConn returns a connection to the pool for reuse by other goroutines.
func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { func (p *dotConnPool) putConn(conn net.Conn, isGood bool) {
p.mu.Lock() if !isGood || conn == nil {
defer p.mu.Unlock()
dotConn, ok := p.conns[addr]
if !ok {
return
}
dotConn.mu.Lock()
defer dotConn.mu.Unlock()
dotConn.refCount--
if dotConn.refCount < 0 {
dotConn.refCount = 0
}
// If connection is bad, remove it from pool
if !isGood {
delete(p.conns, addr)
if conn != nil { if conn != nil {
conn.Close() conn.Close()
} }
return return
} }
dc := &dotConn{conn: conn.(*tls.Conn)}
dotConn.lastUsed = time.Now() select {
case p.conns <- dc:
default:
// Channel full, close the connection
dc.conn.Close()
}
} }
// dialConn creates a new TCP/TLS connection. // dialConn creates a new TCP/TLS connection.
@@ -293,20 +260,17 @@ func (p *dotConnPool) dialConn(ctx context.Context) (string, *tls.Conn, error) {
} }
// CloseIdleConnections closes all connections in the pool. // CloseIdleConnections closes all connections in the pool.
// Connections currently checked out (in use) are not closed.
func (p *dotConnPool) CloseIdleConnections() { func (p *dotConnPool) CloseIdleConnections() {
p.mu.Lock() for {
defer p.mu.Unlock() select {
if p.closed { case dc := <-p.conns:
return if dc.conn != nil {
} dc.conn.Close()
p.closed = true }
for addr, dotConn := range p.conns { default:
dotConn.mu.Lock() return
if dotConn.conn != nil {
dotConn.conn.Close()
} }
dotConn.mu.Unlock()
delete(p.conns, addr)
} }
} }