mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
perf(dot): implement connection pooling for improved performance
Implement TCP/TLS connection pooling for DoT resolver to match DoQ performance. Previously, DoT created a new TCP/TLS connection for every DNS query, incurring significant TLS handshake overhead. Now connections are reused across queries, eliminating this overhead for subsequent requests. The implementation follows the same pattern as DoQ, using parallel dialing and connection pooling to achieve comparable performance characteristics.
This commit is contained in:
committed by
Cuong Manh Le
parent
2e8a0f00a0
commit
acbebcf7c2
20
config.go
20
config.go
@@ -282,6 +282,9 @@ type UpstreamConfig struct {
|
||||
doqConnPool *doqConnPool
|
||||
doqConnPool4 *doqConnPool
|
||||
doqConnPool6 *doqConnPool
|
||||
dotClientPool *dotConnPool
|
||||
dotClientPool4 *dotConnPool
|
||||
dotClientPool6 *dotConnPool
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
fallbackOnce sync.Once
|
||||
@@ -496,7 +499,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
// ReBootstrap re-setup the bootstrap IP and the transport.
|
||||
func (uc *UpstreamConfig) ReBootstrap() {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
|
||||
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT:
|
||||
default:
|
||||
return
|
||||
}
|
||||
@@ -508,11 +511,11 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
})
|
||||
}
|
||||
|
||||
// SetupTransport initializes the network transport used to connect to upstream server.
|
||||
// For now, only DoH upstream is supported.
|
||||
// SetupTransport initializes the network transport used to connect to upstream servers.
|
||||
// For now, DoH/DoH3/DoQ/DoT upstreams are supported.
|
||||
func (uc *UpstreamConfig) SetupTransport() {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ:
|
||||
case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT:
|
||||
default:
|
||||
return
|
||||
}
|
||||
@@ -523,21 +526,26 @@ func (uc *UpstreamConfig) SetupTransport() {
|
||||
case IpStackV6:
|
||||
ips = uc.bootstrapIPs6
|
||||
}
|
||||
|
||||
uc.transport = uc.newDOHTransport(ips)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ips)
|
||||
uc.doqConnPool = uc.newDOQConnPool(ips)
|
||||
uc.dotClientPool = uc.newDOTClientPool(ips)
|
||||
if uc.IPStack == IpStackSplit {
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
uc.doqConnPool4 = uc.newDOQConnPool(uc.bootstrapIPs4)
|
||||
uc.dotClientPool4 = uc.newDOTClientPool(uc.bootstrapIPs4)
|
||||
if HasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.doqConnPool6 = uc.newDOQConnPool(uc.bootstrapIPs6)
|
||||
uc.dotClientPool6 = uc.newDOTClientPool(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
uc.doqConnPool6 = uc.doqConnPool4
|
||||
uc.dotClientPool6 = uc.dotClientPool4
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -656,6 +664,10 @@ func (uc *UpstreamConfig) ping() error {
|
||||
// For DoQ, we just ensure transport is set up by calling doqTransport
|
||||
// DoQ doesn't use HTTP, so we can't ping it the same way
|
||||
_ = uc.doqTransport(typ)
|
||||
case ResolverTypeDOT:
|
||||
// For DoT, we just ensure transport is set up by calling dotTransport
|
||||
// DoT doesn't use HTTP, so we can't ping it the same way
|
||||
_ = uc.dotTransport(typ)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,25 +13,6 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
if HasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if uc.Type != ResolverTypeDOH3 {
|
||||
return nil
|
||||
@@ -82,6 +63,11 @@ func (uc *UpstreamConfig) doqTransport(dnsType uint16) *doqConnPool {
|
||||
return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dotTransport(dnsType uint16) *dotConnPool {
|
||||
uc.ensureSetupTransport()
|
||||
return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6)
|
||||
}
|
||||
|
||||
// Putting the code for quic parallel dialer here:
|
||||
//
|
||||
// - quic dialer is different with net.Dialer
|
||||
@@ -156,3 +142,10 @@ func (uc *UpstreamConfig) newDOQConnPool(addrs []string) *doqConnPool {
|
||||
}
|
||||
return newDOQConnPool(uc, addrs)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOTClientPool(addrs []string) *dotConnPool {
|
||||
if uc.Type != ResolverTypeDOT {
|
||||
return nil
|
||||
}
|
||||
return newDOTClientPool(uc, addrs)
|
||||
}
|
||||
|
||||
4
doh.go
4
doh.go
@@ -85,6 +85,10 @@ type dohResolver struct {
|
||||
|
||||
// Resolve performs DNS query with given DNS message using DOH protocol.
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
4
doq.go
4
doq.go
@@ -21,6 +21,10 @@ type doqResolver struct {
|
||||
}
|
||||
|
||||
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the appropriate connection pool based on DNS type and IP stack
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
|
||||
307
dot.go
307
dot.go
@@ -3,7 +3,12 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -13,30 +18,292 @@ type dotResolver struct {
|
||||
}
|
||||
|
||||
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// The dialer is used to prevent bootstrapping cycle.
|
||||
// If r.endpoint is set to dns.controld.dev, we need to resolve
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: tcpNet,
|
||||
Dialer: dialer,
|
||||
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
|
||||
}
|
||||
endpoint := r.uc.Endpoint
|
||||
if r.uc.BootstrapIP != "" {
|
||||
dnsClient.TLSConfig.ServerName = r.uc.Domain
|
||||
dnsClient.Net = "tcp-tls"
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
|
||||
|
||||
pool := r.uc.dotTransport(dnsTyp)
|
||||
if pool == nil {
|
||||
return nil, errors.New("DoT client pool is not available")
|
||||
}
|
||||
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
|
||||
return answer, wrapCertificateVerificationError(err)
|
||||
return pool.Resolve(ctx, msg)
|
||||
}
|
||||
|
||||
// dotConnPool manages a pool of TCP/TLS connections for DoT queries.
|
||||
type dotConnPool struct {
|
||||
uc *UpstreamConfig
|
||||
addrs []string
|
||||
port string
|
||||
tlsConfig *tls.Config
|
||||
dialer *net.Dialer
|
||||
mu sync.RWMutex
|
||||
conns map[string]*dotConn
|
||||
closed bool
|
||||
}
|
||||
|
||||
type dotConn struct {
|
||||
conn net.Conn
|
||||
lastUsed time.Time
|
||||
refCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newDOTClientPool(uc *UpstreamConfig, addrs []string) *dotConnPool {
|
||||
_, port, _ := net.SplitHostPort(uc.Endpoint)
|
||||
if port == "" {
|
||||
port = "853"
|
||||
}
|
||||
|
||||
// The dialer is used to prevent bootstrapping cycle.
|
||||
// If endpoint is set to dns.controld.dev, we need to resolve
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: uc.certPool,
|
||||
}
|
||||
|
||||
if uc.BootstrapIP != "" {
|
||||
tlsConfig.ServerName = uc.Domain
|
||||
}
|
||||
|
||||
pool := &dotConnPool{
|
||||
uc: uc,
|
||||
addrs: addrs,
|
||||
port: port,
|
||||
tlsConfig: tlsConfig,
|
||||
dialer: dialer,
|
||||
conns: make(map[string]*dotConn),
|
||||
}
|
||||
|
||||
// Use SetFinalizer here because we need to call a method on the pool itself.
|
||||
// AddCleanup would require passing the pool as arg (which panics) or capturing
|
||||
// it in a closure (which prevents GC). SetFinalizer is appropriate for this case.
|
||||
runtime.SetFinalizer(pool, func(p *dotConnPool) {
|
||||
p.CloseIdleConnections()
|
||||
})
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// Resolve performs a DNS query using a pooled TCP/TLS connection.
|
||||
func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if msg == nil {
|
||||
return nil, errors.New("nil DNS message")
|
||||
}
|
||||
|
||||
conn, addr, err := p.getConn(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapCertificateVerificationError(err)
|
||||
}
|
||||
|
||||
// Set deadline
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
}
|
||||
_ = conn.SetDeadline(deadline)
|
||||
|
||||
client := dns.Client{Net: "tcp-tls"}
|
||||
answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn})
|
||||
isGood := err == nil
|
||||
p.putConn(addr, conn, isGood)
|
||||
|
||||
if err != nil {
|
||||
return nil, wrapCertificateVerificationError(err)
|
||||
}
|
||||
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
// getConn gets a TCP/TLS connection from the pool or creates a new one.
|
||||
func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil, "", io.EOF
|
||||
}
|
||||
|
||||
// Try to reuse an existing connection
|
||||
for addr, dotConn := range p.conns {
|
||||
dotConn.mu.Lock()
|
||||
if dotConn.refCount == 0 && dotConn.conn != nil {
|
||||
dotConn.refCount++
|
||||
dotConn.lastUsed = time.Now()
|
||||
conn := dotConn.conn
|
||||
dotConn.mu.Unlock()
|
||||
return conn, addr, nil
|
||||
}
|
||||
dotConn.mu.Unlock()
|
||||
}
|
||||
|
||||
// No available connection, create a new one
|
||||
addr, conn, err := p.dialConn(ctx)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
dotConn := &dotConn{
|
||||
conn: conn,
|
||||
lastUsed: time.Now(),
|
||||
refCount: 1,
|
||||
}
|
||||
p.conns[addr] = dotConn
|
||||
|
||||
return conn, addr, nil
|
||||
}
|
||||
|
||||
// putConn returns a connection to the pool.
|
||||
func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
dotConn, ok := p.conns[addr]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
dotConn.mu.Lock()
|
||||
defer dotConn.mu.Unlock()
|
||||
|
||||
dotConn.refCount--
|
||||
if dotConn.refCount < 0 {
|
||||
dotConn.refCount = 0
|
||||
}
|
||||
|
||||
// If connection is bad, remove it from pool
|
||||
if !isGood {
|
||||
delete(p.conns, addr)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dotConn.lastUsed = time.Now()
|
||||
}
|
||||
|
||||
// dialConn creates a new TCP/TLS connection.
|
||||
func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) {
|
||||
logger := ProxyLogger.Load()
|
||||
var endpoint string
|
||||
|
||||
if p.uc.BootstrapIP != "" {
|
||||
endpoint = net.JoinHostPort(p.uc.BootstrapIP, p.port)
|
||||
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
|
||||
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
tlsConn := tls.Client(conn, p.tlsConfig)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
conn.Close()
|
||||
return "", nil, err
|
||||
}
|
||||
return endpoint, tlsConn, nil
|
||||
}
|
||||
|
||||
// Try bootstrap IPs in parallel
|
||||
if len(p.addrs) > 0 {
|
||||
type result struct {
|
||||
conn net.Conn
|
||||
addr string
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan result, len(p.addrs))
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
for _, addr := range p.addrs {
|
||||
go func(addr string) {
|
||||
endpoint := net.JoinHostPort(addr, p.port)
|
||||
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
|
||||
if err != nil {
|
||||
select {
|
||||
case ch <- result{conn: nil, addr: endpoint, err: err}:
|
||||
case <-done:
|
||||
}
|
||||
return
|
||||
}
|
||||
tlsConfig := p.tlsConfig.Clone()
|
||||
tlsConfig.ServerName = p.uc.Domain
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
conn.Close()
|
||||
select {
|
||||
case ch <- result{conn: nil, addr: endpoint, err: err}:
|
||||
case <-done:
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- result{conn: tlsConn, addr: endpoint, err: nil}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(p.addrs))
|
||||
for range len(p.addrs) {
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.err == nil && res.conn != nil {
|
||||
Log(ctx, logger.Debug(), "Sending DoT request to: %s", res.addr)
|
||||
return res.addr, res.conn, nil
|
||||
}
|
||||
if res.err != nil {
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return "", nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Fallback to endpoint resolution
|
||||
endpoint = p.uc.Endpoint
|
||||
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
|
||||
conn, err := p.dialer.DialContext(ctx, "tcp", endpoint)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
tlsConn := tls.Client(conn, p.tlsConfig)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
conn.Close()
|
||||
return "", nil, err
|
||||
}
|
||||
return endpoint, tlsConn, nil
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes all connections in the pool.
|
||||
func (p *dotConnPool) CloseIdleConnections() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
for addr, dotConn := range p.conns {
|
||||
dotConn.mu.Lock()
|
||||
if dotConn.conn != nil {
|
||||
dotConn.conn.Close()
|
||||
}
|
||||
dotConn.mu.Unlock()
|
||||
delete(p.conns, addr)
|
||||
}
|
||||
}
|
||||
|
||||
20
resolver.go
20
resolver.go
@@ -291,6 +291,9 @@ const hotCacheTTL = time.Second
|
||||
// for a short period (currently 1 second), reducing unnecessary traffics
|
||||
// sent to upstreams.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msg.Question) == 0 {
|
||||
return nil, errors.New("no question found")
|
||||
}
|
||||
@@ -509,6 +512,10 @@ type legacyResolver struct {
|
||||
}
|
||||
|
||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// See comment in (*dotResolver).resolve method.
|
||||
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
@@ -534,6 +541,9 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
|
||||
type dummyResolver struct{}
|
||||
|
||||
func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if err := validateMsg(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ans := new(dns.Msg)
|
||||
ans.SetReply(msg)
|
||||
return ans, nil
|
||||
@@ -769,3 +779,13 @@ func isLanAddr(addr netip.Addr) bool {
|
||||
addr.IsLinkLocalUnicast() ||
|
||||
tsaddr.CGNATRange().Contains(addr)
|
||||
}
|
||||
|
||||
func validateMsg(msg *dns.Msg) error {
|
||||
if msg == nil {
|
||||
return errors.New("nil DNS message")
|
||||
}
|
||||
if len(msg.Question) == 0 {
|
||||
return errors.New("no question found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user