From fbc6468ee34d9c3974fbd51c7e6e52d20d6194f7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 28 Jan 2026 23:50:43 +0700 Subject: [PATCH] refactor(dot): simplify DoT connection pool implementation Replace the map-based pool and refCount bookkeeping with a channel-based pool. Drop the closed state, per-connection address tracking, and extra mutexes so the pool relies on the channel for concurrency and lifecycle. --- dot.go | 124 ++++++++++++++++++++------------------------------------- 1 file changed, 44 insertions(+), 80 deletions(-) diff --git a/dot.go b/dot.go index e8049bb..66dc710 100644 --- a/dot.go +++ b/dot.go @@ -7,7 +7,6 @@ import ( "io" "net" "runtime" - "sync" "time" "github.com/miekg/dns" @@ -44,23 +43,20 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, err } -// dotConnPool manages a pool of TCP/TLS connections for DoT queries. +const dotPoolSize = 16 + +// dotConnPool manages a pool of TCP/TLS connections for DoT queries using a buffered channel. type dotConnPool struct { uc *UpstreamConfig addrs []string port string tlsConfig *tls.Config dialer *net.Dialer - mu sync.RWMutex - conns map[string]*dotConn - closed bool + conns chan *dotConn } type dotConn struct { - conn *tls.Conn - lastUsed time.Time - refCount int - mu sync.Mutex + conn *tls.Conn } func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool { @@ -90,7 +86,7 @@ func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *do port: port, tlsConfig: tlsConfig, dialer: dialer, - conns: make(map[string]*dotConn), + conns: make(chan *dotConn, dotPoolSize), } // Use SetFinalizer here because we need to call a method on the pool itself. @@ -109,7 +105,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return nil, errors.New("nil DNS message") } - conn, addr, err := p.getConn(ctx) + conn, err := p.getConn(ctx) if err != nil { return nil, wrapCertificateVerificationError(err) } @@ -117,7 +113,7 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro client := dns.Client{Net: "tcp-tls"} answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) isGood := err == nil - p.putConn(addr, conn, isGood) + p.putConn(conn, isGood) if err != nil { return nil, wrapCertificateVerificationError(err) @@ -127,71 +123,42 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } // getConn gets a TCP/TLS connection from the pool or creates a new one. -func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return nil, "", io.EOF - } - - // Try to reuse an existing connection - for addr, dotConn := range p.conns { - dotConn.mu.Lock() - if dotConn.refCount == 0 && dotConn.conn != nil && isAlive(dotConn.conn) { - dotConn.refCount++ - dotConn.lastUsed = time.Now() - conn := dotConn.conn - dotConn.mu.Unlock() - return conn, addr, nil +// A connection is taken from the channel while in use; putConn returns it. +func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, error) { + for { + select { + case dc := <-p.conns: + if dc.conn != nil && isAlive(dc.conn) { + return dc.conn, nil + } + if dc.conn != nil { + dc.conn.Close() + } + default: + _, conn, err := p.dialConn(ctx) + if err != nil { + return nil, err + } + return conn, nil } - dotConn.mu.Unlock() } - - // No available connection, create a new one - addr, conn, err := p.dialConn(ctx) - if err != nil { - return nil, "", err - } - - dotConn := &dotConn{ - conn: conn, - lastUsed: time.Now(), - refCount: 1, - } - p.conns[addr] = dotConn - - return conn, addr, nil } -// putConn returns a connection to the pool. -func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { - p.mu.Lock() - defer p.mu.Unlock() - - dotConn, ok := p.conns[addr] - if !ok { - return - } - - dotConn.mu.Lock() - defer dotConn.mu.Unlock() - - dotConn.refCount-- - if dotConn.refCount < 0 { - dotConn.refCount = 0 - } - - // If connection is bad, remove it from pool - if !isGood { - delete(p.conns, addr) +// putConn returns a connection to the pool for reuse by other goroutines. +func (p *dotConnPool) putConn(conn net.Conn, isGood bool) { + if !isGood || conn == nil { if conn != nil { conn.Close() } return } - - dotConn.lastUsed = time.Now() + dc := &dotConn{conn: conn.(*tls.Conn)} + select { + case p.conns <- dc: + default: + // Channel full, close the connection + dc.conn.Close() + } } // dialConn creates a new TCP/TLS connection. @@ -293,20 +260,17 @@ func (p *dotConnPool) dialConn(ctx context.Context) (string, *tls.Conn, error) { } // CloseIdleConnections closes all connections in the pool. +// Connections currently checked out (in use) are not closed. func (p *dotConnPool) CloseIdleConnections() { - p.mu.Lock() - defer p.mu.Unlock() - if p.closed { - return - } - p.closed = true - for addr, dotConn := range p.conns { - dotConn.mu.Lock() - if dotConn.conn != nil { - dotConn.conn.Close() + for { + select { + case dc := <-p.conns: + if dc.conn != nil { + dc.conn.Close() + } + default: + return } - dotConn.mu.Unlock() - delete(p.conns, addr) } }