fix(doq): share QUIC transport, close send side before read (RFC 9250)

DoQ pools now keep a single quic.Transport and UDP socket for all dials,
so parallel dial and reconnect churn no longer allocate a new socket per
attempt or leak the winner's UDP conn when the caller owns the packet
conn.

quicParallelDialer accepts an optional transport: when set, dials use
Transport.DialEarly on that socket; when nil, behavior matches the old
per-dial ListenUDP path (losers close their sockets).

Per RFC 9250 §4.2, close the query stream's send side before reading the
response so strict upstreams see STREAM FIN before answering.

CloseIdleConnections closes the shared transport and underlying UDP
conn so checked-out connections and the OS socket are torn down.

Add a FIN-strict test server, coverage for bootstrap vs parallel-dial
paths, and a Linux-only FD churn regression test.
This commit is contained in:
Cuong Manh Le
2026-05-14 21:01:29 +07:00
committed by Cuong Manh Le
parent 7b360288ed
commit 98ca63325f
3 changed files with 397 additions and 29 deletions
+28 -6
View File
@@ -78,7 +78,17 @@ type parallelDialerResult struct {
err error
}
type quicParallelDialer struct{}
// quicParallelDialer races DialEarly across a list of remote addresses and
// returns the first successful connection. When transport is non-nil, all
// dials share that transport's UDP socket, which removes both the per-dial
// socket allocation and the winner-path socket leak that an owner-of-the-conn
// receiver cannot clean up. When transport is nil, the dialer falls back to a
// fresh UDP socket per attempt (compat path used where no shared transport is
// available yet); the loser paths close their sockets, and the winner path's
// socket is owned by quic.DialEarly's internal transport.
type quicParallelDialer struct {
transport *quic.Transport
}
// Dial performs parallel dialing to the given address list.
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
@@ -106,12 +116,24 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
var (
conn *quic.Conn
udpConn *net.UDPConn
)
if d.transport != nil {
conn, err = d.transport.DialEarly(ctx, remoteAddr, tlsCfg, cfg)
} else {
udpConn, err = net.ListenUDP("udp", nil)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err = quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
if err != nil {
udpConn.Close()
udpConn = nil
}
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
select {
case ch <- &parallelDialerResult{conn: conn, err: err}:
case <-done:
+75 -14
View File
@@ -10,6 +10,7 @@ import (
"io"
"net"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
@@ -51,6 +52,10 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
const doqPoolSize = 16
// doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel.
// A single quic.Transport (and its UDP socket) is shared by every connection in the pool,
// so the OS socket lifecycle is tied to the pool rather than to each dial. Without this
// ownership model, a strict DoQ upstream that triggers reconnect churn would leak one
// caller-owned UDP socket per dial — see github.com/Control-D-Inc/ctrld/issues/309.
type doqConnPool struct {
uc *UpstreamConfig
addrs []string
@@ -58,6 +63,13 @@ type doqConnPool struct {
tlsConfig *tls.Config
quicConfig *quic.Config
conns chan *doqConn
transportMu sync.Mutex
transport *quic.Transport
transportConn *net.UDPConn
transportErr error
transportInit bool
closed bool
}
type doqConn struct {
@@ -178,10 +190,17 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
return nil, err
}
// Read response
buf, err := io.ReadAll(stream)
stream.Close()
// RFC 9250 section 4.2 requires the client to indicate end-of-request by
// closing the send side of the stream (STREAM FIN). Servers may defer
// processing until FIN arrives, so the close must happen before reading.
// Stream.Close closes only the send direction; the receive direction
// remains open for the response.
if err := stream.Close(); err != nil {
p.putConn(conn, false)
return nil, err
}
buf, err := io.ReadAll(stream)
if err != nil {
p.putConn(conn, false)
return nil, err
@@ -259,25 +278,26 @@ func (p *doqConnPool) putConn(conn *quic.Conn, isGood bool) {
}
// dialConn creates a new QUIC connection using parallel dialing like DoH3.
// All connections from the pool multiplex on a single pool-owned UDP socket,
// so reconnect churn cannot grow the host's FD count.
func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) {
logger := LoggerFromCtx(ctx)
tr, err := p.getOrInitTransport()
if err != nil {
return "", nil, err
}
// If we have a bootstrap IP, use it directly
if p.uc.BootstrapIP != "" {
addr := net.JoinHostPort(p.uc.BootstrapIP, p.port)
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return "", nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
udpConn.Close()
return "", nil, err
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, p.quicConfig)
conn, err := tr.DialEarly(ctx, remoteAddr, p.tlsConfig, p.quicConfig)
if err != nil {
udpConn.Close()
return "", nil, err
}
return addr, conn, nil
@@ -289,7 +309,7 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error)
dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port)
}
pd := &quicParallelDialer{}
pd := &quicParallelDialer{transport: tr}
conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, p.quicConfig)
if err != nil {
return "", nil, err
@@ -300,9 +320,35 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error)
return addr, conn, nil
}
// CloseIdleConnections closes all connections in the pool.
// Connections currently checked out (in use) are not closed.
// getOrInitTransport returns the pool's shared quic.Transport, initialising it
// on first call. Once the pool has been closed it permanently returns an error
// so that callers cannot resurrect a dead pool.
func (p *doqConnPool) getOrInitTransport() (*quic.Transport, error) {
p.transportMu.Lock()
defer p.transportMu.Unlock()
if p.closed {
return nil, errors.New("doq pool closed")
}
if p.transportInit {
return p.transport, p.transportErr
}
p.transportInit = true
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
p.transportErr = err
return nil, err
}
p.transportConn = udpConn
p.transport = &quic.Transport{Conn: udpConn}
return p.transport, nil
}
// CloseIdleConnections closes all idle connections, the shared quic.Transport,
// and the pool's UDP socket. Connections currently checked out (in use) get
// terminated by the transport close as well — without that, the OS socket
// would remain bound to a goroutine that the caller cannot reach to clean up.
func (p *doqConnPool) CloseIdleConnections() {
drain:
for {
select {
case dc := <-p.conns:
@@ -310,7 +356,22 @@ func (p *doqConnPool) CloseIdleConnections() {
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
default:
return
break drain
}
}
p.transportMu.Lock()
if p.closed {
p.transportMu.Unlock()
return
}
p.closed = true
tr := p.transport
udpConn := p.transportConn
p.transportMu.Unlock()
if tr != nil {
_ = tr.Close()
}
if udpConn != nil {
_ = udpConn.Close()
}
}
+294 -9
View File
@@ -10,6 +10,8 @@ import (
"io"
"math/big"
"net"
"os"
"runtime"
"strings"
"testing"
"time"
@@ -255,35 +257,32 @@ func newMalformedDoQServer(t *testing.T, response []byte) *malformedDoQServer {
response: response,
}
go s.serve(t)
go s.serve()
t.Cleanup(func() { _ = listener.Close() })
return s
}
func (s *malformedDoQServer) serve(t *testing.T) {
func (s *malformedDoQServer) serve() {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
if strings.Contains(err.Error(), "server closed") {
return
}
return
}
go s.handleConn(t, conn)
go s.handleConn(conn)
}
}
func (s *malformedDoQServer) handleConn(t *testing.T, conn *quic.Conn) {
func (s *malformedDoQServer) handleConn(conn *quic.Conn) {
for {
stream, err := conn.AcceptStream(context.Background())
if err != nil {
return
}
go s.handleStream(t, stream)
go s.handleStream(stream)
}
}
func (s *malformedDoQServer) handleStream(t *testing.T, stream *quic.Stream) {
func (s *malformedDoQServer) handleStream(stream *quic.Stream) {
defer stream.Close()
// Drain the client's DoQ-framed request so the client's writes complete
@@ -385,3 +384,289 @@ func TestDoQResolve_MalformedResponse(t *testing.T) {
})
}
}
// strictDoQServer accepts DoQ queries but defers the response until the
// client signals end-of-request with STREAM FIN, as required by RFC 9250
// section 4.2. It exists to lock in the fix for
// github.com/Control-D-Inc/ctrld/issues/309 where a client
// that never closes its send side caused the server to wait forever and the
// client to churn through reconnects.
type strictDoQServer struct {
listener *quic.Listener
cert *x509.Certificate
addr string
}
func newStrictDoQServer(t *testing.T) *strictDoQServer {
t.Helper()
testCert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"doq"},
MinVersion: tls.VersionTLS12,
}
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
if err != nil {
t.Fatalf("failed to create QUIC listener: %v", err)
}
s := &strictDoQServer{
listener: listener,
cert: testCert.cert,
addr: listener.Addr().String(),
}
go s.serve()
t.Cleanup(func() { _ = listener.Close() })
return s
}
func (s *strictDoQServer) serve() {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
return
}
go s.handleConn(conn)
}
}
func (s *strictDoQServer) handleConn(conn *quic.Conn) {
for {
stream, err := conn.AcceptStream(context.Background())
if err != nil {
return
}
go s.handleStream(stream)
}
}
func (s *strictDoQServer) handleStream(stream *quic.Stream) {
defer stream.Close()
// Drain until the client closes the send side. This is the behaviour
// that triggered the bug: if the client never sends STREAM FIN, this
// read blocks until the stream's deadline fires.
body, err := io.ReadAll(stream)
if err != nil {
return
}
if len(body) < 2 {
return
}
msgLen := uint16(body[0])<<8 | uint16(body[1])
if int(msgLen) != len(body)-2 {
return
}
msg := new(dns.Msg)
if err := msg.Unpack(body[2:]); err != nil {
return
}
response := new(dns.Msg)
response.SetReply(msg)
response.Authoritative = true
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
response.Answer = append(response.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP("192.0.2.1"),
})
}
respBytes, err := response.Pack()
if err != nil {
return
}
respLen := uint16(len(respBytes))
if _, err := stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)}); err != nil {
return
}
if _, err := stream.Write(respBytes); err != nil {
return
}
}
func newStrictDoQUpstream(t *testing.T, cert *x509.Certificate, addr string, useBootstrap bool) *UpstreamConfig {
t.Helper()
pool := x509.NewCertPool()
pool.AddCert(cert)
host, _, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split host/port %q: %v", addr, err)
}
uc := &UpstreamConfig{
Name: "doq-strict",
Type: ResolverTypeDOQ,
Endpoint: addr,
Domain: host,
Timeout: 3000,
}
if useBootstrap {
uc.BootstrapIP = host
}
uc.SetCertPool(pool)
return uc
}
// TestDoQResolve_StrictServerWaitsForFIN exercises the RFC 9250 client-FIN
// requirement. With the bug present, the server's io.ReadAll blocks until
// the stream deadline expires and the client sees a timeout, so a successful
// resolve here proves that the client now sends STREAM FIN before reading.
func TestDoQResolve_StrictServerWaitsForFIN(t *testing.T) {
t.Parallel()
server := newStrictDoQServer(t)
uc := newStrictDoQUpstream(t, server.cert, server.addr, true)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
host, _, _ := net.SplitHostPort(server.addr)
pool := newDOQConnPool(ctx, uc, []string{host})
t.Cleanup(pool.CloseIdleConnections)
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.RecursionDesired = true
answer, err := pool.Resolve(ctx, msg)
if err != nil {
t.Fatalf("Resolve failed against strict DoQ server: %v", err)
}
if answer == nil || len(answer.Answer) == 0 {
t.Fatalf("Resolve returned no answer records: %+v", answer)
}
a, ok := answer.Answer[0].(*dns.A)
if !ok || !a.A.Equal(net.ParseIP("192.0.2.1")) {
t.Fatalf("unexpected answer: %+v", answer.Answer[0])
}
}
// TestDoQResolve_ParallelDialPathStrictFIN exercises the parallel-dial path
// (no BootstrapIP) against the same FIN-strict server, so that both the
// single-dial branch and the parallel-dial branch are covered.
func TestDoQResolve_ParallelDialPathStrictFIN(t *testing.T) {
t.Parallel()
server := newStrictDoQServer(t)
uc := newStrictDoQUpstream(t, server.cert, server.addr, false)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
host, _, _ := net.SplitHostPort(server.addr)
pool := newDOQConnPool(ctx, uc, []string{host})
t.Cleanup(pool.CloseIdleConnections)
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.RecursionDesired = true
answer, err := pool.Resolve(ctx, msg)
if err != nil {
t.Fatalf("Resolve (parallel-dial path) failed against strict DoQ server: %v", err)
}
if answer == nil || len(answer.Answer) == 0 {
t.Fatalf("Resolve (parallel-dial path) returned no answer records: %+v", answer)
}
}
// TestDoQPool_ChurnDoesNotGrowFDs exercises the reconnect-churn scenario
// described in github.com/Control-D-Inc/ctrld/issues/309: repeated dials
// against a server that closes existing connections must not grow the process
// FD count, because the pool now shares one UDP socket via quic.Transport instead
// of allocating one per dial. Linux-only because /proc/self/fd is the cheapest
// portable proxy for "what's still open."
func TestDoQPool_ChurnDoesNotGrowFDs(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("FD accounting via /proc/self/fd is linux-only")
}
t.Parallel()
server := newStrictDoQServer(t)
uc := newStrictDoQUpstream(t, server.cert, server.addr, true)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
host, _, _ := net.SplitHostPort(server.addr)
pool := newDOQConnPool(ctx, uc, []string{host})
t.Cleanup(pool.CloseIdleConnections)
makeQuery := func(i int) *dns.Msg {
msg := new(dns.Msg)
// Vary the question so any caching layer cannot short-circuit.
msg.SetQuestion(dns.Fqdn(strings.Repeat("a", 1+i%8)+".example.com"), dns.TypeA)
msg.RecursionDesired = true
return msg
}
// Warm the pool so the steady-state transport and at least one
// connection are open. Without this, the first resolve in the measured
// loop would inflate the baseline.
if _, err := pool.Resolve(ctx, makeQuery(0)); err != nil {
t.Fatalf("warm-up Resolve failed: %v", err)
}
baseline := countOpenFDs(t)
// Force reconnect churn by closing the connection between each query.
// Without the fix this would leak one UDP socket per round; with the
// fix the pool's shared transport keeps a single socket open.
const rounds = 20
for i := 1; i <= rounds; i++ {
// Drain any pooled connection so the next Resolve has to redial.
drainPooledConns(pool)
if _, err := pool.Resolve(ctx, makeQuery(i)); err != nil {
t.Fatalf("Resolve in churn loop iteration %d failed: %v", i, err)
}
}
// Give quic-go a moment to drop any background goroutines that hold
// references to closed sockets.
time.Sleep(200 * time.Millisecond)
after := countOpenFDs(t)
// Allow a small slack for transient FDs (goroutine wake-ups, qlog,
// etc.) but reject anything that scales with the number of rounds.
const slack = 5
if after > baseline+slack {
t.Fatalf("FD count grew under DoQ churn: baseline=%d after=%d rounds=%d (slack=%d)", baseline, after, rounds, slack)
}
}
// drainPooledConns removes any idle pooled connections so the next Resolve
// is forced to dial a fresh one. It does not close the pool's transport.
func drainPooledConns(p *doqConnPool) {
for {
select {
case dc := <-p.conns:
if dc.conn != nil {
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
default:
return
}
}
}
func countOpenFDs(t *testing.T) int {
t.Helper()
entries, err := os.ReadDir("/proc/self/fd")
if err != nil {
t.Fatalf("read /proc/self/fd: %v", err)
}
return len(entries)
}