diff --git a/config.go b/config.go index c7ad161..5c95fce 100644 --- a/config.go +++ b/config.go @@ -288,6 +288,9 @@ type UpstreamConfig struct { doqConnPool *doqConnPool doqConnPool4 *doqConnPool doqConnPool6 *doqConnPool + dotClientPool *dotConnPool + dotClientPool4 *dotConnPool + dotClientPool6 *dotConnPool certPool *x509.CertPool u *url.URL fallbackOnce sync.Once @@ -510,7 +513,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { // ReBootstrap re-setup the bootstrap IP and the transport. func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT: default: return } @@ -524,10 +527,10 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { } // SetupTransport initializes the network transport used to connect to upstream servers. -// For now, DoH/DoH3/DoQ upstreams are supported. +// For now, DoH/DoH3/DoQ/DoT upstreams are supported. func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT: default: return } @@ -541,18 +544,22 @@ func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { uc.transport = uc.newDOHTransport(ctx, ips) uc.http3RoundTripper = uc.newDOH3Transport(ctx, ips) uc.doqConnPool = uc.newDOQConnPool(ctx, ips) + uc.dotClientPool = uc.newDOTClientPool(ctx, ips) if uc.IPStack == IpStackSplit { uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) uc.doqConnPool4 = uc.newDOQConnPool(ctx, uc.bootstrapIPs4) + uc.dotClientPool4 = uc.newDOTClientPool(ctx, uc.bootstrapIPs4) if HasIPv6(ctx) { uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) uc.doqConnPool6 = uc.newDOQConnPool(ctx, uc.bootstrapIPs6) + uc.dotClientPool6 = uc.newDOTClientPool(ctx, uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 uc.http3RoundTripper6 = uc.http3RoundTripper4 uc.doqConnPool6 = uc.doqConnPool4 + uc.dotClientPool6 = uc.dotClientPool4 } } } @@ -674,6 +681,10 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error { // For DoQ, we just ensure transport is set up by calling doqTransport // DoQ doesn't use HTTP, so we can't ping it the same way _ = uc.doqTransport(ctx, typ) + case ResolverTypeDOT: + // For DoT, we just ensure transport is set up by calling dotTransport + // DoT doesn't use HTTP, so we can't ping it the same way + _ = uc.dotTransport(ctx, typ) } } diff --git a/config_quic.go b/config_quic.go index df9f22b..f2469a3 100644 --- a/config_quic.go +++ b/config_quic.go @@ -64,6 +64,11 @@ func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doq return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6) } +func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool { + uc.ensureSetupTransport(ctx) + return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6) +} + // Putting the code for quic parallel dialer here: // // - quic dialer is different with net.Dialer @@ -138,3 +143,10 @@ func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *d } return newDOQConnPool(ctx, uc, addrs) } + +func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool { + if uc.Type != ResolverTypeDOT { + return nil + } + return newDOTClientPool(ctx, uc, addrs) +} diff --git a/doh.go b/doh.go index 9e944dd..f5ec7e1 100644 --- a/doh.go +++ b/doh.go @@ -88,6 +88,9 @@ type dohResolver struct { // Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoH resolver query started") diff --git a/doq.go b/doq.go index 6556eb3..c9202a3 100644 --- a/doq.go +++ b/doq.go @@ -21,6 +21,9 @@ type doqResolver struct { } func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoQ resolver query started") diff --git a/dot.go b/dot.go index 96fa651..74f5ece 100644 --- a/dot.go +++ b/dot.go @@ -3,7 +3,12 @@ package ctrld import ( "context" "crypto/tls" + "errors" + "io" "net" + "runtime" + "sync" + "time" "github.com/miekg/dns" ) @@ -13,39 +18,301 @@ type dotResolver struct { } func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoT resolver query started") - // The dialer is used to prevent bootstrapping cycle. - // If r.endpoint is set to dns.controld.dev, we need to resolve - // dns.controld.dev first. By using a dialer with custom resolver, - // we ensure that we can always resolve the bootstrap domain - // regardless of the machine DNS status. - dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp) - dnsClient := &dns.Client{ - Net: tcpNet, - Dialer: dialer, - TLSConfig: &tls.Config{RootCAs: r.uc.certPool}, - } - endpoint := r.uc.Endpoint - if r.uc.BootstrapIP != "" { - dnsClient.TLSConfig.ServerName = r.uc.Domain - dnsClient.Net = "tcp-tls" - _, port, _ := net.SplitHostPort(endpoint) - endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) + + pool := r.uc.dotTransport(ctx, dnsTyp) + if pool == nil { + Log(ctx, logger.Error(), "DoT client pool is not available") + return nil, errors.New("DoT client pool is not available") } - Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) - answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) + answer, err := pool.Resolve(ctx, msg) if err != nil { Log(ctx, logger.Error().Err(err), "DoT request failed") } else { Log(ctx, logger.Debug(), "DoT resolver query successful") } - return answer, wrapCertificateVerificationError(err) + return answer, err +} + +// dotConnPool manages a pool of TCP/TLS connections for DoT queries. +type dotConnPool struct { + uc *UpstreamConfig + addrs []string + port string + tlsConfig *tls.Config + dialer *net.Dialer + mu sync.RWMutex + conns map[string]*dotConn + closed bool +} + +type dotConn struct { + conn net.Conn + lastUsed time.Time + refCount int + mu sync.Mutex +} + +func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool { + _, port, _ := net.SplitHostPort(uc.Endpoint) + if port == "" { + port = "853" + } + + // The dialer is used to prevent bootstrapping cycle. + // If endpoint is set to dns.controld.dev, we need to resolve + // dns.controld.dev first. By using a dialer with custom resolver, + // we ensure that we can always resolve the bootstrap domain + // regardless of the machine DNS status. + dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) + + tlsConfig := &tls.Config{ + RootCAs: uc.certPool, + } + + if uc.BootstrapIP != "" { + tlsConfig.ServerName = uc.Domain + } + + pool := &dotConnPool{ + uc: uc, + addrs: addrs, + port: port, + tlsConfig: tlsConfig, + dialer: dialer, + conns: make(map[string]*dotConn), + } + + // Use SetFinalizer here because we need to call a method on the pool itself. + // AddCleanup would require passing the pool as arg (which panics) or capturing + // it in a closure (which prevents GC). SetFinalizer is appropriate for this case. + runtime.SetFinalizer(pool, func(p *dotConnPool) { + p.CloseIdleConnections() + }) + + return pool +} + +// Resolve performs a DNS query using a pooled TCP/TLS connection. +func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if msg == nil { + return nil, errors.New("nil DNS message") + } + + conn, addr, err := p.getConn(ctx) + if err != nil { + return nil, wrapCertificateVerificationError(err) + } + + // Set deadline + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(5 * time.Second) + } + _ = conn.SetDeadline(deadline) + + client := dns.Client{Net: "tcp-tls"} + answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) + isGood := err == nil + p.putConn(addr, conn, isGood) + + if err != nil { + return nil, wrapCertificateVerificationError(err) + } + + return answer, nil +} + +// getConn gets a TCP/TLS connection from the pool or creates a new one. +func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, "", io.EOF + } + + // Try to reuse an existing connection + for addr, dotConn := range p.conns { + dotConn.mu.Lock() + if dotConn.refCount == 0 && dotConn.conn != nil { + dotConn.refCount++ + dotConn.lastUsed = time.Now() + conn := dotConn.conn + dotConn.mu.Unlock() + return conn, addr, nil + } + dotConn.mu.Unlock() + } + + // No available connection, create a new one + addr, conn, err := p.dialConn(ctx) + if err != nil { + return nil, "", err + } + + dotConn := &dotConn{ + conn: conn, + lastUsed: time.Now(), + refCount: 1, + } + p.conns[addr] = dotConn + + return conn, addr, nil +} + +// putConn returns a connection to the pool. +func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { + p.mu.Lock() + defer p.mu.Unlock() + + dotConn, ok := p.conns[addr] + if !ok { + return + } + + dotConn.mu.Lock() + defer dotConn.mu.Unlock() + + dotConn.refCount-- + if dotConn.refCount < 0 { + dotConn.refCount = 0 + } + + // If connection is bad, remove it from pool + if !isGood { + delete(p.conns, addr) + if conn != nil { + conn.Close() + } + return + } + + dotConn.lastUsed = time.Now() +} + +// dialConn creates a new TCP/TLS connection. +func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) { + logger := LoggerFromCtx(ctx) + var endpoint string + + if p.uc.BootstrapIP != "" { + endpoint = net.JoinHostPort(p.uc.BootstrapIP, p.port) + Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + return "", nil, err + } + tlsConn := tls.Client(conn, p.tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return "", nil, err + } + return endpoint, tlsConn, nil + } + + // Try bootstrap IPs in parallel + if len(p.addrs) > 0 { + type result struct { + conn net.Conn + addr string + err error + } + + ch := make(chan result, len(p.addrs)) + done := make(chan struct{}) + defer close(done) + + for _, addr := range p.addrs { + go func(addr string) { + endpoint := net.JoinHostPort(addr, p.port) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + select { + case ch <- result{conn: nil, addr: endpoint, err: err}: + case <-done: + } + return + } + tlsConfig := p.tlsConfig.Clone() + tlsConfig.ServerName = p.uc.Domain + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + select { + case ch <- result{conn: nil, addr: endpoint, err: err}: + case <-done: + } + return + } + select { + case ch <- result{conn: tlsConn, addr: endpoint, err: nil}: + case <-done: + if conn != nil { + conn.Close() + } + } + }(addr) + } + + errs := make([]error, 0, len(p.addrs)) + for range len(p.addrs) { + select { + case res := <-ch: + if res.err == nil && res.conn != nil { + Log(ctx, logger.Debug(), "Sending DoT request to: %s", res.addr) + return res.addr, res.conn, nil + } + if res.err != nil { + errs = append(errs, res.err) + } + case <-ctx.Done(): + return "", nil, ctx.Err() + } + } + + return "", nil, errors.Join(errs...) + } + + // Fallback to endpoint resolution + endpoint = p.uc.Endpoint + Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + return "", nil, err + } + tlsConn := tls.Client(conn, p.tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return "", nil, err + } + return endpoint, tlsConn, nil +} + +// CloseIdleConnections closes all connections in the pool. +func (p *dotConnPool) CloseIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return + } + p.closed = true + for addr, dotConn := range p.conns { + dotConn.mu.Lock() + if dotConn.conn != nil { + dotConn.conn.Close() + } + dotConn.mu.Unlock() + delete(p.conns, addr) + } } diff --git a/resolver.go b/resolver.go index 878663d..19ca67b 100644 --- a/resolver.go +++ b/resolver.go @@ -267,6 +267,9 @@ const hotCacheTTL = time.Second // for a short period (currently 1 second), reducing unnecessary traffics // sent to upstreams. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } if len(msg.Question) == 0 { return nil, errors.New("no question found") } @@ -492,6 +495,9 @@ type legacyResolver struct { } func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "Legacy resolver query started") @@ -526,6 +532,9 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e type dummyResolver struct{} func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } ans := new(dns.Msg) ans.SetReply(msg) return ans, nil @@ -749,3 +758,13 @@ func isLanAddr(addr netip.Addr) bool { addr.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(addr) } + +func validateMsg(msg *dns.Msg) error { + if msg == nil { + return errors.New("nil DNS message") + } + if len(msg.Question) == 0 { + return errors.New("no question found") + } + return nil +}