mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-13 10:26:06 +00:00
Implement TCP/TLS connection pooling for DoT resolver to match DoQ performance. Previously, DoT created a new TCP/TLS connection for every DNS query, incurring significant TLS handshake overhead. Now connections are reused across queries, eliminating this overhead for subsequent requests. The implementation follows the same pattern as DoQ, using parallel dialing and connection pooling to achieve comparable performance characteristics.
153 lines
4.1 KiB
Go
153 lines
4.1 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"runtime"
|
|
"sync"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/http3"
|
|
)
|
|
|
|
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
|
if uc.Type != ResolverTypeDOH3 {
|
|
return nil
|
|
}
|
|
rt := &http3.Transport{}
|
|
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
|
logger := LoggerFromCtx(ctx)
|
|
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
|
_, port, _ := net.SplitHostPort(addr)
|
|
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
|
if uc.BootstrapIP != "" {
|
|
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
|
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
|
|
udpConn, err := net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
|
}
|
|
dialAddrs := make([]string, len(addrs))
|
|
for i := range addrs {
|
|
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
|
}
|
|
pd := &quicParallelDialer{}
|
|
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
|
|
return conn, err
|
|
}
|
|
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
|
rt.CloseIdleConnections()
|
|
})
|
|
return rt
|
|
}
|
|
|
|
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool {
|
|
uc.ensureSetupTransport(ctx)
|
|
return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6)
|
|
}
|
|
|
|
// Putting the code for quic parallel dialer here:
|
|
//
|
|
// - quic dialer is different with net.Dialer
|
|
// - simplification for quic free version
|
|
type parallelDialerResult struct {
|
|
conn *quic.Conn
|
|
err error
|
|
}
|
|
|
|
type quicParallelDialer struct{}
|
|
|
|
// Dial performs parallel dialing to the given address list.
|
|
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
|
if len(addrs) == 0 {
|
|
return nil, errors.New("empty addresses")
|
|
}
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
done := make(chan struct{})
|
|
defer close(done)
|
|
ch := make(chan *parallelDialerResult, len(addrs))
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(addrs))
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
|
|
for _, addr := range addrs {
|
|
go func(addr string) {
|
|
defer wg.Done()
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
ch <- ¶llelDialerResult{conn: nil, err: err}
|
|
return
|
|
}
|
|
udpConn, err := net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
ch <- ¶llelDialerResult{conn: nil, err: err}
|
|
return
|
|
}
|
|
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
|
select {
|
|
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
|
case <-done:
|
|
if conn != nil {
|
|
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
|
}
|
|
if udpConn != nil {
|
|
udpConn.Close()
|
|
}
|
|
}
|
|
}(addr)
|
|
}
|
|
|
|
errs := make([]error, 0, len(addrs))
|
|
for res := range ch {
|
|
if res.err == nil {
|
|
cancel()
|
|
return res.conn, res.err
|
|
}
|
|
errs = append(errs, res.err)
|
|
}
|
|
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool {
|
|
if uc.Type != ResolverTypeDOQ {
|
|
return nil
|
|
}
|
|
return newDOQConnPool(ctx, uc, addrs)
|
|
}
|
|
|
|
func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool {
|
|
if uc.Type != ResolverTypeDOT {
|
|
return nil
|
|
}
|
|
return newDOTClientPool(ctx, uc, addrs)
|
|
}
|