mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Since this may create security vulnerabilities such as DNS amplification or abusing because the listener was exposed to the entire local network.
1636 lines
51 KiB
Go
1636 lines
51 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os/exec"
|
|
"runtime"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"golang.org/x/sync/errgroup"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/tsaddr"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
"github.com/Control-D-Inc/ctrld/internal/router"
|
|
)
|
|
|
|
const (
|
|
staleTTL = 60 * time.Second
|
|
localTTL = 3600 * time.Second
|
|
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
|
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
|
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
|
EDNS0_OPTION_MAC = 0xFDE9
|
|
|
|
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
|
|
selfUninstallMaxQueries = 32
|
|
)
|
|
|
|
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "OS resolver",
|
|
Type: ctrld.ResolverTypeOS,
|
|
Timeout: 3000,
|
|
}
|
|
|
|
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "Private resolver",
|
|
Type: ctrld.ResolverTypePrivate,
|
|
Timeout: 2000,
|
|
}
|
|
|
|
var localUpstreamConfig = &ctrld.UpstreamConfig{
|
|
Name: "Local resolver",
|
|
Type: ctrld.ResolverTypeLocal,
|
|
Timeout: 2000,
|
|
}
|
|
|
|
// proxyRequest contains data for proxying a DNS query to upstream.
|
|
type proxyRequest struct {
|
|
msg *dns.Msg
|
|
ci *ctrld.ClientInfo
|
|
failoverRcodes []int
|
|
ufr *upstreamForResult
|
|
}
|
|
|
|
// proxyResponse contains data for proxying a DNS response from upstream.
|
|
type proxyResponse struct {
|
|
answer *dns.Msg
|
|
cached bool
|
|
clientInfo bool
|
|
upstream string
|
|
}
|
|
|
|
// upstreamForResult represents the result of processing rules for a request.
|
|
type upstreamForResult struct {
|
|
upstreams []string
|
|
matchedPolicy string
|
|
matchedNetwork string
|
|
matchedRule string
|
|
matched bool
|
|
srcAddr string
|
|
}
|
|
|
|
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
|
// Start network monitoring
|
|
if err := p.monitorNetworkChanges(mainCtx); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
|
|
// Don't return here as we still want DNS service to run
|
|
}
|
|
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
// make sure ip is allocated
|
|
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
|
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
|
return allocErr
|
|
}
|
|
|
|
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
|
p.sema.acquire()
|
|
defer p.sema.release()
|
|
if len(m.Question) == 0 {
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(m, dns.RcodeFormatError)
|
|
_ = w.WriteMsg(answer)
|
|
return
|
|
}
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
reqId := requestID()
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
|
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(m, dns.RcodeRefused)
|
|
_ = w.WriteMsg(answer)
|
|
return
|
|
}
|
|
go p.detectLoop(m)
|
|
q := m.Question[0]
|
|
domain := canonicalName(q.Name)
|
|
switch {
|
|
case domain == "":
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(m, dns.RcodeFormatError)
|
|
_ = w.WriteMsg(answer)
|
|
return
|
|
case domain == selfCheckInternalTestDomain:
|
|
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
|
_ = w.WriteMsg(answer)
|
|
return
|
|
}
|
|
|
|
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
|
p.cache.Purge()
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
|
|
}
|
|
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
|
ci := p.getClientInfo(remoteIP, m)
|
|
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
|
stripClientSubnet(m)
|
|
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
|
fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String())
|
|
t := time.Now()
|
|
ctrld.Log(ctx, mainLog.Load().Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
|
ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
|
|
|
|
labelValues := make([]string, 0, len(statsQueriesCountLabels))
|
|
labelValues = append(labelValues, net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)))
|
|
labelValues = append(labelValues, ci.IP)
|
|
labelValues = append(labelValues, ci.Mac)
|
|
labelValues = append(labelValues, ci.Hostname)
|
|
|
|
var answer *dns.Msg
|
|
if !ur.matched && listenerConfig.Restricted {
|
|
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
|
answer = new(dns.Msg)
|
|
answer.SetRcode(m, dns.RcodeRefused)
|
|
labelValues = append(labelValues, "") // no upstream
|
|
} else {
|
|
var failoverRcode []int
|
|
if listenerConfig.Policy != nil {
|
|
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
|
|
}
|
|
pr := p.proxy(ctx, &proxyRequest{
|
|
msg: m,
|
|
ci: ci,
|
|
failoverRcodes: failoverRcode,
|
|
ufr: ur,
|
|
})
|
|
go p.doSelfUninstall(pr.answer)
|
|
|
|
answer = pr.answer
|
|
rtt := time.Since(t)
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
|
upstream := pr.upstream
|
|
switch {
|
|
case pr.cached:
|
|
upstream = "cache"
|
|
case pr.clientInfo:
|
|
upstream = "client_info_table"
|
|
}
|
|
labelValues = append(labelValues, upstream)
|
|
}
|
|
labelValues = append(labelValues, dns.TypeToString[q.Qtype])
|
|
labelValues = append(labelValues, dns.RcodeToString[answer.Rcode])
|
|
go func() {
|
|
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
|
|
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
|
|
p.forceFetchingAPI(domain)
|
|
}()
|
|
if err := w.WriteMsg(answer); err != nil {
|
|
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
|
|
}
|
|
})
|
|
|
|
g, ctx := errgroup.WithContext(context.Background())
|
|
for _, proto := range []string{"udp", "tcp"} {
|
|
proto := proto
|
|
if needLocalIPv6Listener() {
|
|
g.Go(func() error {
|
|
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-ctx.Done():
|
|
case err := <-errCh:
|
|
// Local ipv6 listener should not terminate ctrld.
|
|
// It's a workaround for a quirk on Windows.
|
|
mainLog.Load().Warn().Err(err).Msg("local ipv6 listener failed")
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918
|
|
// addresses of the machine. So ctrld could receive queries from LAN clients.
|
|
if needRFC1918Listeners(listenerConfig) {
|
|
g.Go(func() error {
|
|
for _, addr := range ctrld.Rfc1918Addresses() {
|
|
func() {
|
|
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port))
|
|
s, errCh := runDNSServer(listenAddr, proto, handler)
|
|
defer s.Shutdown()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-ctx.Done():
|
|
case err := <-errCh:
|
|
// RFC1918 listener should not terminate ctrld.
|
|
// It's a workaround for a quirk on system with systemd-resolved.
|
|
mainLog.Load().Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr)
|
|
}
|
|
}()
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
g.Go(func() error {
|
|
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
|
s, errCh := runDNSServer(addr, proto, handler)
|
|
defer s.Shutdown()
|
|
|
|
p.started <- struct{}{}
|
|
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-ctx.Done():
|
|
case err := <-errCh:
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
return g.Wait()
|
|
}
|
|
|
|
// upstreamFor returns the list of upstreams for resolving the given domain,
|
|
// matching by policies defined in the listener config. The second return value
|
|
// reports whether the domain matches the policy.
|
|
//
|
|
// Though domain policy has higher priority than network policy, it is still
|
|
// processed later, because policy logging want to know whether a network rule
|
|
// is disregarded in favor of the domain level rule.
|
|
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
|
|
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
|
matchedPolicy := "no policy"
|
|
matchedNetwork := "no network"
|
|
matchedRule := "no rule"
|
|
matched := false
|
|
res = &upstreamForResult{srcAddr: addr.String()}
|
|
|
|
defer func() {
|
|
res.upstreams = upstreams
|
|
res.matched = matched
|
|
res.matchedPolicy = matchedPolicy
|
|
res.matchedNetwork = matchedNetwork
|
|
res.matchedRule = matchedRule
|
|
}()
|
|
|
|
if lc.Policy == nil {
|
|
return
|
|
}
|
|
|
|
do := func(policyUpstreams []string) {
|
|
upstreams = append([]string(nil), policyUpstreams...)
|
|
}
|
|
|
|
var networkTargets []string
|
|
var sourceIP net.IP
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
sourceIP = addr.IP
|
|
case *net.TCPAddr:
|
|
sourceIP = addr.IP
|
|
}
|
|
|
|
networkRules:
|
|
for _, rule := range lc.Policy.Networks {
|
|
for source, targets := range rule {
|
|
networkNum := strings.TrimPrefix(source, "network.")
|
|
nc := p.cfg.Network[networkNum]
|
|
if nc == nil {
|
|
continue
|
|
}
|
|
for _, ipNet := range nc.IPNets {
|
|
if ipNet.Contains(sourceIP) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break networkRules
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
macRules:
|
|
for _, rule := range lc.Policy.Macs {
|
|
for source, targets := range rule {
|
|
if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) {
|
|
matchedPolicy = lc.Policy.Name
|
|
matchedNetwork = source
|
|
networkTargets = targets
|
|
matched = true
|
|
break macRules
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, rule := range lc.Policy.Rules {
|
|
// There's only one entry per rule, config validation ensures this.
|
|
for source, targets := range rule {
|
|
if source == domain || wildcardMatches(source, domain) {
|
|
matchedPolicy = lc.Policy.Name
|
|
if len(networkTargets) > 0 {
|
|
matchedNetwork += " (unenforced)"
|
|
}
|
|
matchedRule = source
|
|
do(targets)
|
|
matched = true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if matched {
|
|
do(networkTargets)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
cDomainName := msg.Question[0].Name
|
|
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
|
defer p.ptrLoopGuard.Unlock(cDomainName)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
ip := ipFromARPA(cDomainName)
|
|
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
answer.Answer = []dns.RR{&dns.PTR{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypePTR,
|
|
Class: dns.ClassINET,
|
|
},
|
|
Ptr: dns.Fqdn(name),
|
|
}}
|
|
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: name,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
|
q := msg.Question[0]
|
|
hostname := strings.TrimSuffix(q.Name, ".")
|
|
locked := p.lanLoopGuard.TryLock(hostname)
|
|
defer p.lanLoopGuard.Unlock(hostname)
|
|
if !locked {
|
|
return nil
|
|
}
|
|
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
|
answer := new(dns.Msg)
|
|
answer.SetReply(msg)
|
|
answer.Compress = true
|
|
switch {
|
|
case ip.Is4():
|
|
answer.Answer = []dns.RR{&dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
A: ip.AsSlice(),
|
|
}}
|
|
case ip.Is6():
|
|
answer.Answer = []dns.RR{&dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: uint32(localTTL.Seconds()),
|
|
},
|
|
AAAA: ip.AsSlice(),
|
|
}}
|
|
}
|
|
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
|
Mac: p.ciTable.LookupMac(ip.String()),
|
|
IP: ip.String(),
|
|
Hostname: hostname,
|
|
})
|
|
return answer
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
|
var staleAnswer *dns.Msg
|
|
upstreams := req.ufr.upstreams
|
|
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
|
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
|
|
|
if len(upstreamConfigs) == 0 {
|
|
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
upstreams = []string{upstreamOS}
|
|
// For OS resolver, local addresses are ignored to prevent possible looping.
|
|
// However, on Active Directory Domain Controller, where it has local DNS server
|
|
// running and listening on local addresses, these local addresses must be used
|
|
// as nameservers, so queries for ADDC could be resolved as expected.
|
|
if p.isAdDomainQuery(req.msg) {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(),
|
|
"AD domain query detected for %s in domain %s",
|
|
req.msg.Question[0].Name, p.adDomain)
|
|
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
|
upstreams = []string{upstreamOSLocal}
|
|
}
|
|
}
|
|
|
|
res := &proxyResponse{}
|
|
|
|
// LAN/PTR lookup flow:
|
|
//
|
|
// 1. If there's matching rule, follow it.
|
|
// 2. Try from client info table.
|
|
// 3. Try private resolver.
|
|
// 4. Try remote upstream.
|
|
isLanOrPtrQuery := false
|
|
if req.ufr.matched {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
|
} else {
|
|
switch {
|
|
case isSrvLanLookup(req.msg):
|
|
upstreams = []string{upstreamOS}
|
|
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
ctx = ctrld.LanQueryCtx(ctx)
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
|
case isPrivatePtrLookup(req.msg):
|
|
isLanOrPtrQuery = true
|
|
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
|
res.answer = answer
|
|
res.clientInfo = true
|
|
return res
|
|
}
|
|
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
|
ctx = ctrld.LanQueryCtx(ctx)
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
|
case isLanHostnameQuery(req.msg):
|
|
isLanOrPtrQuery = true
|
|
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
|
|
res.answer = answer
|
|
res.clientInfo = true
|
|
return res
|
|
}
|
|
upstreams = []string{upstreamOS}
|
|
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
|
ctx = ctrld.LanQueryCtx(ctx)
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
|
default:
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
|
}
|
|
}
|
|
|
|
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
|
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
|
for _, upstream := range upstreams {
|
|
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
|
if cachedValue == nil {
|
|
continue
|
|
}
|
|
answer := cachedValue.Msg.Copy()
|
|
answer.SetRcode(req.msg, answer.Rcode)
|
|
now := time.Now()
|
|
if cachedValue.Expire.After(now) {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
|
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
|
res.answer = answer
|
|
res.cached = true
|
|
return res
|
|
}
|
|
staleAnswer = answer
|
|
}
|
|
}
|
|
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
|
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
|
if err != nil {
|
|
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
|
return nil, err
|
|
}
|
|
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
|
defer cancel()
|
|
return dnsResolver.Resolve(resolveCtx, msg)
|
|
}
|
|
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
|
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
|
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
|
}
|
|
answer, err := resolve1(upstream, upstreamConfig, msg)
|
|
// if we have an answer, we should reset the failure count
|
|
// we dont use reset here since we dont want to prevent failure counts from being incremented
|
|
if answer != nil {
|
|
p.um.mu.Lock()
|
|
p.um.failureReq[upstream] = 0
|
|
p.um.down[upstream] = false
|
|
p.um.mu.Unlock()
|
|
return answer
|
|
}
|
|
|
|
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
|
|
|
// increase failure count when there is no answer
|
|
// rehardless of what kind of error we get
|
|
p.um.increaseFailureCount(upstream)
|
|
|
|
if err != nil {
|
|
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
|
var e net.Error
|
|
if errors.As(err, &e) && e.Timeout() {
|
|
upstreamConfig.ReBootstrap()
|
|
}
|
|
// For network error, turn ipv6 off if enabled.
|
|
if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) {
|
|
ctrld.DisableIPv6()
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
for n, upstreamConfig := range upstreamConfigs {
|
|
if upstreamConfig == nil {
|
|
continue
|
|
}
|
|
logger := mainLog.Load().Debug().
|
|
Str("upstream", upstreamConfig.String()).
|
|
Str("query", req.msg.Question[0].Name).
|
|
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
|
|
Bool("is_lan_query", isLanOrPtrQuery)
|
|
|
|
if p.isLoop(upstreamConfig) {
|
|
ctrld.Log(ctx, logger, "DNS loop detected")
|
|
continue
|
|
}
|
|
answer := resolve(upstreams[n], upstreamConfig, req.msg)
|
|
if answer == nil {
|
|
if serveStaleCache && staleAnswer != nil {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
|
now := time.Now()
|
|
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
|
res.answer = staleAnswer
|
|
res.cached = true
|
|
return res
|
|
}
|
|
continue
|
|
}
|
|
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
|
// Except for the last, we want to send response instead of saying all upstream failed.
|
|
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
|
|
continue
|
|
}
|
|
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
|
|
continue
|
|
}
|
|
|
|
// set compression, as it is not set by default when unpacking
|
|
answer.Compress = true
|
|
|
|
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
|
ttl := ttlFromMsg(answer)
|
|
now := time.Now()
|
|
expired := now.Add(time.Duration(ttl) * time.Second)
|
|
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
|
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
|
}
|
|
setCachedAnswerTTL(answer, now, expired)
|
|
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
|
|
}
|
|
hostname := ""
|
|
if req.ci != nil {
|
|
hostname = req.ci.Hostname
|
|
}
|
|
ctrld.Log(ctx, mainLog.Load().Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
|
res.answer = answer
|
|
res.upstream = upstreamConfig.Endpoint
|
|
return res
|
|
}
|
|
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
|
|
|
// if we have no healthy upstreams, trigger recovery flow
|
|
if p.leakOnUpstreamFailure() {
|
|
if p.um.countHealthy(upstreams) == 0 {
|
|
p.recoveryCancelMu.Lock()
|
|
if p.recoveryCancel == nil {
|
|
var reason RecoveryReason
|
|
if upstreams[0] == upstreamOS {
|
|
reason = RecoveryReasonOSFailure
|
|
} else {
|
|
reason = RecoveryReasonRegularFailure
|
|
}
|
|
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
|
go p.handleRecovery(reason)
|
|
} else {
|
|
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
|
}
|
|
p.recoveryCancelMu.Unlock()
|
|
} else {
|
|
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
|
}
|
|
|
|
// attempt query to OS resolver while as a retry catch all
|
|
// we dont want this to happen if leakOnUpstreamFailure is false
|
|
if upstreams[0] != upstreamOS {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
|
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
|
if answer != nil {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
|
res.answer = answer
|
|
res.upstream = osUpstreamConfig.Endpoint
|
|
return res
|
|
}
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
|
}
|
|
}
|
|
|
|
answer := new(dns.Msg)
|
|
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
|
res.answer = answer
|
|
return res
|
|
}
|
|
|
|
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
|
if len(p.localUpstreams) > 0 {
|
|
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
|
tmp = append(tmp, p.localUpstreams...)
|
|
tmp = append(tmp, upstreams...)
|
|
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
|
}
|
|
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
|
}
|
|
|
|
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
|
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
|
for _, upstream := range upstreams {
|
|
upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix)
|
|
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
|
}
|
|
return upstreamConfigs
|
|
}
|
|
|
|
func (p *prog) isAdDomainQuery(msg *dns.Msg) bool {
|
|
if p.adDomain == "" {
|
|
return false
|
|
}
|
|
cDomainName := canonicalName(msg.Question[0].Name)
|
|
return dns.IsSubDomain(p.adDomain, cDomainName)
|
|
}
|
|
|
|
// canonicalName returns canonical name from FQDN with "." trimmed.
|
|
func canonicalName(fqdn string) string {
|
|
q := strings.TrimSpace(fqdn)
|
|
q = strings.TrimSuffix(q, ".")
|
|
// https://datatracker.ietf.org/doc/html/rfc4343
|
|
q = strings.ToLower(q)
|
|
|
|
return q
|
|
}
|
|
|
|
// wildcardMatches reports whether string str matches the wildcard pattern in case-insensitive manner.
|
|
func wildcardMatches(wildcard, str string) bool {
|
|
// Wildcard match.
|
|
wildCardParts := strings.Split(strings.ToLower(wildcard), "*")
|
|
if len(wildCardParts) != 2 {
|
|
return false
|
|
}
|
|
|
|
str = strings.ToLower(str)
|
|
switch {
|
|
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
|
|
// Domain must match both prefix and suffix.
|
|
return strings.HasPrefix(str, wildCardParts[0]) && strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[1]) > 0:
|
|
// Only suffix must match.
|
|
return strings.HasSuffix(str, wildCardParts[1])
|
|
|
|
case len(wildCardParts[0]) > 0:
|
|
// Only prefix must match.
|
|
return strings.HasPrefix(str, wildCardParts[0])
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func fmtRemoteToLocal(listenerNum, hostname, remote string) string {
|
|
return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum)
|
|
}
|
|
|
|
func requestID() string {
|
|
b := make([]byte, 3) // 6 chars
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(err)
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func containRcode(rcodes []int, rcode int) bool {
|
|
for i := range rcodes {
|
|
if rcodes[i] == rcode {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
|
ttlSecs := expiredTime.Sub(now).Seconds()
|
|
if ttlSecs < 0 {
|
|
return
|
|
}
|
|
|
|
ttl := uint32(ttlSecs)
|
|
for _, rr := range answer.Answer {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Ns {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
for _, rr := range answer.Extra {
|
|
if rr.Header().Rrtype != dns.TypeOPT {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
}
|
|
}
|
|
|
|
func ttlFromMsg(msg *dns.Msg) uint32 {
|
|
for _, rr := range msg.Answer {
|
|
return rr.Header().Ttl
|
|
}
|
|
for _, rr := range msg.Ns {
|
|
return rr.Header().Ttl
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func needLocalIPv6Listener() bool {
|
|
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
|
// listen on ::1, then spawn a listener for receiving DNS requests.
|
|
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
|
}
|
|
|
|
// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any.
|
|
func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
|
ip, mac := "", ""
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
switch e := s.(type) {
|
|
case *dns.EDNS0_LOCAL:
|
|
if e.Code == EDNS0_OPTION_MAC {
|
|
mac = net.HardwareAddr(e.Data).String()
|
|
}
|
|
case *dns.EDNS0_SUBNET:
|
|
if len(e.Address) > 0 && !e.Address.IsLoopback() {
|
|
ip = e.Address.String()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ip, mac
|
|
}
|
|
|
|
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
|
|
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
|
|
func stripClientSubnet(msg *dns.Msg) {
|
|
if opt := msg.IsEdns0(); opt != nil {
|
|
opts := make([]dns.EDNS0, 0, len(opt.Option))
|
|
for _, s := range opt.Option {
|
|
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
|
|
continue
|
|
}
|
|
opts = append(opts, s)
|
|
}
|
|
if len(opts) != len(opt.Option) {
|
|
opt.Option = opts
|
|
}
|
|
}
|
|
}
|
|
|
|
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
|
if ci != nil && ci.IP != "" {
|
|
switch addr := addr.(type) {
|
|
case *net.UDPAddr:
|
|
udpAddr := &net.UDPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
case *net.TCPAddr:
|
|
udpAddr := &net.TCPAddr{
|
|
IP: net.ParseIP(ci.IP),
|
|
Port: addr.Port,
|
|
Zone: addr.Zone,
|
|
}
|
|
return udpAddr
|
|
}
|
|
}
|
|
return addr
|
|
}
|
|
|
|
// runDNSServer starts a DNS server for given address and network,
|
|
// with the given handler. It ensures the server has started listening.
|
|
// Any error will be reported to the caller via returned channel.
|
|
//
|
|
// It's the caller responsibility to call Shutdown to close the server.
|
|
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
|
s := &dns.Server{
|
|
Addr: addr,
|
|
Net: network,
|
|
Handler: handler,
|
|
}
|
|
|
|
startedCh := make(chan struct{})
|
|
s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() }
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
defer close(errCh)
|
|
if err := s.ListenAndServe(); err != nil {
|
|
s.NotifyStartedFunc()
|
|
mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
|
errCh <- err
|
|
}
|
|
}()
|
|
<-startedCh
|
|
return s, errCh
|
|
}
|
|
|
|
func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
|
ci := &ctrld.ClientInfo{}
|
|
if p.appCallback != nil {
|
|
ci.IP = p.appCallback.LanIp()
|
|
ci.Mac = p.appCallback.MacAddress()
|
|
ci.Hostname = p.appCallback.HostName()
|
|
ci.Self = true
|
|
return ci
|
|
}
|
|
ci.IP, ci.Mac = ipAndMacFromMsg(msg)
|
|
switch {
|
|
case ci.IP != "" && ci.Mac != "":
|
|
// Nothing to do.
|
|
case ci.IP == "" && ci.Mac != "":
|
|
// Have MAC, no IP.
|
|
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
|
case ci.IP == "" && ci.Mac == "":
|
|
// Have nothing, use remote IP then lookup MAC.
|
|
ci.IP = remoteIP
|
|
fallthrough
|
|
case ci.IP != "" && ci.Mac == "":
|
|
// Have IP, no MAC.
|
|
ci.Mac = p.ciTable.LookupMac(ci.IP)
|
|
}
|
|
|
|
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
|
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
|
|
if ci.Mac == "" {
|
|
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
|
ci.Hostname = hostname
|
|
} else {
|
|
// Only use IP as hostname for IPv4 clients.
|
|
// For Android devices, when it joins the network, it uses ctrld to resolve
|
|
// its private DNS once and never reaches ctrld again. For each time, it uses
|
|
// a different IPv6 address, which causes hundreds/thousands different client
|
|
// IDs created for the same device, which is pointless.
|
|
//
|
|
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
|
|
if !ctrldnet.IsIPv6(ci.IP) {
|
|
ci.Hostname = ci.IP
|
|
p.ciTable.StoreVPNClient(ci)
|
|
}
|
|
}
|
|
} else {
|
|
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
|
}
|
|
ci.Self = p.queryFromSelf(ci.IP)
|
|
// If this is a query from self, but ci.IP is not loopback IP,
|
|
// try using hostname mapping for lookback IP if presents.
|
|
if ci.Self {
|
|
if name := p.ciTable.LocalHostname(); name != "" {
|
|
ci.Hostname = name
|
|
}
|
|
}
|
|
p.spoofLoopbackIpInClientInfo(ci)
|
|
return ci
|
|
}
|
|
|
|
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
|
|
//
|
|
// - Preference IPv4.
|
|
// - Preference RFC1918.
|
|
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
|
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
|
|
return
|
|
}
|
|
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
|
|
ci.IP = ip
|
|
}
|
|
}
|
|
|
|
// doSelfUninstall performs self-uninstall if these condition met:
|
|
//
|
|
// - There is only 1 ControlD upstream in-use.
|
|
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
|
// - The cdUID is deleted.
|
|
func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
|
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
|
return
|
|
}
|
|
|
|
p.selfUninstallMu.Lock()
|
|
defer p.selfUninstallMu.Unlock()
|
|
if p.checkingSelfUninstall {
|
|
return
|
|
}
|
|
|
|
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
|
|
if p.refusedQueryCount > selfUninstallMaxQueries {
|
|
p.checkingSelfUninstall = true
|
|
_, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
|
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
|
|
selfUninstallCheck(err, p, logger)
|
|
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
|
}
|
|
// Cool-of period to prevent abusing the API.
|
|
go p.selfUninstallCoolOfPeriod()
|
|
return
|
|
}
|
|
p.refusedQueryCount++
|
|
}
|
|
|
|
// selfUninstallCoolOfPeriod waits for 30 minutes before
|
|
// calling API again for checking ControlD device status.
|
|
func (p *prog) selfUninstallCoolOfPeriod() {
|
|
t := time.NewTimer(time.Minute * 30)
|
|
defer t.Stop()
|
|
<-t.C
|
|
p.selfUninstallMu.Lock()
|
|
p.checkingSelfUninstall = false
|
|
p.refusedQueryCount = 0
|
|
p.selfUninstallMu.Unlock()
|
|
}
|
|
|
|
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
|
// and the domain == "cdUID.verify.controld.com"
|
|
func (p *prog) forceFetchingAPI(domain string) {
|
|
if cdUID == "" {
|
|
return
|
|
}
|
|
resolverID, parent, _ := strings.Cut(domain, ".")
|
|
if resolverID != cdUID {
|
|
return
|
|
}
|
|
switch {
|
|
case cdDev && parent == "verify.controld.dev":
|
|
// match ControlD dev
|
|
case parent == "verify.controld.com":
|
|
// match ControlD
|
|
default:
|
|
return
|
|
}
|
|
_ = p.apiForceReloadGroup.DoChan("force_sync_api", func() (interface{}, error) {
|
|
p.apiForceReloadCh <- struct{}{}
|
|
// Wait here to prevent abusing API if we are flooded.
|
|
time.Sleep(timeDurationOrDefault(p.cfg.Service.ForceRefetchWaitTime, 30) * time.Second)
|
|
return nil, nil
|
|
})
|
|
}
|
|
|
|
// timeDurationOrDefault returns time duration value from n if not nil.
|
|
// Otherwise, it returns time duration value defaultN.
|
|
func timeDurationOrDefault(n *int, defaultN int) time.Duration {
|
|
if n != nil && *n > 0 {
|
|
return time.Duration(*n)
|
|
}
|
|
return time.Duration(defaultN)
|
|
}
|
|
|
|
// queryFromSelf reports whether the input IP is from device running ctrld.
|
|
func (p *prog) queryFromSelf(ip string) bool {
|
|
if val, ok := p.queryFromSelfMap.Load(ip); ok {
|
|
return val.(bool)
|
|
}
|
|
netIP := netip.MustParseAddr(ip)
|
|
regularIPs, loopbackIPs, err := netmon.LocalAddresses()
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not get local addresses")
|
|
return false
|
|
}
|
|
for _, localIP := range slices.Concat(regularIPs, loopbackIPs) {
|
|
if localIP.Compare(netIP) == 0 {
|
|
p.queryFromSelfMap.Store(ip, true)
|
|
return true
|
|
}
|
|
}
|
|
p.queryFromSelfMap.Store(ip, false)
|
|
return false
|
|
}
|
|
|
|
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
|
|
// This is helpful for non-desktop platforms to receive queries from LAN clients.
|
|
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
|
return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform()
|
|
}
|
|
|
|
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
|
func ipFromARPA(arpa string) net.IP {
|
|
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
|
|
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
|
|
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
|
|
}
|
|
}
|
|
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
|
|
l := net.IPv6len * 2
|
|
base := 16
|
|
ip := make(net.IP, net.IPv6len)
|
|
for i := 0; i < l && arpa != ""; i++ {
|
|
idx := strings.LastIndexByte(arpa, '.')
|
|
off := idx + 1
|
|
if idx == -1 {
|
|
idx = 0
|
|
off = 0
|
|
} else if idx == len(arpa)-1 {
|
|
return nil
|
|
}
|
|
n, err := strconv.ParseUint(arpa[off:], base, 8)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
b := byte(n)
|
|
ii := i / 2
|
|
if i&1 == 1 {
|
|
b |= ip[ii] << 4
|
|
}
|
|
ip[ii] = b
|
|
arpa = arpa[:idx]
|
|
}
|
|
return ip
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
|
func isPrivatePtrLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
if ip := ipFromARPA(q.Name); ip != nil {
|
|
if addr, ok := netip.AddrFromSlice(ip); ok {
|
|
return addr.IsPrivate() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsLinkLocalUnicast() ||
|
|
tsaddr.CGNATRange().Contains(addr)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
|
|
func isLanHostnameQuery(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
switch q.Qtype {
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
default:
|
|
return false
|
|
}
|
|
return isLanHostname(q.Name)
|
|
}
|
|
|
|
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
|
|
func isSrvLanLookup(m *dns.Msg) bool {
|
|
if m == nil || len(m.Question) == 0 {
|
|
return false
|
|
}
|
|
q := m.Question[0]
|
|
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
|
|
}
|
|
|
|
// isLanHostname reports whether name is a LAN hostname.
|
|
func isLanHostname(name string) bool {
|
|
name = strings.TrimSuffix(name, ".")
|
|
return !strings.Contains(name, ".") ||
|
|
strings.HasSuffix(name, ".domain") ||
|
|
strings.HasSuffix(name, ".lan") ||
|
|
strings.HasSuffix(name, ".local")
|
|
}
|
|
|
|
// isWanClient reports whether the input is a WAN address.
|
|
func isWanClient(na net.Addr) bool {
|
|
var ip netip.Addr
|
|
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
|
|
ip = ap.Addr()
|
|
}
|
|
return !ip.IsLoopback() &&
|
|
!ip.IsPrivate() &&
|
|
!ip.IsLinkLocalUnicast() &&
|
|
!ip.IsLinkLocalMulticast() &&
|
|
!tsaddr.CGNATRange().Contains(ip)
|
|
}
|
|
|
|
// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller.
|
|
func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg {
|
|
ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query")
|
|
|
|
q := m.Question[0]
|
|
answer := new(dns.Msg)
|
|
rrStr := fmt.Sprintf("%s A %s", domain, net.IPv4zero)
|
|
if q.Qtype == dns.TypeAAAA {
|
|
rrStr = fmt.Sprintf("%s AAAA %s", domain, net.IPv6zero)
|
|
}
|
|
rr, err := dns.NewRR(rrStr)
|
|
if err == nil {
|
|
answer.Answer = append(answer.Answer, rr)
|
|
}
|
|
answer.SetReply(m)
|
|
return answer
|
|
}
|
|
|
|
// FlushDNSCache flushes the DNS cache on macOS.
|
|
func FlushDNSCache() error {
|
|
// if not macOS, return
|
|
if runtime.GOOS != "darwin" {
|
|
return nil
|
|
}
|
|
|
|
// Flush the DNS cache via mDNSResponder.
|
|
// This is typically needed on modern macOS systems.
|
|
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
// Optionally, flush the directory services cache.
|
|
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
|
|
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// monitorNetworkChanges starts monitoring for network interface changes
|
|
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
|
mon, err := netmon.New(func(format string, args ...any) {
|
|
// Always fetch the latest logger (and inject the prefix)
|
|
mainLog.Load().Printf("netmon: "+format, args...)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("creating network monitor: %w", err)
|
|
}
|
|
|
|
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
|
// Get map of valid interfaces
|
|
validIfaces := validInterfacesMap()
|
|
|
|
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
|
|
|
mainLog.Load().Debug().
|
|
Interface("old_state", delta.Old).
|
|
Interface("new_state", delta.New).
|
|
Bool("is_major_change", isMajorChange).
|
|
Msg("Network change detected")
|
|
|
|
changed := false
|
|
activeInterfaceExists := false
|
|
var changeIPs []netip.Prefix
|
|
// Check each valid interface for changes
|
|
for ifaceName := range validIfaces {
|
|
oldIface, oldExists := delta.Old.Interface[ifaceName]
|
|
newIface, newExists := delta.New.Interface[ifaceName]
|
|
if !newExists {
|
|
continue
|
|
}
|
|
|
|
oldIPs := delta.Old.InterfaceIPs[ifaceName]
|
|
newIPs := delta.New.InterfaceIPs[ifaceName]
|
|
|
|
// if a valid interface did not exist in old
|
|
// check that its up and has usable IPs
|
|
if !oldExists {
|
|
// The interface is new (was not present in the old state).
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
mainLog.Load().Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface newly appeared (was not present in old state)")
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Filter new IPs to only those that are usable.
|
|
usableNewIPs := filterUsableIPs(newIPs)
|
|
|
|
// Check if interface is up and has usable IPs.
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
activeInterfaceExists = true
|
|
}
|
|
|
|
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
|
|
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
|
|
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
|
changed = true
|
|
changeIPs = usableNewIPs
|
|
mainLog.Load().Debug().
|
|
Str("interface", ifaceName).
|
|
Interface("old_ips", oldIPs).
|
|
Interface("new_ips", usableNewIPs).
|
|
Msg("Interface state or IPs changed")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// if the default route changed, set changed to true
|
|
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
|
changed = true
|
|
mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
|
}
|
|
|
|
if !changed {
|
|
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
|
// check if the default IPs are still on an interface that is up
|
|
ValidateDefaultLocalIPsFromDelta(delta.New)
|
|
return
|
|
}
|
|
|
|
if !activeInterfaceExists {
|
|
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
|
|
return
|
|
}
|
|
|
|
// Get IPs from default route interface in new state
|
|
selfIP := defaultRouteIP()
|
|
|
|
// Ensure that selfIP is an IPv4 address.
|
|
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
|
mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
|
|
selfIP = ""
|
|
}
|
|
var ipv6 string
|
|
|
|
if delta.New.DefaultRouteInterface != "" {
|
|
mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
|
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
} else {
|
|
// If no default route interface is set yet, use the changed IPs
|
|
mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
|
|
for _, ip := range changeIPs {
|
|
ipAddr, _ := netip.ParsePrefix(ip.String())
|
|
addr := ipAddr.Addr()
|
|
if selfIP == "" && addr.Is4() {
|
|
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
|
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
selfIP = addr.String()
|
|
}
|
|
}
|
|
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
|
ipv6 = addr.String()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
|
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
|
ctrld.SetDefaultLocalIPv4(ip)
|
|
if !isMobile() && p.ciTable != nil {
|
|
p.ciTable.SetSelfIP(selfIP)
|
|
}
|
|
}
|
|
if ip := net.ParseIP(ipv6); ip != nil {
|
|
ctrld.SetDefaultLocalIPv6(ip)
|
|
}
|
|
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
|
|
|
// we only trigger recovery flow for network changes on non router devices
|
|
if router.Name() == "" {
|
|
p.handleRecovery(RecoveryReasonNetworkChange)
|
|
}
|
|
})
|
|
|
|
mon.Start()
|
|
mainLog.Load().Debug().Msg("Network monitor started")
|
|
return nil
|
|
}
|
|
|
|
// interfaceStatesEqual compares two interface states
|
|
func interfaceStatesEqual(a, b *netmon.Interface) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
return a.IsUp() == b.IsUp()
|
|
}
|
|
|
|
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
|
|
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
|
|
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
|
|
var usable []netip.Prefix
|
|
for _, p := range prefixes {
|
|
addr := p.Addr()
|
|
if addr.IsLinkLocalUnicast() ||
|
|
addr.IsLoopback() ||
|
|
addr.IsMulticast() ||
|
|
addr.IsUnspecified() ||
|
|
addr.IsLinkLocalMulticast() ||
|
|
(addr.Is4() && addr.String() == "255.255.255.255") ||
|
|
tsaddr.CGNATRange().Contains(addr) {
|
|
continue
|
|
}
|
|
usable = append(usable, p)
|
|
}
|
|
return usable
|
|
}
|
|
|
|
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
|
|
func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
|
aUsable := filterUsableIPs(a)
|
|
bUsable := filterUsableIPs(b)
|
|
if len(aUsable) != len(bUsable) {
|
|
return false
|
|
}
|
|
|
|
aMap := make(map[string]bool)
|
|
for _, ip := range aUsable {
|
|
aMap[ip.String()] = true
|
|
}
|
|
for _, ip := range bUsable {
|
|
if !aMap[ip.String()] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// checkUpstreamOnce sends a test query to the specified upstream.
|
|
// Returns nil if the upstream responds successfully.
|
|
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
|
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
|
|
|
resolver, err := ctrld.NewResolver(uc)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
|
return err
|
|
}
|
|
|
|
msg := new(dns.Msg)
|
|
msg.SetQuestion(".", dns.TypeNS)
|
|
|
|
timeout := 1000 * time.Millisecond
|
|
if uc.Timeout > 0 {
|
|
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
|
}
|
|
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
uc.ReBootstrap()
|
|
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
|
|
|
start := time.Now()
|
|
_, err = resolver.Resolve(ctx, msg)
|
|
duration := time.Since(start)
|
|
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
|
} else {
|
|
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handleRecovery performs a unified recovery by removing DNS settings,
|
|
// canceling existing recovery checks for network changes, but coalescing duplicate
|
|
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
|
|
// and then re-applying the DNS settings.
|
|
func (p *prog) handleRecovery(reason RecoveryReason) {
|
|
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
|
|
|
|
// For network changes, cancel any existing recovery check because the network state has changed.
|
|
if reason == RecoveryReasonNetworkChange {
|
|
p.recoveryCancelMu.Lock()
|
|
if p.recoveryCancel != nil {
|
|
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
|
|
p.recoveryCancel()
|
|
p.recoveryCancel = nil
|
|
}
|
|
p.recoveryCancelMu.Unlock()
|
|
} else {
|
|
// For upstream failures, if a recovery is already in progress, do nothing new.
|
|
p.recoveryCancelMu.Lock()
|
|
if p.recoveryCancel != nil {
|
|
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
|
p.recoveryCancelMu.Unlock()
|
|
return
|
|
}
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
// Create a new recovery context without a fixed timeout.
|
|
p.recoveryCancelMu.Lock()
|
|
recoveryCtx, cancel := context.WithCancel(context.Background())
|
|
p.recoveryCancel = cancel
|
|
p.recoveryCancelMu.Unlock()
|
|
|
|
// Immediately remove our DNS settings from the interface.
|
|
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
|
p.recoveryRunning.Store(true)
|
|
// we do not want to restore any static DNS settings
|
|
// we must try to get the DHCP values, any static DNS settings
|
|
// will be appended to nameservers from the saved interface values
|
|
p.resetDNS(false, false)
|
|
|
|
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
|
if reason == RecoveryReasonOSFailure {
|
|
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
|
ns := ctrld.InitializeOsResolver(true)
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
}
|
|
|
|
// Build upstream map based on the recovery reason.
|
|
upstreams := p.buildRecoveryUpstreams(reason)
|
|
|
|
// Wait indefinitely until one of the upstreams recovers.
|
|
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = nil
|
|
p.recoveryCancelMu.Unlock()
|
|
return
|
|
}
|
|
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
|
|
|
|
// reset the upstream failure count and down state
|
|
p.um.reset(recovered)
|
|
|
|
// For network changes we also reinitialize the OS resolver.
|
|
if reason == RecoveryReasonNetworkChange {
|
|
ns := ctrld.InitializeOsResolver(true)
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
|
} else {
|
|
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
}
|
|
|
|
// Apply our DNS settings back and log the interface state.
|
|
p.setDNS()
|
|
p.logInterfacesState()
|
|
|
|
// allow watchdogs to put the listener back on the interface if its changed for any reason
|
|
p.recoveryRunning.Store(false)
|
|
|
|
// Clear the recovery cancellation for a clean slate.
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = nil
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
|
|
// It returns the name of the recovered upstream or an error if the check times out.
|
|
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
|
|
recoveredCh := make(chan string, 1)
|
|
var wg sync.WaitGroup
|
|
|
|
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
|
|
|
for name, uc := range upstreams {
|
|
wg.Add(1)
|
|
go func(name string, uc *ctrld.UpstreamConfig) {
|
|
defer wg.Done()
|
|
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
|
attempts := 0
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
|
|
return
|
|
default:
|
|
attempts++
|
|
// checkUpstreamOnce will reset any failure counters on success.
|
|
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
|
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
|
|
select {
|
|
case recoveredCh <- name:
|
|
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
|
|
default:
|
|
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
|
|
}
|
|
return
|
|
}
|
|
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
|
time.Sleep(checkUpstreamBackoffSleep)
|
|
|
|
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
|
// we should try to reinit the OS resolver to ensure we can recover
|
|
if name == upstreamOS && attempts%3 == 0 {
|
|
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
|
ns := ctrld.InitializeOsResolver(true)
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
|
} else {
|
|
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(name, uc)
|
|
}
|
|
|
|
var recovered string
|
|
select {
|
|
case recovered = <-recoveredCh:
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
wg.Wait()
|
|
return recovered, nil
|
|
}
|
|
|
|
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
|
|
// For OS failures we supply the manual OS resolver upstream configuration.
|
|
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
|
|
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
|
|
upstreams := make(map[string]*ctrld.UpstreamConfig)
|
|
switch reason {
|
|
case RecoveryReasonOSFailure:
|
|
upstreams[upstreamOS] = osUpstreamConfig
|
|
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
|
|
// Use all configured upstreams except any OS type.
|
|
for k, uc := range p.cfg.Upstream {
|
|
if uc.Type != ctrld.ResolverTypeOS {
|
|
upstreams[upstreamPrefix+k] = uc
|
|
}
|
|
}
|
|
}
|
|
return upstreams
|
|
}
|
|
|
|
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
|
|
// are still present in the new network state (provided by delta.New).
|
|
// If a stored default IP is no longer active, it resets that default (sets it to nil)
|
|
// so that it won't be used in subsequent custom dialer contexts.
|
|
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
|
currentIPv4 := ctrld.GetDefaultLocalIPv4()
|
|
currentIPv6 := ctrld.GetDefaultLocalIPv6()
|
|
|
|
// Build a map of active IP addresses from the new state.
|
|
activeIPs := make(map[string]bool)
|
|
for _, prefixes := range newState.InterfaceIPs {
|
|
for _, prefix := range prefixes {
|
|
activeIPs[prefix.Addr().String()] = true
|
|
}
|
|
}
|
|
|
|
// Check if the default IPv4 is still active.
|
|
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
|
ctrld.SetDefaultLocalIPv4(nil)
|
|
}
|
|
|
|
// Check if the default IPv6 is still active.
|
|
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
|
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
|
ctrld.SetDefaultLocalIPv6(nil)
|
|
}
|
|
}
|