refactor(dot): simplify DoT connection pool implementation

Replace the map-based pool and refCount bookkeeping with a channel-based
pool. Drop the closed state, per-connection address tracking, and
extra mutexes so the pool relies on the channel for concurrency and
lifecycle.
This commit is contained in:
Cuong Manh Le
2026-01-28 23:50:43 +07:00
committed by Cuong Manh Le
parent 09a689149e
commit fbc6468ee3

124
dot.go
View File

@@ -7,7 +7,6 @@ import (
"io"
"net"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
@@ -44,23 +43,20 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, err
}
// dotConnPool manages a pool of TCP/TLS connections for DoT queries.
const dotPoolSize = 16
// dotConnPool manages a pool of TCP/TLS connections for DoT queries using a buffered channel.
type dotConnPool struct {
uc *UpstreamConfig
addrs []string
port string
tlsConfig *tls.Config
dialer *net.Dialer
mu sync.RWMutex
conns map[string]*dotConn
closed bool
conns chan *dotConn
}
type dotConn struct {
conn *tls.Conn
lastUsed time.Time
refCount int
mu sync.Mutex
conn *tls.Conn
}
func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool {
@@ -90,7 +86,7 @@ func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *do
port: port,
tlsConfig: tlsConfig,
dialer: dialer,
conns: make(map[string]*dotConn),
conns: make(chan *dotConn, dotPoolSize),
}
// Use SetFinalizer here because we need to call a method on the pool itself.
@@ -109,7 +105,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return nil, errors.New("nil DNS message")
}
conn, addr, err := p.getConn(ctx)
conn, err := p.getConn(ctx)
if err != nil {
return nil, wrapCertificateVerificationError(err)
}
@@ -117,7 +113,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
client := dns.Client{Net: "tcp-tls"}
answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn})
isGood := err == nil
p.putConn(addr, conn, isGood)
p.putConn(conn, isGood)
if err != nil {
return nil, wrapCertificateVerificationError(err)
@@ -127,71 +123,42 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
// getConn gets a TCP/TLS connection from the pool or creates a new one.
func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil, "", io.EOF
}
// Try to reuse an existing connection
for addr, dotConn := range p.conns {
dotConn.mu.Lock()
if dotConn.refCount == 0 && dotConn.conn != nil && isAlive(dotConn.conn) {
dotConn.refCount++
dotConn.lastUsed = time.Now()
conn := dotConn.conn
dotConn.mu.Unlock()
return conn, addr, nil
// A connection is taken from the channel while in use; putConn returns it.
func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, error) {
for {
select {
case dc := <-p.conns:
if dc.conn != nil && isAlive(dc.conn) {
return dc.conn, nil
}
if dc.conn != nil {
dc.conn.Close()
}
default:
_, conn, err := p.dialConn(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
dotConn.mu.Unlock()
}
// No available connection, create a new one
addr, conn, err := p.dialConn(ctx)
if err != nil {
return nil, "", err
}
dotConn := &dotConn{
conn: conn,
lastUsed: time.Now(),
refCount: 1,
}
p.conns[addr] = dotConn
return conn, addr, nil
}
// putConn returns a connection to the pool.
func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) {
p.mu.Lock()
defer p.mu.Unlock()
dotConn, ok := p.conns[addr]
if !ok {
return
}
dotConn.mu.Lock()
defer dotConn.mu.Unlock()
dotConn.refCount--
if dotConn.refCount < 0 {
dotConn.refCount = 0
}
// If connection is bad, remove it from pool
if !isGood {
delete(p.conns, addr)
// putConn returns a connection to the pool for reuse by other goroutines.
func (p *dotConnPool) putConn(conn net.Conn, isGood bool) {
if !isGood || conn == nil {
if conn != nil {
conn.Close()
}
return
}
dotConn.lastUsed = time.Now()
dc := &dotConn{conn: conn.(*tls.Conn)}
select {
case p.conns <- dc:
default:
// Channel full, close the connection
dc.conn.Close()
}
}
// dialConn creates a new TCP/TLS connection.
@@ -293,20 +260,17 @@ func (p *dotConnPool) dialConn(ctx context.Context) (string, *tls.Conn, error) {
}
// CloseIdleConnections closes all connections in the pool.
// Connections currently checked out (in use) are not closed.
func (p *dotConnPool) CloseIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
for addr, dotConn := range p.conns {
dotConn.mu.Lock()
if dotConn.conn != nil {
dotConn.conn.Close()
for {
select {
case dc := <-p.conns:
if dc.conn != nil {
dc.conn.Close()
}
default:
return
}
dotConn.mu.Unlock()
delete(p.conns, addr)
}
}