Files
ctrld/cmd/cli/prog.go
Cuong Manh Le c6365e6b74 cmd/cli: handle stop signal from service manager
So using "ctrld stop" or service manager to stop ctrld will end up with
the same result, stopped ctrld with a working DNS, and deactivation pin
code will always have effects if set.
2025-03-26 23:18:36 +07:00

1535 lines
44 KiB
Go

package cli
import (
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"math/rand"
"net"
"net/netip"
"net/url"
"os"
"os/exec"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/Masterminds/semver/v3"
"github.com/kardianos/service"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"golang.org/x/sync/singleflight"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/internal/router"
)
const (
defaultSemaphoreCap = 256
ctrldLogUnixSock = "ctrld_start.sock"
ctrldControlUnixSock = "ctrld_control.sock"
// iOS unix socket name max length is 11.
ctrldControlUnixSockMobile = "cd.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamOSLocal = upstreamOS + ".local"
dnsWatchdogDefaultInterval = 20 * time.Second
ctrldServiceName = "ctrld"
)
// RecoveryReason provides context for why we are waiting for recovery.
// recovery involves removing the listener IP from the interface and
// waiting for the upstreams to work before returning
type RecoveryReason int
const (
RecoveryReasonNetworkChange RecoveryReason = iota
RecoveryReasonRegularFailure
RecoveryReasonOSFailure
)
// ControlSocketName returns name for control unix socket.
func ControlSocketName() string {
if isMobile() {
return ctrldControlUnixSockMobile
} else {
return ctrldControlUnixSock
}
}
var logf = func(format string, args ...any) {
mainLog.Load().Debug().Msgf(format, args...)
}
var svcConfig = &service.Config{
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
pinCodeValidCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
apiReloadCh chan *ctrld.Config
apiForceReloadCh chan struct{}
apiForceReloadGroup singleflight.Group
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
dnsWg sync.WaitGroup
dnsWatcherClosedOnce sync.Once
dnsWatcherStopCh chan struct{}
rc *controld.ResolverConfig
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
cacheFlushDomainsMap map[string]struct{}
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
metricsQueryStats atomic.Bool
queryFromSelfMap sync.Map
initInternalLogWriterOnce sync.Once
internalLogWriter *logWriter
internalWarnLogWriter *logWriter
internalLogSent time.Time
runningIface string
requiredMultiNICsConfig bool
adDomain string
runningOnDomainController bool
selfUninstallMu sync.Mutex
refusedQueryCount int
canSelfUninstall atomic.Bool
checkingSelfUninstall bool
loopMu sync.Mutex
loop map[string]bool
recoveryCancelMu sync.Mutex
recoveryCancel context.CancelFunc
recoveryRunning atomic.Bool
started chan struct{}
onStartedDone chan struct{}
onStarted []func()
onStopped []func()
}
func (p *prog) Start(s service.Service) error {
go p.runWait()
return nil
}
// runWait runs ctrld components, waiting for signal to reload.
func (p *prog) runWait() {
p.mu.Lock()
p.cfg = &cfg
p.mu.Unlock()
reloadSigCh := make(chan os.Signal, 1)
notifyReloadSigCh(reloadSigCh)
reload := false
logger := mainLog.Load()
for {
reloadCh := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
p.run(reload, reloadCh)
reload = true
}()
var newCfg *ctrld.Config
select {
case sig := <-reloadSigCh:
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
case <-p.reloadCh:
logger.Notice().Msg("reloading...")
case apiCfg := <-p.apiReloadCh:
newCfg = apiCfg
case <-p.stopCh:
close(reloadCh)
return
}
waitOldRunDone := func() {
close(reloadCh)
<-done
}
if newCfg == nil {
newCfg = &ctrld.Config{}
confFile := v.ConfigFileUsed()
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
confFile = configPath
}
v.SetConfigFile(confFile)
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if rc, err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
} else {
p.mu.Lock()
p.rc = rc
p.mu.Unlock()
}
}
}
waitOldRunDone()
p.mu.Lock()
curListener := p.cfg.Listener
p.mu.Unlock()
for n, lc := range newCfg.Listener {
curLc := curListener[n]
if curLc == nil {
continue
}
if lc.IP == "" {
lc.IP = curLc.IP
}
if lc.Port == 0 {
lc.Port = curLc.Port
}
}
if err := validateConfig(newCfg); err != nil {
logger.Err(err).Msg("invalid config")
continue
}
addExtraSplitDnsRule(newCfg)
if err := writeConfigFile(newCfg); err != nil {
logger.Err(err).Msg("could not write new config")
}
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
p.setupUpstream(newCfg)
p.mu.Lock()
*p.cfg = *newCfg
p.mu.Unlock()
logger.Notice().Msg("reloading config successfully")
select {
case p.reloadDoneCh <- struct{}{}:
default:
}
}
}
func (p *prog) preRun() {
if iface == "auto" {
iface = defaultIfaceName()
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
}
p.runningIface = iface
}
func (p *prog) postRun() {
if !service.Interactive() {
if runtime.GOOS == "windows" {
isDC, roleInt := isRunningOnDomainController()
p.runningOnDomainController = isDC
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
}
p.resetDNS(false, false)
ns := ctrld.InitializeOsResolver(false)
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
p.setDNS()
p.csSetDnsDone <- struct{}{}
close(p.csSetDnsDone)
p.logInterfacesState()
}
}
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
func (p *prog) apiConfigReload() {
if cdUID == "" {
return
}
ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second)
defer ticker.Stop()
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
logger.Debug().Msg("starting custom config reload timer")
lastUpdated := time.Now().Unix()
curVerStr := curVersion()
curVer, err := semver.NewVersion(curVerStr)
isStable := curVer != nil && curVer.Prerelease() == ""
if err != nil || !isStable {
l := mainLog.Load().Warn()
if err != nil {
l = l.Err(err)
}
l.Msgf("current version is not stable, skipping self-upgrade: %s", curVerStr)
}
doReloadApiConfig := func(forced bool, logger zerolog.Logger) {
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
selfUninstallCheck(err, p, logger)
if err != nil {
logger.Warn().Err(err).Msg("could not fetch resolver config")
return
}
// Performing self-upgrade check for production version.
if isStable {
selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger)
}
if resolverConfig.DeactivationPin != nil {
newDeactivationPin := *resolverConfig.DeactivationPin
curDeactivationPin := cdDeactivationPin.Load()
switch {
case curDeactivationPin != defaultDeactivationPin:
logger.Debug().Msg("saving deactivation pin")
case curDeactivationPin != newDeactivationPin:
logger.Debug().Msg("update deactivation pin")
}
cdDeactivationPin.Store(newDeactivationPin)
} else {
cdDeactivationPin.Store(defaultDeactivationPin)
}
p.mu.Lock()
rc := p.rc
p.rc = resolverConfig
p.mu.Unlock()
noCustomConfig := resolverConfig.Ctrld.CustomConfig == ""
noExcludeListChanged := true
if rc != nil {
slices.Sort(rc.Exclude)
slices.Sort(resolverConfig.Exclude)
noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude)
}
if noCustomConfig && noExcludeListChanged {
return
}
if noCustomConfig && !noExcludeListChanged {
logger.Debug().Msg("exclude list changes detected, reloading...")
p.apiReloadCh <- nil
return
}
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced {
lastUpdated = time.Now().Unix()
cfg := &ctrld.Config{}
var cfgErr error
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
setListenerDefaultValue(cfg)
setNetworkDefaultValue(cfg)
cfgErr = validateConfig(cfg)
}
if cfgErr != nil {
logger.Warn().Err(err).Msg("skipping invalid custom config")
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
logger.Error().Err(err).Msg("could not mark custom last update failed")
}
return
}
logger.Debug().Msg("custom config changes detected, reloading...")
p.apiReloadCh <- cfg
} else {
logger.Debug().Msg("custom config does not change")
}
}
for {
select {
case <-p.apiForceReloadCh:
doReloadApiConfig(true, logger.With().Bool("forced", true).Logger())
case <-ticker.C:
doReloadApiConfig(false, logger)
case <-p.stopCh:
return
}
}
}
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
isControlDUpstream := false
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
sdns := uc.Type == ctrld.ResolverTypeSDNS
uc.Init()
if sdns {
mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
}
isControlDUpstream = isControlDUpstream || uc.IsControlD()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
if canBeLocalUpstream(uc.Domain) {
localUpstreams = append(localUpstreams, upstreamPrefix+n)
}
if uc.IsDiscoverable() {
ptrNameservers = append(ptrNameservers, uc.Endpoint)
}
}
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
if len(cfg.Upstream) == 1 && isControlDUpstream {
p.canSelfUninstall.Store(true)
}
p.localUpstreams = localUpstreams
p.ptrNameservers = ptrNameservers
}
// run runs the ctrld main components.
//
// The reload boolean indicates that the function is run when ctrld first start
// or when ctrld receive reloading signal. Platform specifics setup is only done
// on started, mean reload is "false".
//
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
// so all listeners could be terminated and re-spawned again.
func (p *prog) run(reload bool, reloadCh chan struct{}) {
// Wait the caller to signal that we can do our logic.
<-p.waitCh
if !reload {
p.preRun()
}
numListeners := len(p.cfg.Listener)
if !reload {
p.started = make(chan struct{}, numListeners)
if p.cs != nil {
p.csSetDnsDone = make(chan struct{}, 1)
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
}
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
} else {
p.cache = cacher
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
for _, domain := range p.cfg.Service.CacheFlushDomains {
p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{}
}
}
}
if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msgf("active directory domain: %s", domain)
p.adDomain = domain
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
continue
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
p.um = newUpstreamMonitor(p.cfg)
if !reload {
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
p.setupUpstream(p.cfg)
p.setupClientInfoDiscover(defaultRouteIP())
}
// context for managing spawn goroutines.
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
// Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() && !reload {
wg.Add(1)
go func() {
defer wg.Done()
p.runClientInfoDiscover(ctx)
}()
go p.watchLinkState(ctx)
}
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
if !reload {
go func(listenerNum string) {
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(ctx, listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
}(listenerNum)
}
go func() {
defer func() {
cancelFunc()
wg.Done()
}()
select {
case <-p.stopCh:
case <-ctx.Done():
case <-reloadCh:
}
}()
}
if !reload {
for i := 0; i < numListeners; i++ {
<-p.started
}
for _, f := range p.onStarted {
f()
}
}
close(p.onStartedDone)
wg.Add(1)
go func() {
defer wg.Done()
// Check for possible DNS loop.
p.checkDnsLoop()
// Start check DNS loop ticker.
p.checkDnsLoopTicker(ctx)
}()
wg.Add(1)
// Prometheus exporter goroutine.
go func() {
defer wg.Done()
p.runMetricsServer(ctx, reloadCh)
}()
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
p.initLogging(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
go p.apiConfigReload()
p.postRun()
}
wg.Wait()
}
// setupClientInfoDiscover performs necessary works for running client info discover.
func (p *prog) setupClientInfoDiscover(selfIP string) {
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
}
// runClientInfoDiscover runs the client info discover.
func (p *prog) runClientInfoDiscover(ctx context.Context) {
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
}
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
func (p *prog) metricsEnabled() bool {
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
}
func (p *prog) Stop(s service.Service) error {
p.stopDnsWatchers()
mainLog.Load().Debug().Msg("dns watchers stopped")
for _, f := range p.onStopped {
f()
}
mainLog.Load().Debug().Msg("finish running onStopped functions")
defer func() {
mainLog.Load().Info().Msg("Service stopped")
}()
if err := p.deAllocateIP(); err != nil {
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
return err
}
if deactivationPinSet() {
select {
case <-p.pinCodeValidCh:
// Allow stopping the service, pinCodeValidCh is only filled
// after control server did validate the pin code.
case <-time.After(time.Millisecond * 100):
// No valid pin code was checked, that mean we are stopping
// because of OS signal sent directly from someone else.
// In this case, restarting ctrld service by ourselves.
mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code")
mainLog.Load().Debug().Msgf("self restarting ctrld service")
if exe, err := os.Executable(); err == nil {
cmd := exec.Command(exe, "restart")
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
if err := cmd.Start(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to run self restart command")
}
} else {
mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service")
}
os.Exit(deactivationPinInvalidExitCode)
}
}
close(p.stopCh)
return nil
}
func (p *prog) stopDnsWatchers() {
// Ensure all DNS watchers goroutine are terminated,
// so it won't mess up with other DNS changes.
p.dnsWatcherClosedOnce.Do(func() {
close(p.dnsWatcherStopCh)
})
p.dnsWg.Wait()
}
func (p *prog) allocateIP(ip string) error {
p.mu.Lock()
defer p.mu.Unlock()
if !p.cfg.Service.AllocateIP {
return nil
}
return allocateIP(ip)
}
func (p *prog) deAllocateIP() error {
p.mu.Lock()
defer p.mu.Unlock()
if !p.cfg.Service.AllocateIP {
return nil
}
for _, lc := range p.cfg.Listener {
if err := deAllocateIP(lc.IP); err != nil {
return err
}
}
return nil
}
func (p *prog) setDNS() {
setDnsOK := false
defer func() {
p.csSetDnsOk = setDnsOK
}()
if cfg.Listener == nil {
return
}
lc := cfg.FirstListener()
if lc == nil {
return
}
ns := lc.IP
switch {
case lc.IsDirectDnsListener():
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
ns = "127.0.0.1"
case lc.Port != 53:
ns = "127.0.0.1"
if resolver := router.LocalResolverIP(); resolver != "" {
ns = resolver
}
default:
// If we ever reach here, it means ctrld is running on lc.IP port 53,
// so we could just use lc.IP as nameserver.
}
nameservers := []string{ns}
if needRFC1918Listeners(lc) {
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
}
if needLocalIPv6Listener() {
nameservers = append(nameservers, "::1")
}
slices.Sort(nameservers)
netIfaceName := ""
netIface := p.setDnsForRunningIface(nameservers)
if netIface != nil {
netIfaceName = netIface.Name
}
setDnsOK = true
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
// resolvconf file is only useful when we have default route interface,
// then set DNS on this interface will push change to /etc/resolv.conf file.
if netIface != nil && shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
for i := range nameservers {
servers[i] = netip.MustParseAddr(nameservers[i])
}
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.dnsWatchdog(netIface, nameservers)
}()
}
}
func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) {
if p.runningIface == "" {
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
const maxDNSRetryAttempts = 3
const retryDelay = 1 * time.Second
var netIface *net.Interface
var err error
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
netIface, err = netInterface(p.runningIface)
if err == nil {
break
}
if attempt < maxDNSRetryAttempts {
// Try to find a different working interface
newIface := findWorkingInterface(p.runningIface)
if newIface != p.runningIface {
p.runningIface = newIface
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
logger.Info().Msg("switched to new interface")
continue
}
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
time.Sleep(retryDelay)
continue
}
logger.Error().Err(err).Msg("could not get interface after all attempts")
return
}
if err := setupNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not patch NetworkManager")
return
}
runningIface = netIface
logger.Debug().Msg("setting DNS for interface")
if err := setDNS(netIface, nameservers); err != nil {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
logger.Debug().Msg("setting DNS successfully")
return
}
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
func (p *prog) dnsWatchdogEnabled() bool {
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
return *ptr
}
return true
}
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
func (p *prog) dnsWatchdogDuration() time.Duration {
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
if (*ptr).Seconds() > 0 {
return *ptr
}
}
return dnsWatchdogDefaultInterval
}
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
// This is only works when deactivation pin set.
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
if !requiredMultiNICsConfig() {
return
}
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
for {
select {
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.recoveryRunning.Load() {
return
}
if dnsChanged(iface, ns) {
mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings")
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(iface); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
}
}
}
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if p.requiredMultiNICsConfig {
ifaceName := ""
if iface != nil {
ifaceName = iface.Name
}
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
if dnsChanged(i, ns) {
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(i); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
}
}
}
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
}
}
return nil
})
}
}
}
}
// resetDNS performs a DNS reset for all interfaces.
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
netIfaceName := ""
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
netIfaceName = netIface.Name
}
// See corresponding comments in (*prog).setDNS function.
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface)
}
}
// resetDNSForRunningIface performs a DNS reset on the running interface.
// The parameter isStart indicates whether this is being called as part of a start (or restart)
// command. When true, we check if the current static DNS configuration already differs from the
// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's
// static DNS settings and we do not override them using the potentially out-of-date saved file.
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
if p.runningIface == "" {
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
netIface, err := netInterface(p.runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
runningIface = netIface
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
// If starting, check the current static DNS configuration.
if isStart {
current, err := currentStaticDNS(netIface)
if err != nil {
logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config")
} else if len(current) > 0 {
// If any static DNS value is not our own listener, assume an admin override.
hasManualConfig := false
for _, ns := range current {
if ns != "127.0.0.1" && ns != "::1" {
hasManualConfig = true
break
}
}
if hasManualConfig {
logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current)
return
}
}
}
// Default logic: if there is a saved static DNS configuration, restore it.
saved := savedStaticNameservers(netIface)
if len(saved) > 0 && restoreStatic {
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
if err := setDNS(netIface, saved); err != nil {
logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name)
return
}
} else {
logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name)
if err := resetDNS(netIface); err != nil {
logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name)
return
}
}
return
}
func (p *prog) logInterfacesState() {
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
addrs, err := i.Addrs()
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
}
nss, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
}
if len(nss) == 0 {
nss = currentDNS(i)
}
mainLog.Load().Debug().
Any("addrs", addrs).
Strs("nameservers", nss).
Int("index", i.Index).
Msgf("interface state: %s", i.Name)
return nil
})
}
// findWorkingInterface looks for a network interface with a valid IP configuration
func findWorkingInterface(currentIface string) string {
// Helper to check if IP is valid (not link-local)
isValidIP := func(ip net.IP) bool {
return ip != nil &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!ip.IsLoopback() &&
!ip.IsUnspecified()
}
// Helper to check if interface has valid IP configuration
hasValidIPConfig := func(iface *net.Interface) bool {
if iface == nil || iface.Flags&net.FlagUp == 0 {
return false
}
addrs, err := iface.Addrs()
if err != nil {
mainLog.Load().Debug().
Str("interface", iface.Name).
Err(err).
Msg("failed to get interface addresses")
return false
}
for _, addr := range addrs {
// Check for IP network
if ipNet, ok := addr.(*net.IPNet); ok {
if isValidIP(ipNet.IP) {
return true
}
}
}
return false
}
// Get default route interface
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
mainLog.Load().Debug().
Err(err).
Msg("failed to get default route")
} else {
mainLog.Load().Debug().
Str("default_route_iface", defaultRoute.InterfaceName).
Msg("found default route")
}
// Get all interfaces
ifaces, err := net.Interfaces()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
return currentIface // Return current interface as fallback
}
var firstWorkingIface string
var currentIfaceValid bool
// Single pass through interfaces
for _, iface := range ifaces {
// Must be physical (has MAC address)
if len(iface.HardwareAddr) == 0 {
continue
}
// Skip interfaces that are:
// - Loopback
// - Not up
// - Point-to-point (like VPN tunnels)
if iface.Flags&net.FlagLoopback != 0 ||
iface.Flags&net.FlagUp == 0 ||
iface.Flags&net.FlagPointToPoint != 0 {
continue
}
if !hasValidIPConfig(&iface) {
continue
}
// Found working physical interface
if err == nil && defaultRoute.InterfaceName == iface.Name {
// Found interface with default route - use it immediately
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", iface.Name).
Msg("switching to interface with default route")
return iface.Name
}
// Keep track of first working interface as fallback
if firstWorkingIface == "" {
firstWorkingIface = iface.Name
}
// Check if this is our current interface
if iface.Name == currentIface {
currentIfaceValid = true
}
}
// Return interfaces in order of preference:
// 1. Current interface if it's still valid
if currentIfaceValid {
mainLog.Load().Debug().
Str("interface", currentIface).
Msg("keeping current interface")
return currentIface
}
// 2. First working interface found
if firstWorkingIface != "" {
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", firstWorkingIface).
Msg("switching to first working physical interface")
return firstWorkingIface
}
// 3. Fall back to current interface if nothing else works
mainLog.Load().Warn().
Str("current_iface", currentIface).
Msg("no working physical interface found, keeping current")
return currentIface
}
func randomLocalIP() string {
n := rand.Intn(254-2) + 2
return fmt.Sprintf("127.0.0.%d", n)
}
func randomPort() int {
max := 1<<16 - 1
min := 1025
n := rand.Intn(max-min) + min
return n
}
// runLogServer starts a unix listener, use by startCmd to gather log from runCmd.
func runLogServer(sockPath string) net.Conn {
addr, err := net.ResolveUnixAddr("unix", sockPath)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("invalid log sock path")
return nil
}
ln, err := net.ListenUnix("unix", addr)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not listen log socket")
return nil
}
defer ln.Close()
server, err := ln.Accept()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not accept connection")
return nil
}
return server
}
func errAddrInUse(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {
return errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(opErr.Err, windowsEADDRINUSE)
}
return false
}
var _ = errAddrInUse
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
var (
windowsECONNREFUSED = syscall.Errno(10061)
windowsENETUNREACH = syscall.Errno(10051)
windowsEINVAL = syscall.Errno(10022)
windowsEADDRINUSE = syscall.Errno(10048)
windowsEHOSTUNREACH = syscall.Errno(10065)
)
func errUrlNetworkError(err error) bool {
var urlErr *url.Error
if errors.As(err, &urlErr) {
return errNetworkError(urlErr.Err)
}
return false
}
func errNetworkError(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {
if opErr.Temporary() {
return true
}
switch {
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
errors.Is(opErr.Err, syscall.EINVAL),
errors.Is(opErr.Err, syscall.ENETUNREACH),
errors.Is(opErr.Err, windowsENETUNREACH),
errors.Is(opErr.Err, windowsEINVAL),
errors.Is(opErr.Err, windowsECONNREFUSED),
errors.Is(opErr.Err, windowsEHOSTUNREACH):
return true
}
}
return false
}
// errConnectionRefused reports whether err is connection refused.
func errConnectionRefused(err error) bool {
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
return errors.Is(opErr.Err, syscall.ECONNREFUSED) || errors.Is(opErr.Err, windowsECONNREFUSED)
}
func ifaceFirstPrivateIP(iface *net.Interface) string {
if iface == nil {
return ""
}
do := func(addrs []net.Addr, v4 bool) net.IP {
for _, addr := range addrs {
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
if v4 {
return netIP.IP.To4()
}
return netIP.IP
}
}
return nil
}
addrs, _ := iface.Addrs()
if ip := do(addrs, true); ip != nil {
return ip.String()
}
if ip := do(addrs, false); ip != nil {
return ip.String()
}
return ""
}
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
func defaultRouteIP() string {
dr, err := netmon.DefaultRoute()
if err != nil {
return ""
}
drNetIface, err := netInterface(dr.InterfaceName)
if err != nil {
return ""
}
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
return ip
}
// If we reach here, it means the default route interface is connected directly to ISP.
// We need to find the LAN interface with the same Mac address with drNetIface.
//
// There could be multiple LAN interfaces with the same Mac address, so we find all private
// IPs then using the smallest one.
var addrs []netip.Addr
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
if i.Name == drNetIface.Name {
return
}
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
for _, pfx := range prefixes {
addr := pfx.Addr()
if addr.IsPrivate() {
addrs = append(addrs, addr)
}
}
}
})
if len(addrs) == 0 {
mainLog.Load().Warn().Msg("no default route IP found")
return ""
}
sort.Slice(addrs, func(i, j int) bool {
return addrs[i].Less(addrs[j])
})
ip := addrs[0].String()
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
return ip
}
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
func canBeLocalUpstream(addr string) bool {
if ip, err := netip.ParseAddr(addr); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
return false
}
// withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding
// the interface that matches excludeIfaceName. The context is used to clarify the
// log message when error happens.
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap()
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
// Skip loopback/virtual/down interface.
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
return
}
// Skip invalid interface.
if !validInterface(i.Interface, validIfacesMap) {
return
}
netIface := i.Interface
if patched, err := patchNetIfaceName(netIface); err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name")
return
} else if !patched {
// The interface is not functional, skipping.
return
}
// Skip excluded interface.
if netIface.Name == excludeIfaceName {
return
}
// TODO: investigate whether we should report this error?
if err := f(netIface); err == nil {
if context != "" {
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name)
}
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
}
})
}
// requiredMultiNicConfig reports whether ctrld needs to set/reset DNS for multiple NICs.
func requiredMultiNICsConfig() bool {
switch runtime.GOOS {
case "windows", "darwin":
return true
default:
return false
}
}
var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not supported on this platform")
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
// Only works on Windows and Mac.
func saveCurrentStaticDNS(iface *net.Interface) error {
if iface == nil {
mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface")
return nil
}
switch runtime.GOOS {
case "windows", "darwin":
default:
return errSaveCurrentStaticDNSNotSupported
}
file := savedStaticDnsSettingsFilePath(iface)
ns, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name)
return err
}
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name)
_ = os.Remove(file) // removing old static DNS settings
return nil
}
//filter out loopback addresses
ns = slices.DeleteFunc(ns, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
//if we now have no static DNS settings and the file already exists
// return and do not save the file
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name)
return nil
}
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file)
}
nss := strings.Join(ns, ",")
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
if err := os.WriteFile(file, []byte(nss), 0600); err != nil {
mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name)
return err
}
mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name)
return nil
}
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
if iface == nil {
return ""
}
return absHomeDir(".dns_" + iface.Name)
}
// savedStaticNameservers returns the static DNS nameservers of the given interface.
//
//lint:ignore U1000 use in os_windows.go and os_darwin.go
func savedStaticNameservers(iface *net.Interface) []string {
if iface == nil {
mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface")
return nil
}
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
saveValues := strings.Split(string(data), ",")
returnValues := []string{}
// check each one, if its in loopback range, remove it
for _, v := range saveValues {
if net.ParseIP(v).IsLoopback() {
continue
}
returnValues = append(returnValues, v)
}
return returnValues
}
return nil
}
// dnsChanged reports whether DNS settings for given interface was changed.
// It returns false for a nil iface.
//
// The caller must sort the nameservers before calling this function.
func dnsChanged(iface *net.Interface, nameservers []string) bool {
if iface == nil {
return false
}
curNameservers, _ := currentStaticDNS(iface)
slices.Sort(curNameservers)
if !slices.Equal(curNameservers, nameservers) {
mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers)
return true
}
return false
}
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
var uer *controld.ErrorResponse
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
p.stopDnsWatchers()
// Perform self-uninstall now.
selfUninstall(p, logger)
}
}
// selfUpgradeCheck checks if the version target vt is greater
// than the current one cv, perform self-upgrade then.
//
// The callers must ensure curVer and logger are non-nil.
func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) {
if vt == "" {
logger.Debug().Msg("no version target set, skipped checking self-upgrade")
return
}
vts := vt
if !strings.HasPrefix(vts, "v") {
vts = "v" + vts
}
targetVer, err := semver.NewVersion(vts)
if err != nil {
logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt)
return
}
if !targetVer.GreaterThan(cv) {
logger.Debug().
Str("target", vt).
Str("current", cv.String()).
Msgf("target version is not greater than current one, skipped self-upgrade")
return
}
exe, err := os.Executable()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade")
return
}
cmd := exec.Command(exe, "upgrade", "prod", "-vv")
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
if err := cmd.Start(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade")
return
}
mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts)
}
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
// when upstream failures occur.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
return *ptr
}
// Default is false on routers, since this leaking is only useful for devices that move between networks.
if router.Name() != "" {
return false
}
// if we are running on ADDC, we should not leak on upstream failure
if p.runningOnDomainController {
return false
}
return true
}
// Domain controller role values from Win32_ComputerSystem
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
const (
BackupDomainController = 4
PrimaryDomainController = 5
)
// isRunningOnDomainController checks if the current machine is a domain controller
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
func isRunningOnDomainController() (bool, int) {
if runtime.GOOS != "windows" {
return false, 0
}
return isRunningOnDomainControllerWindows()
}