Files
ctrld/cmd/ctrld_library/netstack/dns_bridge.go
2026-03-19 00:24:35 -04:00

229 lines
4.7 KiB
Go

package netstack
import (
"fmt"
"sync"
"time"
"github.com/miekg/dns"
)
// DNSBridge provides a bridge between the netstack DNS filter and the existing ctrld DNS proxy.
// It allows DNS queries captured from packets to be processed by the same logic as traditional DNS queries.
type DNSBridge struct {
// Channel for sending DNS queries
queryCh chan *DNSQuery
// Channel for receiving DNS responses
responseCh chan *DNSResponse
// Map to track pending queries by transaction ID
pendingQueries map[uint16]*PendingQuery
mu sync.RWMutex
// Timeout for DNS queries
queryTimeout time.Duration
// Running state
running bool
stopCh chan struct{}
wg sync.WaitGroup
}
// DNSQuery represents a DNS query to be processed
type DNSQuery struct {
ID uint16 // Transaction ID for matching response
Query []byte // Raw DNS query bytes
RespCh chan []byte // Response channel
SrcIP string // Source IP for logging
SrcPort uint16 // Source port
}
// DNSResponse represents a DNS response
type DNSResponse struct {
ID uint16
Response []byte
}
// PendingQuery tracks a query waiting for response
type PendingQuery struct {
Query *DNSQuery
Timestamp time.Time
}
// NewDNSBridge creates a new DNS bridge
func NewDNSBridge() *DNSBridge {
return &DNSBridge{
queryCh: make(chan *DNSQuery, 100),
responseCh: make(chan *DNSResponse, 100),
pendingQueries: make(map[uint16]*PendingQuery),
queryTimeout: 5 * time.Second,
stopCh: make(chan struct{}),
}
}
// Start starts the DNS bridge
func (b *DNSBridge) Start() {
b.mu.Lock()
if b.running {
b.mu.Unlock()
return
}
b.running = true
b.mu.Unlock()
// Start response handler
b.wg.Add(1)
go b.handleResponses()
// Start timeout checker
b.wg.Add(1)
go b.checkTimeouts()
}
// Stop stops the DNS bridge
func (b *DNSBridge) Stop() {
b.mu.Lock()
if !b.running {
b.mu.Unlock()
return
}
b.running = false
b.mu.Unlock()
close(b.stopCh)
b.wg.Wait()
// Clean up pending queries
b.mu.Lock()
for _, pending := range b.pendingQueries {
close(pending.Query.RespCh)
}
b.pendingQueries = make(map[uint16]*PendingQuery)
b.mu.Unlock()
}
// ProcessQuery processes a DNS query and waits for response
func (b *DNSBridge) ProcessQuery(query []byte, srcIP string, srcPort uint16) ([]byte, error) {
if len(query) < 12 {
return nil, fmt.Errorf("invalid DNS query: too short")
}
// Parse DNS message to get transaction ID
msg := new(dns.Msg)
if err := msg.Unpack(query); err != nil {
return nil, fmt.Errorf("failed to parse DNS query: %v", err)
}
// Create response channel
respCh := make(chan []byte, 1)
// Create query
dnsQuery := &DNSQuery{
ID: msg.Id,
Query: query,
RespCh: respCh,
SrcIP: srcIP,
SrcPort: srcPort,
}
// Store as pending
b.mu.Lock()
b.pendingQueries[msg.Id] = &PendingQuery{
Query: dnsQuery,
Timestamp: time.Now(),
}
b.mu.Unlock()
// Send query
select {
case b.queryCh <- dnsQuery:
case <-time.After(time.Second):
b.mu.Lock()
delete(b.pendingQueries, msg.Id)
b.mu.Unlock()
return nil, fmt.Errorf("query channel full")
}
// Wait for response with timeout
select {
case response := <-respCh:
b.mu.Lock()
delete(b.pendingQueries, msg.Id)
b.mu.Unlock()
return response, nil
case <-time.After(b.queryTimeout):
b.mu.Lock()
delete(b.pendingQueries, msg.Id)
b.mu.Unlock()
return nil, fmt.Errorf("DNS query timeout")
}
}
// GetQueryChannel returns the channel for receiving DNS queries
func (b *DNSBridge) GetQueryChannel() <-chan *DNSQuery {
return b.queryCh
}
// SendResponse sends a DNS response back to the waiting query
func (b *DNSBridge) SendResponse(id uint16, response []byte) error {
b.mu.RLock()
pending, exists := b.pendingQueries[id]
b.mu.RUnlock()
if !exists {
return fmt.Errorf("no pending query for ID %d", id)
}
select {
case pending.Query.RespCh <- response:
return nil
case <-time.After(time.Second):
return fmt.Errorf("failed to send response: channel blocked")
}
}
// handleResponses handles incoming responses
func (b *DNSBridge) handleResponses() {
defer b.wg.Done()
for {
select {
case <-b.stopCh:
return
case resp := <-b.responseCh:
if err := b.SendResponse(resp.ID, resp.Response); err != nil {
// Log error but continue
}
}
}
}
// checkTimeouts periodically checks for and removes timed out queries
func (b *DNSBridge) checkTimeouts() {
defer b.wg.Done()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-b.stopCh:
return
case <-ticker.C:
now := time.Now()
b.mu.Lock()
for id, pending := range b.pendingQueries {
if now.Sub(pending.Timestamp) > b.queryTimeout {
close(pending.Query.RespCh)
delete(b.pendingQueries, id)
}
}
b.mu.Unlock()
}
}
}