diff --git a/config_quic.go b/config_quic.go index a8c6872..3979795 100644 --- a/config_quic.go +++ b/config_quic.go @@ -78,7 +78,17 @@ type parallelDialerResult struct { err error } -type quicParallelDialer struct{} +// quicParallelDialer races DialEarly across a list of remote addresses and +// returns the first successful connection. When transport is non-nil, all +// dials share that transport's UDP socket, which removes both the per-dial +// socket allocation and the winner-path socket leak that an owner-of-the-conn +// receiver cannot clean up. When transport is nil, the dialer falls back to a +// fresh UDP socket per attempt (compat path used where no shared transport is +// available yet); the loser paths close their sockets, and the winner path's +// socket is owned by quic.DialEarly's internal transport. +type quicParallelDialer struct { + transport *quic.Transport +} // Dial performs parallel dialing to the given address list. func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { @@ -106,12 +116,24 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t ch <- ¶llelDialerResult{conn: nil, err: err} return } - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - ch <- ¶llelDialerResult{conn: nil, err: err} - return + var ( + conn *quic.Conn + udpConn *net.UDPConn + ) + if d.transport != nil { + conn, err = d.transport.DialEarly(ctx, remoteAddr, tlsCfg, cfg) + } else { + udpConn, err = net.ListenUDP("udp", nil) + if err != nil { + ch <- ¶llelDialerResult{conn: nil, err: err} + return + } + conn, err = quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) + if err != nil { + udpConn.Close() + udpConn = nil + } } - conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) select { case ch <- ¶llelDialerResult{conn: conn, err: err}: case <-done: diff --git a/doq.go b/doq.go index eabb3a4..f6711f3 100644 --- a/doq.go +++ b/doq.go @@ -10,6 +10,7 @@ import ( "io" "net" "runtime" + "sync" "time" "github.com/miekg/dns" @@ -51,6 +52,10 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro const doqPoolSize = 16 // doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel. +// A single quic.Transport (and its UDP socket) is shared by every connection in the pool, +// so the OS socket lifecycle is tied to the pool rather than to each dial. Without this +// ownership model, a strict DoQ upstream that triggers reconnect churn would leak one +// caller-owned UDP socket per dial — see github.com/Control-D-Inc/ctrld/issues/309. type doqConnPool struct { uc *UpstreamConfig addrs []string @@ -58,6 +63,13 @@ type doqConnPool struct { tlsConfig *tls.Config quicConfig *quic.Config conns chan *doqConn + + transportMu sync.Mutex + transport *quic.Transport + transportConn *net.UDPConn + transportErr error + transportInit bool + closed bool } type doqConn struct { @@ -178,10 +190,17 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er return nil, err } - // Read response - buf, err := io.ReadAll(stream) - stream.Close() + // RFC 9250 section 4.2 requires the client to indicate end-of-request by + // closing the send side of the stream (STREAM FIN). Servers may defer + // processing until FIN arrives, so the close must happen before reading. + // Stream.Close closes only the send direction; the receive direction + // remains open for the response. + if err := stream.Close(); err != nil { + p.putConn(conn, false) + return nil, err + } + buf, err := io.ReadAll(stream) if err != nil { p.putConn(conn, false) return nil, err @@ -259,25 +278,26 @@ func (p *doqConnPool) putConn(conn *quic.Conn, isGood bool) { } // dialConn creates a new QUIC connection using parallel dialing like DoH3. +// All connections from the pool multiplex on a single pool-owned UDP socket, +// so reconnect churn cannot grow the host's FD count. func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) { logger := LoggerFromCtx(ctx) + tr, err := p.getOrInitTransport() + if err != nil { + return "", nil, err + } + // If we have a bootstrap IP, use it directly if p.uc.BootstrapIP != "" { addr := net.JoinHostPort(p.uc.BootstrapIP, p.port) Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr) - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return "", nil, err - } remoteAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { - udpConn.Close() return "", nil, err } - conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, p.quicConfig) + conn, err := tr.DialEarly(ctx, remoteAddr, p.tlsConfig, p.quicConfig) if err != nil { - udpConn.Close() return "", nil, err } return addr, conn, nil @@ -289,7 +309,7 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port) } - pd := &quicParallelDialer{} + pd := &quicParallelDialer{transport: tr} conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, p.quicConfig) if err != nil { return "", nil, err @@ -300,9 +320,35 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) return addr, conn, nil } -// CloseIdleConnections closes all connections in the pool. -// Connections currently checked out (in use) are not closed. +// getOrInitTransport returns the pool's shared quic.Transport, initialising it +// on first call. Once the pool has been closed it permanently returns an error +// so that callers cannot resurrect a dead pool. +func (p *doqConnPool) getOrInitTransport() (*quic.Transport, error) { + p.transportMu.Lock() + defer p.transportMu.Unlock() + if p.closed { + return nil, errors.New("doq pool closed") + } + if p.transportInit { + return p.transport, p.transportErr + } + p.transportInit = true + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + p.transportErr = err + return nil, err + } + p.transportConn = udpConn + p.transport = &quic.Transport{Conn: udpConn} + return p.transport, nil +} + +// CloseIdleConnections closes all idle connections, the shared quic.Transport, +// and the pool's UDP socket. Connections currently checked out (in use) get +// terminated by the transport close as well — without that, the OS socket +// would remain bound to a goroutine that the caller cannot reach to clean up. func (p *doqConnPool) CloseIdleConnections() { +drain: for { select { case dc := <-p.conns: @@ -310,7 +356,22 @@ func (p *doqConnPool) CloseIdleConnections() { dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") } default: - return + break drain } } + p.transportMu.Lock() + if p.closed { + p.transportMu.Unlock() + return + } + p.closed = true + tr := p.transport + udpConn := p.transportConn + p.transportMu.Unlock() + if tr != nil { + _ = tr.Close() + } + if udpConn != nil { + _ = udpConn.Close() + } } diff --git a/doq_test.go b/doq_test.go index 806e229..463f7e9 100644 --- a/doq_test.go +++ b/doq_test.go @@ -10,6 +10,8 @@ import ( "io" "math/big" "net" + "os" + "runtime" "strings" "testing" "time" @@ -255,35 +257,32 @@ func newMalformedDoQServer(t *testing.T, response []byte) *malformedDoQServer { response: response, } - go s.serve(t) + go s.serve() t.Cleanup(func() { _ = listener.Close() }) return s } -func (s *malformedDoQServer) serve(t *testing.T) { +func (s *malformedDoQServer) serve() { for { conn, err := s.listener.Accept(context.Background()) if err != nil { - if strings.Contains(err.Error(), "server closed") { - return - } return } - go s.handleConn(t, conn) + go s.handleConn(conn) } } -func (s *malformedDoQServer) handleConn(t *testing.T, conn *quic.Conn) { +func (s *malformedDoQServer) handleConn(conn *quic.Conn) { for { stream, err := conn.AcceptStream(context.Background()) if err != nil { return } - go s.handleStream(t, stream) + go s.handleStream(stream) } } -func (s *malformedDoQServer) handleStream(t *testing.T, stream *quic.Stream) { +func (s *malformedDoQServer) handleStream(stream *quic.Stream) { defer stream.Close() // Drain the client's DoQ-framed request so the client's writes complete @@ -385,3 +384,289 @@ func TestDoQResolve_MalformedResponse(t *testing.T) { }) } } + +// strictDoQServer accepts DoQ queries but defers the response until the +// client signals end-of-request with STREAM FIN, as required by RFC 9250 +// section 4.2. It exists to lock in the fix for +// github.com/Control-D-Inc/ctrld/issues/309 where a client +// that never closes its send side caused the server to wait forever and the +// client to churn through reconnects. +type strictDoQServer struct { + listener *quic.Listener + cert *x509.Certificate + addr string +} + +func newStrictDoQServer(t *testing.T) *strictDoQServer { + t.Helper() + + testCert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"doq"}, + MinVersion: tls.VersionTLS12, + } + + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil) + if err != nil { + t.Fatalf("failed to create QUIC listener: %v", err) + } + + s := &strictDoQServer{ + listener: listener, + cert: testCert.cert, + addr: listener.Addr().String(), + } + go s.serve() + t.Cleanup(func() { _ = listener.Close() }) + return s +} + +func (s *strictDoQServer) serve() { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + return + } + go s.handleConn(conn) + } +} + +func (s *strictDoQServer) handleConn(conn *quic.Conn) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + return + } + go s.handleStream(stream) + } +} + +func (s *strictDoQServer) handleStream(stream *quic.Stream) { + defer stream.Close() + + // Drain until the client closes the send side. This is the behaviour + // that triggered the bug: if the client never sends STREAM FIN, this + // read blocks until the stream's deadline fires. + body, err := io.ReadAll(stream) + if err != nil { + return + } + if len(body) < 2 { + return + } + msgLen := uint16(body[0])<<8 | uint16(body[1]) + if int(msgLen) != len(body)-2 { + return + } + + msg := new(dns.Msg) + if err := msg.Unpack(body[2:]); err != nil { + return + } + + response := new(dns.Msg) + response.SetReply(msg) + response.Authoritative = true + if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA { + response.Answer = append(response.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP("192.0.2.1"), + }) + } + + respBytes, err := response.Pack() + if err != nil { + return + } + respLen := uint16(len(respBytes)) + if _, err := stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)}); err != nil { + return + } + if _, err := stream.Write(respBytes); err != nil { + return + } +} + +func newStrictDoQUpstream(t *testing.T, cert *x509.Certificate, addr string, useBootstrap bool) *UpstreamConfig { + t.Helper() + + pool := x509.NewCertPool() + pool.AddCert(cert) + + host, _, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("split host/port %q: %v", addr, err) + } + + uc := &UpstreamConfig{ + Name: "doq-strict", + Type: ResolverTypeDOQ, + Endpoint: addr, + Domain: host, + Timeout: 3000, + } + if useBootstrap { + uc.BootstrapIP = host + } + uc.SetCertPool(pool) + return uc +} + +// TestDoQResolve_StrictServerWaitsForFIN exercises the RFC 9250 client-FIN +// requirement. With the bug present, the server's io.ReadAll blocks until +// the stream deadline expires and the client sees a timeout, so a successful +// resolve here proves that the client now sends STREAM FIN before reading. +func TestDoQResolve_StrictServerWaitsForFIN(t *testing.T) { + t.Parallel() + + server := newStrictDoQServer(t) + uc := newStrictDoQUpstream(t, server.cert, server.addr, true) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + host, _, _ := net.SplitHostPort(server.addr) + pool := newDOQConnPool(ctx, uc, []string{host}) + t.Cleanup(pool.CloseIdleConnections) + + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + msg.RecursionDesired = true + + answer, err := pool.Resolve(ctx, msg) + if err != nil { + t.Fatalf("Resolve failed against strict DoQ server: %v", err) + } + if answer == nil || len(answer.Answer) == 0 { + t.Fatalf("Resolve returned no answer records: %+v", answer) + } + a, ok := answer.Answer[0].(*dns.A) + if !ok || !a.A.Equal(net.ParseIP("192.0.2.1")) { + t.Fatalf("unexpected answer: %+v", answer.Answer[0]) + } +} + +// TestDoQResolve_ParallelDialPathStrictFIN exercises the parallel-dial path +// (no BootstrapIP) against the same FIN-strict server, so that both the +// single-dial branch and the parallel-dial branch are covered. +func TestDoQResolve_ParallelDialPathStrictFIN(t *testing.T) { + t.Parallel() + + server := newStrictDoQServer(t) + uc := newStrictDoQUpstream(t, server.cert, server.addr, false) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + host, _, _ := net.SplitHostPort(server.addr) + pool := newDOQConnPool(ctx, uc, []string{host}) + t.Cleanup(pool.CloseIdleConnections) + + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + msg.RecursionDesired = true + + answer, err := pool.Resolve(ctx, msg) + if err != nil { + t.Fatalf("Resolve (parallel-dial path) failed against strict DoQ server: %v", err) + } + if answer == nil || len(answer.Answer) == 0 { + t.Fatalf("Resolve (parallel-dial path) returned no answer records: %+v", answer) + } +} + +// TestDoQPool_ChurnDoesNotGrowFDs exercises the reconnect-churn scenario +// described in github.com/Control-D-Inc/ctrld/issues/309: repeated dials +// against a server that closes existing connections must not grow the process +// FD count, because the pool now shares one UDP socket via quic.Transport instead +// of allocating one per dial. Linux-only because /proc/self/fd is the cheapest +// portable proxy for "what's still open." +func TestDoQPool_ChurnDoesNotGrowFDs(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("FD accounting via /proc/self/fd is linux-only") + } + t.Parallel() + + server := newStrictDoQServer(t) + uc := newStrictDoQUpstream(t, server.cert, server.addr, true) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + host, _, _ := net.SplitHostPort(server.addr) + pool := newDOQConnPool(ctx, uc, []string{host}) + t.Cleanup(pool.CloseIdleConnections) + + makeQuery := func(i int) *dns.Msg { + msg := new(dns.Msg) + // Vary the question so any caching layer cannot short-circuit. + msg.SetQuestion(dns.Fqdn(strings.Repeat("a", 1+i%8)+".example.com"), dns.TypeA) + msg.RecursionDesired = true + return msg + } + + // Warm the pool so the steady-state transport and at least one + // connection are open. Without this, the first resolve in the measured + // loop would inflate the baseline. + if _, err := pool.Resolve(ctx, makeQuery(0)); err != nil { + t.Fatalf("warm-up Resolve failed: %v", err) + } + + baseline := countOpenFDs(t) + + // Force reconnect churn by closing the connection between each query. + // Without the fix this would leak one UDP socket per round; with the + // fix the pool's shared transport keeps a single socket open. + const rounds = 20 + for i := 1; i <= rounds; i++ { + // Drain any pooled connection so the next Resolve has to redial. + drainPooledConns(pool) + + if _, err := pool.Resolve(ctx, makeQuery(i)); err != nil { + t.Fatalf("Resolve in churn loop iteration %d failed: %v", i, err) + } + } + + // Give quic-go a moment to drop any background goroutines that hold + // references to closed sockets. + time.Sleep(200 * time.Millisecond) + + after := countOpenFDs(t) + + // Allow a small slack for transient FDs (goroutine wake-ups, qlog, + // etc.) but reject anything that scales with the number of rounds. + const slack = 5 + if after > baseline+slack { + t.Fatalf("FD count grew under DoQ churn: baseline=%d after=%d rounds=%d (slack=%d)", baseline, after, rounds, slack) + } +} + +// drainPooledConns removes any idle pooled connections so the next Resolve +// is forced to dial a fresh one. It does not close the pool's transport. +func drainPooledConns(p *doqConnPool) { + for { + select { + case dc := <-p.conns: + if dc.conn != nil { + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } + default: + return + } + } +} + +func countOpenFDs(t *testing.T) int { + t.Helper() + entries, err := os.ReadDir("/proc/self/fd") + if err != nil { + t.Fatalf("read /proc/self/fd: %v", err) + } + return len(entries) +}