package ctrld import ( "context" "crypto/tls" "errors" "io" "net" "runtime" "sync" "time" "github.com/miekg/dns" ) type dotResolver struct { uc *UpstreamConfig } func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { if err := validateMsg(msg); err != nil { return nil, err } logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "DoT resolver query started") dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } pool := r.uc.dotTransport(ctx, dnsTyp) if pool == nil { Log(ctx, logger.Error(), "DoT client pool is not available") return nil, errors.New("DoT client pool is not available") } answer, err := pool.Resolve(ctx, msg) if err != nil { Log(ctx, logger.Error().Err(err), "DoT request failed") } else { Log(ctx, logger.Debug(), "DoT resolver query successful") } return answer, err } // dotConnPool manages a pool of TCP/TLS connections for DoT queries. type dotConnPool struct { uc *UpstreamConfig addrs []string port string tlsConfig *tls.Config dialer *net.Dialer mu sync.RWMutex conns map[string]*dotConn closed bool } type dotConn struct { conn net.Conn lastUsed time.Time refCount int mu sync.Mutex } func newDOTClientPool(_ context.Context, uc *UpstreamConfig, addrs []string) *dotConnPool { _, port, _ := net.SplitHostPort(uc.Endpoint) if port == "" { port = "853" } // The dialer is used to prevent bootstrapping cycle. // If endpoint is set to dns.controld.dev, we need to resolve // dns.controld.dev first. By using a dialer with custom resolver, // we ensure that we can always resolve the bootstrap domain // regardless of the machine DNS status. dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) tlsConfig := &tls.Config{ RootCAs: uc.certPool, } if uc.BootstrapIP != "" { tlsConfig.ServerName = uc.Domain } pool := &dotConnPool{ uc: uc, addrs: addrs, port: port, tlsConfig: tlsConfig, dialer: dialer, conns: make(map[string]*dotConn), } // Use SetFinalizer here because we need to call a method on the pool itself. // AddCleanup would require passing the pool as arg (which panics) or capturing // it in a closure (which prevents GC). SetFinalizer is appropriate for this case. runtime.SetFinalizer(pool, func(p *dotConnPool) { p.CloseIdleConnections() }) return pool } // Resolve performs a DNS query using a pooled TCP/TLS connection. func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { if msg == nil { return nil, errors.New("nil DNS message") } conn, addr, err := p.getConn(ctx) if err != nil { return nil, wrapCertificateVerificationError(err) } // Set deadline deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(5 * time.Second) } _ = conn.SetDeadline(deadline) client := dns.Client{Net: "tcp-tls"} answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) isGood := err == nil p.putConn(addr, conn, isGood) if err != nil { return nil, wrapCertificateVerificationError(err) } return answer, nil } // getConn gets a TCP/TLS connection from the pool or creates a new one. func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { p.mu.Lock() defer p.mu.Unlock() if p.closed { return nil, "", io.EOF } // Try to reuse an existing connection for addr, dotConn := range p.conns { dotConn.mu.Lock() if dotConn.refCount == 0 && dotConn.conn != nil { dotConn.refCount++ dotConn.lastUsed = time.Now() conn := dotConn.conn dotConn.mu.Unlock() return conn, addr, nil } dotConn.mu.Unlock() } // No available connection, create a new one addr, conn, err := p.dialConn(ctx) if err != nil { return nil, "", err } dotConn := &dotConn{ conn: conn, lastUsed: time.Now(), refCount: 1, } p.conns[addr] = dotConn return conn, addr, nil } // putConn returns a connection to the pool. func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { p.mu.Lock() defer p.mu.Unlock() dotConn, ok := p.conns[addr] if !ok { return } dotConn.mu.Lock() defer dotConn.mu.Unlock() dotConn.refCount-- if dotConn.refCount < 0 { dotConn.refCount = 0 } // If connection is bad, remove it from pool if !isGood { delete(p.conns, addr) if conn != nil { conn.Close() } return } dotConn.lastUsed = time.Now() } // dialConn creates a new TCP/TLS connection. func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) { logger := LoggerFromCtx(ctx) var endpoint string if p.uc.BootstrapIP != "" { endpoint = net.JoinHostPort(p.uc.BootstrapIP, p.port) Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) if err != nil { return "", nil, err } tlsConn := tls.Client(conn, p.tlsConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { conn.Close() return "", nil, err } return endpoint, tlsConn, nil } // Try bootstrap IPs in parallel if len(p.addrs) > 0 { type result struct { conn net.Conn addr string err error } ch := make(chan result, len(p.addrs)) done := make(chan struct{}) defer close(done) for _, addr := range p.addrs { go func(addr string) { endpoint := net.JoinHostPort(addr, p.port) conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) if err != nil { select { case ch <- result{conn: nil, addr: endpoint, err: err}: case <-done: } return } tlsConfig := p.tlsConfig.Clone() tlsConfig.ServerName = p.uc.Domain tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { conn.Close() select { case ch <- result{conn: nil, addr: endpoint, err: err}: case <-done: } return } select { case ch <- result{conn: tlsConn, addr: endpoint, err: nil}: case <-done: if conn != nil { conn.Close() } } }(addr) } errs := make([]error, 0, len(p.addrs)) for range len(p.addrs) { select { case res := <-ch: if res.err == nil && res.conn != nil { Log(ctx, logger.Debug(), "Sending DoT request to: %s", res.addr) return res.addr, res.conn, nil } if res.err != nil { errs = append(errs, res.err) } case <-ctx.Done(): return "", nil, ctx.Err() } } return "", nil, errors.Join(errs...) } // Fallback to endpoint resolution endpoint = p.uc.Endpoint Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) conn, err := p.dialer.DialContext(ctx, "tcp", endpoint) if err != nil { return "", nil, err } tlsConn := tls.Client(conn, p.tlsConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { conn.Close() return "", nil, err } return endpoint, tlsConn, nil } // CloseIdleConnections closes all connections in the pool. func (p *dotConnPool) CloseIdleConnections() { p.mu.Lock() defer p.mu.Unlock() if p.closed { return } p.closed = true for addr, dotConn := range p.conns { dotConn.mu.Lock() if dotConn.conn != nil { dotConn.conn.Close() } dotConn.mu.Unlock() delete(p.conns, addr) } }