mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Move platform-specific network interface detection from cmd/cli/ to root package as ValidInterfaces function. This eliminates code duplication and provides a consistent interface for determining valid physical network interfaces across all platforms. - Remove duplicate validInterfacesMap functions from platform-specific files - Add context parameter to virtualInterfaces for proper logging - Update all callers to use ctrld.ValidInterfaces instead of local functions - Improve error handling in virtual interface detection on Linux
1914 lines
65 KiB
Go
1914 lines
65 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os/exec"
|
|
"runtime"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"golang.org/x/sync/errgroup"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/tsaddr"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
// DNS proxy constants for configuration and behavior control
|
|
const (
|
|
// staleTTL is the TTL for stale cache entries
|
|
// This allows serving cached responses even when upstreams are temporarily unavailable
|
|
staleTTL = 60 * time.Second
|
|
|
|
// localTTL is the TTL for local network responses
|
|
// Longer TTL for local queries reduces unnecessary repeated lookups
|
|
localTTL = 3600 * time.Second
|
|
|
|
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
|
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
|
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
|
// This enables MAC address-based client identification for policy routing
|
|
EDNS0_OPTION_MAC = 0xFDE9
|
|
|
|
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
|
|
// This prevents premature self-uninstallation due to temporary network issues
|
|
selfUninstallMaxQueries = 32
|
|
)
|
|
|
|
// osUpstreamConfig defines the default OS resolver configuration
|
|
// This is used as a fallback when all configured upstreams fail
|
|
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "OS resolver",
|
|
Type: ctrld.ResolverTypeOS,
|
|
Timeout: 3000,
|
|
}
|
|
|
|
// privateUpstreamConfig defines the default private resolver configuration
|
|
// This is used for internal network queries that should not go to public resolvers
|
|
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "Private resolver",
|
|
Type: ctrld.ResolverTypePrivate,
|
|
Timeout: 2000,
|
|
}
|
|
|
|
// proxyRequest contains data for proxying a DNS query to upstream.
|
|
// This structure encapsulates all the information needed to process a DNS request
|
|
type proxyRequest struct {
|
|
msg *dns.Msg
|
|
ci *ctrld.ClientInfo
|
|
failoverRcodes []int
|
|
ufr *upstreamForResult
|
|
staleAnswer *dns.Msg
|
|
isLanOrPtrQuery bool
|
|
upstreamConfigs []*ctrld.UpstreamConfig
|
|
}
|
|
|
|
// proxyResponse contains data for proxying a DNS response from upstream.
|
|
// This structure encapsulates the response and metadata for logging and metrics
|
|
type proxyResponse struct {
|
|
answer *dns.Msg
|
|
upstream string
|
|
cached bool
|
|
clientInfo bool
|
|
refused bool
|
|
}
|
|
|
|
// upstreamForResult represents the result of processing rules for a request.
|
|
// This contains the matched policy information for logging and debugging
|
|
type upstreamForResult struct {
|
|
upstreams []string
|
|
matchedPolicy string
|
|
matchedNetwork string
|
|
matchedRule string
|
|
matched bool
|
|
srcAddr string
|
|
}
|
|
|
|
// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring.
|
|
// This is the main entry point for DNS server functionality
|
|
func (p *prog) serveDNS(ctx context.Context, listenerNum string) error {
|
|
logger := p.logger.Load()
|
|
logger.Debug().Msg("DNS server setup started")
|
|
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
|
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: Failed to allocate listen IP")
|
|
return allocErr
|
|
}
|
|
|
|
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
|
p.handleDNSQuery(w, m, listenerNum, listenerConfig)
|
|
})
|
|
|
|
logger.Debug().Msg("DNS server setup completed")
|
|
return p.startListeners(ctx, listenerConfig, handler)
|
|
}
|
|
|
|
// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols.
|
|
// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors.
|
|
// This function manages the lifecycle of DNS server listeners
|
|
func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error {
|
|
logger := p.logger.Load()
|
|
logger.Debug().Msg("Starting DNS listeners")
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
|
|
for _, proto := range []string{"udp", "tcp"} {
|
|
if needLocalIPv6Listener() {
|
|
logger.Debug().Str("protocol", proto).Msg("Starting local IPv6 listener")
|
|
g.Go(func() error {
|
|
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
p.Warn().Err(err).Msg("Local IPv6 listener failed")
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918 addresses of the machine
|
|
// if explicitly set via setting rfc1918 flag, so ctrld could receive queries from LAN clients.
|
|
if needRFC1918Listeners(cfg) {
|
|
logger.Debug().Str("protocol", proto).Msg("Starting RFC1918 listeners")
|
|
g.Go(func() error {
|
|
for _, addr := range ctrld.Rfc1918Addresses() {
|
|
func() {
|
|
listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port))
|
|
s, errCh := runDNSServer(listenAddr, proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
p.Warn().Err(err).Msgf("Could not listen on %s: %s", proto, listenAddr)
|
|
}
|
|
}()
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
logger.Debug().Str("protocol", proto).Str("ip", cfg.IP).Int("port", cfg.Port).Msg("Starting main listener")
|
|
g.Go(func() error {
|
|
addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port))
|
|
s, errCh := runDNSServer(addr, proto, handler)
|
|
defer s.Shutdown()
|
|
p.started <- struct{}{}
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-gctx.Done():
|
|
case err := <-errCh:
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
logger.Debug().Msg("DNS listeners started successfully")
|
|
return g.Wait()
|
|
}
|
|
|
|
// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers.
|
|
// This is the main entry point for all DNS query processing
|
|
func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) {
|
|
p.sema.acquire()
|
|
defer p.sema.release()
|
|
|
|
if len(m.Question) == 0 {
|
|
sendDNSResponse(w, m, dns.RcodeFormatError)
|
|
return
|
|
}
|
|
|
|
reqID := requestID()
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID)
|
|
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Processing DNS query from %s", w.RemoteAddr().String())
|
|
|
|
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
|
ctrld.Log(ctx, p.Debug(), "Query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
|
sendDNSResponse(w, m, dns.RcodeRefused)
|
|
return
|
|
}
|
|
|
|
go p.detectLoop(m)
|
|
|
|
q := m.Question[0]
|
|
domain := canonicalName(q.Name)
|
|
|
|
if p.handleSpecialDomains(ctx, w, m, domain) {
|
|
ctrld.Log(ctx, p.Debug(), "Special domain query handled")
|
|
return
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Processing standard query for domain: %s", domain)
|
|
p.processStandardQuery(&standardQueryRequest{
|
|
ctx: ctx,
|
|
writer: w,
|
|
msg: m,
|
|
listenerNum: listenerNum,
|
|
listenerConfig: listenerConfig,
|
|
domain: domain,
|
|
})
|
|
}
|
|
|
|
// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status.
|
|
// This handles internal test domains and cache management commands
|
|
func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool {
|
|
switch {
|
|
case domain == "":
|
|
ctrld.Log(ctx, p.Debug(), "Empty domain query, sending format error")
|
|
sendDNSResponse(w, m, dns.RcodeFormatError)
|
|
return true
|
|
case domain == selfCheckInternalTestDomain:
|
|
ctrld.Log(ctx, p.Debug(), "Internal test domain query: %s", domain)
|
|
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
|
_ = w.WriteMsg(answer)
|
|
return true
|
|
}
|
|
|
|
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
|
p.cache.Purge()
|
|
ctrld.Log(ctx, p.Debug(), "Received query %q, local cache is purged", domain)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// standardQueryRequest represents a standard DNS query request with associated context and configuration.
|
|
// This encapsulates all the data needed to process a standard DNS query
|
|
type standardQueryRequest struct {
|
|
ctx context.Context
|
|
writer dns.ResponseWriter
|
|
msg *dns.Msg
|
|
listenerNum string
|
|
listenerConfig *ctrld.ListenerConfig
|
|
domain string
|
|
}
|
|
|
|
// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response.
|
|
// This is the main processing pipeline for normal DNS queries
|
|
func (p *prog) processStandardQuery(req *standardQueryRequest) {
|
|
ctrld.Log(req.ctx, p.Debug(), "Processing standard query started")
|
|
|
|
remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String())
|
|
ci := p.getClientInfo(remoteIP, req.msg)
|
|
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
|
|
|
stripClientSubnet(req.msg)
|
|
remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci)
|
|
fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String())
|
|
|
|
startTime := time.Now()
|
|
q := req.msg.Question[0]
|
|
ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain)
|
|
|
|
ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain)
|
|
|
|
var answer *dns.Msg
|
|
// Handle restricted listener case
|
|
if !ur.matched && req.listenerConfig.Restricted {
|
|
ctrld.Log(req.ctx, p.Debug(), "Query refused, %s does not match any network policy", remoteAddr.String())
|
|
answer = new(dns.Msg)
|
|
answer.SetRcode(req.msg, dns.RcodeRefused)
|
|
// Process the refused query
|
|
go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true})
|
|
} else {
|
|
// Process a normal query
|
|
ctrld.Log(req.ctx, p.Debug(), "Starting proxy query processing")
|
|
pr := p.proxy(req.ctx, &proxyRequest{
|
|
msg: req.msg,
|
|
ci: ci,
|
|
failoverRcodes: p.getFailoverRcodes(req.listenerConfig),
|
|
ufr: ur,
|
|
})
|
|
|
|
rtt := time.Since(startTime)
|
|
ctrld.Log(req.ctx, p.Debug(), "Received response of %d bytes in %s", pr.answer.Len(), rtt)
|
|
|
|
go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr)
|
|
answer = pr.answer
|
|
}
|
|
|
|
if err := req.writer.WriteMsg(answer); err != nil {
|
|
ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
|
|
}
|
|
|
|
ctrld.Log(req.ctx, p.Debug(), "Standard query processing completed")
|
|
}
|
|
|
|
// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording,
|
|
// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures.
|
|
func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
|
p.doSelfUninstall(pr)
|
|
p.recordMetrics(ci, listenerConfig, q, pr)
|
|
p.forceFetchingAPI(canonicalName(q.Name))
|
|
}
|
|
|
|
// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists.
|
|
func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int {
|
|
if cfg.Policy != nil {
|
|
return cfg.Policy.FailoverRcodeNumbers
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics.
|
|
func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
|
upstream := pr.upstream
|
|
switch {
|
|
case pr.cached:
|
|
upstream = "cache"
|
|
case pr.clientInfo:
|
|
upstream = "client_info_table"
|
|
}
|
|
labelValues := []string{
|
|
net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)),
|
|
ci.IP,
|
|
ci.Mac,
|
|
ci.Hostname,
|
|
upstream,
|
|
dns.TypeToString[q.Qtype],
|
|
dns.RcodeToString[pr.answer.Rcode],
|
|
}
|
|
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
|
|
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
|
|
}
|
|
|
|
// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter.
|
|
func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) {
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(m, rcode)
|
|
_ = w.WriteMsg(answer)
|
|
}
|
|
|
|
// upstreamFor returns the list of upstreams for resolving the given domain,
|
|
// matching by policies defined in the listener config. The second return value
|
|
// reports whether the domain matches the policy.
|
|
//
|
|
// Though domain policy has higher priority than network policy, it is still
|
|
// processed later, because policy logging want to know whether a network rule
|
|
// is disregarded in favor of the domain level rule.
|
|
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
|
|
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
|
matchedPolicy := "no policy"
|
|
matchedNetwork := "no network"
|
|
matchedRule := "no rule"
|
|
matched := false
|
|
res = &upstreamForResult{srcAddr: addr.String()}
|
|
|
|
defer func() {
|
|
res.upstreams = upstreams
|
|
res.matched = matched
|
|
res.matchedPolicy = matchedPolicy
|
|
res.matchedNetwork = matchedNetwork
|
|
res.matchedRule = matchedRule
|
|
}()
|
|
|
|
if lc.Policy == nil {
|
|
return
|
|
}
|
|
|
|
do := func(policyUpstreams []string) {
|
|
upstreams = append([]string(nil), policyUpstreams...)
|
|
}
|
|
|
|
var networkTargets []string
|
|
var sourceIP net.IP
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
sourceIP = addr.IP
|
|
case *net.TCPAddr:
|
|
sourceIP = addr.IP
|
|
}
|
|
|
|
networkRules:
|
|
for _, rule := range lc.Policy.Networks {
|
|
for source, targets := range rule {
|
|
networkNum := strings.TrimPrefix(source, "network.")
|
|
nc := p.cfg.Network[networkNum]
|
|
if nc == nil {
|
|
continue
|
|
}
|
|
for _, ipNet := range nc.IPNets {
|
|
if ipNet.Contains(sourceIP) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break networkRules
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
macRules:
|
|
for _, rule := range lc.Policy.Macs {
|
|
for source, targets := range rule {
|
|
if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break macRules
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, rule := range lc.Policy.Rules {
|
|
// There's only one entry per rule, config validation ensures this.
|
|
for source, targets := range rule {
|
|
if source == domain || wildcardMatches(source, domain) {
|
|
matchedPolicy = lc.Policy.Name
|
|
if len(networkTargets) > 0 {
|
|
matchedNetwork += " (unenforced)"
|
|
}
|
|
matchedRule = source
|
|
do(targets)
|
|
matched = true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if matched {
|
|
do(networkTargets)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// proxyPrivatePtrLookup performs a private PTR DNS lookup based on the client info table for the given query.
|
|
// It prevents DNS loops by locking the processing of the same domain name simultaneously.
|
|
// If a valid IP-to-hostname mapping exists, it creates a PTR DNS record as the response.
|
|
// Returns the DNS response if a hostname is found or nil otherwise.
|
|
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
cDomainName := msg.Question[0].Name
|
|
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
|
defer p.ptrLoopGuard.Unlock(cDomainName)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
ip := ipFromARPA(cDomainName)
|
|
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
answer.Answer = []dns.RR{&dns.PTR{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypePTR,
|
|
Class: dns.ClassINET,
|
|
},
|
|
Ptr: dns.Fqdn(name),
|
|
}}
|
|
ctrld.Log(ctx, p.Info(), "Private PTR lookup, using client info table")
|
|
ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: name,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// proxyLanHostnameQuery resolves LAN hostnames to their corresponding IP addresses based on the dns.Msg request.
|
|
// It uses a loop guard mechanism to prevent DNS query loops and ensures a hostname is processed only once at a time.
|
|
// This method queries the client info table for the hostname's IP address and logs relevant debug and client info.
|
|
// If the hostname matches known IPs in the table, it generates an appropriate dns.Msg response; otherwise, it returns nil.
|
|
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
q := msg.Question[0]
|
|
hostname := strings.TrimSuffix(q.Name, ".")
|
|
locked := p.lanLoopGuard.TryLock(hostname)
|
|
defer p.lanLoopGuard.Unlock(hostname)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
switch {
|
|
case ip.Is4():
|
|
answer.Answer = []dns.RR{&dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
A: ip.AsSlice(),
|
|
}}
|
|
case ip.Is6():
|
|
answer.Answer = []dns.RR{&dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
AAAA: ip.AsSlice(),
|
|
}}
|
|
}
|
|
ctrld.Log(ctx, p.Info(), "Lan hostname lookup, using client info table")
|
|
ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: hostname,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleSpecialQueryTypes processes specific types of DNS queries such as SRV, PTR, and LAN hostname lookups.
|
|
// It modifies upstreams and upstreamConfigs based on the query type and updates the query context accordingly.
|
|
// Returns a proxyResponse if the query is resolved locally; otherwise, returns nil to proceed with upstream processing.
|
|
func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest, upstreams *[]string, upstreamConfigs *[]*ctrld.UpstreamConfig) *proxyResponse {
|
|
if req.ufr.matched {
|
|
ctrld.Log(*ctx, p.Debug(), "%s, %s, %s -> %v",
|
|
req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, *upstreams)
|
|
return nil
|
|
}
|
|
|
|
switch {
|
|
case isSrvLanLookup(req.msg):
|
|
*upstreams = []string{upstreamOS}
|
|
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "SRV record lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
case isPrivatePtrLookup(req.msg):
|
|
req.isLanOrPtrQuery = true
|
|
if answer := p.proxyPrivatePtrLookup(*ctx, req.msg); answer != nil {
|
|
return &proxyResponse{answer: answer, clientInfo: true}
|
|
}
|
|
*upstreams, *upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(*upstreams, *upstreamConfigs)
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "Private PTR lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
case isLanHostnameQuery(req.msg):
|
|
req.isLanOrPtrQuery = true
|
|
if answer := p.proxyLanHostnameQuery(*ctx, req.msg); answer != nil {
|
|
return &proxyResponse{answer: answer, clientInfo: true}
|
|
}
|
|
*upstreams = []string{upstreamOS}
|
|
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
*ctx = ctrld.LanQueryCtx(*ctx)
|
|
ctrld.Log(*ctx, p.Debug(), "Lan hostname lookup, using upstreams: %v", *upstreams)
|
|
return nil
|
|
default:
|
|
ctrld.Log(*ctx, p.Debug(), "No explicit policy matched, using default routing -> %v", *upstreams)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// proxy handles DNS query proxying by selecting upstreams, attempting cache lookups, and querying configured resolvers.
|
|
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Proxy query processing started")
|
|
|
|
upstreams, upstreamConfigs := p.initializeUpstreams(req)
|
|
ctrld.Log(ctx, p.Debug(), "Initialized upstreams: %v", upstreams)
|
|
|
|
if specialRes := p.handleSpecialQueryTypes(&ctx, req, &upstreams, &upstreamConfigs); specialRes != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Special query type handled")
|
|
return specialRes
|
|
}
|
|
|
|
if cachedRes := p.tryCache(ctx, req, upstreams); cachedRes != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Cache hit, returning cached response")
|
|
return cachedRes
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "No cache hit, trying upstreams")
|
|
if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query successful")
|
|
return res
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "All upstreams failed, handling failure")
|
|
return p.handleAllUpstreamsFailure(ctx, req, upstreams)
|
|
}
|
|
|
|
// initializeUpstreams determines which upstreams and configurations to use for a given proxyRequest.
|
|
// If no upstreams are configured, it defaults to the operating system's resolver configuration.
|
|
// Returns a slice of upstream names and their corresponding configurations.
|
|
func (p *prog) initializeUpstreams(req *proxyRequest) ([]string, []*ctrld.UpstreamConfig) {
|
|
upstreams := req.ufr.upstreams
|
|
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
|
if len(upstreamConfigs) == 0 {
|
|
return []string{upstreamOS}, []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
}
|
|
return upstreams, upstreamConfigs
|
|
}
|
|
|
|
// tryCache attempts to retrieve a cached response for the given DNS request from specified upstreams.
|
|
// Returns a proxyResponse if a cache hit occurs; otherwise, returns nil.
|
|
// Skips cache checking if caching is disabled or the request is a PTR query.
|
|
// Iterates through the provided upstreams to find a cached response using the checkCache method.
|
|
func (p *prog) tryCache(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
|
if p.cache == nil || req.msg.Question[0].Qtype == dns.TypePTR { // https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
|
ctrld.Log(ctx, p.Debug(), "Cache disabled or PTR query, skipping cache lookup")
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Checking cache for upstreams: %v", upstreams)
|
|
for _, upstream := range upstreams {
|
|
if res := p.checkCache(ctx, req, upstream); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Cache hit found for upstream: %s", upstream)
|
|
return res
|
|
}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "No cache hit found")
|
|
return nil
|
|
}
|
|
|
|
// checkCache checks if a cached DNS response exists for the given request and upstream.
|
|
// Returns a proxyResponse with the cached response if found and valid, or nil otherwise.
|
|
func (p *prog) checkCache(ctx context.Context, req *proxyRequest, upstream string) *proxyResponse {
|
|
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
|
if cachedValue == nil {
|
|
ctrld.Log(ctx, p.Debug(), "No cached value found for upstream: %s", upstream)
|
|
return nil
|
|
}
|
|
|
|
answer := cachedValue.Msg.Copy()
|
|
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
|
now := time.Now()
|
|
|
|
if cachedValue.Expire.After(now) {
|
|
ctrld.Log(ctx, p.Debug(), "Hit cached response")
|
|
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
|
return &proxyResponse{answer: answer, cached: true}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Cached response expired, storing as stale")
|
|
req.staleAnswer = answer
|
|
return nil
|
|
}
|
|
|
|
// updateCache updates the DNS response cache with the given request, response, TTL, and upstream information.
|
|
func (p *prog) updateCache(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string) {
|
|
ttl := ttlFromMsg(answer)
|
|
now := time.Now()
|
|
expired := now.Add(time.Duration(ttl) * time.Second)
|
|
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
|
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
|
}
|
|
setCachedAnswerTTL(answer, now, expired)
|
|
p.cache.Add(dnscache.NewKey(req.msg, upstream), dnscache.NewValue(answer, expired))
|
|
ctrld.Log(ctx, p.Debug(), "Added cached response")
|
|
}
|
|
|
|
// serveStaleResponse serves a stale cached DNS response when an upstream query fails, updating TTL for cached records.
|
|
func (p *prog) serveStaleResponse(ctx context.Context, staleAnswer *dns.Msg) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Serving stale cached response")
|
|
now := time.Now()
|
|
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
|
return &proxyResponse{answer: staleAnswer, cached: true}
|
|
}
|
|
|
|
// handleAllUpstreamsFailure handles the failure scenario when all upstream resolvers fail to respond or process the request.
|
|
func (p *prog) handleAllUpstreamsFailure(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
|
ctrld.Log(ctx, p.Error(), "All %v endpoints failed", upstreams)
|
|
|
|
if p.leakOnUpstreamFailure() {
|
|
ctrld.Log(ctx, p.Debug(), "Leak on upstream failure enabled")
|
|
if p.um.countHealthy(upstreams) == 0 {
|
|
ctrld.Log(ctx, p.Debug(), "No healthy upstreams, triggering recovery")
|
|
p.triggerRecovery(upstreams[0] == upstreamOS)
|
|
} else {
|
|
ctrld.Log(ctx, p.Debug(), "One upstream is down but at least one is healthy; skipping recovery trigger")
|
|
}
|
|
|
|
if upstreams[0] != upstreamOS {
|
|
ctrld.Log(ctx, p.Debug(), "Trying OS resolver as fallback")
|
|
if answer := p.tryOSResolver(ctx, req); answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver fallback successful")
|
|
return answer
|
|
}
|
|
}
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Returning server failure response")
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
|
return &proxyResponse{answer: answer}
|
|
}
|
|
|
|
// shouldContinueWithNextUpstream determines whether processing should continue with the next upstream based on response conditions.
|
|
func (p *prog) shouldContinueWithNextUpstream(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, lastUpstream bool) bool {
|
|
if answer.Rcode == dns.RcodeSuccess {
|
|
ctrld.Log(ctx, p.Debug(), "Successful response, not continuing to next upstream")
|
|
return false
|
|
}
|
|
|
|
// We are doing LAN/PTR lookup using private resolver, so always process the next one.
|
|
// Except for the last, we want to send a response instead of saying all upstream failed.
|
|
if req.isLanOrPtrQuery && !lastUpstream {
|
|
ctrld.Log(ctx, p.Debug(), "No response for LAN/PTR query from %s, process to next upstream", upstream)
|
|
return true
|
|
}
|
|
|
|
if len(req.upstreamConfigs) > 1 && slices.Contains(req.failoverRcodes, answer.Rcode) {
|
|
ctrld.Log(ctx, p.Debug(), "Failover rcode matched, process to next upstream")
|
|
return true
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Not continuing to next upstream")
|
|
return false
|
|
}
|
|
|
|
// prepareSuccessResponse prepares a successful DNS response for a given request, logs it, and updates the cache if applicable.
|
|
func (p *prog) prepareSuccessResponse(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, upstreamConfig *ctrld.UpstreamConfig) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Preparing success response")
|
|
|
|
answer.Compress = true
|
|
|
|
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
|
ctrld.Log(ctx, p.Debug(), "Updating cache with successful response")
|
|
p.updateCache(ctx, req, answer, upstream)
|
|
}
|
|
|
|
hostname := ""
|
|
if req.ci != nil {
|
|
hostname = req.ci.Hostname
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s",
|
|
upstream, req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
|
|
|
return &proxyResponse{
|
|
answer: answer,
|
|
upstream: upstreamConfig.Endpoint,
|
|
}
|
|
}
|
|
|
|
// tryUpstreams attempts to proxy a DNS request through the provided upstreams and their configurations sequentially.
|
|
// It returns a successful proxyResponse if any upstream processes the request successfully, or nil otherwise.
|
|
// The function supports "serve stale" for cache by utilizing cached responses when upstreams fail.
|
|
func (p *prog) tryUpstreams(ctx context.Context, req *proxyRequest, upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) *proxyResponse {
|
|
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
|
req.upstreamConfigs = upstreamConfigs
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Trying %d upstreams", len(upstreamConfigs))
|
|
|
|
for n, upstreamConfig := range upstreamConfigs {
|
|
last := n == len(upstreamConfigs)-1
|
|
ctrld.Log(ctx, p.Debug(), "Processing upstream %d/%d: %s", n+1, len(upstreamConfigs), upstreams[n])
|
|
|
|
if res := p.processUpstream(ctx, req, upstreams[n], upstreamConfig, serveStaleCache, last); res != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream %s succeeded", upstreams[n])
|
|
return res
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Upstream %s failed", upstreams[n])
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "All upstreams failed")
|
|
return nil
|
|
}
|
|
|
|
// processUpstream proxies a DNS query to a given upstream server and processes the response based on the provided configuration.
|
|
// It supports serving stale cache when upstream queries fail, and checks if processing should continue to another upstream.
|
|
// Returns a proxyResponse on success or nil if the upstream query fails or processing conditions are not met.
|
|
func (p *prog) processUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig, serveStaleCache, lastUpstream bool) *proxyResponse {
|
|
if upstreamConfig == nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream config is nil, skipping")
|
|
return nil
|
|
}
|
|
if p.isLoop(upstreamConfig) {
|
|
logger := p.Debug().
|
|
Str("upstream", upstreamConfig.String()).
|
|
Str("query", req.msg.Question[0].Name).
|
|
Bool("is_lan_query", req.isLanOrPtrQuery)
|
|
ctrld.Log(ctx, logger, "DNS loop detected")
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Querying upstream: %s", upstream)
|
|
answer := p.queryUpstream(ctx, req, upstream, upstreamConfig)
|
|
if answer == nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query failed")
|
|
if serveStaleCache && req.staleAnswer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Serving stale response due to upstream failure")
|
|
return p.serveStaleResponse(ctx, req.staleAnswer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Upstream query successful")
|
|
if p.shouldContinueWithNextUpstream(ctx, req, answer, upstream, lastUpstream) {
|
|
return nil
|
|
}
|
|
return p.prepareSuccessResponse(ctx, req, answer, upstream, upstreamConfig)
|
|
}
|
|
|
|
// queryUpstream sends a DNS query to a specified upstream using its configuration and handles errors and retries.
|
|
func (p *prog) queryUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig) *dns.Msg {
|
|
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Adding client info to upstream query")
|
|
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Sending query to %s: %s", upstream, upstreamConfig.Name)
|
|
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
|
if err != nil {
|
|
ctrld.Log(ctx, p.Error().Err(err), "Failed to create resolver")
|
|
return nil
|
|
}
|
|
|
|
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
|
defer cancel()
|
|
|
|
ctrld.Log(ctx, p.Debug(), "Resolving query with upstream")
|
|
answer, err := dnsResolver.Resolve(resolveCtx, req.msg)
|
|
if answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "Upstream resolution successful")
|
|
p.um.mu.Lock()
|
|
p.um.failureReq[upstream] = 0
|
|
p.um.down[upstream] = false
|
|
p.um.mu.Unlock()
|
|
return answer
|
|
}
|
|
|
|
ctrld.Log(ctx, p.Error().Err(err), "Failed to resolve query")
|
|
// Increasing the failure count when there is no answer regardless of what kind of error we get
|
|
p.um.increaseFailureCount(upstream)
|
|
if err != nil {
|
|
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
|
var e net.Error
|
|
if errors.As(err, &e) && e.Timeout() {
|
|
ctrld.Log(ctx, p.Debug(), "Timeout error, forcing re-bootstrapping")
|
|
upstreamConfig.ReBootstrap(ctx)
|
|
}
|
|
// For network error, turn ipv6 off if enabled.
|
|
if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) {
|
|
ctrld.Log(ctx, p.Debug(), "Network error, disabling IPv6")
|
|
ctrld.DisableIPv6(ctx)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// triggerRecovery attempts to initiate a recovery process if no healthy upstreams are detected.
|
|
// If "isOSFailure" is true, the recovery will account for an operating system failure.
|
|
// Logs are generated to indicate whether recovery is triggered or already in progress.
|
|
func (p *prog) triggerRecovery(isOSFailure bool) {
|
|
p.recoveryCancelMu.Lock()
|
|
defer p.recoveryCancelMu.Unlock()
|
|
|
|
if p.recoveryCancel == nil {
|
|
var reason RecoveryReason
|
|
if isOSFailure {
|
|
reason = RecoveryReasonOSFailure
|
|
} else {
|
|
reason = RecoveryReasonRegularFailure
|
|
}
|
|
p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
|
go p.handleRecovery(reason)
|
|
} else {
|
|
p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
|
}
|
|
}
|
|
|
|
// tryOSResolver attempts to query the OS resolver as a fallback mechanism when other upstreams fail.
|
|
// Logs success or failure of the query attempt and returns a proxyResponse or nil based on query result.
|
|
func (p *prog) tryOSResolver(ctx context.Context, req *proxyRequest) *proxyResponse {
|
|
ctrld.Log(ctx, p.Debug(), "Attempting query to OS resolver as a retry catch all")
|
|
answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig)
|
|
if answer != nil {
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful")
|
|
return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint}
|
|
}
|
|
ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed")
|
|
return nil
|
|
}
|
|
|
|
// upstreamsAndUpstreamConfigForPtr returns the updated upstreams and upstreamConfigs for a private PTR lookup scenario.
|
|
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
|
if len(p.localUpstreams) > 0 {
|
|
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
|
tmp = append(tmp, p.localUpstreams...)
|
|
tmp = append(tmp, upstreams...)
|
|
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
|
}
|
|
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
|
}
|
|
|
|
// upstreamConfigsFromUpstreamNumbers converts a list of upstream names into their corresponding UpstreamConfig objects.
|
|
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
|
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
|
for _, upstream := range upstreams {
|
|
upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix)
|
|
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
|
}
|
|
return upstreamConfigs
|
|
}
|
|
|
|
// canonicalName returns canonical name from FQDN with "." trimmed.
|
|
func canonicalName(fqdn string) string {
|
|
q := strings.TrimSpace(fqdn)
|
|
q = strings.TrimSuffix(q, ".")
|
|
// https://datatracker.ietf.org/doc/html/rfc4343
|
|
q = strings.ToLower(q)
|
|
|
|
return q
|
|
}
|
|
|
|
// wildcardMatches reports whether string str matches the wildcard pattern in case-insensitive manner.
|
|
func wildcardMatches(wildcard, str string) bool {
|
|
// Wildcard match.
|
|
wildCardParts := strings.Split(strings.ToLower(wildcard), "*")
|
|
if len(wildCardParts) != 2 {
|
|
return false
|
|
}
|
|
|
|
str = strings.ToLower(str)
|
|
switch {
|
|
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
|
|
// Domain must match both prefix and suffix.
|
|
return strings.HasPrefix(str, wildCardParts[0]) && strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[1]) > 0:
|
|
// Only suffix must match.
|
|
return strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[0]) > 0:
|
|
// Only prefix must match.
|
|
return strings.HasPrefix(str, wildCardParts[0])
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// fmtRemoteToLocal formats a remote address to indicate its mapping to a local listener using listener number and hostname.
|
|
func fmtRemoteToLocal(listenerNum, hostname, remote string) string {
|
|
return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum)
|
|
}
|
|
|
|
// requestID generates a random 6-character hexadecimal string to uniquely identify a request. It panics on error.
|
|
func requestID() string {
|
|
b := make([]byte, 3) // 6 chars
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(err)
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// setCachedAnswerTTL updates the TTL of each DNS record in the provided message based on the current and expiration times.
|
|
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
|
ttlSecs := expiredTime.Sub(now).Seconds()
|
|
if ttlSecs < 0 {
|
|
return
|
|
}
|
|
|
|
ttl := uint32(ttlSecs)
|
|
for _, rr := range answer.Answer {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Ns {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Extra {
|
|
if rr.Header().Rrtype != dns.TypeOPT {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
}
|
|
}
|
|
|
|
// ttlFromMsg extracts and returns the TTL value from the first record in the Answer or Ns sections of a DNS message.
|
|
// If no records exist in either section, the function returns 0.
|
|
func ttlFromMsg(msg *dns.Msg) uint32 {
|
|
for _, rr := range msg.Answer {
|
|
return rr.Header().Ttl
|
|
}
|
|
for _, rr := range msg.Ns {
|
|
return rr.Header().Ttl
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// needLocalIPv6Listener checks if a local IPv6 listener is required on Windows by verifying IPv6 support and the OS type.
|
|
func needLocalIPv6Listener() bool {
|
|
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
|
// listen on ::1, then spawn a listener for receiving DNS requests.
|
|
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
|
}
|
|
|
|
// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any.
|
|
func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
|
ip, mac := "", ""
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
switch e := s.(type) {
|
|
case *dns.EDNS0_LOCAL:
|
|
if e.Code == EDNS0_OPTION_MAC {
|
|
mac = net.HardwareAddr(e.Data).String()
|
|
}
|
|
case *dns.EDNS0_SUBNET:
|
|
if len(e.Address) > 0 && !e.Address.IsLoopback() {
|
|
ip = e.Address.String()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ip, mac
|
|
}
|
|
|
|
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
|
|
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
|
|
func stripClientSubnet(msg *dns.Msg) {
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
opts := make([]dns.EDNS0, 0, len(opt.Option))
|
|
for _, s := range opt.Option {
|
|
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
|
|
continue
|
|
}
|
|
opts = append(opts, s)
|
|
}
|
|
if len(opts) != len(opt.Option) {
|
|
opt.Option = opts
|
|
}
|
|
}
|
|
}
|
|
|
|
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
|
if ci != nil && ci.IP != "" {
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
udpAddr := &net.UDPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
case *net.TCPAddr:
|
|
udpAddr := &net.TCPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
}
|
|
}
|
|
return addr
|
|
}
|
|
|
|
// runDNSServer starts a DNS server for given address and network,
|
|
// with the given handler. It ensures the server has started listening.
|
|
// Any error will be reported to the caller via returned channel.
|
|
//
|
|
// It's the caller responsibility to call Shutdown to close the server.
|
|
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
|
mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("Starting DNS server")
|
|
|
|
s := &dns.Server{
|
|
Addr: addr,
|
|
Net: network,
|
|
Handler: handler,
|
|
}
|
|
|
|
startedCh := make(chan struct{})
|
|
s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() }
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
defer close(errCh)
|
|
if err := s.ListenAndServe(); err != nil {
|
|
s.NotifyStartedFunc()
|
|
mainLog.Load().Error().Err(err).Msgf("Could not listen and serve on: %s", s.Addr)
|
|
errCh <- err
|
|
}
|
|
}()
|
|
<-startedCh
|
|
mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("DNS server started successfully")
|
|
return s, errCh
|
|
}
|
|
|
|
func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
|
ci := &ctrld.ClientInfo{}
|
|
if p.appCallback != nil {
|
|
ci.IP = p.appCallback.LanIp()
|
|
ci.Mac = p.appCallback.MacAddress()
|
|
ci.Hostname = p.appCallback.HostName()
|
|
ci.Self = true
|
|
return ci
|
|
}
|
|
ci.IP, ci.Mac = ipAndMacFromMsg(msg)
|
|
switch {
|
|
case ci.IP != "" && ci.Mac != "":
|
|
// Nothing to do.
|
|
case ci.IP == "" && ci.Mac != "":
|
|
// Have MAC, no IP.
|
|
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
|
case ci.IP == "" && ci.Mac == "":
|
|
// Have nothing, use remote IP then lookup MAC.
|
|
ci.IP = remoteIP
|
|
fallthrough
|
|
case ci.IP != "" && ci.Mac == "":
|
|
// Have IP, no MAC.
|
|
ci.Mac = p.ciTable.LookupMac(ci.IP)
|
|
}
|
|
|
|
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
|
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
|
|
if ci.Mac == "" {
|
|
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
|
ci.Hostname = hostname
|
|
} else {
|
|
// Only use IP as hostname for IPv4 clients.
|
|
// For Android devices, when it joins the network, it uses ctrld to resolve
|
|
// its private DNS once and never reaches ctrld again. For each time, it uses
|
|
// a different IPv6 address, which causes hundreds/thousands different client
|
|
// IDs created for the same device, which is pointless.
|
|
//
|
|
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
|
|
if !ctrldnet.IsIPv6(ci.IP) {
|
|
ci.Hostname = ci.IP
|
|
p.ciTable.StoreVPNClient(ci)
|
|
}
|
|
}
|
|
} else {
|
|
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
|
}
|
|
ci.Self = p.queryFromSelf(ci.IP)
|
|
// If this is a query from self, but ci.IP is not loopback IP,
|
|
// try using hostname mapping for lookback IP if presents.
|
|
if ci.Self {
|
|
if name := p.ciTable.LocalHostname(); name != "" {
|
|
ci.Hostname = name
|
|
}
|
|
}
|
|
p.spoofLoopbackIpInClientInfo(ci)
|
|
return ci
|
|
}
|
|
|
|
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
|
|
//
|
|
// - Preference IPv4.
|
|
// - Preference RFC1918.
|
|
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
|
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
|
|
return
|
|
}
|
|
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
|
|
ci.IP = ip
|
|
}
|
|
}
|
|
|
|
// doSelfUninstall performs self-uninstall if these condition met:
|
|
//
|
|
// - There is only 1 ControlD upstream in-use.
|
|
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
|
// - The cdUID is deleted.
|
|
func (p *prog) doSelfUninstall(pr *proxyResponse) {
|
|
answer := pr.answer
|
|
if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
|
return
|
|
}
|
|
|
|
p.selfUninstallMu.Lock()
|
|
defer p.selfUninstallMu.Unlock()
|
|
if p.checkingSelfUninstall {
|
|
return
|
|
}
|
|
|
|
logger := p.logger.Load().With().Str("mode", "self-uninstall")
|
|
if p.refusedQueryCount > selfUninstallMaxQueries {
|
|
p.checkingSelfUninstall = true
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
_, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev)
|
|
logger.Debug().Msg("Maximum number of refused queries reached, checking device status")
|
|
selfUninstallCheck(err, p, logger)
|
|
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msg("Could not fetch resolver config")
|
|
}
|
|
// Cool-of period to prevent abusing the API.
|
|
go p.selfUninstallCoolOfPeriod()
|
|
return
|
|
}
|
|
p.refusedQueryCount++
|
|
}
|
|
|
|
// selfUninstallCoolOfPeriod waits for 30 minutes before
|
|
// calling API again for checking ControlD device status.
|
|
func (p *prog) selfUninstallCoolOfPeriod() {
|
|
t := time.NewTimer(time.Minute * 30)
|
|
defer t.Stop()
|
|
<-t.C
|
|
p.selfUninstallMu.Lock()
|
|
p.checkingSelfUninstall = false
|
|
p.refusedQueryCount = 0
|
|
p.selfUninstallMu.Unlock()
|
|
}
|
|
|
|
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
|
// and the domain == "cdUID.verify.controld.com"
|
|
func (p *prog) forceFetchingAPI(domain string) {
|
|
if cdUID == "" {
|
|
return
|
|
}
|
|
resolverID, parent, _ := strings.Cut(domain, ".")
|
|
if resolverID != cdUID {
|
|
return
|
|
}
|
|
switch {
|
|
case cdDev && parent == "verify.controld.dev":
|
|
// match ControlD dev
|
|
case parent == "verify.controld.com":
|
|
// match ControlD
|
|
default:
|
|
return
|
|
}
|
|
_ = p.apiForceReloadGroup.DoChan("force_sync_api", func() (interface{}, error) {
|
|
p.apiForceReloadCh <- struct{}{}
|
|
// Wait here to prevent abusing API if we are flooded.
|
|
time.Sleep(timeDurationOrDefault(p.cfg.Service.ForceRefetchWaitTime, 30) * time.Second)
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// timeDurationOrDefault returns time duration value from n if not nil.
|
|
// Otherwise, it returns time duration value defaultN.
|
|
func timeDurationOrDefault(n *int, defaultN int) time.Duration {
|
|
if n != nil && *n > 0 {
|
|
return time.Duration(*n)
|
|
}
|
|
return time.Duration(defaultN)
|
|
}
|
|
|
|
// queryFromSelf reports whether the input IP is from device running ctrld.
|
|
func (p *prog) queryFromSelf(ip string) bool {
|
|
if val, ok := p.queryFromSelfMap.Load(ip); ok {
|
|
return val.(bool)
|
|
}
|
|
netIP := netip.MustParseAddr(ip)
|
|
regularIPs, loopbackIPs, err := netmon.LocalAddresses()
|
|
if err != nil {
|
|
p.Warn().Err(err).Msg("Could not get local addresses")
|
|
return false
|
|
}
|
|
for _, localIP := range slices.Concat(regularIPs, loopbackIPs) {
|
|
if localIP.Compare(netIP) == 0 {
|
|
p.queryFromSelfMap.Store(ip, true)
|
|
return true
|
|
}
|
|
}
|
|
p.queryFromSelfMap.Store(ip, false)
|
|
return false
|
|
}
|
|
|
|
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
|
|
// This is helpful for non-desktop platforms to receive queries from LAN clients.
|
|
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
|
return rfc1918 && lc.IP == "127.0.0.1" && lc.Port == 53
|
|
}
|
|
|
|
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
|
func ipFromARPA(arpa string) net.IP {
|
|
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
|
|
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
|
|
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
|
|
}
|
|
}
|
|
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
|
|
l := net.IPv6len * 2
|
|
base := 16
|
|
ip := make(net.IP, net.IPv6len)
|
|
for i := 0; i < l && arpa != ""; i++ {
|
|
idx := strings.LastIndexByte(arpa, '.')
|
|
off := idx + 1
|
|
if idx == -1 {
|
|
idx = 0
|
|
off = 0
|
|
} else if idx == len(arpa)-1 {
|
|
return nil
|
|
}
|
|
n, err := strconv.ParseUint(arpa[off:], base, 8)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
b := byte(n)
|
|
ii := i / 2
|
|
if i&1 == 1 {
|
|
b |= ip[ii] << 4
|
|
}
|
|
ip[ii] = b
|
|
arpa = arpa[:idx]
|
|
}
|
|
return ip
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
|
func isPrivatePtrLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
if ip := ipFromARPA(q.Name); ip != nil {
|
|
if addr, ok := netip.AddrFromSlice(ip); ok {
|
|
return addr.IsPrivate() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsLinkLocalUnicast() ||
|
|
tsaddr.CGNATRange().Contains(addr)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
|
|
func isLanHostnameQuery(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
switch q.Qtype {
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
default:
|
|
return false
|
|
}
|
|
return isLanHostname(q.Name)
|
|
}
|
|
|
|
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
|
|
func isSrvLanLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
|
|
}
|
|
|
|
// isLanHostname reports whether name is a LAN hostname.
|
|
func isLanHostname(name string) bool {
|
|
name = strings.TrimSuffix(name, ".")
|
|
return !strings.Contains(name, ".") ||
|
|
strings.HasSuffix(name, ".domain") ||
|
|
strings.HasSuffix(name, ".lan") ||
|
|
strings.HasSuffix(name, ".local")
|
|
}
|
|
|
|
// isWanClient reports whether the input is a WAN address.
|
|
func isWanClient(na net.Addr) bool {
|
|
var ip netip.Addr
|
|
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
|
|
ip = ap.Addr()
|
|
}
|
|
return !ip.IsLoopback() &&
|
|
!ip.IsPrivate() &&
|
|
!ip.IsLinkLocalUnicast() &&
|
|
!ip.IsLinkLocalMulticast() &&
|
|
!tsaddr.CGNATRange().Contains(ip)
|
|
}
|
|
|
|
// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller.
|
|
func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg {
|
|
logger := ctrld.LoggerFromCtx(ctx)
|
|
ctrld.Log(ctx, logger.Debug(), "Internal domain test query")
|
|
|
|
q := m.Question[0]
|
|
answer := new(dns.Msg)
|
|
rrStr := fmt.Sprintf("%s A %s", domain, net.IPv4zero)
|
|
if q.Qtype == dns.TypeAAAA {
|
|
rrStr = fmt.Sprintf("%s AAAA %s", domain, net.IPv6zero)
|
|
}
|
|
rr, err := dns.NewRR(rrStr)
|
|
if err == nil {
|
|
answer.Answer = append(answer.Answer, rr)
|
|
}
|
|
answer.SetReply(m)
|
|
return answer
|
|
}
|
|
|
|
// FlushDNSCache flushes the DNS cache on macOS.
|
|
func FlushDNSCache() error {
|
|
// if not macOS, return
|
|
if runtime.GOOS != "darwin" {
|
|
return nil
|
|
}
|
|
|
|
// Flush the DNS cache via mDNSResponder.
|
|
// This is typically needed on modern macOS systems.
|
|
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
// Optionally, flush the directory services cache.
|
|
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// monitorNetworkChanges starts monitoring for network interface changes
|
|
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
|
mon, err := netmon.New(func(format string, args ...any) {
|
|
// Always fetch the latest logger (and inject the prefix)
|
|
p.logger.Load().Printf("netmon: "+format, args...)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("creating network monitor: %w", err)
|
|
}
|
|
|
|
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
|
// Get map of valid interfaces
|
|
validIfaces := ctrld.ValidInterfaces(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
|
|
|
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
|
|
|
p.Debug().
|
|
Interface("old_state", delta.Old).
|
|
Interface("new_state", delta.New).
|
|
Bool("is_major_change", isMajorChange).
|
|
Msg("Network change detected")
|
|
|
|
changed := false
|
|
activeInterfaceExists := false
|
|
var changeIPs []netip.Prefix
|
|
// Check each valid interface for changes
|
|
for ifaceName := range validIfaces {
|
|
oldIface, oldExists := delta.Old.Interface[ifaceName]
|
|
newIface, newExists := delta.New.Interface[ifaceName]
|
|
if !newExists {
|
|
continue
|
|
}
|
|
|
|
oldIPs := delta.Old.InterfaceIPs[ifaceName]
|
|
newIPs := delta.New.InterfaceIPs[ifaceName]
|
|
|
|
// if a valid interface did not exist in old
|
|
// check that its up and has usable IPs
|
|
if !oldExists {
|
|
// The interface is new (was not present in the old state).
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
p.Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface newly appeared (was not present in old state)")
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Filter new IPs to only those that are usable.
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
|
|
// Check if interface is up and has usable IPs.
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
activeInterfaceExists = true
|
|
}
|
|
|
|
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
|
|
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
p.Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("old_ips", oldIPs).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface state or IPs changed")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// if the default route changed, set changed to true
|
|
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
|
changed = true
|
|
p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
|
}
|
|
|
|
if !changed {
|
|
p.Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
|
// check if the default IPs are still on an interface that is up
|
|
ValidateDefaultLocalIPsFromDelta(delta.New)
|
|
return
|
|
}
|
|
|
|
if !activeInterfaceExists {
|
|
p.Debug().Msg("No active interfaces found, skipping reinitialization")
|
|
return
|
|
}
|
|
|
|
// Get IPs from default route interface in new state
|
|
selfIP := p.defaultRouteIP()
|
|
|
|
// Ensure that selfIP is an IPv4 address.
|
|
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
|
p.Debug().Msgf("DefaultRouteIP returned a non-ipv4 address: %s, ignoring it", selfIP)
|
|
selfIP = ""
|
|
}
|
|
var ipv6 string
|
|
|
|
if delta.New.DefaultRouteInterface != "" {
|
|
p.Debug().Msgf("Default route interface: %s, ips: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
|
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
p.Debug().Msgf("Checking ip: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
} else {
|
|
// If no default route interface is set yet, use the changed IPs
|
|
p.Debug().Msgf("No default route interface found, using changed ips: %v", changeIPs)
|
|
for _, ip := range changeIPs {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
p.Debug().Msgf("Checking ip: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
|
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
|
if !isMobile() && p.ciTable != nil {
|
|
p.ciTable.SetSelfIP(selfIP)
|
|
}
|
|
}
|
|
if ip := net.ParseIP(ipv6); ip != nil {
|
|
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
|
}
|
|
p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
|
|
|
p.handleRecovery(RecoveryReasonNetworkChange)
|
|
})
|
|
|
|
mon.Start()
|
|
p.Debug().Msg("Network monitor started")
|
|
return nil
|
|
}
|
|
|
|
// interfaceStatesEqual compares two interface states
|
|
func interfaceStatesEqual(a, b *netmon.Interface) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
return a.IsUp() == b.IsUp()
|
|
}
|
|
|
|
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
|
|
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
|
|
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
|
|
var usable []netip.Prefix
|
|
for _, p := range prefixes {
|
|
addr := p.Addr()
|
|
if addr.IsLinkLocalUnicast() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsMulticast() ||
|
|
addr.IsUnspecified() ||
|
|
addr.IsLinkLocalMulticast() ||
|
|
(addr.Is4() && addr.String() == "255.255.255.255") ||
|
|
tsaddr.CGNATRange().Contains(addr) {
|
|
continue
|
|
}
|
|
usable = append(usable, p)
|
|
}
|
|
return usable
|
|
}
|
|
|
|
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
|
|
func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
|
aUsable := filterUsableIPs(a)
|
|
bUsable := filterUsableIPs(b)
|
|
if len(aUsable) != len(bUsable) {
|
|
return false
|
|
}
|
|
|
|
aMap := make(map[string]bool)
|
|
for _, ip := range aUsable {
|
|
aMap[ip.String()] = true
|
|
}
|
|
for _, ip := range bUsable {
|
|
if !aMap[ip.String()] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// checkUpstreamOnce sends a test query to the specified upstream.
|
|
// Returns nil if the upstream responds successfully.
|
|
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
|
p.Debug().Msgf("Starting check for upstream: %s", upstream)
|
|
|
|
resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc)
|
|
if err != nil {
|
|
p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
|
return err
|
|
}
|
|
|
|
timeout := 1000 * time.Millisecond
|
|
if uc.Timeout > 0 {
|
|
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
|
}
|
|
p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
|
p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
|
|
|
start := time.Now()
|
|
msg := uc.VerifyMsg()
|
|
_, err = resolver.Resolve(ctx, msg)
|
|
duration := time.Since(start)
|
|
|
|
if err != nil {
|
|
p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
|
} else {
|
|
p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handleRecovery orchestrates the recovery process by coordinating multiple smaller methods.
|
|
// It handles recovery cancellation logic, creates recovery context, prepares the system,
|
|
// waits for upstream recovery with timeout, and completes the recovery process.
|
|
// The method is designed to be called from a goroutine and handles different recovery reasons
|
|
// (network changes, regular failures, OS failures) with appropriate logic for each.
|
|
func (p *prog) handleRecovery(reason RecoveryReason) {
|
|
p.Debug().Msg("Starting recovery process: removing DNS settings")
|
|
|
|
// Handle recovery cancellation based on reason
|
|
if !p.shouldStartRecovery(reason) {
|
|
return
|
|
}
|
|
|
|
// Create recovery context and cleanup function
|
|
recoveryCtx, cleanup := p.createRecoveryContext()
|
|
defer cleanup()
|
|
|
|
// Remove DNS settings and prepare for recovery
|
|
if err := p.prepareForRecovery(reason); err != nil {
|
|
p.Error().Err(err).Msg("Failed to prepare for recovery")
|
|
return
|
|
}
|
|
|
|
// Build upstream map based on the recovery reason
|
|
upstreams := p.buildRecoveryUpstreams(reason)
|
|
|
|
// Wait for upstream recovery
|
|
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
|
if err != nil {
|
|
p.Error().Err(err).Msg("Recovery failed; DNS settings remain removed")
|
|
return
|
|
}
|
|
|
|
// Complete recovery process
|
|
if err := p.completeRecovery(reason, recovered); err != nil {
|
|
p.Error().Err(err).Msg("Failed to complete recovery")
|
|
return
|
|
}
|
|
|
|
p.Info().Msgf("Recovery completed successfully for upstream %q", recovered)
|
|
}
|
|
|
|
// shouldStartRecovery determines if recovery should start based on the reason and current state.
|
|
// Returns true if recovery should proceed, false otherwise.
|
|
func (p *prog) shouldStartRecovery(reason RecoveryReason) bool {
|
|
p.recoveryCancelMu.Lock()
|
|
defer p.recoveryCancelMu.Unlock()
|
|
|
|
if reason == RecoveryReasonNetworkChange {
|
|
// For network changes, cancel any existing recovery check because the network state has changed.
|
|
if p.recoveryCancel != nil {
|
|
p.Debug().Msg("Cancelling existing recovery check (network change)")
|
|
p.recoveryCancel()
|
|
p.recoveryCancel = nil
|
|
}
|
|
return true
|
|
}
|
|
|
|
// For upstream failures, if a recovery is already in progress, do nothing new.
|
|
if p.recoveryCancel != nil {
|
|
p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// createRecoveryContext creates a new recovery context and returns it along with a cleanup function.
|
|
func (p *prog) createRecoveryContext() (context.Context, func()) {
|
|
p.recoveryCancelMu.Lock()
|
|
recoveryCtx, cancel := context.WithCancel(context.Background())
|
|
p.recoveryCancel = cancel
|
|
p.recoveryCancelMu.Unlock()
|
|
|
|
cleanup := func() {
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = nil
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
return recoveryCtx, cleanup
|
|
}
|
|
|
|
// prepareForRecovery removes DNS settings and initializes OS resolver if needed.
|
|
func (p *prog) prepareForRecovery(reason RecoveryReason) error {
|
|
// Set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
|
p.recoveryRunning.Store(true)
|
|
|
|
// Remove DNS settings - we do not want to restore any static DNS settings
|
|
// we must try to get the DHCP values, any static DNS settings
|
|
// will be appended to nameservers from the saved interface values
|
|
p.resetDNS(false, false)
|
|
|
|
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
|
if reason == RecoveryReasonOSFailure {
|
|
if err := p.reinitializeOSResolver("OS resolver failure detected"); err != nil {
|
|
return fmt.Errorf("failed to reinitialize OS resolver: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// reinitializeOSResolver reinitializes the OS resolver and logs the results.
|
|
func (p *prog) reinitializeOSResolver(message string) error {
|
|
p.Debug().Msg(message)
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
|
if len(ns) == 0 {
|
|
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// completeRecovery completes the recovery process by resetting upstream state and reapplying DNS settings.
|
|
func (p *prog) completeRecovery(reason RecoveryReason, recovered string) error {
|
|
// Reset the upstream failure count and down state
|
|
p.um.reset(recovered)
|
|
|
|
// For network changes we also reinitialize the OS resolver.
|
|
if reason == RecoveryReasonNetworkChange {
|
|
if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil {
|
|
return fmt.Errorf("failed to reinitialize OS resolver during network change: %w", err)
|
|
}
|
|
}
|
|
|
|
// Apply our DNS settings back and log the interface state.
|
|
p.setDNS()
|
|
p.logInterfacesState()
|
|
|
|
// Allow watchdogs to put the listener back on the interface if it's changed for any reason
|
|
p.recoveryRunning.Store(false)
|
|
|
|
return nil
|
|
}
|
|
|
|
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
|
|
// It returns the name of the recovered upstream or an error if the check times out.
|
|
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
|
|
recoveredCh := make(chan string, 1)
|
|
var wg sync.WaitGroup
|
|
|
|
p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
|
|
|
for name, uc := range upstreams {
|
|
wg.Add(1)
|
|
go func(name string, uc *ctrld.UpstreamConfig) {
|
|
defer wg.Done()
|
|
p.Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
|
attempts := 0
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
p.Debug().Msgf("Context canceled for upstream %s", name)
|
|
return
|
|
default:
|
|
attempts++
|
|
// checkUpstreamOnce will reset any failure counters on success.
|
|
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
|
p.Debug().Msgf("Upstream %s recovered successfully", name)
|
|
select {
|
|
case recoveredCh <- name:
|
|
p.Debug().Msgf("Sent recovery notification for upstream %s", name)
|
|
default:
|
|
p.Debug().Msg("Recovery channel full, another upstream already recovered")
|
|
}
|
|
return
|
|
}
|
|
p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
|
time.Sleep(checkUpstreamBackoffSleep)
|
|
|
|
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
|
// we should try to reinit the OS resolver to ensure we can recover
|
|
if name == upstreamOS && attempts%3 == 0 {
|
|
p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
|
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true)
|
|
if len(ns) == 0 {
|
|
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(name, uc)
|
|
}
|
|
|
|
var recovered string
|
|
select {
|
|
case recovered = <-recoveredCh:
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
wg.Wait()
|
|
return recovered, nil
|
|
}
|
|
|
|
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
|
|
// For OS failures we supply the manual OS resolver upstream configuration.
|
|
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
|
|
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
|
|
upstreams := make(map[string]*ctrld.UpstreamConfig)
|
|
switch reason {
|
|
case RecoveryReasonOSFailure:
|
|
upstreams[upstreamOS] = osUpstreamConfig
|
|
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
|
|
// Use all configured upstreams except any OS type.
|
|
for k, uc := range p.cfg.Upstream {
|
|
if uc.Type != ctrld.ResolverTypeOS {
|
|
upstreams[upstreamPrefix+k] = uc
|
|
}
|
|
}
|
|
}
|
|
return upstreams
|
|
}
|
|
|
|
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
|
|
// are still present in the new network state (provided by delta.New).
|
|
// If a stored default IP is no longer active, it resets that default (sets it to nil)
|
|
// so that it won't be used in subsequent custom dialer contexts.
|
|
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
|
currentIPv4 := ctrld.GetDefaultLocalIPv4()
|
|
currentIPv6 := ctrld.GetDefaultLocalIPv6()
|
|
|
|
// Build a map of active IP addresses from the new state.
|
|
activeIPs := make(map[string]bool)
|
|
for _, prefixes := range newState.InterfaceIPs {
|
|
for _, prefix := range prefixes {
|
|
activeIPs[prefix.Addr().String()] = true
|
|
}
|
|
}
|
|
|
|
// Check if the default IPv4 is still active.
|
|
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
|
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
|
}
|
|
|
|
// Check if the default IPv6 is still active.
|
|
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
|
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
|
}
|
|
}
|