diff --git a/config.go b/config.go index 3e6548d..4a3c113 100644 --- a/config.go +++ b/config.go @@ -282,6 +282,9 @@ type UpstreamConfig struct { http3RoundTripper http.RoundTripper http3RoundTripper4 http.RoundTripper http3RoundTripper6 http.RoundTripper + doqConnPool *doqConnPool + doqConnPool4 *doqConnPool + doqConnPool6 *doqConnPool certPool *x509.CertPool u *url.URL fallbackOnce sync.Once @@ -504,7 +507,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { // ReBootstrap re-setup the bootstrap IP and the transport. func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: default: return } @@ -525,6 +528,27 @@ func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { uc.setupDOHTransport(ctx) case ResolverTypeDOH3: uc.setupDOH3Transport(ctx) + case ResolverTypeDOQ: + uc.setupDOQTransport(ctx) + } +} + +func (uc *UpstreamConfig) setupDOQTransport(ctx context.Context) { + switch uc.IPStack { + case IpStackBoth, "": + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) + case IpStackV4: + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + case IpStackV6: + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) + case IpStackSplit: + uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) + } else { + uc.doqConnPool6 = uc.doqConnPool4 + } + uc.doqConnPool = uc.newDOQConnPool(ctx, uc.bootstrapIPs) } } @@ -612,7 +636,7 @@ func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error { func (uc *UpstreamConfig) ping(ctx context.Context) error { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: default: return nil } @@ -646,6 +670,10 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error { if err := ping(uc.doh3Transport(ctx, typ)); err != nil { return err } + case ResolverTypeDOQ: + // For DoQ, we just ensure transport is set up by calling doqTransport + // DoQ doesn't use HTTP, so we can't ping it the same way + _ = uc.doqTransport(ctx, typ) } } diff --git a/config_quic.go b/config_quic.go index 57bd864..6172ba2 100644 --- a/config_quic.go +++ b/config_quic.go @@ -92,6 +92,27 @@ func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) htt return uc.http3RoundTripper } +func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool { + uc.transportOnce.Do(func() { + uc.SetupTransport(ctx) + }) + if uc.rebootstrap.CompareAndSwap(true, false) { + uc.SetupTransport(ctx) + } + switch uc.IPStack { + case IpStackBoth, IpStackV4, IpStackV6: + return uc.doqConnPool + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return uc.doqConnPool4 + default: + return uc.doqConnPool6 + } + } + return uc.doqConnPool +} + // Putting the code for quic parallel dialer here: // // - quic dialer is different with net.Dialer @@ -159,3 +180,7 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t return nil, errors.Join(errs...) } + +func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool { + return newDOQConnPool(ctx, uc, addrs) +} diff --git a/doq.go b/doq.go index b665cec..d309e45 100644 --- a/doq.go +++ b/doq.go @@ -5,8 +5,11 @@ package ctrld import ( "context" "crypto/tls" + "errors" "io" "net" + "runtime" + "sync" "time" "github.com/miekg/dns" @@ -21,22 +24,19 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoQ resolver query started") - endpoint := r.uc.Endpoint - tlsConfig := &tls.Config{NextProtos: []string{"doq"}} - ip := r.uc.BootstrapIP - if ip == "" { - dnsTyp := uint16(0) - if msg != nil && len(msg.Question) > 0 { - dnsTyp = msg.Question[0].Qtype - } - ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp) + // Get the appropriate connection pool based on DNS type and IP stack + dnsTyp := uint16(0) + if msg != nil && len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype } - tlsConfig.ServerName = r.uc.Domain - _, port, _ := net.SplitHostPort(endpoint) - endpoint = net.JoinHostPort(ip, port) - Log(ctx, logger.Debug(), "Sending DoQ request to: %s", endpoint) - answer, err := resolve(ctx, msg, endpoint, tlsConfig) + pool := r.uc.doqTransport(ctx, dnsTyp) + if pool == nil { + Log(ctx, logger.Error(), "DoQ connection pool is not available") + return nil, errors.New("DoQ connection pool is not available") + } + + answer, err := pool.Resolve(ctx, msg) if err != nil { Log(ctx, logger.Error().Err(err), "DoQ request failed") } else { @@ -45,11 +45,59 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, err } -func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { - // DoQ quic-go server returns io.EOF error after running for a long time, - // even for a good stream. So retrying the query for 5 times before giving up. +// doqConnPool manages a pool of QUIC connections for DoQ queries. +type doqConnPool struct { + uc *UpstreamConfig + addrs []string + port string + tlsConfig *tls.Config + mu sync.RWMutex + conns map[string]*doqConn + closed bool +} + +type doqConn struct { + conn *quic.Conn + lastUsed time.Time + refCount int + mu sync.Mutex +} + +func newDOQConnPool(ctx context.Context, uc *UpstreamConfig, addrs []string) *doqConnPool { + _, port, _ := net.SplitHostPort(uc.Endpoint) + if port == "" { + port = "853" + } + + tlsConfig := &tls.Config{ + NextProtos: []string{"doq"}, + RootCAs: uc.certPool, + ServerName: uc.Domain, + } + + pool := &doqConnPool{ + uc: uc, + addrs: addrs, + port: port, + tlsConfig: tlsConfig, + conns: make(map[string]*doqConn), + } + + // Use SetFinalizer here because we need to call a method on the pool itself. + // AddCleanup would require passing the pool as arg (which panics) or capturing + // it in a closure (which prevents GC). SetFinalizer is appropriate for this case. + runtime.SetFinalizer(pool, func(p *doqConnPool) { + p.CloseIdleConnections() + }) + + return pool +} + +// Resolve performs a DNS query using a pooled QUIC connection. +func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + // Retry logic for io.EOF errors (as per original implementation) for i := 0; i < 5; i++ { - answer, err := doResolve(ctx, msg, endpoint, tlsConfig) + answer, err := p.doResolve(ctx, msg) if err == io.EOF { continue } @@ -58,57 +106,72 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. } return answer, nil } - return nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(quic.InternalError), ErrorMessage: quic.InternalError.Message()} + return nil, &quic.ApplicationError{ + ErrorCode: quic.ApplicationErrorCode(quic.InternalError), + ErrorMessage: quic.InternalError.Message(), + } } -func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { - session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil) +func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + conn, addr, err := p.getConn(ctx) if err != nil { return nil, err } - defer session.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + // Pack the DNS message msgBytes, err := msg.Pack() if err != nil { + p.putConn(addr, conn, false) return nil, err } - stream, err := session.OpenStream() + // Open a new stream for this query + stream, err := conn.OpenStream() if err != nil { + p.putConn(addr, conn, false) return nil, err } + // Set deadline deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(5 * time.Second) } _ = stream.SetDeadline(deadline) + // Write message length (2 bytes) followed by message var msgLen = uint16(len(msgBytes)) var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)} if _, err := stream.Write(msgLenBytes); err != nil { + stream.Close() + p.putConn(addr, conn, false) return nil, err } if _, err := stream.Write(msgBytes); err != nil { + stream.Close() + p.putConn(addr, conn, false) return nil, err } + // Read response buf, err := io.ReadAll(stream) + stream.Close() + + // Return connection to pool (mark as potentially bad if error occurred) + isGood := err == nil && len(buf) > 0 + p.putConn(addr, conn, isGood) + if err != nil { return nil, err } - _ = stream.Close() - - // io.ReadAll hide the io.EOF error returned by quic-go server. - // Once we figure out why quic-go server sends io.EOF after running - // for a long time, we can have a better way to handle this. For now, - // make sure io.EOF error returned, so the caller can handle it cleanly. + // io.ReadAll hides io.EOF error, so check for empty buffer if len(buf) == 0 { return nil, io.EOF } + // Unpack DNS response (skip 2-byte length prefix) answer := new(dns.Msg) if err := answer.Unpack(buf[2:]); err != nil { return nil, err @@ -116,3 +179,142 @@ func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tl answer.SetReply(msg) return answer, nil } + +// getConn gets a QUIC connection from the pool or creates a new one. +func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, "", io.EOF + } + + // Try to reuse an existing connection + for addr, doqConn := range p.conns { + doqConn.mu.Lock() + if doqConn.refCount == 0 && doqConn.conn != nil { + // Check if connection is still alive + select { + case <-doqConn.conn.Context().Done(): + // Connection is closed, remove it + doqConn.mu.Unlock() + delete(p.conns, addr) + continue + default: + } + + doqConn.refCount++ + doqConn.lastUsed = time.Now() + conn := doqConn.conn + doqConn.mu.Unlock() + return conn, addr, nil + } + doqConn.mu.Unlock() + } + + // No available connection, create a new one + addr, conn, err := p.dialConn(ctx) + if err != nil { + return nil, "", err + } + + doqConn := &doqConn{ + conn: conn, + lastUsed: time.Now(), + refCount: 1, + } + p.conns[addr] = doqConn + + return conn, addr, nil +} + +// putConn returns a connection to the pool. +func (p *doqConnPool) putConn(addr string, conn *quic.Conn, isGood bool) { + p.mu.Lock() + defer p.mu.Unlock() + + doqConn, ok := p.conns[addr] + if !ok { + return + } + + doqConn.mu.Lock() + defer doqConn.mu.Unlock() + + doqConn.refCount-- + if doqConn.refCount < 0 { + doqConn.refCount = 0 + } + + // If connection is bad or closed, remove it from pool + if !isGood || conn.Context().Err() != nil { + delete(p.conns, addr) + conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + return + } + + doqConn.lastUsed = time.Now() +} + +// dialConn creates a new QUIC connection using parallel dialing like DoH3. +func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) { + logger := LoggerFromCtx(ctx) + + // If we have a bootstrap IP, use it directly + if p.uc.BootstrapIP != "" { + addr := net.JoinHostPort(p.uc.BootstrapIP, p.port) + Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr) + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return "", nil, err + } + remoteAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + udpConn.Close() + return "", nil, err + } + conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, nil) + if err != nil { + udpConn.Close() + return "", nil, err + } + return addr, conn, nil + } + + // Use parallel dialing like DoH3 + dialAddrs := make([]string, len(p.addrs)) + for i := range p.addrs { + dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port) + } + + pd := &quicParallelDialer{} + conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, nil) + if err != nil { + return "", nil, err + } + + addr := conn.RemoteAddr().String() + Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr) + return addr, conn, nil +} + +// CloseIdleConnections closes all idle connections in the pool. +// When called during cleanup (e.g., from finalizer), it closes all connections +// regardless of refCount to prevent resource leaks. +func (p *doqConnPool) CloseIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + p.closed = true + + for addr, dc := range p.conns { + dc.mu.Lock() + if dc.conn != nil { + // Close all connections to ensure proper cleanup, even if in use + // This prevents resource leaks when the pool is being destroyed + dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") + } + dc.mu.Unlock() + delete(p.conns, addr) + } +}