mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-04-20 00:36:37 +02:00
ed98104384
Replace conn.OpenStream (non-blocking) with conn.OpenStreamSync so that the resolver waits for the server's MAX_STREAMS credit replenishment frame instead of immediately failing when the stream limit is temporarily exhausted. Also retry on StreamLimitReachedError as defense-in-depth for servers that are slow or fail to send MAX_STREAMS updates.
291 lines
7.2 KiB
Go
291 lines
7.2 KiB
Go
//go:build !qf
|
|
|
|
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"runtime"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/quic-go/quic-go"
|
|
)
|
|
|
|
type doqResolver struct {
|
|
uc *UpstreamConfig
|
|
}
|
|
|
|
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
if err := validateMsg(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get the appropriate connection pool based on DNS type and IP stack
|
|
dnsTyp := uint16(0)
|
|
if msg != nil && len(msg.Question) > 0 {
|
|
dnsTyp = msg.Question[0].Qtype
|
|
}
|
|
|
|
pool := r.uc.doqTransport(dnsTyp)
|
|
if pool == nil {
|
|
return nil, errors.New("DoQ connection pool is not available")
|
|
}
|
|
|
|
return pool.Resolve(ctx, msg)
|
|
}
|
|
|
|
const doqPoolSize = 16
|
|
|
|
// doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel.
|
|
type doqConnPool struct {
|
|
uc *UpstreamConfig
|
|
addrs []string
|
|
port string
|
|
tlsConfig *tls.Config
|
|
quicConfig *quic.Config
|
|
conns chan *doqConn
|
|
}
|
|
|
|
type doqConn struct {
|
|
conn *quic.Conn
|
|
}
|
|
|
|
func newDOQConnPool(uc *UpstreamConfig, addrs []string) *doqConnPool {
|
|
_, port, _ := net.SplitHostPort(uc.Endpoint)
|
|
if port == "" {
|
|
port = "853"
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
NextProtos: []string{"doq"},
|
|
RootCAs: uc.certPool,
|
|
ServerName: uc.Domain,
|
|
}
|
|
|
|
quicConfig := &quic.Config{
|
|
KeepAlivePeriod: 15 * time.Second,
|
|
}
|
|
|
|
pool := &doqConnPool{
|
|
uc: uc,
|
|
addrs: addrs,
|
|
port: port,
|
|
tlsConfig: tlsConfig,
|
|
quicConfig: quicConfig,
|
|
conns: make(chan *doqConn, doqPoolSize),
|
|
}
|
|
|
|
// Use SetFinalizer here because we need to call a method on the pool itself.
|
|
// AddCleanup would require passing the pool as arg (which panics) or capturing
|
|
// it in a closure (which prevents GC). SetFinalizer is appropriate for this case.
|
|
runtime.SetFinalizer(pool, func(p *doqConnPool) {
|
|
p.CloseIdleConnections()
|
|
})
|
|
|
|
return pool
|
|
}
|
|
|
|
// Resolve performs a DNS query using a pooled QUIC connection.
|
|
func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
// Retry logic for transient errors: io.EOF (connection reset),
|
|
// IdleTimeoutError (stale pooled connection timed out), and
|
|
// StreamLimitReachedError (stream credit exhausted before server MAX_STREAMS arrived).
|
|
for range 5 {
|
|
answer, err := p.doResolve(ctx, msg)
|
|
if err == io.EOF {
|
|
continue
|
|
}
|
|
var idleErr *quic.IdleTimeoutError
|
|
if errors.As(err, &idleErr) {
|
|
continue
|
|
}
|
|
var streamLimitErr quic.StreamLimitReachedError
|
|
if errors.As(err, &streamLimitErr) {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return nil, wrapCertificateVerificationError(err)
|
|
}
|
|
return answer, nil
|
|
}
|
|
return nil, &quic.ApplicationError{
|
|
ErrorCode: quic.ApplicationErrorCode(quic.InternalError),
|
|
ErrorMessage: quic.InternalError.Message(),
|
|
}
|
|
}
|
|
|
|
func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
conn, err := p.getConn(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Pack the DNS message
|
|
msgBytes, err := msg.Pack()
|
|
if err != nil {
|
|
p.putConn(conn, false)
|
|
return nil, err
|
|
}
|
|
|
|
// Ensure the context has a deadline before calling OpenStreamSync, which
|
|
// blocks until the server sends a MAX_STREAMS update. Without a deadline the
|
|
// call could block indefinitely when the server never sends the update.
|
|
deadline, ok := ctx.Deadline()
|
|
if !ok {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
deadline, _ = ctx.Deadline()
|
|
}
|
|
|
|
// OpenStreamSync blocks until the server's MAX_STREAMS credit arrives,
|
|
// avoiding the StreamLimitReachedError race that OpenStream (non-blocking)
|
|
// triggers when the credit replenishment frame is still in flight.
|
|
stream, err := conn.OpenStreamSync(ctx)
|
|
if err != nil {
|
|
p.putConn(conn, false)
|
|
return nil, err
|
|
}
|
|
_ = stream.SetDeadline(deadline)
|
|
|
|
// Write message length (2 bytes) followed by message
|
|
var msgLen = uint16(len(msgBytes))
|
|
var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)}
|
|
if _, err := stream.Write(msgLenBytes); err != nil {
|
|
stream.Close()
|
|
p.putConn(conn, false)
|
|
return nil, err
|
|
}
|
|
|
|
if _, err := stream.Write(msgBytes); err != nil {
|
|
stream.Close()
|
|
p.putConn(conn, false)
|
|
return nil, err
|
|
}
|
|
|
|
// Read response
|
|
buf, err := io.ReadAll(stream)
|
|
stream.Close()
|
|
|
|
// Return connection to pool (mark as potentially bad if error occurred)
|
|
isGood := err == nil && len(buf) > 0
|
|
p.putConn(conn, isGood)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// io.ReadAll hides io.EOF error, so check for empty buffer
|
|
if len(buf) == 0 {
|
|
return nil, io.EOF
|
|
}
|
|
|
|
// Unpack DNS response (skip 2-byte length prefix)
|
|
answer := new(dns.Msg)
|
|
if err := answer.Unpack(buf[2:]); err != nil {
|
|
return nil, err
|
|
}
|
|
answer.SetReply(msg)
|
|
return answer, nil
|
|
}
|
|
|
|
// getConn gets a QUIC connection from the pool or creates a new one.
|
|
// A connection is taken from the channel while in use; putConn returns it.
|
|
func (p *doqConnPool) getConn(ctx context.Context) (*quic.Conn, error) {
|
|
for {
|
|
select {
|
|
case dc := <-p.conns:
|
|
if dc.conn != nil && dc.conn.Context().Err() == nil {
|
|
return dc.conn, nil
|
|
}
|
|
if dc.conn != nil {
|
|
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
|
|
}
|
|
default:
|
|
_, conn, err := p.dialConn(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// putConn returns a connection to the pool for reuse by other goroutines.
|
|
func (p *doqConnPool) putConn(conn *quic.Conn, isGood bool) {
|
|
if !isGood || conn == nil || conn.Context().Err() != nil {
|
|
if conn != nil {
|
|
conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
|
|
}
|
|
return
|
|
}
|
|
dc := &doqConn{conn: conn}
|
|
select {
|
|
case p.conns <- dc:
|
|
default:
|
|
// Channel full, close the connection
|
|
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
|
|
}
|
|
}
|
|
|
|
// dialConn creates a new QUIC connection using parallel dialing like DoH3.
|
|
func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) {
|
|
logger := ProxyLogger.Load()
|
|
|
|
// If we have a bootstrap IP, use it directly
|
|
if p.uc.BootstrapIP != "" {
|
|
addr := net.JoinHostPort(p.uc.BootstrapIP, p.port)
|
|
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr)
|
|
udpConn, err := net.ListenUDP("udp", nil)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
udpConn.Close()
|
|
return "", nil, err
|
|
}
|
|
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, p.quicConfig)
|
|
if err != nil {
|
|
udpConn.Close()
|
|
return "", nil, err
|
|
}
|
|
return addr, conn, nil
|
|
}
|
|
|
|
// Use parallel dialing like DoH3
|
|
dialAddrs := make([]string, len(p.addrs))
|
|
for i := range p.addrs {
|
|
dialAddrs[i] = net.JoinHostPort(p.addrs[i], p.port)
|
|
}
|
|
|
|
pd := &quicParallelDialer{}
|
|
conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, p.quicConfig)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
addr := conn.RemoteAddr().String()
|
|
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", addr)
|
|
return addr, conn, nil
|
|
}
|
|
|
|
// CloseIdleConnections closes all connections in the pool.
|
|
// Connections currently checked out (in use) are not closed.
|
|
func (p *doqConnPool) CloseIdleConnections() {
|
|
for {
|
|
select {
|
|
case dc := <-p.conns:
|
|
if dc.conn != nil {
|
|
dc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
|
|
}
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
}
|