mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Capitalize the first letter of all log messages throughout the codebase to improve readability and consistency in logging output. Key improvements: - All log messages now start with capital letters - Consistent formatting across all logging statements - Improved readability for debugging and monitoring - Enhanced user experience with better formatted messages Files updated: - CLI commands and service management - Internal client information discovery - Network operations and configuration - DNS resolver and proxy operations - Platform-specific implementations This completes the final phase of the logging improvement project, ensuring all log messages follow consistent capitalization standards for better readability and professional appearance.
1912 lines
64 KiB
Go
1912 lines
64 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os/exec"
|
|
"runtime"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"golang.org/x/sync/errgroup"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/tsaddr"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
// DNS proxy constants for configuration and behavior control
|
|
const (
|
|
// staleTTL is the TTL for stale cache entries
|
|
// This allows serving cached responses even when upstreams are temporarily unavailable
|
|
staleTTL = 60 * time.Second
|
|
|
|
// localTTL is the TTL for local network responses
|
|
// Longer TTL for local queries reduces unnecessary repeated lookups
|
|
localTTL = 3600 * time.Second
|
|
|
|
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
|
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
|
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
|
// This enables MAC address-based client identification for policy routing
|
|
EDNS0_OPTION_MAC = 0xFDE9
|
|
|
|
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
|
|
// This prevents premature self-uninstallation due to temporary network issues
|
|
selfUninstallMaxQueries = 32
|
|
)
|
|
|
|
// osUpstreamConfig defines the default OS resolver configuration
|
|
// This is used as a fallback when all configured upstreams fail
|
|
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "OS resolver",
|
|
Type: ctrld.ResolverTypeOS,
|
|
Timeout: 3000,
|
|
}
|
|
|
|
// privateUpstreamConfig defines the default private resolver configuration
|
|
// This is used for internal network queries that should not go to public resolvers
|
|
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "Private resolver",
|
|
Type: ctrld.ResolverTypePrivate,
|
|
Timeout: 2000,
|
|
}
|
|
|
|
// proxyRequest contains data for proxying a DNS query to upstream.
|
|
// This structure encapsulates all the information needed to process a DNS request
|
|
type proxyRequest struct {
|
|
msg *dns.Msg
|
|
ci *ctrld.ClientInfo
|
|
failoverRcodes []int
|
|
ufr *upstreamForResult
|
|
staleAnswer *dns.Msg
|
|
isLanOrPtrQuery bool
|
|
upstreamConfigs []*ctrld.UpstreamConfig
|
|
}
|
|
|
|
// proxyResponse contains data for proxying a DNS response from upstream.
|
|
// This structure encapsulates the response and metadata for logging and metrics
|
|
type proxyResponse struct {
|
|
answer *dns.Msg
|
|
upstream string
|
|
cached bool
|
|
clientInfo bool
|
|
refused bool
|
|
}
|
|
|
|
// upstreamForResult represents the result of processing rules for a request.
|
|
// This contains the matched policy information for logging and debugging
|
|
type upstreamForResult struct {
|
|
upstreams []string
|
|
matchedPolicy string
|
|
matchedNetwork string
|
|
matchedRule string
|
|
matched bool
|
|
srcAddr string
|
|
}
|
|
|
|
// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring.
|
|
// This is the main entry point for DNS server functionality
|
|
func (p *prog) serveDNS(ctx context.Context, listenerNum string) error {
|
|
logger := p.logger.Load()
|
|
logger.Debug().Msg("DNS server setup started")
|
|
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
|
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: Failed to allocate listen IP")
|
|
return allocErr
|
|
}
|
|
|
|
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
|
p.handleDNSQuery(w, m, listenerNum, listenerConfig)
|
|
})
|
|
|
|
logger.Debug().Msg("DNS server setup completed")
|
|
return p.startListeners(ctx, listenerConfig, handler)
|
|
}
|
|
|
|
// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols.
|
|
// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors.
|
|
// This function manages the lifecycle of DNS server listeners
|
|
func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error {
|
|
logger := p.logger.Load()
|
|
logger.Debug().Msg("Starting DNS listeners")
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
|
|
for _, proto := range []string{"udp", "tcp"} {
|
|
if needLocalIPv6Listener() {
|
|
logger.Debug().Str("protocol", proto).Msg("Starting local IPv6 listener")
|
|
g.Go(func() error {
|
|
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
p.Warn().Err(err).Msg("Local IPv6 listener failed")
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if needRFC1918Listeners(cfg) {
|
|
logger.Debug().Str("protocol", proto).Msg("Starting RFC1918 listeners")
|
|
g.Go(func() error {
|
|
for _, addr := range ctrld.Rfc1918Addresses() {
|
|
func() {
|
|
listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port))
|
|
s, errCh := runDNSServer(listenAddr, proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
p.Warn().Err(err).Msgf("Could not listen on %s: %s", proto, listenAddr)
|
|
}
|
|
}()
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
logger.Debug().Str("protocol", proto).Str("ip", cfg.IP).Int("port", cfg.Port).Msg("Starting main listener")
|
|
g.Go(func() error {
|
|
addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port))
|
|
s, errCh := runDNSServer(addr, proto, handler)
|
|
defer s.Shutdown()
|
|
p.started <- struct{}{}
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
logger.Debug().Msg("DNS listeners started successfully")
|
|
return g.Wait()
|
|
}
|
|
|
|
// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers.
|
|
// This is the main entry point for all DNS query processing
|
|
func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) {
|
|
p.sema.acquire()
|
|
defer p.sema.release()
|
|
|
|
if len(m.Question) == 0 {
|
|
sendDNSResponse(w, m, dns.RcodeFormatError)
|
|
return
|
|
}
|
|
|
|
reqID := requestID()
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID)
|
|
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Processing DNS query from %s", w.RemoteAddr().String())
|
|
|
|
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
|
ctrld.Log(ctx, p.Debug(), "Query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
|
sendDNSResponse(w, m, dns.RcodeRefused)
|
|
return
|
|
}
|
|
|
|
go p.detectLoop(m)
|
|
|
|
q := m.Question[0]
|
|
domain := canonicalName(q.Name)
|
|
|
|
if p.handleSpecialDomains(ctx, w, m, domain) {
|
|
ctrld.Log(ctx, p.Debug(), "Special domain query handled")
|
|
return
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Processing standard query for domain: %s", domain)
|
|
p.processStandardQuery(&standardQueryRequest{
|
|
ctx: ctx,
|
|
writer: w,
|
|
msg: m,
|
|
listenerNum: listenerNum,
|
|
listenerConfig: listenerConfig,
|
|
domain: domain,
|
|
})
|
|
}
|
|
|
|
// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status.
|
|
// This handles internal test domains and cache management commands
|
|
func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool {
|
|
switch {
|
|
case domain == "":
|
|
ctrld.Log(ctx, p.Debug(), "Empty domain query, sending format error")
|
|
sendDNSResponse(w, m, dns.RcodeFormatError)
|
|
return true
|
|
case domain == selfCheckInternalTestDomain:
|
|
ctrld.Log(ctx, p.Debug(), "Internal test domain query: %s", domain)
|
|
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
|
_ = w.WriteMsg(answer)
|
|
return true
|
|
}
|
|
|
|
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
|
p.cache.Purge()
|
|
ctrld.Log(ctx, p.Debug(), "Received query %q, local cache is purged", domain)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// standardQueryRequest represents a standard DNS query request with associated context and configuration.
|
|
// This encapsulates all the data needed to process a standard DNS query
|
|
type standardQueryRequest struct {
|
|
ctx context.Context
|
|
writer dns.ResponseWriter
|
|
msg *dns.Msg
|
|
listenerNum string
|
|
listenerConfig *ctrld.ListenerConfig
|
|
domain string
|
|
}
|
|
|
|
// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response.
|
|
// This is the main processing pipeline for normal DNS queries
|
|
func (p *prog) processStandardQuery(req *standardQueryRequest) {
|
|
ctrld.Log(req.ctx, p.Debug(), "Processing standard query started")
|
|
|
|
remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String())
|
|
ci := p.getClientInfo(remoteIP, req.msg)
|
|
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
|
|
|
stripClientSubnet(req.msg)
|
|
remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci)
|
|
fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String())
|
|
|
|
startTime := time.Now()
|
|
q := req.msg.Question[0]
|
|
ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain)
|
|
|
|
ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain)
|
|
|
|
var answer *dns.Msg
|
|
// Handle restricted listener case
|
|
if !ur.matched && req.listenerConfig.Restricted {
|
|
ctrld.Log(req.ctx, p.Debug(), "Query refused, %s does not match any network policy", remoteAddr.String())
|
|
answer = new(dns.Msg)
|
|
answer.SetRcode(req.msg, dns.RcodeRefused)
|
|
// Process the refused query
|
|
go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true})
|
|
} else {
|
|
// Process a normal query
|
|
ctrld.Log(req.ctx, p.Debug(), "Starting proxy query processing")
|
|
pr := p.proxy(req.ctx, &proxyRequest{
|
|
msg: req.msg,
|
|
ci: ci,
|
|
failoverRcodes: p.getFailoverRcodes(req.listenerConfig),
|
|
ufr: ur,
|
|
})
|
|
|
|
rtt := time.Since(startTime)
|
|
ctrld.Log(req.ctx, p.Debug(), "Received response of %d bytes in %s", pr.answer.Len(), rtt)
|
|
|
|
go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr)
|
|
answer = pr.answer
|
|
}
|
|
|
|
if err := req.writer.WriteMsg(answer); err != nil {
|
|
ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
|
|
}
|
|
|
|
ctrld.Log(req.ctx, p.Debug(), "Standard query processing completed")
|
|
}
|
|
|
|
// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording,
|
|
// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures.
|
|
func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
|
p.doSelfUninstall(pr)
|
|
p.recordMetrics(ci, listenerConfig, q, pr)
|
|
p.forceFetchingAPI(canonicalName(q.Name))
|
|
}
|
|
|
|
// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists.
|
|
func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int {
|
|
if cfg.Policy != nil {
|
|
return cfg.Policy.FailoverRcodeNumbers
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics.
|
|
func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
|
upstream := pr.upstream
|
|
switch {
|
|
case pr.cached:
|
|
upstream = "cache"
|
|
case pr.clientInfo:
|
|
upstream = "client_info_table"
|
|
}
|
|
labelValues := []string{
|
|
net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)),
|
|
ci.IP,
|
|
ci.Mac,
|
|
ci.Hostname,
|
|
upstream,
|
|
dns.TypeToString[q.Qtype],
|
|
dns.RcodeToString[pr.answer.Rcode],
|
|
}
|
|
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
|
|
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
|
|
}
|
|
|
|
// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter.
|
|
func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) {
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(m, rcode)
|
|
_ = w.WriteMsg(answer)
|
|
}
|
|
|
|
// upstreamFor returns the list of upstreams for resolving the given domain,
|
|
// matching by policies defined in the listener config. The second return value
|
|
// reports whether the domain matches the policy.
|
|
//
|
|
// Though domain policy has higher priority than network policy, it is still
|
|
// processed later, because policy logging want to know whether a network rule
|
|
// is disregarded in favor of the domain level rule.
|
|
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
|
|
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
|
matchedPolicy := "no policy"
|
|
matchedNetwork := "no network"
|
|
matchedRule := "no rule"
|
|
matched := false
|
|
res = &upstreamForResult{srcAddr: addr.String()}
|
|
|
|
defer func() {
|
|
res.upstreams = upstreams
|
|
res.matched = matched
|
|
res.matchedPolicy = matchedPolicy
|
|
res.matchedNetwork = matchedNetwork
|
|
res.matchedRule = matchedRule
|
|
}()
|
|
|
|
if lc.Policy == nil {
|
|
return
|
|
}
|
|
|
|
do := func(policyUpstreams []string) {
|
|
upstreams = append([]string(nil), policyUpstreams...)
|
|
}
|
|
|
|
var networkTargets []string
|
|
var sourceIP net.IP
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
sourceIP = addr.IP
|
|
case *net.TCPAddr:
|
|
sourceIP = addr.IP
|
|
}
|
|
|
|
networkRules:
|
|
for _, rule := range lc.Policy.Networks {
|
|
for source, targets := range rule {
|
|
networkNum := strings.TrimPrefix(source, "network.")
|
|
nc := p.cfg.Network[networkNum]
|
|
if nc == nil {
|
|
continue
|
|
}
|
|
for _, ipNet := range nc.IPNets {
|
|
if ipNet.Contains(sourceIP) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break networkRules
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
macRules:
|
|
for _, rule := range lc.Policy.Macs {
|
|
for source, targets := range rule {
|
|
if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break macRules
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, rule := range lc.Policy.Rules {
|
|
// There's only one entry per rule, config validation ensures this.
|
|
for source, targets := range rule {
|
|
if source == domain || wildcardMatches(source, domain) {
|
|
matchedPolicy = lc.Policy.Name
|
|
if len(networkTargets) > 0 {
|
|
matchedNetwork += " (unenforced)"
|
|
}
|
|
matchedRule = source
|
|
do(targets)
|
|
matched = true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if matched {
|
|
do(networkTargets)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// proxyPrivatePtrLookup performs a private PTR DNS lookup based on the client info table for the given query.
|
|
// It prevents DNS loops by locking the processing of the same domain name simultaneously.
|
|
// If a valid IP-to-hostname mapping exists, it creates a PTR DNS record as the response.
|
|
// Returns the DNS response if a hostname is found or nil otherwise.
|
|
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
cDomainName := msg.Question[0].Name
|
|
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
|
defer p.ptrLoopGuard.Unlock(cDomainName)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
ip := ipFromARPA(cDomainName)
|
|
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
answer.Answer = []dns.RR{&dns.PTR{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypePTR,
|
|
Class: dns.ClassINET,
|
|
},
|
|
Ptr: dns.Fqdn(name),
|
|
}}
|
|
ctrld.Log(ctx, p.Info(), "Private PTR lookup, using client info table")
|
|
ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: name,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// proxyLanHostnameQuery resolves LAN hostnames to their corresponding IP addresses based on the dns.Msg request.
|
|
// It uses a loop guard mechanism to prevent DNS query loops and ensures a hostname is processed only once at a time.
|
|
// This method queries the client info table for the hostname's IP address and logs relevant debug and client info.
|
|
// If the hostname matches known IPs in the table, it generates an appropriate dns.Msg response; otherwise, it returns nil.
|
|
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
q := msg.Question[0]
|
|
hostname := strings.TrimSuffix(q.Name, ".")
|
|
locked := p.lanLoopGuard.TryLock(hostname)
|
|
defer p.lanLoopGuard.Unlock(hostname)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
switch {
|
|
case ip.Is4():
|
|
answer.Answer = []dns.RR{&dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
A: ip.AsSlice(),
|
|
}}
|
|
case ip.Is6():
|
|
answer.Answer = []dns.RR{&dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
AAAA: ip.AsSlice(),
|
|
}}
|
|
}
|
|
ctrld.Log(ctx, p.Info(), "Lan hostname lookup, using client info table")
|
|
ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: hostname,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleSpecialQueryTypes processes specific types of DNS queries such as SRV, PTR, and LAN hostname lookups.
|
|
// It modifies upstreams and upstreamConfigs based on the query type and updates the query context accordingly.
|
|
// Returns a proxyResponse if the query is resolved locally; otherwise, returns nil to proceed with upstream processing.
|
|
func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest, upstreams *[]string, upstreamConfigs *[]*ctrld.UpstreamConfig) *proxyResponse {
|
|
if req.ufr.matched {
|
|
ctrld.Log(*ctx, p.Debug(), "%s, %s, %s -> %v",
|
|
req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, *upstreams)
|
|
return nil
|
|
}
|
|
|
|
switch {
|
|
case isSrvLanLookup(req.msg):
|
|
*upstreams = []string{upstreamOS}
|
|
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "SRV record lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
case isPrivatePtrLookup(req.msg):
|
|
req.isLanOrPtrQuery = true
|
|
if answer := p.proxyPrivatePtrLookup(*ctx, req.msg); answer != nil {
|
|
return &proxyResponse{answer: answer, clientInfo: true}
|
|
}
|
|
*upstreams, *upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(*upstreams, *upstreamConfigs)
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "Private PTR lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
case isLanHostnameQuery(req.msg):
|
|
req.isLanOrPtrQuery = true
|
|
if answer := p.proxyLanHostnameQuery(*ctx, req.msg); answer != nil {
|
|
return &proxyResponse{answer: answer, clientInfo: true}
|
|
}
|
|
*upstreams = []string{upstreamOS}
|
|
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "Lan hostname lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
default:
|
|
ctrld.Log(*ctx, p.Debug(), "No explicit policy matched, using default routing -> %v", *upstreams)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// proxy handles DNS query proxying by selecting upstreams, attempting cache lookups, and querying configured resolvers.
|
|
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Proxy query processing started")
|
|
|
|
upstreams, upstreamConfigs := p.initializeUpstreams(req)
|
|
ctrld.Log(ctx, p.Debug(), "Initialized upstreams: %v", upstreams)
|
|
|
|
if specialRes := p.handleSpecialQueryTypes(&ctx, req, &upstreams, &upstreamConfigs); specialRes != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Special query type handled")
|
|
return specialRes
|
|
}
|
|
|
|
if cachedRes := p.tryCache(ctx, req, upstreams); cachedRes != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Cache hit, returning cached response")
|
|
return cachedRes
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "No cache hit, trying upstreams")
|
|
if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query successful")
|
|
return res
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "All upstreams failed, handling failure")
|
|
return p.handleAllUpstreamsFailure(ctx, req, upstreams)
|
|
}
|
|
|
|
// initializeUpstreams determines which upstreams and configurations to use for a given proxyRequest.
|
|
// If no upstreams are configured, it defaults to the operating system's resolver configuration.
|
|
// Returns a slice of upstream names and their corresponding configurations.
|
|
func (p *prog) initializeUpstreams(req *proxyRequest) ([]string, []*ctrld.UpstreamConfig) {
|
|
upstreams := req.ufr.upstreams
|
|
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
|
if len(upstreamConfigs) == 0 {
|
|
return []string{upstreamOS}, []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
}
|
|
return upstreams, upstreamConfigs
|
|
}
|
|
|
|
// tryCache attempts to retrieve a cached response for the given DNS request from specified upstreams.
|
|
// Returns a proxyResponse if a cache hit occurs; otherwise, returns nil.
|
|
// Skips cache checking if caching is disabled or the request is a PTR query.
|
|
// Iterates through the provided upstreams to find a cached response using the checkCache method.
|
|
func (p *prog) tryCache(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
|
if p.cache == nil || req.msg.Question[0].Qtype == dns.TypePTR { // https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
|
ctrld.Log(ctx, p.Debug(), "Cache disabled or PTR query, skipping cache lookup")
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Checking cache for upstreams: %v", upstreams)
|
|
for _, upstream := range upstreams {
|
|
if res := p.checkCache(ctx, req, upstream); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Cache hit found for upstream: %s", upstream)
|
|
return res
|
|
}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "No cache hit found")
|
|
return nil
|
|
}
|
|
|
|
// checkCache checks if a cached DNS response exists for the given request and upstream.
|
|
// Returns a proxyResponse with the cached response if found and valid, or nil otherwise.
|
|
func (p *prog) checkCache(ctx context.Context, req *proxyRequest, upstream string) *proxyResponse {
|
|
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
|
if cachedValue == nil {
|
|
ctrld.Log(ctx, p.Debug(), "No cached value found for upstream: %s", upstream)
|
|
return nil
|
|
}
|
|
|
|
answer := cachedValue.Msg.Copy()
|
|
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
|
now := time.Now()
|
|
|
|
if cachedValue.Expire.After(now) {
|
|
ctrld.Log(ctx, p.Debug(), "Hit cached response")
|
|
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
|
return &proxyResponse{answer: answer, cached: true}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Cached response expired, storing as stale")
|
|
req.staleAnswer = answer
|
|
return nil
|
|
}
|
|
|
|
// updateCache updates the DNS response cache with the given request, response, TTL, and upstream information.
|
|
func (p *prog) updateCache(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string) {
|
|
ttl := ttlFromMsg(answer)
|
|
now := time.Now()
|
|
expired := now.Add(time.Duration(ttl) * time.Second)
|
|
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
|
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
|
}
|
|
setCachedAnswerTTL(answer, now, expired)
|
|
p.cache.Add(dnscache.NewKey(req.msg, upstream), dnscache.NewValue(answer, expired))
|
|
ctrld.Log(ctx, p.Debug(), "Added cached response")
|
|
}
|
|
|
|
// serveStaleResponse serves a stale cached DNS response when an upstream query fails, updating TTL for cached records.
|
|
func (p *prog) serveStaleResponse(ctx context.Context, staleAnswer *dns.Msg) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Serving stale cached response")
|
|
now := time.Now()
|
|
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
|
return &proxyResponse{answer: staleAnswer, cached: true}
|
|
}
|
|
|
|
// handleAllUpstreamsFailure handles the failure scenario when all upstream resolvers fail to respond or process the request.
|
|
func (p *prog) handleAllUpstreamsFailure(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
|
ctrld.Log(ctx, p.Error(), "All %v endpoints failed", upstreams)
|
|
|
|
if p.leakOnUpstreamFailure() {
|
|
ctrld.Log(ctx, p.Debug(), "Leak on upstream failure enabled")
|
|
if p.um.countHealthy(upstreams) == 0 {
|
|
ctrld.Log(ctx, p.Debug(), "No healthy upstreams, triggering recovery")
|
|
p.triggerRecovery(upstreams[0] == upstreamOS)
|
|
} else {
|
|
ctrld.Log(ctx, p.Debug(), "One upstream is down but at least one is healthy; skipping recovery trigger")
|
|
}
|
|
|
|
if upstreams[0] != upstreamOS {
|
|
ctrld.Log(ctx, p.Debug(), "Trying OS resolver as fallback")
|
|
if answer := p.tryOSResolver(ctx, req); answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver fallback successful")
|
|
return answer
|
|
}
|
|
}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Returning server failure response")
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
|
return &proxyResponse{answer: answer}
|
|
}
|
|
|
|
// shouldContinueWithNextUpstream determines whether processing should continue with the next upstream based on response conditions.
|
|
func (p *prog) shouldContinueWithNextUpstream(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, lastUpstream bool) bool {
|
|
if answer.Rcode == dns.RcodeSuccess {
|
|
ctrld.Log(ctx, p.Debug(), "Successful response, not continuing to next upstream")
|
|
return false
|
|
}
|
|
|
|
// We are doing LAN/PTR lookup using private resolver, so always process the next one.
|
|
// Except for the last, we want to send a response instead of saying all upstream failed.
|
|
if req.isLanOrPtrQuery && !lastUpstream {
|
|
ctrld.Log(ctx, p.Debug(), "No response for LAN/PTR query from %s, process to next upstream", upstream)
|
|
return true
|
|
}
|
|
|
|
if len(req.upstreamConfigs) > 1 && slices.Contains(req.failoverRcodes, answer.Rcode) {
|
|
ctrld.Log(ctx, p.Debug(), "Failover rcode matched, process to next upstream")
|
|
return true
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Not continuing to next upstream")
|
|
return false
|
|
}
|
|
|
|
// prepareSuccessResponse prepares a successful DNS response for a given request, logs it, and updates the cache if applicable.
|
|
func (p *prog) prepareSuccessResponse(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, upstreamConfig *ctrld.UpstreamConfig) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Preparing success response")
|
|
|
|
answer.Compress = true
|
|
|
|
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
|
ctrld.Log(ctx, p.Debug(), "Updating cache with successful response")
|
|
p.updateCache(ctx, req, answer, upstream)
|
|
}
|
|
|
|
hostname := ""
|
|
if req.ci != nil {
|
|
hostname = req.ci.Hostname
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s",
|
|
upstream, req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
|
|
|
return &proxyResponse{
|
|
answer: answer,
|
|
upstream: upstreamConfig.Endpoint,
|
|
}
|
|
}
|
|
|
|
// tryUpstreams attempts to proxy a DNS request through the provided upstreams and their configurations sequentially.
|
|
// It returns a successful proxyResponse if any upstream processes the request successfully, or nil otherwise.
|
|
// The function supports "serve stale" for cache by utilizing cached responses when upstreams fail.
|
|
func (p *prog) tryUpstreams(ctx context.Context, req *proxyRequest, upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) *proxyResponse {
|
|
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
|
req.upstreamConfigs = upstreamConfigs
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Trying %d upstreams", len(upstreamConfigs))
|
|
|
|
for n, upstreamConfig := range upstreamConfigs {
|
|
last := n == len(upstreamConfigs)-1
|
|
ctrld.Log(ctx, p.Debug(), "Processing upstream %d/%d: %s", n+1, len(upstreamConfigs), upstreams[n])
|
|
|
|
if res := p.processUpstream(ctx, req, upstreams[n], upstreamConfig, serveStaleCache, last); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream %s succeeded", upstreams[n])
|
|
return res
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Upstream %s failed", upstreams[n])
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "All upstreams failed")
|
|
return nil
|
|
}
|
|
|
|
// processUpstream proxies a DNS query to a given upstream server and processes the response based on the provided configuration.
|
|
// It supports serving stale cache when upstream queries fail, and checks if processing should continue to another upstream.
|
|
// Returns a proxyResponse on success or nil if the upstream query fails or processing conditions are not met.
|
|
func (p *prog) processUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig, serveStaleCache, lastUpstream bool) *proxyResponse {
|
|
if upstreamConfig == nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream config is nil, skipping")
|
|
return nil
|
|
}
|
|
if p.isLoop(upstreamConfig) {
|
|
logger := p.Debug().
|
|
Str("upstream", upstreamConfig.String()).
|
|
Str("query", req.msg.Question[0].Name).
|
|
Bool("is_lan_query", req.isLanOrPtrQuery)
|
|
ctrld.Log(ctx, logger, "DNS loop detected")
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Querying upstream: %s", upstream)
|
|
answer := p.queryUpstream(ctx, req, upstream, upstreamConfig)
|
|
if answer == nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query failed")
|
|
if serveStaleCache && req.staleAnswer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Serving stale response due to upstream failure")
|
|
return p.serveStaleResponse(ctx, req.staleAnswer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query successful")
|
|
if p.shouldContinueWithNextUpstream(ctx, req, answer, upstream, lastUpstream) {
|
|
return nil
|
|
}
|
|
return p.prepareSuccessResponse(ctx, req, answer, upstream, upstreamConfig)
|
|
}
|
|
|
|
// queryUpstream sends a DNS query to a specified upstream using its configuration and handles errors and retries.
|
|
func (p *prog) queryUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig) *dns.Msg {
|
|
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Adding client info to upstream query")
|
|
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Sending query to %s: %s", upstream, upstreamConfig.Name)
|
|
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
|
if err != nil {
|
|
ctrld.Log(ctx, p.Error().Err(err), "Failed to create resolver")
|
|
return nil
|
|
}
|
|
|
|
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
|
defer cancel()
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Resolving query with upstream")
|
|
answer, err := dnsResolver.Resolve(resolveCtx, req.msg)
|
|
if answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream resolution successful")
|
|
p.um.mu.Lock()
|
|
p.um.failureReq[upstream] = 0
|
|
p.um.down[upstream] = false
|
|
p.um.mu.Unlock()
|
|
return answer
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Error().Err(err), "Failed to resolve query")
|
|
// Increasing the failure count when there is no answer regardless of what kind of error we get
|
|
p.um.increaseFailureCount(upstream)
|
|
if err != nil {
|
|
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
|
var e net.Error
|
|
if errors.As(err, &e) && e.Timeout() {
|
|
ctrld.Log(ctx, p.Debug(), "Timeout error, forcing re-bootstrapping")
|
|
upstreamConfig.ReBootstrap(ctx)
|
|
}
|
|
// For network error, turn ipv6 off if enabled.
|
|
if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) {
|
|
ctrld.Log(ctx, p.Debug(), "Network error, disabling IPv6")
|
|
ctrld.DisableIPv6(ctx)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// triggerRecovery attempts to initiate a recovery process if no healthy upstreams are detected.
|
|
// If "isOSFailure" is true, the recovery will account for an operating system failure.
|
|
// Logs are generated to indicate whether recovery is triggered or already in progress.
|
|
func (p *prog) triggerRecovery(isOSFailure bool) {
|
|
p.recoveryCancelMu.Lock()
|
|
defer p.recoveryCancelMu.Unlock()
|
|
|
|
if p.recoveryCancel == nil {
|
|
var reason RecoveryReason
|
|
if isOSFailure {
|
|
reason = RecoveryReasonOSFailure
|
|
} else {
|
|
reason = RecoveryReasonRegularFailure
|
|
}
|
|
p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
|
go p.handleRecovery(reason)
|
|
} else {
|
|
p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
|
}
|
|
}
|
|
|
|
// tryOSResolver attempts to query the OS resolver as a fallback mechanism when other upstreams fail.
|
|
// Logs success or failure of the query attempt and returns a proxyResponse or nil based on query result.
|
|
func (p *prog) tryOSResolver(ctx context.Context, req *proxyRequest) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Attempting query to OS resolver as a retry catch all")
|
|
answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig)
|
|
if answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful")
|
|
return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint}
|
|
}
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed")
|
|
return nil
|
|
}
|
|
|
|
// upstreamsAndUpstreamConfigForPtr returns the updated upstreams and upstreamConfigs for a private PTR lookup scenario.
|
|
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
|
if len(p.localUpstreams) > 0 {
|
|
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
|
tmp = append(tmp, p.localUpstreams...)
|
|
tmp = append(tmp, upstreams...)
|
|
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
|
}
|
|
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
|
}
|
|
|
|
// upstreamConfigsFromUpstreamNumbers converts a list of upstream names into their corresponding UpstreamConfig objects.
|
|
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
|
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
|
for _, upstream := range upstreams {
|
|
upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix)
|
|
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
|
}
|
|
return upstreamConfigs
|
|
}
|
|
|
|
// canonicalName returns canonical name from FQDN with "." trimmed.
|
|
func canonicalName(fqdn string) string {
|
|
q := strings.TrimSpace(fqdn)
|
|
q = strings.TrimSuffix(q, ".")
|
|
// https://datatracker.ietf.org/doc/html/rfc4343
|
|
q = strings.ToLower(q)
|
|
|
|
return q
|
|
}
|
|
|
|
// wildcardMatches reports whether string str matches the wildcard pattern in case-insensitive manner.
|
|
func wildcardMatches(wildcard, str string) bool {
|
|
// Wildcard match.
|
|
wildCardParts := strings.Split(strings.ToLower(wildcard), "*")
|
|
if len(wildCardParts) != 2 {
|
|
return false
|
|
}
|
|
|
|
str = strings.ToLower(str)
|
|
switch {
|
|
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
|
|
// Domain must match both prefix and suffix.
|
|
return strings.HasPrefix(str, wildCardParts[0]) && strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[1]) > 0:
|
|
// Only suffix must match.
|
|
return strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[0]) > 0:
|
|
// Only prefix must match.
|
|
return strings.HasPrefix(str, wildCardParts[0])
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// fmtRemoteToLocal formats a remote address to indicate its mapping to a local listener using listener number and hostname.
|
|
func fmtRemoteToLocal(listenerNum, hostname, remote string) string {
|
|
return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum)
|
|
}
|
|
|
|
// requestID generates a random 6-character hexadecimal string to uniquely identify a request. It panics on error.
|
|
func requestID() string {
|
|
b := make([]byte, 3) // 6 chars
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(err)
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// setCachedAnswerTTL updates the TTL of each DNS record in the provided message based on the current and expiration times.
|
|
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
|
ttlSecs := expiredTime.Sub(now).Seconds()
|
|
if ttlSecs < 0 {
|
|
return
|
|
}
|
|
|
|
ttl := uint32(ttlSecs)
|
|
for _, rr := range answer.Answer {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Ns {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Extra {
|
|
if rr.Header().Rrtype != dns.TypeOPT {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
}
|
|
}
|
|
|
|
// ttlFromMsg extracts and returns the TTL value from the first record in the Answer or Ns sections of a DNS message.
|
|
// If no records exist in either section, the function returns 0.
|
|
func ttlFromMsg(msg *dns.Msg) uint32 {
|
|
for _, rr := range msg.Answer {
|
|
return rr.Header().Ttl
|
|
}
|
|
for _, rr := range msg.Ns {
|
|
return rr.Header().Ttl
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// needLocalIPv6Listener checks if a local IPv6 listener is required on Windows by verifying IPv6 support and the OS type.
|
|
func needLocalIPv6Listener() bool {
|
|
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
|
// listen on ::1, then spawn a listener for receiving DNS requests.
|
|
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
|
}
|
|
|
|
// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any.
|
|
func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
|
ip, mac := "", ""
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
switch e := s.(type) {
|
|
case *dns.EDNS0_LOCAL:
|
|
if e.Code == EDNS0_OPTION_MAC {
|
|
mac = net.HardwareAddr(e.Data).String()
|
|
}
|
|
case *dns.EDNS0_SUBNET:
|
|
if len(e.Address) > 0 && !e.Address.IsLoopback() {
|
|
ip = e.Address.String()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ip, mac
|
|
}
|
|
|
|
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
|
|
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
|
|
func stripClientSubnet(msg *dns.Msg) {
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
opts := make([]dns.EDNS0, 0, len(opt.Option))
|
|
for _, s := range opt.Option {
|
|
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
|
|
continue
|
|
}
|
|
opts = append(opts, s)
|
|
}
|
|
if len(opts) != len(opt.Option) {
|
|
opt.Option = opts
|
|
}
|
|
}
|
|
}
|
|
|
|
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
|
if ci != nil && ci.IP != "" {
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
udpAddr := &net.UDPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
case *net.TCPAddr:
|
|
udpAddr := &net.TCPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
}
|
|
}
|
|
return addr
|
|
}
|
|
|
|
// runDNSServer starts a DNS server for given address and network,
|
|
// with the given handler. It ensures the server has started listening.
|
|
// Any error will be reported to the caller via returned channel.
|
|
//
|
|
// It's the caller responsibility to call Shutdown to close the server.
|
|
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
|
mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("Starting DNS server")
|
|
|
|
s := &dns.Server{
|
|
Addr: addr,
|
|
Net: network,
|
|
Handler: handler,
|
|
}
|
|
|
|
startedCh := make(chan struct{})
|
|
s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() }
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
defer close(errCh)
|
|
if err := s.ListenAndServe(); err != nil {
|
|
s.NotifyStartedFunc()
|
|
mainLog.Load().Error().Err(err).Msgf("Could not listen and serve on: %s", s.Addr)
|
|
errCh <- err
|
|
}
|
|
}()
|
|
<-startedCh
|
|
mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("DNS server started successfully")
|
|
return s, errCh
|
|
}
|
|
|
|
func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
|
ci := &ctrld.ClientInfo{}
|
|
if p.appCallback != nil {
|
|
ci.IP = p.appCallback.LanIp()
|
|
ci.Mac = p.appCallback.MacAddress()
|
|
ci.Hostname = p.appCallback.HostName()
|
|
ci.Self = true
|
|
return ci
|
|
}
|
|
ci.IP, ci.Mac = ipAndMacFromMsg(msg)
|
|
switch {
|
|
case ci.IP != "" && ci.Mac != "":
|
|
// Nothing to do.
|
|
case ci.IP == "" && ci.Mac != "":
|
|
// Have MAC, no IP.
|
|
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
|
case ci.IP == "" && ci.Mac == "":
|
|
// Have nothing, use remote IP then lookup MAC.
|
|
ci.IP = remoteIP
|
|
fallthrough
|
|
case ci.IP != "" && ci.Mac == "":
|
|
// Have IP, no MAC.
|
|
ci.Mac = p.ciTable.LookupMac(ci.IP)
|
|
}
|
|
|
|
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
|
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
|
|
if ci.Mac == "" {
|
|
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
|
ci.Hostname = hostname
|
|
} else {
|
|
// Only use IP as hostname for IPv4 clients.
|
|
// For Android devices, when it joins the network, it uses ctrld to resolve
|
|
// its private DNS once and never reaches ctrld again. For each time, it uses
|
|
// a different IPv6 address, which causes hundreds/thousands different client
|
|
// IDs created for the same device, which is pointless.
|
|
//
|
|
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
|
|
if !ctrldnet.IsIPv6(ci.IP) {
|
|
ci.Hostname = ci.IP
|
|
p.ciTable.StoreVPNClient(ci)
|
|
}
|
|
}
|
|
} else {
|
|
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
|
}
|
|
ci.Self = p.queryFromSelf(ci.IP)
|
|
// If this is a query from self, but ci.IP is not loopback IP,
|
|
// try using hostname mapping for lookback IP if presents.
|
|
if ci.Self {
|
|
if name := p.ciTable.LocalHostname(); name != "" {
|
|
ci.Hostname = name
|
|
}
|
|
}
|
|
p.spoofLoopbackIpInClientInfo(ci)
|
|
return ci
|
|
}
|
|
|
|
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
|
|
//
|
|
// - Preference IPv4.
|
|
// - Preference RFC1918.
|
|
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
|
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
|
|
return
|
|
}
|
|
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
|
|
ci.IP = ip
|
|
}
|
|
}
|
|
|
|
// doSelfUninstall performs self-uninstall if these condition met:
|
|
//
|
|
// - There is only 1 ControlD upstream in-use.
|
|
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
|
// - The cdUID is deleted.
|
|
func (p *prog) doSelfUninstall(pr *proxyResponse) {
|
|
answer := pr.answer
|
|
if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
|
return
|
|
}
|
|
|
|
p.selfUninstallMu.Lock()
|
|
defer p.selfUninstallMu.Unlock()
|
|
if p.checkingSelfUninstall {
|
|
return
|
|
}
|
|
|
|
logger := p.logger.Load().With().Str("mode", "self-uninstall")
|
|
if p.refusedQueryCount > selfUninstallMaxQueries {
|
|
p.checkingSelfUninstall = true
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
_, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev)
|
|
logger.Debug().Msg("Maximum number of refused queries reached, checking device status")
|
|
selfUninstallCheck(err, p, logger)
|
|
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msg("Could not fetch resolver config")
|
|
}
|
|
// Cool-of period to prevent abusing the API.
|
|
go p.selfUninstallCoolOfPeriod()
|
|
return
|
|
}
|
|
p.refusedQueryCount++
|
|
}
|
|
|
|
// selfUninstallCoolOfPeriod waits for 30 minutes before
|
|
// calling API again for checking ControlD device status.
|
|
func (p *prog) selfUninstallCoolOfPeriod() {
|
|
t := time.NewTimer(time.Minute * 30)
|
|
defer t.Stop()
|
|
<-t.C
|
|
p.selfUninstallMu.Lock()
|
|
p.checkingSelfUninstall = false
|
|
p.refusedQueryCount = 0
|
|
p.selfUninstallMu.Unlock()
|
|
}
|
|
|
|
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
|
// and the domain == "cdUID.verify.controld.com"
|
|
func (p *prog) forceFetchingAPI(domain string) {
|
|
if cdUID == "" {
|
|
return
|
|
}
|
|
resolverID, parent, _ := strings.Cut(domain, ".")
|
|
if resolverID != cdUID {
|
|
return
|
|
}
|
|
switch {
|
|
case cdDev && parent == "verify.controld.dev":
|
|
// match ControlD dev
|
|
case parent == "verify.controld.com":
|
|
// match ControlD
|
|
default:
|
|
return
|
|
}
|
|
_ = p.apiForceReloadGroup.DoChan("force_sync_api", func() (interface{}, error) {
|
|
p.apiForceReloadCh <- struct{}{}
|
|
// Wait here to prevent abusing API if we are flooded.
|
|
time.Sleep(timeDurationOrDefault(p.cfg.Service.ForceRefetchWaitTime, 30) * time.Second)
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// timeDurationOrDefault returns time duration value from n if not nil.
|
|
// Otherwise, it returns time duration value defaultN.
|
|
func timeDurationOrDefault(n *int, defaultN int) time.Duration {
|
|
if n != nil && *n > 0 {
|
|
return time.Duration(*n)
|
|
}
|
|
return time.Duration(defaultN)
|
|
}
|
|
|
|
// queryFromSelf reports whether the input IP is from device running ctrld.
|
|
func (p *prog) queryFromSelf(ip string) bool {
|
|
if val, ok := p.queryFromSelfMap.Load(ip); ok {
|
|
return val.(bool)
|
|
}
|
|
netIP := netip.MustParseAddr(ip)
|
|
regularIPs, loopbackIPs, err := netmon.LocalAddresses()
|
|
if err != nil {
|
|
p.Warn().Err(err).Msg("Could not get local addresses")
|
|
return false
|
|
}
|
|
for _, localIP := range slices.Concat(regularIPs, loopbackIPs) {
|
|
if localIP.Compare(netIP) == 0 {
|
|
p.queryFromSelfMap.Store(ip, true)
|
|
return true
|
|
}
|
|
}
|
|
p.queryFromSelfMap.Store(ip, false)
|
|
return false
|
|
}
|
|
|
|
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
|
|
// This is helpful for non-desktop platforms to receive queries from LAN clients.
|
|
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
|
return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform()
|
|
}
|
|
|
|
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
|
func ipFromARPA(arpa string) net.IP {
|
|
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
|
|
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
|
|
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
|
|
}
|
|
}
|
|
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
|
|
l := net.IPv6len * 2
|
|
base := 16
|
|
ip := make(net.IP, net.IPv6len)
|
|
for i := 0; i < l && arpa != ""; i++ {
|
|
idx := strings.LastIndexByte(arpa, '.')
|
|
off := idx + 1
|
|
if idx == -1 {
|
|
idx = 0
|
|
off = 0
|
|
} else if idx == len(arpa)-1 {
|
|
return nil
|
|
}
|
|
n, err := strconv.ParseUint(arpa[off:], base, 8)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
b := byte(n)
|
|
ii := i / 2
|
|
if i&1 == 1 {
|
|
b |= ip[ii] << 4
|
|
}
|
|
ip[ii] = b
|
|
arpa = arpa[:idx]
|
|
}
|
|
return ip
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
|
func isPrivatePtrLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
if ip := ipFromARPA(q.Name); ip != nil {
|
|
if addr, ok := netip.AddrFromSlice(ip); ok {
|
|
return addr.IsPrivate() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsLinkLocalUnicast() ||
|
|
tsaddr.CGNATRange().Contains(addr)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
|
|
func isLanHostnameQuery(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
switch q.Qtype {
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
default:
|
|
return false
|
|
}
|
|
return isLanHostname(q.Name)
|
|
}
|
|
|
|
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
|
|
func isSrvLanLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
|
|
}
|
|
|
|
// isLanHostname reports whether name is a LAN hostname.
|
|
func isLanHostname(name string) bool {
|
|
name = strings.TrimSuffix(name, ".")
|
|
return !strings.Contains(name, ".") ||
|
|
strings.HasSuffix(name, ".domain") ||
|
|
strings.HasSuffix(name, ".lan") ||
|
|
strings.HasSuffix(name, ".local")
|
|
}
|
|
|
|
// isWanClient reports whether the input is a WAN address.
|
|
func isWanClient(na net.Addr) bool {
|
|
var ip netip.Addr
|
|
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
|
|
ip = ap.Addr()
|
|
}
|
|
return !ip.IsLoopback() &&
|
|
!ip.IsPrivate() &&
|
|
!ip.IsLinkLocalUnicast() &&
|
|
!ip.IsLinkLocalMulticast() &&
|
|
!tsaddr.CGNATRange().Contains(ip)
|
|
}
|
|
|
|
// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller.
|
|
func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg {
|
|
logger := ctrld.LoggerFromCtx(ctx)
|
|
ctrld.Log(ctx, logger.Debug(), "Internal domain test query")
|
|
|
|
q := m.Question[0]
|
|
answer := new(dns.Msg)
|
|
rrStr := fmt.Sprintf("%s A %s", domain, net.IPv4zero)
|
|
if q.Qtype == dns.TypeAAAA {
|
|
rrStr = fmt.Sprintf("%s AAAA %s", domain, net.IPv6zero)
|
|
}
|
|
rr, err := dns.NewRR(rrStr)
|
|
if err == nil {
|
|
answer.Answer = append(answer.Answer, rr)
|
|
}
|
|
answer.SetReply(m)
|
|
return answer
|
|
}
|
|
|
|
// FlushDNSCache flushes the DNS cache on macOS.
|
|
func FlushDNSCache() error {
|
|
// if not macOS, return
|
|
if runtime.GOOS != "darwin" {
|
|
return nil
|
|
}
|
|
|
|
// Flush the DNS cache via mDNSResponder.
|
|
// This is typically needed on modern macOS systems.
|
|
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
// Optionally, flush the directory services cache.
|
|
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// monitorNetworkChanges starts monitoring for network interface changes
|
|
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
|
mon, err := netmon.New(func(format string, args ...any) {
|
|
// Always fetch the latest logger (and inject the prefix)
|
|
p.logger.Load().Printf("netmon: "+format, args...)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("creating network monitor: %w", err)
|
|
}
|
|
|
|
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
|
// Get map of valid interfaces
|
|
validIfaces := validInterfacesMap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
|
|
|
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
|
|
|
p.Debug().
|
|
Interface("old_state", delta.Old).
|
|
Interface("new_state", delta.New).
|
|
Bool("is_major_change", isMajorChange).
|
|
Msg("Network change detected")
|
|
|
|
changed := false
|
|
activeInterfaceExists := false
|
|
var changeIPs []netip.Prefix
|
|
// Check each valid interface for changes
|
|
for ifaceName := range validIfaces {
|
|
oldIface, oldExists := delta.Old.Interface[ifaceName]
|
|
newIface, newExists := delta.New.Interface[ifaceName]
|
|
if !newExists {
|
|
continue
|
|
}
|
|
|
|
oldIPs := delta.Old.InterfaceIPs[ifaceName]
|
|
newIPs := delta.New.InterfaceIPs[ifaceName]
|
|
|
|
// if a valid interface did not exist in old
|
|
// check that its up and has usable IPs
|
|
if !oldExists {
|
|
// The interface is new (was not present in the old state).
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
p.Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface newly appeared (was not present in old state)")
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Filter new IPs to only those that are usable.
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
|
|
// Check if interface is up and has usable IPs.
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
activeInterfaceExists = true
|
|
}
|
|
|
|
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
|
|
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
p.Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("old_ips", oldIPs).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface state or IPs changed")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// if the default route changed, set changed to true
|
|
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
|
changed = true
|
|
p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
|
}
|
|
|
|
if !changed {
|
|
p.Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
|
// check if the default IPs are still on an interface that is up
|
|
ValidateDefaultLocalIPsFromDelta(delta.New)
|
|
return
|
|
}
|
|
|
|
if !activeInterfaceExists {
|
|
p.Debug().Msg("No active interfaces found, skipping reinitialization")
|
|
return
|
|
}
|
|
|
|
// Get IPs from default route interface in new state
|
|
selfIP := p.defaultRouteIP()
|
|
|
|
// Ensure that selfIP is an IPv4 address.
|
|
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
|
p.Debug().Msgf("DefaultRouteIP returned a non-ipv4 address: %s, ignoring it", selfIP)
|
|
selfIP = ""
|
|
}
|
|
var ipv6 string
|
|
|
|
if delta.New.DefaultRouteInterface != "" {
|
|
p.Debug().Msgf("Default route interface: %s, ips: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
|
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
p.Debug().Msgf("Checking ip: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
} else {
|
|
// If no default route interface is set yet, use the changed IPs
|
|
p.Debug().Msgf("No default route interface found, using changed ips: %v", changeIPs)
|
|
for _, ip := range changeIPs {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
p.Debug().Msgf("Checking ip: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
|
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
|
if !isMobile() && p.ciTable != nil {
|
|
p.ciTable.SetSelfIP(selfIP)
|
|
}
|
|
}
|
|
if ip := net.ParseIP(ipv6); ip != nil {
|
|
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
|
}
|
|
p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
|
|
|
p.handleRecovery(RecoveryReasonNetworkChange)
|
|
})
|
|
|
|
mon.Start()
|
|
p.Debug().Msg("Network monitor started")
|
|
return nil
|
|
}
|
|
|
|
// interfaceStatesEqual compares two interface states
|
|
func interfaceStatesEqual(a, b *netmon.Interface) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
return a.IsUp() == b.IsUp()
|
|
}
|
|
|
|
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
|
|
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
|
|
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
|
|
var usable []netip.Prefix
|
|
for _, p := range prefixes {
|
|
addr := p.Addr()
|
|
if addr.IsLinkLocalUnicast() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsMulticast() ||
|
|
addr.IsUnspecified() ||
|
|
addr.IsLinkLocalMulticast() ||
|
|
(addr.Is4() && addr.String() == "255.255.255.255") ||
|
|
tsaddr.CGNATRange().Contains(addr) {
|
|
continue
|
|
}
|
|
usable = append(usable, p)
|
|
}
|
|
return usable
|
|
}
|
|
|
|
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
|
|
func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
|
aUsable := filterUsableIPs(a)
|
|
bUsable := filterUsableIPs(b)
|
|
if len(aUsable) != len(bUsable) {
|
|
return false
|
|
}
|
|
|
|
aMap := make(map[string]bool)
|
|
for _, ip := range aUsable {
|
|
aMap[ip.String()] = true
|
|
}
|
|
for _, ip := range bUsable {
|
|
if !aMap[ip.String()] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// checkUpstreamOnce sends a test query to the specified upstream.
|
|
// Returns nil if the upstream responds successfully.
|
|
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
|
p.Debug().Msgf("Starting check for upstream: %s", upstream)
|
|
|
|
resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc)
|
|
if err != nil {
|
|
p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
|
return err
|
|
}
|
|
|
|
timeout := 1000 * time.Millisecond
|
|
if uc.Timeout > 0 {
|
|
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
|
}
|
|
p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
|
p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
|
|
|
start := time.Now()
|
|
msg := uc.VerifyMsg()
|
|
_, err = resolver.Resolve(ctx, msg)
|
|
duration := time.Since(start)
|
|
|
|
if err != nil {
|
|
p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
|
} else {
|
|
p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handleRecovery orchestrates the recovery process by coordinating multiple smaller methods.
|
|
// It handles recovery cancellation logic, creates recovery context, prepares the system,
|
|
// waits for upstream recovery with timeout, and completes the recovery process.
|
|
// The method is designed to be called from a goroutine and handles different recovery reasons
|
|
// (network changes, regular failures, OS failures) with appropriate logic for each.
|
|
func (p *prog) handleRecovery(reason RecoveryReason) {
|
|
p.Debug().Msg("Starting recovery process: removing DNS settings")
|
|
|
|
// Handle recovery cancellation based on reason
|
|
if !p.shouldStartRecovery(reason) {
|
|
return
|
|
}
|
|
|
|
// Create recovery context and cleanup function
|
|
recoveryCtx, cleanup := p.createRecoveryContext()
|
|
defer cleanup()
|
|
|
|
// Remove DNS settings and prepare for recovery
|
|
if err := p.prepareForRecovery(reason); err != nil {
|
|
p.Error().Err(err).Msg("Failed to prepare for recovery")
|
|
return
|
|
}
|
|
|
|
// Build upstream map based on the recovery reason
|
|
upstreams := p.buildRecoveryUpstreams(reason)
|
|
|
|
// Wait for upstream recovery
|
|
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
|
if err != nil {
|
|
p.Error().Err(err).Msg("Recovery failed; DNS settings remain removed")
|
|
return
|
|
}
|
|
|
|
// Complete recovery process
|
|
if err := p.completeRecovery(reason, recovered); err != nil {
|
|
p.Error().Err(err).Msg("Failed to complete recovery")
|
|
return
|
|
}
|
|
|
|
p.Info().Msgf("Recovery completed successfully for upstream %q", recovered)
|
|
}
|
|
|
|
// shouldStartRecovery determines if recovery should start based on the reason and current state.
|
|
// Returns true if recovery should proceed, false otherwise.
|
|
func (p *prog) shouldStartRecovery(reason RecoveryReason) bool {
|
|
p.recoveryCancelMu.Lock()
|
|
defer p.recoveryCancelMu.Unlock()
|
|
|
|
if reason == RecoveryReasonNetworkChange {
|
|
// For network changes, cancel any existing recovery check because the network state has changed.
|
|
if p.recoveryCancel != nil {
|
|
p.Debug().Msg("Cancelling existing recovery check (network change)")
|
|
p.recoveryCancel()
|
|
p.recoveryCancel = nil
|
|
}
|
|
return true
|
|
}
|
|
|
|
// For upstream failures, if a recovery is already in progress, do nothing new.
|
|
if p.recoveryCancel != nil {
|
|
p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// createRecoveryContext creates a new recovery context and returns it along with a cleanup function.
|
|
func (p *prog) createRecoveryContext() (context.Context, func()) {
|
|
p.recoveryCancelMu.Lock()
|
|
recoveryCtx, cancel := context.WithCancel(context.Background())
|
|
p.recoveryCancel = cancel
|
|
p.recoveryCancelMu.Unlock()
|
|
|
|
cleanup := func() {
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = nil
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
return recoveryCtx, cleanup
|
|
}
|
|
|
|
// prepareForRecovery removes DNS settings and initializes OS resolver if needed.
|
|
func (p *prog) prepareForRecovery(reason RecoveryReason) error {
|
|
// Set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
|
p.recoveryRunning.Store(true)
|
|
|
|
// Remove DNS settings - we do not want to restore any static DNS settings
|
|
// we must try to get the DHCP values, any static DNS settings
|
|
// will be appended to nameservers from the saved interface values
|
|
p.resetDNS(false, false)
|
|
|
|
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
|
if reason == RecoveryReasonOSFailure {
|
|
if err := p.reinitializeOSResolver("OS resolver failure detected"); err != nil {
|
|
return fmt.Errorf("failed to reinitialize OS resolver: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// reinitializeOSResolver reinitializes the OS resolver and logs the results.
|
|
func (p *prog) reinitializeOSResolver(message string) error {
|
|
p.Debug().Msg(message)
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
|
if len(ns) == 0 {
|
|
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// completeRecovery completes the recovery process by resetting upstream state and reapplying DNS settings.
|
|
func (p *prog) completeRecovery(reason RecoveryReason, recovered string) error {
|
|
// Reset the upstream failure count and down state
|
|
p.um.reset(recovered)
|
|
|
|
// For network changes we also reinitialize the OS resolver.
|
|
if reason == RecoveryReasonNetworkChange {
|
|
if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil {
|
|
return fmt.Errorf("failed to reinitialize OS resolver during network change: %w", err)
|
|
}
|
|
}
|
|
|
|
// Apply our DNS settings back and log the interface state.
|
|
p.setDNS()
|
|
p.logInterfacesState()
|
|
|
|
// Allow watchdogs to put the listener back on the interface if it's changed for any reason
|
|
p.recoveryRunning.Store(false)
|
|
|
|
return nil
|
|
}
|
|
|
|
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
|
|
// It returns the name of the recovered upstream or an error if the check times out.
|
|
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
|
|
recoveredCh := make(chan string, 1)
|
|
var wg sync.WaitGroup
|
|
|
|
p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
|
|
|
for name, uc := range upstreams {
|
|
wg.Add(1)
|
|
go func(name string, uc *ctrld.UpstreamConfig) {
|
|
defer wg.Done()
|
|
p.Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
|
attempts := 0
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
p.Debug().Msgf("Context canceled for upstream %s", name)
|
|
return
|
|
default:
|
|
attempts++
|
|
// checkUpstreamOnce will reset any failure counters on success.
|
|
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
|
p.Debug().Msgf("Upstream %s recovered successfully", name)
|
|
select {
|
|
case recoveredCh <- name:
|
|
p.Debug().Msgf("Sent recovery notification for upstream %s", name)
|
|
default:
|
|
p.Debug().Msg("Recovery channel full, another upstream already recovered")
|
|
}
|
|
return
|
|
}
|
|
p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
|
time.Sleep(checkUpstreamBackoffSleep)
|
|
|
|
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
|
// we should try to reinit the OS resolver to ensure we can recover
|
|
if name == upstreamOS && attempts%3 == 0 {
|
|
p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
|
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true)
|
|
if len(ns) == 0 {
|
|
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(name, uc)
|
|
}
|
|
|
|
var recovered string
|
|
select {
|
|
case recovered = <-recoveredCh:
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
wg.Wait()
|
|
return recovered, nil
|
|
}
|
|
|
|
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
|
|
// For OS failures we supply the manual OS resolver upstream configuration.
|
|
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
|
|
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
|
|
upstreams := make(map[string]*ctrld.UpstreamConfig)
|
|
switch reason {
|
|
case RecoveryReasonOSFailure:
|
|
upstreams[upstreamOS] = osUpstreamConfig
|
|
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
|
|
// Use all configured upstreams except any OS type.
|
|
for k, uc := range p.cfg.Upstream {
|
|
if uc.Type != ctrld.ResolverTypeOS {
|
|
upstreams[upstreamPrefix+k] = uc
|
|
}
|
|
}
|
|
}
|
|
return upstreams
|
|
}
|
|
|
|
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
|
|
// are still present in the new network state (provided by delta.New).
|
|
// If a stored default IP is no longer active, it resets that default (sets it to nil)
|
|
// so that it won't be used in subsequent custom dialer contexts.
|
|
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
|
currentIPv4 := ctrld.GetDefaultLocalIPv4()
|
|
currentIPv6 := ctrld.GetDefaultLocalIPv6()
|
|
|
|
// Build a map of active IP addresses from the new state.
|
|
activeIPs := make(map[string]bool)
|
|
for _, prefixes := range newState.InterfaceIPs {
|
|
for _, prefix := range prefixes {
|
|
activeIPs[prefix.Addr().String()] = true
|
|
}
|
|
}
|
|
|
|
// Check if the default IPv4 is still active.
|
|
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
|
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
|
}
|
|
|
|
// Check if the default IPv6 is still active.
|
|
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
|
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
|
}
|
|
}
|