Files
ctrld/cmd/cli/prog.go
Cuong Manh Le d71d1341b6 refactor(prog): move network monitoring outside listener loop
Move the network monitoring goroutine initialization outside the listener
loop to prevent it from being started multiple times. Previously, the
network monitoring was started once per listener during first run, which
was unnecessary and could lead to multiple monitoring instances.

The change ensures network monitoring is started only once per program
execution cycle, improving efficiency and preventing potential resource
waste from duplicate monitoring goroutines.

- Extract network monitoring goroutine from listener loop
- Start network monitoring once per run cycle instead of per listener
- Maintain same functionality while improving resource usage
2025-08-12 16:49:05 +07:00

1589 lines
46 KiB
Go

package cli
import (
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"math/rand"
"net"
"net/netip"
"net/url"
"os"
"os/exec"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/Masterminds/semver/v3"
"github.com/kardianos/service"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"golang.org/x/sync/singleflight"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/internal/router"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
)
const (
defaultSemaphoreCap = 256
ctrldLogUnixSock = "ctrld_start.sock"
ctrldControlUnixSock = "ctrld_control.sock"
// iOS unix socket name max length is 11.
ctrldControlUnixSockMobile = "cd.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamOSLocal = upstreamOS + ".local"
dnsWatchdogDefaultInterval = 20 * time.Second
ctrldServiceName = "ctrld"
)
// RecoveryReason provides context for why we are waiting for recovery.
// recovery involves removing the listener IP from the interface and
// waiting for the upstreams to work before returning
type RecoveryReason int
const (
RecoveryReasonNetworkChange RecoveryReason = iota
RecoveryReasonRegularFailure
RecoveryReasonOSFailure
)
// ControlSocketName returns name for control unix socket.
func ControlSocketName() string {
if isMobile() {
return ctrldControlUnixSockMobile
} else {
return ctrldControlUnixSock
}
}
// logf is a function variable used for logging formatted debug messages with optional arguments.
// This is used only when creating a new DNS OS configurator.
var logf = func(format string, args ...any) {
mainLog.Load().Debug().Msgf(format, args...)
}
// noopLogf is like logf but discards formatted log messages and arguments without any processing.
//
//lint:ignore U1000 use in newLoopbackOSConfigurator
var noopLogf = func(format string, args ...any) {}
var svcConfig = &service.Config{
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
pinCodeValidCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
apiReloadCh chan *ctrld.Config
apiForceReloadCh chan struct{}
apiForceReloadGroup singleflight.Group
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
dnsWg sync.WaitGroup
dnsWatcherClosedOnce sync.Once
dnsWatcherStopCh chan struct{}
rc *controld.ResolverConfig
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
cacheFlushDomainsMap map[string]struct{}
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
metricsQueryStats atomic.Bool
queryFromSelfMap sync.Map
initInternalLogWriterOnce sync.Once
internalLogWriter *logWriter
internalWarnLogWriter *logWriter
internalLogSent time.Time
runningIface string
requiredMultiNICsConfig bool
adDomain string
runningOnDomainController bool
selfUninstallMu sync.Mutex
refusedQueryCount int
canSelfUninstall atomic.Bool
checkingSelfUninstall bool
loopMu sync.Mutex
loop map[string]bool
recoveryCancelMu sync.Mutex
recoveryCancel context.CancelFunc
recoveryRunning atomic.Bool
started chan struct{}
onStartedDone chan struct{}
onStarted []func()
onStopped []func()
}
func (p *prog) Start(s service.Service) error {
go p.runWait()
return nil
}
// runWait runs ctrld components, waiting for signal to reload.
func (p *prog) runWait() {
p.mu.Lock()
p.cfg = &cfg
p.mu.Unlock()
reloadSigCh := make(chan os.Signal, 1)
notifyReloadSigCh(reloadSigCh)
reload := false
logger := mainLog.Load()
for {
reloadCh := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
p.run(reload, reloadCh)
reload = true
}()
var newCfg *ctrld.Config
select {
case sig := <-reloadSigCh:
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
case <-p.reloadCh:
logger.Notice().Msg("reloading...")
case apiCfg := <-p.apiReloadCh:
newCfg = apiCfg
case <-p.stopCh:
close(reloadCh)
return
}
waitOldRunDone := func() {
close(reloadCh)
<-done
}
if newCfg == nil {
newCfg = &ctrld.Config{}
confFile := v.ConfigFileUsed()
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
confFile = configPath
}
v.SetConfigFile(confFile)
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if rc, err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
} else {
p.mu.Lock()
p.rc = rc
p.mu.Unlock()
}
}
}
waitOldRunDone()
p.mu.Lock()
curListener := p.cfg.Listener
p.mu.Unlock()
for n, lc := range newCfg.Listener {
curLc := curListener[n]
if curLc == nil {
continue
}
if lc.IP == "" {
lc.IP = curLc.IP
}
if lc.Port == 0 {
lc.Port = curLc.Port
}
}
if err := validateConfig(newCfg); err != nil {
logger.Err(err).Msg("invalid config")
continue
}
addExtraSplitDnsRule(newCfg)
if err := writeConfigFile(newCfg); err != nil {
logger.Err(err).Msg("could not write new config")
}
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
p.setupUpstream(newCfg)
p.mu.Lock()
*p.cfg = *newCfg
p.mu.Unlock()
logger.Notice().Msg("reloading config successfully")
select {
case p.reloadDoneCh <- struct{}{}:
default:
}
}
}
func (p *prog) preRun() {
if iface == "auto" {
iface = defaultIfaceName()
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
}
p.runningIface = iface
}
func (p *prog) postRun() {
if !service.Interactive() {
if runtime.GOOS == "windows" {
isDC, roleInt := isRunningOnDomainController()
p.runningOnDomainController = isDC
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
}
p.resetDNS(false, false)
ns := ctrld.InitializeOsResolver(false)
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
p.setDNS()
p.csSetDnsDone <- struct{}{}
close(p.csSetDnsDone)
p.logInterfacesState()
}
}
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
func (p *prog) apiConfigReload() {
if cdUID == "" {
return
}
ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second)
defer ticker.Stop()
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
logger.Debug().Msg("starting custom config reload timer")
lastUpdated := time.Now().Unix()
curVerStr := curVersion()
curVer, err := semver.NewVersion(curVerStr)
isStable := curVer != nil && curVer.Prerelease() == ""
if err != nil || !isStable {
l := mainLog.Load().Warn()
if err != nil {
l = l.Err(err)
}
l.Msgf("current version is not stable, skipping self-upgrade: %s", curVerStr)
}
doReloadApiConfig := func(forced bool, logger zerolog.Logger) {
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
selfUninstallCheck(err, p, logger)
if err != nil {
logger.Warn().Err(err).Msg("could not fetch resolver config")
return
}
// Performing self-upgrade check for production version.
if isStable {
_ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger)
}
if resolverConfig.DeactivationPin != nil {
newDeactivationPin := *resolverConfig.DeactivationPin
curDeactivationPin := cdDeactivationPin.Load()
switch {
case curDeactivationPin != defaultDeactivationPin:
logger.Debug().Msg("saving deactivation pin")
case curDeactivationPin != newDeactivationPin:
logger.Debug().Msg("update deactivation pin")
}
cdDeactivationPin.Store(newDeactivationPin)
} else {
cdDeactivationPin.Store(defaultDeactivationPin)
}
p.mu.Lock()
rc := p.rc
p.rc = resolverConfig
p.mu.Unlock()
noCustomConfig := resolverConfig.Ctrld.CustomConfig == ""
noExcludeListChanged := true
if rc != nil {
slices.Sort(rc.Exclude)
slices.Sort(resolverConfig.Exclude)
noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude)
}
if noCustomConfig && noExcludeListChanged {
return
}
if noCustomConfig && !noExcludeListChanged {
logger.Debug().Msg("exclude list changes detected, reloading...")
p.apiReloadCh <- nil
return
}
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced {
lastUpdated = time.Now().Unix()
cfg := &ctrld.Config{}
var cfgErr error
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
setListenerDefaultValue(cfg)
setNetworkDefaultValue(cfg)
cfgErr = validateConfig(cfg)
}
if cfgErr != nil {
logger.Warn().Err(err).Msg("skipping invalid custom config")
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
logger.Error().Err(err).Msg("could not mark custom last update failed")
}
return
}
logger.Debug().Msg("custom config changes detected, reloading...")
p.apiReloadCh <- cfg
} else {
logger.Debug().Msg("custom config does not change")
}
}
for {
select {
case <-p.apiForceReloadCh:
doReloadApiConfig(true, logger.With().Bool("forced", true).Logger())
case <-ticker.C:
doReloadApiConfig(false, logger)
case <-p.stopCh:
return
}
}
}
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
isControlDUpstream := false
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
sdns := uc.Type == ctrld.ResolverTypeSDNS
uc.Init()
if sdns {
mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
}
isControlDUpstream = isControlDUpstream || uc.IsControlD()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
if canBeLocalUpstream(uc.Domain) {
localUpstreams = append(localUpstreams, upstreamPrefix+n)
}
if uc.IsDiscoverable() {
ptrNameservers = append(ptrNameservers, uc.Endpoint)
}
}
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
if len(cfg.Upstream) == 1 && isControlDUpstream {
p.canSelfUninstall.Store(true)
}
p.localUpstreams = localUpstreams
p.ptrNameservers = ptrNameservers
}
// run runs the ctrld main components.
//
// The reload boolean indicates that the function is run when ctrld first start
// or when ctrld receive reloading signal. Platform specifics setup is only done
// on started, mean reload is "false".
//
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
// so all listeners could be terminated and re-spawned again.
func (p *prog) run(reload bool, reloadCh chan struct{}) {
// Wait the caller to signal that we can do our logic.
<-p.waitCh
if !reload {
p.preRun()
}
numListeners := len(p.cfg.Listener)
if !reload {
p.started = make(chan struct{}, numListeners)
if p.cs != nil {
p.csSetDnsDone = make(chan struct{}, 1)
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
}
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
} else {
p.cache = cacher
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
for _, domain := range p.cfg.Service.CacheFlushDomains {
p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{}
}
}
}
if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msgf("active directory domain: %s", domain)
p.adDomain = domain
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
continue
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
p.um = newUpstreamMonitor(p.cfg)
if !reload {
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
p.setupUpstream(p.cfg)
p.setupClientInfoDiscover(defaultRouteIP())
}
// context for managing spawn goroutines.
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
// Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() && !reload {
wg.Add(1)
go func() {
defer wg.Done()
p.runClientInfoDiscover(ctx)
}()
go p.watchLinkState(ctx)
}
if !reload {
go func() {
// Start network monitoring
if err := p.monitorNetworkChanges(); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
}
}()
}
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
if !reload {
go func(listenerNum string) {
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
}(listenerNum)
}
go func() {
defer func() {
cancelFunc()
wg.Done()
}()
select {
case <-p.stopCh:
case <-ctx.Done():
case <-reloadCh:
}
}()
}
if !reload {
for i := 0; i < numListeners; i++ {
<-p.started
}
for _, f := range p.onStarted {
f()
}
}
close(p.onStartedDone)
wg.Add(1)
go func() {
defer wg.Done()
// Check for possible DNS loop.
p.checkDnsLoop()
// Start check DNS loop ticker.
p.checkDnsLoopTicker(ctx)
}()
wg.Add(1)
// Prometheus exporter goroutine.
go func() {
defer wg.Done()
p.runMetricsServer(ctx, reloadCh)
}()
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
p.initLogging(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
go p.apiConfigReload()
p.postRun()
}
wg.Wait()
}
// setupClientInfoDiscover performs necessary works for running client info discover.
func (p *prog) setupClientInfoDiscover(selfIP string) {
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
if leaseFiles := dnsmasq.AdditionalLeaseFiles(); len(leaseFiles) > 0 {
mainLog.Load().Debug().Msgf("watching additional lease files: %v", leaseFiles)
for _, leaseFile := range leaseFiles {
p.ciTable.AddLeaseFile(leaseFile, ctrld.Dnsmasq)
}
}
}
// runClientInfoDiscover runs the client info discover.
func (p *prog) runClientInfoDiscover(ctx context.Context) {
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
}
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
func (p *prog) metricsEnabled() bool {
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
}
func (p *prog) Stop(s service.Service) error {
p.stopDnsWatchers()
mainLog.Load().Debug().Msg("dns watchers stopped")
for _, f := range p.onStopped {
f()
}
mainLog.Load().Debug().Msg("finish running onStopped functions")
defer func() {
mainLog.Load().Info().Msg("Service stopped")
}()
if err := p.deAllocateIP(); err != nil {
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
return err
}
if deactivationPinSet() {
select {
case <-p.pinCodeValidCh:
// Allow stopping the service, pinCodeValidCh is only filled
// after control server did validate the pin code.
case <-time.After(time.Millisecond * 100):
// No valid pin code was checked, that mean we are stopping
// because of OS signal sent directly from someone else.
// In this case, restarting ctrld service by ourselves.
mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code")
mainLog.Load().Debug().Msgf("self restarting ctrld service")
if exe, err := os.Executable(); err == nil {
cmd := exec.Command(exe, "restart")
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
if err := cmd.Start(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to run self restart command")
}
} else {
mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service")
}
os.Exit(deactivationPinInvalidExitCode)
}
}
close(p.stopCh)
return nil
}
func (p *prog) stopDnsWatchers() {
// Ensure all DNS watchers goroutine are terminated,
// so it won't mess up with other DNS changes.
p.dnsWatcherClosedOnce.Do(func() {
close(p.dnsWatcherStopCh)
})
p.dnsWg.Wait()
}
func (p *prog) allocateIP(ip string) error {
p.mu.Lock()
defer p.mu.Unlock()
if !p.cfg.Service.AllocateIP {
return nil
}
return allocateIP(ip)
}
func (p *prog) deAllocateIP() error {
p.mu.Lock()
defer p.mu.Unlock()
if !p.cfg.Service.AllocateIP {
return nil
}
for _, lc := range p.cfg.Listener {
if err := deAllocateIP(lc.IP); err != nil {
return err
}
}
return nil
}
func (p *prog) setDNS() {
setDnsOK := false
defer func() {
p.csSetDnsOk = setDnsOK
}()
if cfg.Listener == nil {
return
}
lc := cfg.FirstListener()
if lc == nil {
return
}
ns := lc.IP
switch {
case lc.IsDirectDnsListener():
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
ns = "127.0.0.1"
case lc.Port != 53:
ns = "127.0.0.1"
if resolver := router.LocalResolverIP(); resolver != "" {
ns = resolver
}
default:
// If we ever reach here, it means ctrld is running on lc.IP port 53,
// so we could just use lc.IP as nameserver.
}
nameservers := []string{ns}
if needRFC1918Listeners(lc) {
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
}
if needLocalIPv6Listener() {
nameservers = append(nameservers, "::1")
}
slices.Sort(nameservers)
netIfaceName := ""
netIface := p.setDnsForRunningIface(nameservers)
if netIface != nil {
netIfaceName = netIface.Name
}
setDnsOK = true
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
// resolvconf file is only useful when we have default route interface,
// then set DNS on this interface will push change to /etc/resolv.conf file.
if netIface != nil && shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
for i := range nameservers {
servers[i] = netip.MustParseAddr(nameservers[i])
}
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.dnsWatchdog(netIface, nameservers)
}()
}
}
func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) {
if p.runningIface == "" {
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
const maxDNSRetryAttempts = 3
const retryDelay = 1 * time.Second
var netIface *net.Interface
var err error
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
netIface, err = netInterface(p.runningIface)
if err == nil {
break
}
if attempt < maxDNSRetryAttempts {
// Try to find a different working interface
newIface := findWorkingInterface(p.runningIface)
if newIface != p.runningIface {
p.runningIface = newIface
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
logger.Info().Msg("switched to new interface")
continue
}
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
time.Sleep(retryDelay)
continue
}
logger.Error().Err(err).Msg("could not get interface after all attempts")
return
}
if err := setupNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not patch NetworkManager")
return
}
runningIface = netIface
logger.Debug().Msg("setting DNS for interface")
if err := setDNS(netIface, nameservers); err != nil {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
logger.Debug().Msg("setting DNS successfully")
return
}
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
func (p *prog) dnsWatchdogEnabled() bool {
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
return *ptr
}
return true
}
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
func (p *prog) dnsWatchdogDuration() time.Duration {
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
if (*ptr).Seconds() > 0 {
return *ptr
}
}
return dnsWatchdogDefaultInterval
}
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
// This is only works when deactivation pin set.
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
if !requiredMultiNICsConfig() {
return
}
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
for {
select {
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.recoveryRunning.Load() {
return
}
if dnsChanged(iface, ns) {
mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings")
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(iface); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
}
}
}
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if p.requiredMultiNICsConfig {
ifaceName := ""
if iface != nil {
ifaceName = iface.Name
}
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
if dnsChanged(i, ns) {
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(i); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
}
}
}
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
}
}
return nil
})
}
}
}
}
// resetDNS performs a DNS reset for all interfaces.
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
netIfaceName := ""
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
netIfaceName = netIface.Name
}
// See corresponding comments in (*prog).setDNS function.
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface)
}
}
// resetDNSForRunningIface performs a DNS reset on the running interface.
// The parameter isStart indicates whether this is being called as part of a start (or restart)
// command. When true, we check if the current static DNS configuration already differs from the
// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's
// static DNS settings and we do not override them using the potentially out-of-date saved file.
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
if p.runningIface == "" {
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
netIface, err := netInterface(p.runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
runningIface = netIface
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
// If starting, check the current static DNS configuration.
if isStart {
current, err := currentStaticDNS(netIface)
if err != nil {
logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config")
} else if len(current) > 0 {
// If any static DNS value is not our own listener, assume an admin override.
hasManualConfig := false
for _, ns := range current {
if ns != "127.0.0.1" && ns != "::1" {
hasManualConfig = true
break
}
}
if hasManualConfig {
logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current)
return
}
}
}
// Default logic: if there is a saved static DNS configuration, restore it.
saved := savedStaticNameservers(netIface)
if len(saved) > 0 && restoreStatic {
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
if err := setDNS(netIface, saved); err != nil {
logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name)
return
}
} else {
logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name)
if err := resetDNS(netIface); err != nil {
logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name)
return
}
}
return
}
func (p *prog) logInterfacesState() {
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
addrs, err := i.Addrs()
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
}
nss, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
}
if len(nss) == 0 {
nss = currentDNS(i)
}
mainLog.Load().Debug().
Any("addrs", addrs).
Strs("nameservers", nss).
Int("index", i.Index).
Msgf("interface state: %s", i.Name)
return nil
})
}
// findWorkingInterface looks for a network interface with a valid IP configuration
func findWorkingInterface(currentIface string) string {
// Helper to check if IP is valid (not link-local)
isValidIP := func(ip net.IP) bool {
return ip != nil &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!ip.IsLoopback() &&
!ip.IsUnspecified()
}
// Helper to check if interface has valid IP configuration
hasValidIPConfig := func(iface *net.Interface) bool {
if iface == nil || iface.Flags&net.FlagUp == 0 {
return false
}
addrs, err := iface.Addrs()
if err != nil {
mainLog.Load().Debug().
Str("interface", iface.Name).
Err(err).
Msg("failed to get interface addresses")
return false
}
for _, addr := range addrs {
// Check for IP network
if ipNet, ok := addr.(*net.IPNet); ok {
if isValidIP(ipNet.IP) {
return true
}
}
}
return false
}
// Get default route interface
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
mainLog.Load().Debug().
Err(err).
Msg("failed to get default route")
} else {
mainLog.Load().Debug().
Str("default_route_iface", defaultRoute.InterfaceName).
Msg("found default route")
}
// Get all interfaces
ifaces, err := net.Interfaces()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
return currentIface // Return current interface as fallback
}
var firstWorkingIface string
var currentIfaceValid bool
// Single pass through interfaces
for _, iface := range ifaces {
// Must be physical (has MAC address)
if len(iface.HardwareAddr) == 0 {
continue
}
// Skip interfaces that are:
// - Loopback
// - Not up
// - Point-to-point (like VPN tunnels)
if iface.Flags&net.FlagLoopback != 0 ||
iface.Flags&net.FlagUp == 0 ||
iface.Flags&net.FlagPointToPoint != 0 {
continue
}
if !hasValidIPConfig(&iface) {
continue
}
// Found working physical interface
if err == nil && defaultRoute.InterfaceName == iface.Name {
// Found interface with default route - use it immediately
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", iface.Name).
Msg("switching to interface with default route")
return iface.Name
}
// Keep track of first working interface as fallback
if firstWorkingIface == "" {
firstWorkingIface = iface.Name
}
// Check if this is our current interface
if iface.Name == currentIface {
currentIfaceValid = true
}
}
// Return interfaces in order of preference:
// 1. Current interface if it's still valid
if currentIfaceValid {
mainLog.Load().Debug().
Str("interface", currentIface).
Msg("keeping current interface")
return currentIface
}
// 2. First working interface found
if firstWorkingIface != "" {
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", firstWorkingIface).
Msg("switching to first working physical interface")
return firstWorkingIface
}
// 3. Fall back to current interface if nothing else works
mainLog.Load().Warn().
Str("current_iface", currentIface).
Msg("no working physical interface found, keeping current")
return currentIface
}
func randomLocalIP() string {
n := rand.Intn(254-2) + 2
return fmt.Sprintf("127.0.0.%d", n)
}
func randomPort() int {
max := 1<<16 - 1
min := 1025
n := rand.Intn(max-min) + min
return n
}
// runLogServer starts a unix listener, use by startCmd to gather log from runCmd.
func runLogServer(sockPath string) net.Conn {
addr, err := net.ResolveUnixAddr("unix", sockPath)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("invalid log sock path")
return nil
}
ln, err := net.ListenUnix("unix", addr)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not listen log socket")
return nil
}
defer ln.Close()
server, err := ln.Accept()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not accept connection")
return nil
}
return server
}
func errAddrInUse(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {
return errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(opErr.Err, windowsEADDRINUSE)
}
return false
}
var _ = errAddrInUse
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
var (
windowsECONNREFUSED = syscall.Errno(10061)
windowsENETUNREACH = syscall.Errno(10051)
windowsEINVAL = syscall.Errno(10022)
windowsEADDRINUSE = syscall.Errno(10048)
windowsEHOSTUNREACH = syscall.Errno(10065)
)
func errUrlNetworkError(err error) bool {
var urlErr *url.Error
if errors.As(err, &urlErr) {
return errNetworkError(urlErr.Err)
}
return false
}
func errNetworkError(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {
if opErr.Temporary() {
return true
}
switch {
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
errors.Is(opErr.Err, syscall.EINVAL),
errors.Is(opErr.Err, syscall.ENETUNREACH),
errors.Is(opErr.Err, windowsENETUNREACH),
errors.Is(opErr.Err, windowsEINVAL),
errors.Is(opErr.Err, windowsECONNREFUSED),
errors.Is(opErr.Err, windowsEHOSTUNREACH):
return true
}
}
return false
}
// errConnectionRefused reports whether err is connection refused.
func errConnectionRefused(err error) bool {
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
return errors.Is(opErr.Err, syscall.ECONNREFUSED) || errors.Is(opErr.Err, windowsECONNREFUSED)
}
func ifaceFirstPrivateIP(iface *net.Interface) string {
if iface == nil {
return ""
}
do := func(addrs []net.Addr, v4 bool) net.IP {
for _, addr := range addrs {
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
if v4 {
return netIP.IP.To4()
}
return netIP.IP
}
}
return nil
}
addrs, _ := iface.Addrs()
if ip := do(addrs, true); ip != nil {
return ip.String()
}
if ip := do(addrs, false); ip != nil {
return ip.String()
}
return ""
}
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
func defaultRouteIP() string {
dr, err := netmon.DefaultRoute()
if err != nil {
return ""
}
drNetIface, err := netInterface(dr.InterfaceName)
if err != nil {
return ""
}
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
return ip
}
// If we reach here, it means the default route interface is connected directly to ISP.
// We need to find the LAN interface with the same Mac address with drNetIface.
//
// There could be multiple LAN interfaces with the same Mac address, so we find all private
// IPs then using the smallest one.
var addrs []netip.Addr
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
if i.Name == drNetIface.Name {
return
}
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
for _, pfx := range prefixes {
addr := pfx.Addr()
if addr.IsPrivate() {
addrs = append(addrs, addr)
}
}
}
})
if len(addrs) == 0 {
mainLog.Load().Warn().Msg("no default route IP found")
return ""
}
sort.Slice(addrs, func(i, j int) bool {
return addrs[i].Less(addrs[j])
})
ip := addrs[0].String()
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
return ip
}
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
func canBeLocalUpstream(addr string) bool {
if ip, err := netip.ParseAddr(addr); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
return false
}
// withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding
// the interface that matches excludeIfaceName. The context is used to clarify the
// log message when error happens.
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap()
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
// Skip loopback/virtual/down interface.
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
return
}
// Skip invalid interface.
if !validInterface(i.Interface, validIfacesMap) {
return
}
netIface := i.Interface
if patched, err := patchNetIfaceName(netIface); err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name")
return
} else if !patched {
// The interface is not functional, skipping.
return
}
// Skip excluded interface.
if netIface.Name == excludeIfaceName {
return
}
// TODO: investigate whether we should report this error?
if err := f(netIface); err == nil {
if context != "" {
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name)
}
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
}
})
}
// requiredMultiNicConfig reports whether ctrld needs to set/reset DNS for multiple NICs.
func requiredMultiNICsConfig() bool {
switch runtime.GOOS {
case "windows", "darwin":
return true
default:
return false
}
}
var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not supported on this platform")
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
// Only works on Windows and Mac.
func saveCurrentStaticDNS(iface *net.Interface) error {
if iface == nil {
mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface")
return nil
}
switch runtime.GOOS {
case "windows", "darwin":
default:
return errSaveCurrentStaticDNSNotSupported
}
file := savedStaticDnsSettingsFilePath(iface)
ns, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name)
return err
}
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name)
_ = os.Remove(file) // removing old static DNS settings
return nil
}
//filter out loopback addresses
ns = slices.DeleteFunc(ns, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
//if we now have no static DNS settings and the file already exists
// return and do not save the file
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name)
return nil
}
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file)
}
nss := strings.Join(ns, ",")
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
if err := os.WriteFile(file, []byte(nss), 0600); err != nil {
mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name)
return err
}
mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name)
return nil
}
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
if iface == nil {
return ""
}
return absHomeDir(".dns_" + iface.Name)
}
// savedStaticNameservers returns the static DNS nameservers of the given interface.
//
//lint:ignore U1000 use in os_windows.go and os_darwin.go
func savedStaticNameservers(iface *net.Interface) []string {
if iface == nil {
mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface")
return nil
}
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
saveValues := strings.Split(string(data), ",")
returnValues := []string{}
// check each one, if its in loopback range, remove it
for _, v := range saveValues {
if net.ParseIP(v).IsLoopback() {
continue
}
returnValues = append(returnValues, v)
}
return returnValues
}
return nil
}
// dnsChanged reports whether DNS settings for given interface was changed.
// It returns false for a nil iface.
//
// The caller must sort the nameservers before calling this function.
func dnsChanged(iface *net.Interface, nameservers []string) bool {
if iface == nil {
return false
}
curNameservers, _ := currentStaticDNS(iface)
slices.Sort(curNameservers)
if !slices.Equal(curNameservers, nameservers) {
mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers)
return true
}
return false
}
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
var uer *controld.ErrorResponse
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
p.stopDnsWatchers()
// Perform self-uninstall now.
selfUninstall(p, logger)
}
}
// shouldUpgrade checks if the version target vt is greater than the current one cv.
// Major version upgrades are not allowed to prevent breaking changes.
//
// The callers must ensure curVer and logger are non-nil.
// Returns true if upgrade is allowed, false otherwise.
func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool {
if vt == "" {
logger.Debug().Msg("no version target set, skipped checking self-upgrade")
return false
}
vts := vt
if !strings.HasPrefix(vts, "v") {
vts = "v" + vts
}
targetVer, err := semver.NewVersion(vts)
if err != nil {
logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt)
return false
}
// Prevent major version upgrades to avoid breaking changes
if targetVer.Major() != cv.Major() {
logger.Warn().
Str("target", vt).
Str("current", cv.String()).
Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major())
return false
}
if !targetVer.GreaterThan(cv) {
logger.Debug().
Str("target", vt).
Str("current", cv.String()).
Msgf("target version is not greater than current one, skipped self-upgrade")
return false
}
return true
}
// performUpgrade executes the self-upgrade command.
// Returns true if upgrade was initiated successfully, false otherwise.
func performUpgrade(vt string) bool {
exe, err := os.Executable()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade")
return false
}
cmd := exec.Command(exe, "upgrade", "prod", "-vv")
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
if err := cmd.Start(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade")
return false
}
mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt)
return true
}
// selfUpgradeCheck checks if the version target vt is greater
// than the current one cv, perform self-upgrade then.
// Major version upgrades are not allowed to prevent breaking changes.
//
// The callers must ensure curVer and logger are non-nil.
// Returns true if upgrade is allowed and should proceed, false otherwise.
func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool {
if shouldUpgrade(vt, cv, logger) {
return performUpgrade(vt)
}
return false
}
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
// when upstream failures occur.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
return *ptr
}
// Default is false on routers, since this leaking is only useful for devices that move between networks.
if router.Name() != "" {
return false
}
// if we are running on ADDC, we should not leak on upstream failure
if p.runningOnDomainController {
return false
}
return true
}
// Domain controller role values from Win32_ComputerSystem
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
const (
BackupDomainController = 4
PrimaryDomainController = 5
)
// isRunningOnDomainController checks if the current machine is a domain controller
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
func isRunningOnDomainController() (bool, int) {
if runtime.GOOS != "windows" {
return false, 0
}
return isRunningOnDomainControllerWindows()
}