From f4a938c873eab311dad65ff4e5de1d9fb7876912 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 6 Jan 2026 14:46:00 +0700 Subject: [PATCH] perf(doq): implement connection pooling for improved performance Implement QUIC connection pooling for DoQ resolver to match DoH3 performance. Previously, DoQ created a new QUIC connection for every DNS query, incurring significant handshake overhead. Now connections are reused across queries, eliminating this overhead for subsequent requests. The implementation follows the same pattern as DoH3, using parallel dialing and connection pooling to achieve comparable performance characteristics. --- config.go | 32 +++++- config_quic.go | 25 +++++ doq.go | 260 +++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 286 insertions(+), 31 deletions(-) diff --git a/config.go b/config.go index 3e6548d..4a3c113 100644 --- a/config.go +++ b/config.go @@ -282,6 +282,9 @@ type UpstreamConfig struct { http3RoundTripper http.RoundTripper http3RoundTripper4 http.RoundTripper http3RoundTripper6 http.RoundTripper + doqConnPool *doqConnPool + doqConnPool4 *doqConnPool + doqConnPool6 *doqConnPool certPool *x509.CertPool u *url.URL fallbackOnce sync.Once @@ -504,7 +507,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { // ReBootstrap re-setup the bootstrap IP and the transport. func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: default: return } @@ -525,6 +528,27 @@ func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { uc.setupDOHTransport(ctx) case ResolverTypeDOH3: uc.setupDOH3Transport(ctx) + case ResolverTypeDOQ: + uc.setupDOQTransport(ctx) + } +} + +func (uc *UpstreamConfig) setupDOQTransport(ctx context.Context) { + switch uc.IPStack { + case IpStackBoth, "": + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) + case IpStackV4: + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + case IpStackV6: + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) + case IpStackSplit: + uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) + } else { + uc.doqConnPool6 = uc.doqConnPool4 + } + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) } } @@ -612,7 +636,7 @@ func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error { func (uc *UpstreamConfig) ping(ctx context.Context) error { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: default: return nil } @@ -646,6 +670,10 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error { if err := ping(uc.doh3Transport(ctx, typ)); err != nil { return err } + case ResolverTypeDOQ: + // For DoQ, we just ensure transport is set up by calling doqTransport + // DoQ doesn't use HTTP, so we can't ping it the same way + _ = uc.doqTransport(ctx, typ) } } diff --git a/config_quic.go b/config_quic.go index 57bd864..6172ba2 100644 --- a/config_quic.go +++ b/config_quic.go @@ -92,6 +92,27 @@ func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) htt return uc.http3RoundTripper } +func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool { + uc.transportOnce.Do(func() { + uc.SetupTransport(ctx) + }) + if uc.rebootstrap.CompareAndSwap(true, false) { + uc.SetupTransport(ctx) + } + switch uc.IPStack { + case IpStackBoth, IpStackV4, IpStackV6: + return uc.doqConnPool + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return uc.doqConnPool4 + default: + return uc.doqConnPool6 + } + } + return uc.doqConnPool +} + // Putting the code for quic parallel dialer here: // // - quic dialer is different with net.Dialer @@ -159,3 +180,7 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t return nil, errors.Join(errs...) } + +func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool { + return newDOQConnPool(ctx, uc, addrs) +} diff --git a/doq.go b/doq.go index b665cec..d309e45 100644 --- a/doq.go +++ b/doq.go @@ -5,8 +5,11 @@ package ctrld import ( "context" "crypto/tls" + "errors" "io" "net" + "runtime" + "sync" "time" "github.com/miekg/dns" @@ -21,22 +24,19 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoQ resolver query started") - endpoint := r.uc.Endpoint - tlsConfig := &tls.Config{NextProtos: []string{"doq"}} - ip := r.uc.BootstrapIP - if ip == "" { - dnsTyp := uint16(0) - if msg != nil && len(msg.Question) > 0 { - dnsTyp = msg.Question[0].Qtype - } - ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp) + // Get the appropriate connection pool based on DNS type and IP stack + dnsTyp := uint16(0) + if msg != nil && len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype } - tlsConfig.ServerName = r.uc.Domain - _, port, _ := net.SplitHostPort(endpoint) - endpoint = net.JoinHostPort(ip, port) - Log(ctx, logger.Debug(), "Sending DoQ request to: %s", endpoint) - answer, err := resolve(ctx, msg, endpoint, tlsConfig) + pool := r.uc.doqTransport(ctx, dnsTyp) + if pool == nil { + Log(ctx, logger.Error(), "DoQ connection pool is not available") + return nil, errors.New("DoQ connection pool is not available") + } + + answer, err := pool.Resolve(ctx, msg) if err != nil { Log(ctx, logger.Error().Err(err), "DoQ request failed") } else { @@ -45,11 +45,59 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, err } -func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { - // DoQ quic-go server returns io.EOF error after running for a long time, - // even for a good stream. So retrying the query for 5 times before giving up. +// doqConnPool manages a pool of QUIC connections for DoQ queries. +type doqConnPool struct { + uc *UpstreamConfig + addrs []string + port string + tlsConfig *tls.Config + mu sync.RWMutex + conns map[string]*doqConn + closed bool +} + +type doqConn struct { + conn *quic.Conn + lastUsed time.Time + refCount int + mu sync.Mutex +} + +func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool { + _, port, _ := net.SplitHostPort(uc.Endpoint) + if port == "" { + port = "853" + } + + tlsConfig := &tls.Config{ + NextProtos: []string{"doq"}, + RootCAs: uc.certPool, + ServerName: uc.Domain, + } + + pool := &doqConnPool{ + uc: uc, + addrs: addrs, + port: port, + tlsConfig: tlsConfig, + conns: make(map[string]*doqConn), + } + + // Use SetFinalizer here because we need to call a method on the pool itself. + // AddCleanup would require passing the pool as arg (which panics) or capturing + // it in a closure (which prevents GC). SetFinalizer is appropriate for this case. + runtime.SetFinalizer(pool, func(p *doqConnPool) { + p.CloseIdleConnections() + }) + + return pool +} + +// Resolve performs a DNS query using a pooled QUIC connection. +func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + // Retry logic for io.EOF errors (as per original implementation) for i := 0; i < 5; i++ { - answer, err := doResolve(ctx, msg, endpoint, tlsConfig) + answer, err := p.doResolve(ctx, msg) if err == io.EOF { continue } @@ -58,57 +106,72 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. } return answer, nil } - return nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(quic.InternalError), ErrorMessage: quic.InternalError.Message()} + return nil, &quic.ApplicationError{ + ErrorCode: quic.ApplicationErrorCode(quic.InternalError), + ErrorMessage: quic.InternalError.Message(), + } } -func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { - session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil) +func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + conn, addr, err := p.getConn(ctx) if err != nil { return nil, err } - defer session.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + // Pack the DNS message msgBytes, err := msg.Pack() if err != nil { + p.putConn(addr, conn, false) return nil, err } - stream, err := session.OpenStream() + // Open a new stream for this query + stream, err := conn.OpenStream() if err != nil { + p.putConn(addr, conn, false) return nil, err } + // Set deadline deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(5 * time.Second) } _ = stream.SetDeadline(deadline) + // Write message length (2 bytes) followed by message var msgLen = uint16(len(msgBytes)) var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)} if _, err := stream.Write(msgLenBytes); err != nil { + stream.Close() + p.putConn(addr, conn, false) return nil, err } if _, err := stream.Write(msgBytes); err != nil { + stream.Close() + p.putConn(addr, conn, false) return nil, err } + // Read response buf, err := io.ReadAll(stream) + stream.Close() + + // Return connection to pool (mark as potentially bad if error occurred) + isGood := err == nil && len(buf) > 0 + p.putConn(addr, conn, isGood) + if err != nil { return nil, err } - _ = stream.Close() - - // io.ReadAll hide the io.EOF error returned by quic-go server. - // Once we figure out why quic-go server sends io.EOF after running - // for a long time, we can have a better way to handle this. For now, - // make sure io.EOF error returned, so the caller can handle it cleanly. + // io.ReadAll hides io.EOF error, so check for empty buffer if len(buf) == 0 { return nil, io.EOF } + // Unpack DNS response (skip 2-byte length prefix) answer := new(dns.Msg) if err := answer.Unpack(buf[2:]); err != nil { return nil, err @@ -116,3 +179,142 @@ func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tl answer.SetReply(msg) return answer, nil } + +// getConn gets a QUIC connection from the pool or creates a new one. +func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, "", io.EOF + } + + // Try to reuse an existing connection + for addr, doqConn := range p.conns { + doqConn.mu.Lock() + if doqConn.refCount == 0 && doqConn.conn != nil { + // Check if connection is still alive + select { + case <-doqConn.conn.Context().Done(): + // Connection is closed, remove it + doqConn.mu.Unlock() + delete(p.conns, addr) + continue + default: + } + + doqConn.refCount++ + doqConn.lastUsed = time.Now() + conn := doqConn.conn + doqConn.mu.Unlock() + return conn, addr, nil + } + doqConn.mu.Unlock() + } + + // No available connection, create a new one + addr, conn, err := p.dialConn(ctx) + if err != nil { + return nil, "", err + } + + doqConn := &doqConn{ + conn: conn, + lastUsed: time.Now(), + refCount: 1, + } + p.conns[addr] = doqConn + + return conn, addr, nil +} + +// putConn returns a connection to the pool. +func (p *doqConnPool) putConn(addr string, conn *quic.Conn, isGood bool) { + p.mu.Lock() + defer p.mu.Unlock() + + doqConn, ok := p.conns[addr] + if !ok { + return + } + + doqConn.mu.Lock() + defer doqConn.mu.Unlock() + + doqConn.refCount-- + if doqConn.refCount < 0 { + doqConn.refCount = 0 + } + + // If connection is bad or closed, remove it from pool + if !isGood || conn.Context().Err() != nil { + delete(p.conns, addr) + conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + return + } + + doqConn.lastUsed = time.Now() +} + +// dialConn creates a new QUIC connection using parallel dialing like DoH3. +func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) { + logger := LoggerFromCtx(ctx) + + // If we have a bootstrap IP, use it directly + if p.uc.BootstrapIP != "" { + addr := net.JoinHostPort(p.uc.BootstrapIP, p.port) + Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr) + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return "", nil, err + } + remoteAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + udpConn.Close() + return "", nil, err + } + conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, nil) + if err != nil { + udpConn.Close() + return "", nil, err + } + return addr, conn, nil + } + + // Use parallel dialing like DoH3 + dialAddrs := make([]string, len(p.addrs)) + for i := range p.addrs { + dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port) + } + + pd := &quicParallelDialer{} + conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, nil) + if err != nil { + return "", nil, err + } + + addr := conn.RemoteAddr().String() + Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr) + return addr, conn, nil +} + +// CloseIdleConnections closes all idle connections in the pool. +// When called during cleanup (e.g., from finalizer), it closes all connections +// regardless of refCount to prevent resource leaks. +func (p *doqConnPool) CloseIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + p.closed = true + + for addr, dc := range p.conns { + dc.mu.Lock() + if dc.conn != nil { + // Close all connections to ensure proper cleanup, even if in use + // This prevents resource leaks when the pool is being destroyed + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } + dc.mu.Unlock() + delete(p.conns, addr) + } +}