diff --git a/config.go b/config.go index ad55dad..bdfa389 100644 --- a/config.go +++ b/config.go @@ -282,6 +282,9 @@ type UpstreamConfig struct { doqConnPool *doqConnPool doqConnPool4 *doqConnPool doqConnPool6 *doqConnPool + dotClientPool *dotConnPool + dotClientPool4 *dotConnPool + dotClientPool6 *dotConnPool certPool *x509.CertPool u *url.URL fallbackOnce sync.Once @@ -496,7 +499,7 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { // ReBootstrap re-setup the bootstrap IP and the transport. func (uc *UpstreamConfig) ReBootstrap() { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT: default: return } @@ -508,11 +511,11 @@ func (uc *UpstreamConfig) ReBootstrap() { }) } -// SetupTransport initializes the network transport used to connect to upstream server. -// For now, only DoH upstream is supported. +// SetupTransport initializes the network transport used to connect to upstream servers. +// For now, DoH/DoH3/DoQ/DoT upstreams are supported. func (uc *UpstreamConfig) SetupTransport() { switch uc.Type { - case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ: + case ResolverTypeDOH, ResolverTypeDOH3, ResolverTypeDOQ, ResolverTypeDOT: default: return } @@ -523,21 +526,26 @@ func (uc *UpstreamConfig) SetupTransport() { case IpStackV6: ips = uc.bootstrapIPs6 } + uc.transport = uc.newDOHTransport(ips) uc.http3RoundTripper = uc.newDOH3Transport(ips) uc.doqConnPool = uc.newDOQConnPool(ips) + uc.dotClientPool = uc.newDOTClientPool(ips) if uc.IPStack == IpStackSplit { uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) uc.doqConnPool4 = uc.newDOQConnPool(uc.bootstrapIPs4) + uc.dotClientPool4 = uc.newDOTClientPool(uc.bootstrapIPs4) if HasIPv6() { uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) uc.doqConnPool6 = uc.newDOQConnPool(uc.bootstrapIPs6) + uc.dotClientPool6 = uc.newDOTClientPool(uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 uc.http3RoundTripper6 = uc.http3RoundTripper4 uc.doqConnPool6 = uc.doqConnPool4 + uc.dotClientPool6 = uc.dotClientPool4 } } } @@ -656,6 +664,10 @@ func (uc *UpstreamConfig) ping() error { // For DoQ, we just ensure transport is set up by calling doqTransport // DoQ doesn't use HTTP, so we can't ping it the same way _ = uc.doqTransport(typ) + case ResolverTypeDOT: + // For DoT, we just ensure transport is set up by calling dotTransport + // DoT doesn't use HTTP, so we can't ping it the same way + _ = uc.dotTransport(typ) } } diff --git a/config_quic.go b/config_quic.go index f6192d5..237bb82 100644 --- a/config_quic.go +++ b/config_quic.go @@ -13,25 +13,6 @@ import ( "github.com/quic-go/quic-go/http3" ) -func (uc *UpstreamConfig) setupDOH3Transport() { - switch uc.IPStack { - case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) - case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) - case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) - case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - if HasIPv6() { - uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) - } else { - uc.http3RoundTripper6 = uc.http3RoundTripper4 - } - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) - } -} - func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { if uc.Type != ResolverTypeDOH3 { return nil @@ -82,6 +63,11 @@ func (uc *UpstreamConfig) doqTransport(dnsType uint16) *doqConnPool { return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6) } +func (uc *UpstreamConfig) dotTransport(dnsType uint16) *dotConnPool { + uc.ensureSetupTransport() + return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6) +} + // Putting the code for quic parallel dialer here: // // - quic dialer is different with net.Dialer @@ -156,3 +142,10 @@ func (uc *UpstreamConfig) newDOQConnPool(addrs []string) *doqConnPool { } return newDOQConnPool(uc, addrs) } + +func (uc *UpstreamConfig) newDOTClientPool(addrs []string) *dotConnPool { + if uc.Type != ResolverTypeDOT { + return nil + } + return newDOTClientPool(uc, addrs) +} diff --git a/doh.go b/doh.go index 58aaf16..6b41c11 100644 --- a/doh.go +++ b/doh.go @@ -85,6 +85,10 @@ type dohResolver struct { // Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } + data, err := msg.Pack() if err != nil { return nil, err diff --git a/doq.go b/doq.go index 2b74f83..8d8a4e8 100644 --- a/doq.go +++ b/doq.go @@ -21,6 +21,10 @@ type doqResolver struct { } func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } + // Get the appropriate connection pool based on DNS type and IP stack dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { diff --git a/dot.go b/dot.go index 295134c..fe65089 100644 --- a/dot.go +++ b/dot.go @@ -3,7 +3,12 @@ package ctrld import ( "context" "crypto/tls" + "errors" + "io" "net" + "runtime" + "sync" + "time" "github.com/miekg/dns" ) @@ -13,30 +18,292 @@ type dotResolver struct { } func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - // The dialer is used to prevent bootstrapping cycle. - // If r.endpoint is set to dns.controld.dev, we need to resolve - // dns.controld.dev first. By using a dialer with custom resolver, - // we ensure that we can always resolve the bootstrap domain - // regardless of the machine DNS status. - dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) + if err := validateMsg(msg); err != nil { + return nil, err + } + dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) - dnsClient := &dns.Client{ - Net: tcpNet, - Dialer: dialer, - TLSConfig: &tls.Config{RootCAs: r.uc.certPool}, - } - endpoint := r.uc.Endpoint - if r.uc.BootstrapIP != "" { - dnsClient.TLSConfig.ServerName = r.uc.Domain - dnsClient.Net = "tcp-tls" - _, port, _ := net.SplitHostPort(endpoint) - endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) + + pool := r.uc.dotTransport(dnsTyp) + if pool == nil { + return nil, errors.New("DoT client pool is not available") } - answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) - return answer, wrapCertificateVerificationError(err) + return pool.Resolve(ctx, msg) +} + +// dotConnPool manages a pool of TCP/TLS connections for DoT queries. +type dotConnPool struct { + uc *UpstreamConfig + addrs []string + port string + tlsConfig *tls.Config + dialer *net.Dialer + mu sync.RWMutex + conns map[string]*dotConn + closed bool +} + +type dotConn struct { + conn net.Conn + lastUsed time.Time + refCount int + mu sync.Mutex +} + +func newDOTClientPool(uc *UpstreamConfig, addrs []string) *dotConnPool { + _, port, _ := net.SplitHostPort(uc.Endpoint) + if port == "" { + port = "853" + } + + // The dialer is used to prevent bootstrapping cycle. + // If endpoint is set to dns.controld.dev, we need to resolve + // dns.controld.dev first. By using a dialer with custom resolver, + // we ensure that we can always resolve the bootstrap domain + // regardless of the machine DNS status. + dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) + + tlsConfig := &tls.Config{ + RootCAs: uc.certPool, + } + + if uc.BootstrapIP != "" { + tlsConfig.ServerName = uc.Domain + } + + pool := &dotConnPool{ + uc: uc, + addrs: addrs, + port: port, + tlsConfig: tlsConfig, + dialer: dialer, + conns: make(map[string]*dotConn), + } + + // Use SetFinalizer here because we need to call a method on the pool itself. + // AddCleanup would require passing the pool as arg (which panics) or capturing + // it in a closure (which prevents GC). SetFinalizer is appropriate for this case. + runtime.SetFinalizer(pool, func(p *dotConnPool) { + p.CloseIdleConnections() + }) + + return pool +} + +// Resolve performs a DNS query using a pooled TCP/TLS connection. +func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if msg == nil { + return nil, errors.New("nil DNS message") + } + + conn, addr, err := p.getConn(ctx) + if err != nil { + return nil, wrapCertificateVerificationError(err) + } + + // Set deadline + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(5 * time.Second) + } + _ = conn.SetDeadline(deadline) + + client := dns.Client{Net: "tcp-tls"} + answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) + isGood := err == nil + p.putConn(addr, conn, isGood) + + if err != nil { + return nil, wrapCertificateVerificationError(err) + } + + return answer, nil +} + +// getConn gets a TCP/TLS connection from the pool or creates a new one. +func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return nil, "", io.EOF + } + + // Try to reuse an existing connection + for addr, dotConn := range p.conns { + dotConn.mu.Lock() + if dotConn.refCount == 0 && dotConn.conn != nil { + dotConn.refCount++ + dotConn.lastUsed = time.Now() + conn := dotConn.conn + dotConn.mu.Unlock() + return conn, addr, nil + } + dotConn.mu.Unlock() + } + + // No available connection, create a new one + addr, conn, err := p.dialConn(ctx) + if err != nil { + return nil, "", err + } + + dotConn := &dotConn{ + conn: conn, + lastUsed: time.Now(), + refCount: 1, + } + p.conns[addr] = dotConn + + return conn, addr, nil +} + +// putConn returns a connection to the pool. +func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { + p.mu.Lock() + defer p.mu.Unlock() + + dotConn, ok := p.conns[addr] + if !ok { + return + } + + dotConn.mu.Lock() + defer dotConn.mu.Unlock() + + dotConn.refCount-- + if dotConn.refCount < 0 { + dotConn.refCount = 0 + } + + // If connection is bad, remove it from pool + if !isGood { + delete(p.conns, addr) + if conn != nil { + conn.Close() + } + return + } + + dotConn.lastUsed = time.Now() +} + +// dialConn creates a new TCP/TLS connection. +func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) { + logger := ProxyLogger.Load() + var endpoint string + + if p.uc.BootstrapIP != "" { + endpoint = net.JoinHostPort(p.uc.BootstrapIP, p.port) + Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + return "", nil, err + } + tlsConn := tls.Client(conn, p.tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return "", nil, err + } + return endpoint, tlsConn, nil + } + + // Try bootstrap IPs in parallel + if len(p.addrs) > 0 { + type result struct { + conn net.Conn + addr string + err error + } + + ch := make(chan result, len(p.addrs)) + done := make(chan struct{}) + defer close(done) + + for _, addr := range p.addrs { + go func(addr string) { + endpoint := net.JoinHostPort(addr, p.port) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + select { + case ch <- result{conn: nil, addr: endpoint, err: err}: + case <-done: + } + return + } + tlsConfig := p.tlsConfig.Clone() + tlsConfig.ServerName = p.uc.Domain + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + select { + case ch <- result{conn: nil, addr: endpoint, err: err}: + case <-done: + } + return + } + select { + case ch <- result{conn: tlsConn, addr: endpoint, err: nil}: + case <-done: + if conn != nil { + conn.Close() + } + } + }(addr) + } + + errs := make([]error, 0, len(p.addrs)) + for range len(p.addrs) { + select { + case res := <-ch: + if res.err == nil && res.conn != nil { + Log(ctx, logger.Debug(), "Sending DoT request to: %s", res.addr) + return res.addr, res.conn, nil + } + if res.err != nil { + errs = append(errs, res.err) + } + case <-ctx.Done(): + return "", nil, ctx.Err() + } + } + + return "", nil, errors.Join(errs...) + } + + // Fallback to endpoint resolution + endpoint = p.uc.Endpoint + Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) + conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) + if err != nil { + return "", nil, err + } + tlsConn := tls.Client(conn, p.tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return "", nil, err + } + return endpoint, tlsConn, nil +} + +// CloseIdleConnections closes all connections in the pool. +func (p *dotConnPool) CloseIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return + } + p.closed = true + for addr, dotConn := range p.conns { + dotConn.mu.Lock() + if dotConn.conn != nil { + dotConn.conn.Close() + } + dotConn.mu.Unlock() + delete(p.conns, addr) + } } diff --git a/resolver.go b/resolver.go index 3aeddd0..914233d 100644 --- a/resolver.go +++ b/resolver.go @@ -291,6 +291,9 @@ const hotCacheTTL = time.Second // for a short period (currently 1 second), reducing unnecessary traffics // sent to upstreams. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } if len(msg.Question) == 0 { return nil, errors.New("no question found") } @@ -509,6 +512,10 @@ type legacyResolver struct { } func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } + // See comment in (*dotResolver).resolve method. dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) dnsTyp := uint16(0) @@ -534,6 +541,9 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e type dummyResolver struct{} func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if err := validateMsg(msg); err != nil { + return nil, err + } ans := new(dns.Msg) ans.SetReply(msg) return ans, nil @@ -769,3 +779,13 @@ func isLanAddr(addr netip.Addr) bool { addr.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(addr) } + +func validateMsg(msg *dns.Msg) error { + if msg == nil { + return errors.New("nil DNS message") + } + if len(msg.Question) == 0 { + return errors.New("no question found") + } + return nil +}