Files
ctrld/config_quic.go
T
Cuong Manh Le 98ca63325f fix(doq): share QUIC transport, close send side before read (RFC 9250)
DoQ pools now keep a single quic.Transport and UDP socket for all dials,
so parallel dial and reconnect churn no longer allocate a new socket per
attempt or leak the winner's UDP conn when the caller owns the packet
conn.

quicParallelDialer accepts an optional transport: when set, dials use
Transport.DialEarly on that socket; when nil, behavior matches the old
per-dial ListenUDP path (losers close their sockets).

Per RFC 9250 §4.2, close the query stream's send side before reading the
response so strict upstreams see STREAM FIN before answering.

CloseIdleConnections closes the shared transport and underlying UDP
conn so checked-out connections and the OS socket are torn down.

Add a FIN-strict test server, coverage for bootstrap vs parallel-dial
paths, and a Linux-only FD churn regression test.
2026-05-16 04:13:38 +07:00

175 lines
5.0 KiB
Go

package ctrld
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"runtime"
"sync"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
if uc.Type != ResolverTypeDOH3 {
return nil
}
rt := &http3.Transport{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool, MinVersion: tls.VersionTLS12}
logger := LoggerFromCtx(ctx)
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
}
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
rt.CloseIdleConnections()
})
return rt
}
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
uc.ensureSetupTransport(ctx)
return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6)
}
func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool {
uc.ensureSetupTransport(ctx)
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
}
func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool {
uc.ensureSetupTransport(ctx)
return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6)
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
// - simplification for quic free version
type parallelDialerResult struct {
conn *quic.Conn
err error
}
// quicParallelDialer races DialEarly across a list of remote addresses and
// returns the first successful connection. When transport is non-nil, all
// dials share that transport's UDP socket, which removes both the per-dial
// socket allocation and the winner-path socket leak that an owner-of-the-conn
// receiver cannot clean up. When transport is nil, the dialer falls back to a
// fresh UDP socket per attempt (compat path used where no shared transport is
// available yet); the loser paths close their sockets, and the winner path's
// socket is owned by quic.DialEarly's internal transport.
type quicParallelDialer struct {
transport *quic.Transport
}
// Dial performs parallel dialing to the given address list.
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
defer close(done)
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
var (
conn *quic.Conn
udpConn *net.UDPConn
)
if d.transport != nil {
conn, err = d.transport.DialEarly(ctx, remoteAddr, tlsCfg, cfg)
} else {
udpConn, err = net.ListenUDP("udp", nil)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err = quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
if err != nil {
udpConn.Close()
udpConn = nil
}
}
select {
case ch <- &parallelDialerResult{conn: conn, err: err}:
case <-done:
if conn != nil {
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
}
if udpConn != nil {
udpConn.Close()
}
}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}
func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool {
if uc.Type != ResolverTypeDOQ {
return nil
}
return newDOQConnPool(ctx, uc, addrs)
}
func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool {
if uc.Type != ResolverTypeDOT {
return nil
}
return newDOTClientPool(ctx, uc, addrs)
}