mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-27 12:52:27 +02:00
98ca63325f
DoQ pools now keep a single quic.Transport and UDP socket for all dials, so parallel dial and reconnect churn no longer allocate a new socket per attempt or leak the winner's UDP conn when the caller owns the packet conn. quicParallelDialer accepts an optional transport: when set, dials use Transport.DialEarly on that socket; when nil, behavior matches the old per-dial ListenUDP path (losers close their sockets). Per RFC 9250 §4.2, close the query stream's send side before reading the response so strict upstreams see STREAM FIN before answering. CloseIdleConnections closes the shared transport and underlying UDP conn so checked-out connections and the OS socket are torn down. Add a FIN-strict test server, coverage for bootstrap vs parallel-dial paths, and a Linux-only FD churn regression test.
175 lines
5.0 KiB
Go
175 lines
5.0 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"runtime"
|
|
"sync"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/http3"
|
|
)
|
|
|
|
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
|
if uc.Type != ResolverTypeDOH3 {
|
|
return nil
|
|
}
|
|
rt := &http3.Transport{}
|
|
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool, MinVersion: tls.VersionTLS12}
|
|
logger := LoggerFromCtx(ctx)
|
|
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
|
_, port, _ := net.SplitHostPort(addr)
|
|
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
|
if uc.BootstrapIP != "" {
|
|
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
|
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
|
|
udpConn, err := net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
|
}
|
|
dialAddrs := make([]string, len(addrs))
|
|
for i := range addrs {
|
|
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
|
}
|
|
pd := &quicParallelDialer{}
|
|
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
|
|
return conn, err
|
|
}
|
|
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
|
rt.CloseIdleConnections()
|
|
})
|
|
return rt
|
|
}
|
|
|
|
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6)
|
|
}
|
|
|
|
// Putting the code for quic parallel dialer here:
|
|
//
|
|
// - quic dialer is different with net.Dialer
|
|
// - simplification for quic free version
|
|
type parallelDialerResult struct {
|
|
conn *quic.Conn
|
|
err error
|
|
}
|
|
|
|
// quicParallelDialer races DialEarly across a list of remote addresses and
|
|
// returns the first successful connection. When transport is non-nil, all
|
|
// dials share that transport's UDP socket, which removes both the per-dial
|
|
// socket allocation and the winner-path socket leak that an owner-of-the-conn
|
|
// receiver cannot clean up. When transport is nil, the dialer falls back to a
|
|
// fresh UDP socket per attempt (compat path used where no shared transport is
|
|
// available yet); the loser paths close their sockets, and the winner path's
|
|
// socket is owned by quic.DialEarly's internal transport.
|
|
type quicParallelDialer struct {
|
|
transport *quic.Transport
|
|
}
|
|
|
|
// Dial performs parallel dialing to the given address list.
|
|
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
|
if len(addrs) == 0 {
|
|
return nil, errors.New("empty addresses")
|
|
}
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
done := make(chan struct{})
|
|
defer close(done)
|
|
ch := make(chan *parallelDialerResult, len(addrs))
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(addrs))
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
|
|
for _, addr := range addrs {
|
|
go func(addr string) {
|
|
defer wg.Done()
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
ch <- ¶llelDialerResult{conn: nil, err: err}
|
|
return
|
|
}
|
|
var (
|
|
conn *quic.Conn
|
|
udpConn *net.UDPConn
|
|
)
|
|
if d.transport != nil {
|
|
conn, err = d.transport.DialEarly(ctx, remoteAddr, tlsCfg, cfg)
|
|
} else {
|
|
udpConn, err = net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
ch <- ¶llelDialerResult{conn: nil, err: err}
|
|
return
|
|
}
|
|
conn, err = quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
|
if err != nil {
|
|
udpConn.Close()
|
|
udpConn = nil
|
|
}
|
|
}
|
|
select {
|
|
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
|
case <-done:
|
|
if conn != nil {
|
|
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
|
}
|
|
if udpConn != nil {
|
|
udpConn.Close()
|
|
}
|
|
}
|
|
}(addr)
|
|
}
|
|
|
|
errs := make([]error, 0, len(addrs))
|
|
for res := range ch {
|
|
if res.err == nil {
|
|
cancel()
|
|
return res.conn, res.err
|
|
}
|
|
errs = append(errs, res.err)
|
|
}
|
|
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool {
|
|
if uc.Type != ResolverTypeDOQ {
|
|
return nil
|
|
}
|
|
return newDOQConnPool(ctx, uc, addrs)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool {
|
|
if uc.Type != ResolverTypeDOT {
|
|
return nil
|
|
}
|
|
return newDOTClientPool(ctx, uc, addrs)
|
|
}
|