From 4640a9f20a09042a17e9b9a4b094c552692cba8e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 28 Jan 2026 23:50:53 +0700 Subject: [PATCH] refactor(doq): simplify DoQ connection pool implementation Replace the map-based pool and refCount bookkeeping with a channel-based pool. Drop the closed state, per-connection address tracking, and extra mutexes so the pool relies on the channel for concurrency and lifecycle, matching the approach used in the DoT pool. --- doq.go | 147 ++++++++++++++++++++------------------------------------- 1 file changed, 50 insertions(+), 97 deletions(-) diff --git a/doq.go b/doq.go index c9202a3..142993f 100644 --- a/doq.go +++ b/doq.go @@ -9,7 +9,6 @@ import ( "io" "net" "runtime" - "sync" "time" "github.com/miekg/dns" @@ -48,22 +47,19 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, err } -// doqConnPool manages a pool of QUIC connections for DoQ queries. +const doqPoolSize = 16 + +// doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel. type doqConnPool struct { uc *UpstreamConfig addrs []string port string tlsConfig *tls.Config - mu sync.RWMutex - conns map[string]*doqConn - closed bool + conns chan *doqConn } type doqConn struct { - conn *quic.Conn - lastUsed time.Time - refCount int - mu sync.Mutex + conn *quic.Conn } func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool { @@ -83,7 +79,7 @@ func newDOQConnPool(_ context.Context, uc *UpstreamConfig, addrs []string) *doqC addrs: addrs, port: port, tlsConfig: tlsConfig, - conns: make(map[string]*doqConn), + conns: make(chan *doqConn, doqPoolSize), } // Use SetFinalizer here because we need to call a method on the pool itself. @@ -116,7 +112,7 @@ func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - conn, addr, err := p.getConn(ctx) + conn, err := p.getConn(ctx) if err != nil { return nil, err } @@ -124,14 +120,14 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er // Pack the DNS message msgBytes, err := msg.Pack() if err != nil { - p.putConn(addr, conn, false) + p.putConn(conn, false) return nil, err } // Open a new stream for this query stream, err := conn.OpenStream() if err != nil { - p.putConn(addr, conn, false) + p.putConn(conn, false) return nil, err } @@ -147,13 +143,13 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)} if _, err := stream.Write(msgLenBytes); err != nil { stream.Close() - p.putConn(addr, conn, false) + p.putConn(conn, false) return nil, err } if _, err := stream.Write(msgBytes); err != nil { stream.Close() - p.putConn(addr, conn, false) + p.putConn(conn, false) return nil, err } @@ -163,7 +159,7 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er // Return connection to pool (mark as potentially bad if error occurred) isGood := err == nil && len(buf) > 0 - p.putConn(addr, conn, isGood) + p.putConn(conn, isGood) if err != nil { return nil, err @@ -184,79 +180,42 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er } // getConn gets a QUIC connection from the pool or creates a new one. -func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, string, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.closed { - return nil, "", io.EOF - } - - // Try to reuse an existing connection - for addr, doqConn := range p.conns { - doqConn.mu.Lock() - if doqConn.refCount == 0 && doqConn.conn != nil { - // Check if connection is still alive - select { - case <-doqConn.conn.Context().Done(): - // Connection is closed, remove it - doqConn.mu.Unlock() - delete(p.conns, addr) - continue - default: +// A connection is taken from the channel while in use; putConn returns it. +func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, error) { + for { + select { + case dc := <-p.conns: + if dc.conn != nil && dc.conn.Context().Err() == nil { + return dc.conn, nil } - - doqConn.refCount++ - doqConn.lastUsed = time.Now() - conn := doqConn.conn - doqConn.mu.Unlock() - return conn, addr, nil + if dc.conn != nil { + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } + default: + _, conn, err := p.dialConn(ctx) + if err != nil { + return nil, err + } + return conn, nil } - doqConn.mu.Unlock() } - - // No available connection, create a new one - addr, conn, err := p.dialConn(ctx) - if err != nil { - return nil, "", err - } - - doqConn := &doqConn{ - conn: conn, - lastUsed: time.Now(), - refCount: 1, - } - p.conns[addr] = doqConn - - return conn, addr, nil } -// putConn returns a connection to the pool. -func (p *doqConnPool) putConn(addr string, conn *quic.Conn, isGood bool) { - p.mu.Lock() - defer p.mu.Unlock() - - doqConn, ok := p.conns[addr] - if !ok { +// putConn returns a connection to the pool for reuse by other goroutines. +func (p *doqConnPool) putConn(conn *quic.Conn, isGood bool) { + if !isGood || conn == nil || conn.Context().Err() != nil { + if conn != nil { + conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } return } - - doqConn.mu.Lock() - defer doqConn.mu.Unlock() - - doqConn.refCount-- - if doqConn.refCount < 0 { - doqConn.refCount = 0 + dc := &doqConn{conn: conn} + select { + case p.conns <- dc: + default: + // Channel full, close the connection + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") } - - // If connection is bad or closed, remove it from pool - if !isGood || conn.Context().Err() != nil { - delete(p.conns, addr) - conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") - return - } - - doqConn.lastUsed = time.Now() } // dialConn creates a new QUIC connection using parallel dialing like DoH3. @@ -301,23 +260,17 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) return addr, conn, nil } -// CloseIdleConnections closes all idle connections in the pool. -// When called during cleanup (e.g., from finalizer), it closes all connections -// regardless of refCount to prevent resource leaks. +// CloseIdleConnections closes all connections in the pool. +// Connections currently checked out (in use) are not closed. func (p *doqConnPool) CloseIdleConnections() { - p.mu.Lock() - defer p.mu.Unlock() - - p.closed = true - - for addr, dc := range p.conns { - dc.mu.Lock() - if dc.conn != nil { - // Close all connections to ensure proper cleanup, even if in use - // This prevents resource leaks when the pool is being destroyed - dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + for { + select { + case dc := <-p.conns: + if dc.conn != nil { + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } + default: + return } - dc.mu.Unlock() - delete(p.conns, addr) } }