mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Replace the legacy Unix socket log communication between `ctrld start` and `ctrld run` with a modern HTTP-based system for better reliability and maintainability. Benefits: - More reliable communication protocol using standard HTTP - Better error handling and connection management - Cleaner separation of concerns with dedicated endpoints - Easier to test and debug with HTTP-based communication - More maintainable code with proper abstraction layers This change maintains backward compatibility while providing a more robust foundation for inter-process communication between ctrld commands.
1492 lines
42 KiB
Go
1492 lines
42 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"math/rand"
|
|
"net"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"os/exec"
|
|
"runtime"
|
|
"slices"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/Masterminds/semver/v3"
|
|
"github.com/kardianos/service"
|
|
"github.com/spf13/viper"
|
|
"golang.org/x/sync/singleflight"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/tsaddr"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
)
|
|
|
|
const (
|
|
defaultSemaphoreCap = 256
|
|
ctrldLogUnixSock = "ctrld_start.sock"
|
|
ctrldControlUnixSock = "ctrld_control.sock"
|
|
// iOS unix socket name max length is 11.
|
|
ctrldControlUnixSockMobile = "cd.sock"
|
|
upstreamPrefix = "upstream."
|
|
upstreamOS = upstreamPrefix + "os"
|
|
upstreamOSLocal = upstreamOS + ".local"
|
|
dnsWatchdogDefaultInterval = 20 * time.Second
|
|
ctrldServiceName = "ctrld"
|
|
)
|
|
|
|
// RecoveryReason provides context for why we are waiting for recovery.
|
|
// recovery involves removing the listener IP from the interface and
|
|
// waiting for the upstreams to work before returning
|
|
type RecoveryReason int
|
|
|
|
const (
|
|
RecoveryReasonNetworkChange RecoveryReason = iota
|
|
RecoveryReasonRegularFailure
|
|
RecoveryReasonOSFailure
|
|
)
|
|
|
|
// ControlSocketName returns name for control unix socket.
|
|
func ControlSocketName() string {
|
|
if isMobile() {
|
|
return ctrldControlUnixSockMobile
|
|
} else {
|
|
return ctrldControlUnixSock
|
|
}
|
|
}
|
|
|
|
// logf is a function variable used for logging formatted debug messages with optional arguments.
|
|
// This is used only when creating a new DNS OS configurator.
|
|
var logf = func(format string, args ...any) {
|
|
mainLog.Load().Debug().Msgf(format, args...)
|
|
}
|
|
|
|
// noopLogf is like logf but discards formatted log messages and arguments without any processing.
|
|
//
|
|
//lint:ignore U1000 use in newLoopbackOSConfigurator
|
|
var noopLogf = func(format string, args ...any) {}
|
|
|
|
var useSystemdResolved = false
|
|
|
|
type prog struct {
|
|
mu sync.Mutex
|
|
waitCh chan struct{}
|
|
stopCh chan struct{}
|
|
pinCodeValidCh chan struct{}
|
|
reloadCh chan struct{} // For Windows.
|
|
reloadDoneCh chan struct{}
|
|
apiReloadCh chan *ctrld.Config
|
|
apiForceReloadCh chan struct{}
|
|
apiForceReloadGroup singleflight.Group
|
|
logConn io.WriteCloser
|
|
cs *controlServer
|
|
logger atomic.Pointer[ctrld.Logger]
|
|
csSetDnsDone chan struct{}
|
|
csSetDnsOk bool
|
|
dnsWg sync.WaitGroup
|
|
dnsWatcherClosedOnce sync.Once
|
|
dnsWatcherStopCh chan struct{}
|
|
rc *controld.ResolverConfig
|
|
|
|
cfg *ctrld.Config
|
|
localUpstreams []string
|
|
ptrNameservers []string
|
|
appCallback *AppCallback
|
|
cache dnscache.Cacher
|
|
cacheFlushDomainsMap map[string]struct{}
|
|
sema semaphore
|
|
ciTable *clientinfo.Table
|
|
um *upstreamMonitor
|
|
ptrLoopGuard *loopGuard
|
|
lanLoopGuard *loopGuard
|
|
metricsQueryStats atomic.Bool
|
|
queryFromSelfMap sync.Map
|
|
initInternalLogWriterOnce sync.Once
|
|
internalLogWriter *logWriter
|
|
internalWarnLogWriter *logWriter
|
|
internalLogSent time.Time
|
|
runningIface string
|
|
requiredMultiNICsConfig bool
|
|
|
|
selfUninstallMu sync.Mutex
|
|
refusedQueryCount int
|
|
canSelfUninstall atomic.Bool
|
|
checkingSelfUninstall bool
|
|
|
|
loopMu sync.Mutex
|
|
loop map[string]bool
|
|
|
|
recoveryCancelMu sync.Mutex
|
|
recoveryCancel context.CancelFunc
|
|
recoveryRunning atomic.Bool
|
|
|
|
started chan struct{}
|
|
onStartedDone chan struct{}
|
|
onStarted []func()
|
|
onStopped []func()
|
|
}
|
|
|
|
func (p *prog) Start(_ service.Service) error {
|
|
go p.runWait()
|
|
return nil
|
|
}
|
|
|
|
// runWait runs ctrld components, waiting for signal to reload.
|
|
func (p *prog) runWait() {
|
|
p.mu.Lock()
|
|
p.cfg = &cfg
|
|
p.mu.Unlock()
|
|
reloadSigCh := make(chan os.Signal, 1)
|
|
notifyReloadSigCh(reloadSigCh)
|
|
|
|
reload := false
|
|
for {
|
|
reloadCh := make(chan struct{})
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
p.run(reload, reloadCh)
|
|
reload = true
|
|
}()
|
|
|
|
var newCfg *ctrld.Config
|
|
select {
|
|
case sig := <-reloadSigCh:
|
|
p.Notice().Msgf("Got signal: %s, reloading...", sig.String())
|
|
case <-p.reloadCh:
|
|
p.Notice().Msg("Reloading...")
|
|
case apiCfg := <-p.apiReloadCh:
|
|
newCfg = apiCfg
|
|
case <-p.stopCh:
|
|
close(reloadCh)
|
|
return
|
|
}
|
|
|
|
waitOldRunDone := func() {
|
|
close(reloadCh)
|
|
<-done
|
|
}
|
|
|
|
if newCfg == nil {
|
|
newCfg = &ctrld.Config{}
|
|
confFile := v.ConfigFileUsed()
|
|
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
|
ctrld.InitConfig(v, "ctrld")
|
|
if configPath != "" {
|
|
confFile = configPath
|
|
}
|
|
v.SetConfigFile(confFile)
|
|
if err := v.ReadInConfig(); err != nil {
|
|
p.Error().Err(err).Msg("Could not read new config")
|
|
waitOldRunDone()
|
|
continue
|
|
}
|
|
if err := v.Unmarshal(&newCfg); err != nil {
|
|
p.Error().Err(err).Msg("Could not unmarshal new config")
|
|
waitOldRunDone()
|
|
continue
|
|
}
|
|
if cdUID != "" {
|
|
if rc, err := processCDFlags(newCfg); err != nil {
|
|
p.Error().Err(err).Msg("Could not fetch controld config")
|
|
waitOldRunDone()
|
|
continue
|
|
} else {
|
|
p.mu.Lock()
|
|
p.rc = rc
|
|
p.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
waitOldRunDone()
|
|
|
|
p.mu.Lock()
|
|
curListener := p.cfg.Listener
|
|
p.mu.Unlock()
|
|
|
|
for n, lc := range newCfg.Listener {
|
|
curLc := curListener[n]
|
|
if curLc == nil {
|
|
continue
|
|
}
|
|
if lc.IP == "" {
|
|
lc.IP = curLc.IP
|
|
}
|
|
if lc.Port == 0 {
|
|
lc.Port = curLc.Port
|
|
}
|
|
}
|
|
if err := validateConfig(newCfg); err != nil {
|
|
p.Error().Err(err).Msg("Invalid config")
|
|
continue
|
|
}
|
|
|
|
addExtraSplitDnsRule(newCfg)
|
|
if err := writeConfigFile(newCfg); err != nil {
|
|
p.Error().Err(err).Msg("Could not write new config")
|
|
}
|
|
|
|
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
|
// upstream config because its initialization function have not been called yet.
|
|
p.Debug().Msg("Setup upstream with new config")
|
|
p.setupUpstream(newCfg)
|
|
|
|
p.mu.Lock()
|
|
*p.cfg = *newCfg
|
|
p.mu.Unlock()
|
|
|
|
p.Notice().Msg("Reloading config successfully")
|
|
|
|
select {
|
|
case p.reloadDoneCh <- struct{}{}:
|
|
p.Debug().Msg("Reload done signal sent")
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *prog) preRun() {
|
|
if iface == "auto" {
|
|
iface = defaultIfaceName()
|
|
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
|
|
}
|
|
p.runningIface = iface
|
|
p.logger.Store(mainLog.Load())
|
|
}
|
|
|
|
func (p *prog) postRun() {
|
|
if !service.Interactive() {
|
|
p.resetDNS(false, false)
|
|
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), false)
|
|
p.Debug().Msgf("Initialized os resolver with nameservers: %v", ns)
|
|
p.setDNS()
|
|
p.csSetDnsDone <- struct{}{}
|
|
close(p.csSetDnsDone)
|
|
p.logInterfacesState()
|
|
}
|
|
}
|
|
|
|
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
|
|
func (p *prog) apiConfigReload() {
|
|
if cdUID == "" {
|
|
return
|
|
}
|
|
|
|
ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
logger := p.logger.Load().With().Str("mode", "api-reload")
|
|
logger.Debug().Msg("Starting custom config reload timer")
|
|
lastUpdated := time.Now().Unix()
|
|
curVerStr := curVersion()
|
|
curVer, err := semver.NewVersion(curVerStr)
|
|
isStable := curVer != nil && curVer.Prerelease() == ""
|
|
if err != nil || !isStable {
|
|
l := p.Warn()
|
|
if err != nil {
|
|
l = l.Err(err)
|
|
}
|
|
l.Msgf("Current version is not stable, skipping self-upgrade: %s", curVerStr)
|
|
}
|
|
|
|
doReloadApiConfig := func(forced bool, logger *ctrld.Logger) {
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev)
|
|
selfUninstallCheck(err, p, logger)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msg("Could not fetch resolver config")
|
|
return
|
|
}
|
|
|
|
// Performing self-upgrade check for production version.
|
|
if isStable {
|
|
_ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, logger)
|
|
}
|
|
|
|
if resolverConfig.DeactivationPin != nil {
|
|
newDeactivationPin := *resolverConfig.DeactivationPin
|
|
curDeactivationPin := cdDeactivationPin.Load()
|
|
switch {
|
|
case curDeactivationPin != defaultDeactivationPin:
|
|
logger.Debug().Msg("Saving deactivation pin")
|
|
case curDeactivationPin != newDeactivationPin:
|
|
logger.Debug().Msg("Update deactivation pin")
|
|
}
|
|
cdDeactivationPin.Store(newDeactivationPin)
|
|
} else {
|
|
cdDeactivationPin.Store(defaultDeactivationPin)
|
|
}
|
|
|
|
p.mu.Lock()
|
|
rc := p.rc
|
|
p.rc = resolverConfig
|
|
p.mu.Unlock()
|
|
noCustomConfig := resolverConfig.Ctrld.CustomConfig == ""
|
|
noExcludeListChanged := true
|
|
if rc != nil {
|
|
slices.Sort(rc.Exclude)
|
|
slices.Sort(resolverConfig.Exclude)
|
|
noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude)
|
|
}
|
|
if noCustomConfig && noExcludeListChanged {
|
|
return
|
|
}
|
|
|
|
if noCustomConfig && !noExcludeListChanged {
|
|
logger.Debug().Msg("Exclude list changes detected, reloading...")
|
|
p.apiReloadCh <- nil
|
|
return
|
|
}
|
|
|
|
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced {
|
|
lastUpdated = time.Now().Unix()
|
|
cfg := &ctrld.Config{}
|
|
var cfgErr error
|
|
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
|
|
setListenerDefaultValue(cfg)
|
|
setNetworkDefaultValue(cfg)
|
|
cfgErr = validateConfig(cfg)
|
|
}
|
|
if cfgErr != nil {
|
|
logger.Warn().Err(err).Msg("Skipping invalid custom config")
|
|
if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, appVersion, cdDev, true); err != nil {
|
|
logger.Error().Err(err).Msg("Could not mark custom last update failed")
|
|
}
|
|
return
|
|
}
|
|
logger.Debug().Msg("Custom config changes detected, reloading...")
|
|
p.apiReloadCh <- cfg
|
|
} else {
|
|
logger.Debug().Msg("Custom config does not change")
|
|
}
|
|
}
|
|
for {
|
|
select {
|
|
case <-p.apiForceReloadCh:
|
|
doReloadApiConfig(true, logger.With().Bool("forced", true))
|
|
case <-ticker.C:
|
|
doReloadApiConfig(false, logger)
|
|
case <-p.stopCh:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
|
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
|
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
|
isControlDUpstream := false
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
for n := range cfg.Upstream {
|
|
uc := cfg.Upstream[n]
|
|
sdns := uc.Type == ctrld.ResolverTypeSDNS
|
|
uc.Init(loggerCtx)
|
|
if sdns {
|
|
p.Debug().Msgf("Initialized dns stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
|
|
}
|
|
isControlDUpstream = isControlDUpstream || uc.IsControlD()
|
|
if uc.BootstrapIP == "" {
|
|
uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), p.logger.Load()))
|
|
p.Info().Msgf("Bootstrap ips for upstream.%s: %q", n, uc.BootstrapIPs())
|
|
} else {
|
|
p.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap ip for upstream.%s", n)
|
|
}
|
|
uc.SetCertPool(rootCertPool)
|
|
go uc.Ping(loggerCtx)
|
|
|
|
if canBeLocalUpstream(uc.Domain) {
|
|
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
|
}
|
|
if uc.IsDiscoverable() {
|
|
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
|
}
|
|
}
|
|
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
|
|
if len(cfg.Upstream) == 1 && isControlDUpstream {
|
|
p.canSelfUninstall.Store(true)
|
|
}
|
|
p.localUpstreams = localUpstreams
|
|
p.ptrNameservers = ptrNameservers
|
|
}
|
|
|
|
// run runs the ctrld main components.
|
|
//
|
|
// The reload boolean indicates that the function is run when ctrld first start
|
|
// or when ctrld receive reloading signal. Platform specifics setup is only done
|
|
// on started, mean reload is "false".
|
|
//
|
|
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
|
|
// so all listeners could be terminated and re-spawned again.
|
|
func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
|
// Wait the caller to signal that we can do our logic.
|
|
<-p.waitCh
|
|
if !reload {
|
|
p.preRun()
|
|
}
|
|
numListeners := len(p.cfg.Listener)
|
|
if !reload {
|
|
p.started = make(chan struct{}, numListeners)
|
|
if p.cs != nil {
|
|
p.csSetDnsDone = make(chan struct{}, 1)
|
|
p.registerControlServerHandler()
|
|
if err := p.cs.start(); err != nil {
|
|
p.Warn().Err(err).Msg("Could not start control server")
|
|
}
|
|
p.Debug().Msgf("Control server started: %s", p.cs.addr)
|
|
}
|
|
}
|
|
p.onStartedDone = make(chan struct{})
|
|
p.loop = make(map[string]bool)
|
|
p.lanLoopGuard = newLoopGuard()
|
|
p.ptrLoopGuard = newLoopGuard()
|
|
p.cacheFlushDomainsMap = nil
|
|
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
|
|
if p.cfg.Service.CacheEnable {
|
|
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
|
if err != nil {
|
|
p.Error().Err(err).Msg("Failed to create cacher, caching is disabled")
|
|
} else {
|
|
p.cache = cacher
|
|
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
|
|
for _, domain := range p.cfg.Service.CacheFlushDomains {
|
|
p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(p.cfg.Listener))
|
|
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
p.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("Invalid cidr")
|
|
continue
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
p.um = newUpstreamMonitor(p.cfg, p.logger.Load())
|
|
|
|
if !reload {
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
|
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
|
n := *mcr
|
|
if n == 0 {
|
|
p.sema = &noopSemaphore{}
|
|
} else {
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
|
}
|
|
}
|
|
p.setupUpstream(p.cfg)
|
|
p.setupClientInfoDiscover()
|
|
}
|
|
|
|
// context for managing spawn goroutines.
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
defer cancelFunc()
|
|
|
|
// Newer versions of android and iOS denies permission which breaks connectivity.
|
|
if !isMobile() && !reload {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
p.runClientInfoDiscover(ctx)
|
|
}()
|
|
go p.watchLinkState(ctx)
|
|
}
|
|
|
|
if !reload {
|
|
go func() {
|
|
// Start network monitoring
|
|
if err := p.monitorNetworkChanges(ctx); err != nil {
|
|
p.Error().Err(err).Msg("Failed to start network monitoring")
|
|
}
|
|
}()
|
|
}
|
|
|
|
for listenerNum := range p.cfg.Listener {
|
|
p.cfg.Listener[listenerNum].Init()
|
|
if !reload {
|
|
go func(listenerNum string) {
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
upstreamConfig := p.cfg.Upstream[listenerNum]
|
|
if upstreamConfig == nil {
|
|
p.Warn().Msgf("No default upstream for: [listener.%s]", listenerNum)
|
|
}
|
|
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
|
p.Info().Msgf("Starting dns server on listener.%s: %s", listenerNum, addr)
|
|
// serveCtx uses Background() context so listeners survive between reloads.
|
|
// Changes to listeners config require a service restart, not just reload.
|
|
serveCtx := context.Background()
|
|
if err := p.serveDNS(serveCtx, listenerNum); err != nil {
|
|
p.Fatal().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum)
|
|
}
|
|
p.Debug().Msgf("End of serveDNS listener.%s: %s", listenerNum, addr)
|
|
}(listenerNum)
|
|
}
|
|
go func() {
|
|
defer func() {
|
|
cancelFunc()
|
|
wg.Done()
|
|
}()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-ctx.Done():
|
|
case <-reloadCh:
|
|
}
|
|
}()
|
|
}
|
|
|
|
if !reload {
|
|
for i := 0; i < numListeners; i++ {
|
|
<-p.started
|
|
}
|
|
for _, f := range p.onStarted {
|
|
f()
|
|
}
|
|
}
|
|
|
|
close(p.onStartedDone)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
// Check for possible DNS loop.
|
|
p.checkDnsLoop()
|
|
// Start check DNS loop ticker.
|
|
p.checkDnsLoopTicker(ctx)
|
|
}()
|
|
|
|
wg.Add(1)
|
|
// Prometheus exporter goroutine.
|
|
go func() {
|
|
defer wg.Done()
|
|
p.runMetricsServer(ctx, reloadCh)
|
|
}()
|
|
|
|
if !reload {
|
|
// Stop writing log to unix socket.
|
|
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
|
|
p.initLogging(false)
|
|
if p.logConn != nil {
|
|
_ = p.logConn.Close()
|
|
}
|
|
go p.apiConfigReload()
|
|
p.postRun()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// setupClientInfoDiscover performs necessary works for running client info discover.
|
|
func (p *prog) setupClientInfoDiscover() {
|
|
selfIP := p.defaultRouteIP()
|
|
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, p.logger.Load())
|
|
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
|
p.Debug().Msgf("Watching custom lease file: %s", leaseFile)
|
|
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
|
p.ciTable.AddLeaseFile(leaseFile, format)
|
|
}
|
|
}
|
|
|
|
// runClientInfoDiscover runs the client info discover.
|
|
func (p *prog) runClientInfoDiscover(ctx context.Context) {
|
|
p.ciTable.Init()
|
|
p.ciTable.RefreshLoop(ctx)
|
|
}
|
|
|
|
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
|
|
func (p *prog) metricsEnabled() bool {
|
|
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
|
|
}
|
|
|
|
func (p *prog) Stop(_ service.Service) error {
|
|
p.stopDnsWatchers()
|
|
p.Debug().Msg("Dns watchers stopped")
|
|
for _, f := range p.onStopped {
|
|
f()
|
|
}
|
|
p.Debug().Msg("Finish running onStopped functions")
|
|
defer func() {
|
|
p.Info().Msg("Service stopped")
|
|
}()
|
|
if err := p.deAllocateIP(); err != nil {
|
|
p.Error().Err(err).Msg("De-allocate ip failed")
|
|
return err
|
|
}
|
|
if deactivationPinSet() {
|
|
select {
|
|
case <-p.pinCodeValidCh:
|
|
// Allow stopping the service, pinCodeValidCh is only filled
|
|
// after control server did validate the pin code.
|
|
case <-time.After(time.Millisecond * 100):
|
|
// No valid pin code was checked, that mean we are stopping
|
|
// because of OS signal sent directly from someone else.
|
|
// In this case, restarting ctrld service by ourselves.
|
|
p.Debug().Msgf("Receiving stopping signal without valid pin code")
|
|
p.Debug().Msgf("Self restarting ctrld service")
|
|
if exe, err := os.Executable(); err == nil {
|
|
cmd := exec.Command(exe, "restart")
|
|
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
|
|
if err := cmd.Start(); err != nil {
|
|
p.Error().Err(err).Msg("Failed to run self restart command")
|
|
}
|
|
} else {
|
|
p.Error().Err(err).Msg("Failed to self restart ctrld service")
|
|
}
|
|
os.Exit(deactivationPinInvalidExitCode)
|
|
}
|
|
}
|
|
close(p.stopCh)
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) stopDnsWatchers() {
|
|
// Ensure all DNS watchers goroutine are terminated,
|
|
// so it won't mess up with other DNS changes.
|
|
p.dnsWatcherClosedOnce.Do(func() {
|
|
close(p.dnsWatcherStopCh)
|
|
})
|
|
p.dnsWg.Wait()
|
|
}
|
|
|
|
func (p *prog) allocateIP(ip string) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
return allocateIP(ip)
|
|
}
|
|
|
|
func (p *prog) deAllocateIP() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
for _, lc := range p.cfg.Listener {
|
|
if err := deAllocateIP(lc.IP); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) setDNS() {
|
|
setDnsOK := false
|
|
defer func() {
|
|
p.csSetDnsOk = setDnsOK
|
|
}()
|
|
|
|
if cfg.Listener == nil {
|
|
return
|
|
}
|
|
lc := cfg.FirstListener()
|
|
if lc == nil {
|
|
return
|
|
}
|
|
ns := lc.IP
|
|
switch {
|
|
case lc.IsDirectDnsListener():
|
|
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
|
|
ns = "127.0.0.1"
|
|
case lc.Port != 53:
|
|
ns = "127.0.0.1"
|
|
default:
|
|
// If we ever reach here, it means ctrld is running on lc.IP port 53,
|
|
// so we could just use lc.IP as nameserver.
|
|
}
|
|
|
|
nameservers := []string{ns}
|
|
if needRFC1918Listeners(lc) {
|
|
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
|
}
|
|
if needLocalIPv6Listener() {
|
|
nameservers = append(nameservers, "::1")
|
|
}
|
|
|
|
slices.Sort(nameservers)
|
|
|
|
netIfaceName := ""
|
|
netIface := p.setDnsForRunningIface(nameservers)
|
|
if netIface != nil {
|
|
netIfaceName = netIface.Name
|
|
}
|
|
setDnsOK = true
|
|
|
|
if p.requiredMultiNICsConfig {
|
|
withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error {
|
|
return setDnsIgnoreUnusableInterface(i, nameservers)
|
|
})
|
|
}
|
|
// resolvconf file is only useful when we have default route interface,
|
|
// then set DNS on this interface will push change to /etc/resolv.conf file.
|
|
if netIface != nil && shouldWatchResolvconf() {
|
|
servers := make([]netip.Addr, len(nameservers))
|
|
for i := range nameservers {
|
|
servers[i] = netip.MustParseAddr(nameservers[i])
|
|
}
|
|
p.dnsWg.Add(1)
|
|
go func() {
|
|
defer p.dnsWg.Done()
|
|
p.watchResolvConf(netIface, servers, p.setResolvConf)
|
|
}()
|
|
}
|
|
if p.dnsWatchdogEnabled() {
|
|
p.dnsWg.Add(1)
|
|
go func() {
|
|
defer p.dnsWg.Done()
|
|
p.dnsWatchdog(netIface, nameservers)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) {
|
|
if p.runningIface == "" {
|
|
return
|
|
}
|
|
|
|
logger := p.logger.Load().With().Str("iface", p.runningIface)
|
|
|
|
const maxDNSRetryAttempts = 3
|
|
const retryDelay = 1 * time.Second
|
|
var netIface *net.Interface
|
|
var err error
|
|
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
|
|
netIface, err = netInterface(p.runningIface)
|
|
if err == nil {
|
|
break
|
|
}
|
|
if attempt < maxDNSRetryAttempts {
|
|
// Try to find a different working interface
|
|
newIface := p.findWorkingInterface()
|
|
if newIface != p.runningIface {
|
|
p.runningIface = newIface
|
|
logger = p.logger.Load().With().Str("iface", p.runningIface)
|
|
logger.Info().Msg("Switched to new interface")
|
|
continue
|
|
}
|
|
|
|
logger.Warn().Err(err).Int("attempt", attempt).Msg("Could not get interface, retrying...")
|
|
time.Sleep(retryDelay)
|
|
continue
|
|
}
|
|
logger.Error().Err(err).Msg("Could not get interface after all attempts")
|
|
return
|
|
}
|
|
if err := p.setupNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("Could not patch networkmanager")
|
|
return
|
|
}
|
|
|
|
runningIface = netIface
|
|
logger.Debug().Msg("Setting dns for interface")
|
|
if err := setDNS(netIface, nameservers); err != nil {
|
|
logger.Error().Err(err).Msgf("Could not set dns for interface")
|
|
return
|
|
}
|
|
logger.Debug().Msg("Setting dns successfully")
|
|
return
|
|
}
|
|
|
|
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
|
|
func (p *prog) dnsWatchdogEnabled() bool {
|
|
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
|
|
return *ptr
|
|
}
|
|
return true
|
|
}
|
|
|
|
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
|
|
func (p *prog) dnsWatchdogDuration() time.Duration {
|
|
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
|
|
if (*ptr).Seconds() > 0 {
|
|
return *ptr
|
|
}
|
|
}
|
|
return dnsWatchdogDefaultInterval
|
|
}
|
|
|
|
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
|
|
// This is only works when deactivation pin set.
|
|
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
|
if !requiredMultiNICsConfig() {
|
|
return
|
|
}
|
|
|
|
p.Debug().Msg("Start dns settings watchdog")
|
|
|
|
ns := nameservers
|
|
slices.Sort(ns)
|
|
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
|
|
|
for {
|
|
select {
|
|
case <-p.dnsWatcherStopCh:
|
|
return
|
|
case <-p.stopCh:
|
|
p.Debug().Msg("Stop dns watchdog")
|
|
return
|
|
case <-ticker.C:
|
|
if p.recoveryRunning.Load() {
|
|
return
|
|
}
|
|
if p.dnsChanged(iface, ns) {
|
|
p.Debug().Msg("DNS settings were changed, re-applying settings")
|
|
// Check if the interface already has static DNS servers configured.
|
|
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
|
staticDNS, err := currentStaticDNS(iface)
|
|
if err != nil {
|
|
p.Debug().Err(err).Msgf("Failed to get static DNS for interface %s", iface.Name)
|
|
} else if len(staticDNS) > 0 {
|
|
//filter out loopback addresses
|
|
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
|
return net.ParseIP(s).IsLoopback()
|
|
})
|
|
// if we have a static config and no saved IPs already, save them
|
|
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 {
|
|
// Save these static DNS values so that they can be restored later.
|
|
if err := saveCurrentStaticDNS(iface); err != nil {
|
|
p.Debug().Err(err).Msgf("Failed to save static DNS for interface %s", iface.Name)
|
|
}
|
|
}
|
|
}
|
|
if err := setDNS(iface, ns); err != nil {
|
|
p.Error().Err(err).Str("iface", iface.Name).Msgf("Could not re-apply DNS settings")
|
|
}
|
|
}
|
|
if p.requiredMultiNICsConfig {
|
|
ifaceName := ""
|
|
if iface != nil {
|
|
ifaceName = iface.Name
|
|
}
|
|
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
|
|
if p.dnsChanged(i, ns) {
|
|
|
|
// Check if the interface already has static DNS servers configured.
|
|
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
|
staticDNS, err := currentStaticDNS(i)
|
|
if err != nil {
|
|
p.Debug().Err(err).Msgf("Failed to get static DNS for interface %s", i.Name)
|
|
} else if len(staticDNS) > 0 {
|
|
//filter out loopback addresses
|
|
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
|
return net.ParseIP(s).IsLoopback()
|
|
})
|
|
// if we have a static config and no saved IPs already, save them
|
|
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 {
|
|
// Save these static DNS values so that they can be restored later.
|
|
if err := saveCurrentStaticDNS(i); err != nil {
|
|
p.Debug().Err(err).Msgf("Failed to save static DNS for interface %s", i.Name)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
|
p.Error().Err(err).Str("iface", i.Name).Msgf("Could not re-apply DNS settings")
|
|
} else {
|
|
p.Debug().Msgf("Re-applying DNS for interface %q successfully", i.Name)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// resetDNS performs a DNS reset for all interfaces.
|
|
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
|
|
netIfaceName := ""
|
|
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
|
|
netIfaceName = netIface.Name
|
|
}
|
|
// See corresponding comments in (*prog).setDNS function.
|
|
if p.requiredMultiNICsConfig {
|
|
withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface)
|
|
}
|
|
}
|
|
|
|
// resetDNSForRunningIface performs a DNS reset on the running interface.
|
|
// The parameter isStart indicates whether this is being called as part of a start (or restart)
|
|
// command. When true, we check if the current static DNS configuration already differs from the
|
|
// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's
|
|
// static DNS settings and we do not override them using the potentially out-of-date saved file.
|
|
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
|
|
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
|
|
if p.runningIface == "" {
|
|
p.Debug().Msg("No running interface, skipping resetDNS")
|
|
return
|
|
}
|
|
logger := p.logger.Load().With().Str("iface", p.runningIface)
|
|
netIface, err := netInterface(p.runningIface)
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("Could not get interface")
|
|
return
|
|
}
|
|
runningIface = netIface
|
|
if err := p.restoreNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("Could not restore NetworkManager")
|
|
return
|
|
}
|
|
|
|
// If starting, check the current static DNS configuration.
|
|
if isStart {
|
|
current, err := currentStaticDNS(netIface)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msg("Unable to obtain current static DNS configuration; proceeding to restore saved config")
|
|
} else if len(current) > 0 {
|
|
// If any static DNS value is not our own listener, assume an admin override.
|
|
hasManualConfig := false
|
|
for _, ns := range current {
|
|
if ns != "127.0.0.1" && ns != "::1" {
|
|
hasManualConfig = true
|
|
break
|
|
}
|
|
}
|
|
if hasManualConfig {
|
|
logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default logic: if there is a saved static DNS configuration, restore it.
|
|
saved := ctrld.SavedStaticNameservers(netIface)
|
|
if len(saved) > 0 && restoreStatic {
|
|
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
|
|
if err := setDNS(netIface, saved); err != nil {
|
|
logger.Error().Err(err).Msgf("Failed to restore static DNS config on interface %q", netIface.Name)
|
|
return
|
|
}
|
|
} else {
|
|
logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name)
|
|
if err := resetDNS(netIface); err != nil {
|
|
logger.Error().Err(err).Msgf("Failed to reset DNS to DHCP on interface %q", netIface.Name)
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (p *prog) logInterfacesState() {
|
|
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
|
addrs, err := i.Addrs()
|
|
if err != nil {
|
|
p.Warn().Str("interface", i.Name).Err(err).Msg("Failed to get addresses")
|
|
}
|
|
nss, err := currentStaticDNS(i)
|
|
if err != nil {
|
|
p.Warn().Str("interface", i.Name).Err(err).Msg("Failed to get DNS")
|
|
}
|
|
if len(nss) == 0 {
|
|
nss = currentDNS(i)
|
|
}
|
|
p.Debug().
|
|
Any("addrs", addrs).
|
|
Strs("nameservers", nss).
|
|
Int("index", i.Index).
|
|
Msgf("interface state: %s", i.Name)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// findWorkingInterface looks for a network interface with a valid IP configuration
|
|
func (p *prog) findWorkingInterface() string {
|
|
currentIface := p.runningIface
|
|
// Helper to check if IP is valid (not link-local)
|
|
isValidIP := func(ip net.IP) bool {
|
|
return ip != nil &&
|
|
!ip.IsLinkLocalUnicast() &&
|
|
!ip.IsLinkLocalMulticast() &&
|
|
!ip.IsLoopback() &&
|
|
!ip.IsUnspecified()
|
|
}
|
|
|
|
// Helper to check if interface has valid IP configuration
|
|
hasValidIPConfig := func(iface *net.Interface) bool {
|
|
if iface == nil || iface.Flags&net.FlagUp == 0 {
|
|
return false
|
|
}
|
|
|
|
addrs, err := iface.Addrs()
|
|
if err != nil {
|
|
p.Debug().
|
|
Str("interface", iface.Name).
|
|
Err(err).
|
|
Msg("failed to get interface addresses")
|
|
return false
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
// Check for IP network
|
|
if ipNet, ok := addr.(*net.IPNet); ok {
|
|
if isValidIP(ipNet.IP) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Get default route interface
|
|
foundDefaultRoute := false
|
|
defaultRoute, err := netmon.DefaultRoute()
|
|
if err != nil {
|
|
p.Debug().
|
|
Err(err).
|
|
Msg("failed to get default route")
|
|
} else {
|
|
foundDefaultRoute = true
|
|
p.Debug().
|
|
Str("default_route_iface", defaultRoute.InterfaceName).
|
|
Msg("found default route")
|
|
}
|
|
|
|
// Get all interfaces
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
p.Error().Err(err).Msg("Failed to list network interfaces")
|
|
return currentIface // Return current interface as fallback
|
|
}
|
|
|
|
var firstWorkingIface string
|
|
var currentIfaceValid bool
|
|
|
|
// Single pass through interfaces
|
|
for _, iface := range ifaces {
|
|
// Must be physical (has MAC address)
|
|
if len(iface.HardwareAddr) == 0 {
|
|
continue
|
|
}
|
|
// Skip interfaces that are:
|
|
// - Loopback
|
|
// - Not up
|
|
// - Point-to-point (like VPN tunnels)
|
|
if iface.Flags&net.FlagLoopback != 0 ||
|
|
iface.Flags&net.FlagUp == 0 ||
|
|
iface.Flags&net.FlagPointToPoint != 0 {
|
|
continue
|
|
}
|
|
|
|
if !hasValidIPConfig(&iface) {
|
|
continue
|
|
}
|
|
|
|
// Found working physical interface
|
|
if foundDefaultRoute && defaultRoute.InterfaceName == iface.Name {
|
|
// Found interface with default route - use it immediately
|
|
p.Info().
|
|
Str("old_iface", currentIface).
|
|
Str("new_iface", iface.Name).
|
|
Msg("switching to interface with default route")
|
|
return iface.Name
|
|
}
|
|
|
|
// Keep track of first working interface as fallback
|
|
if firstWorkingIface == "" {
|
|
firstWorkingIface = iface.Name
|
|
}
|
|
|
|
// Check if this is our current interface
|
|
if iface.Name == currentIface {
|
|
currentIfaceValid = true
|
|
}
|
|
}
|
|
|
|
// Return interfaces in order of preference:
|
|
// 1. Current interface if it's still valid
|
|
if currentIfaceValid {
|
|
p.Debug().
|
|
Str("interface", currentIface).
|
|
Msg("keeping current interface")
|
|
return currentIface
|
|
}
|
|
|
|
// 2. First working interface found
|
|
if firstWorkingIface != "" {
|
|
p.Info().
|
|
Str("old_iface", currentIface).
|
|
Str("new_iface", firstWorkingIface).
|
|
Msg("switching to first working physical interface")
|
|
return firstWorkingIface
|
|
}
|
|
|
|
// 3. Fall back to current interface if nothing else works
|
|
p.Warn().
|
|
Str("current_iface", currentIface).
|
|
Msg("No working physical interface found, keeping current")
|
|
return currentIface
|
|
}
|
|
|
|
func randomLocalIP() string {
|
|
n := rand.Intn(254-2) + 2
|
|
return fmt.Sprintf("127.0.0.%d", n)
|
|
}
|
|
|
|
func randomPort() int {
|
|
max := 1<<16 - 1
|
|
min := 1025
|
|
n := rand.Intn(max-min) + min
|
|
return n
|
|
}
|
|
|
|
func errAddrInUse(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
return errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(opErr.Err, windowsEADDRINUSE)
|
|
}
|
|
return false
|
|
}
|
|
|
|
var _ = errAddrInUse
|
|
|
|
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
|
|
var (
|
|
windowsECONNREFUSED = syscall.Errno(10061)
|
|
windowsENETUNREACH = syscall.Errno(10051)
|
|
windowsEINVAL = syscall.Errno(10022)
|
|
windowsEADDRINUSE = syscall.Errno(10048)
|
|
windowsEHOSTUNREACH = syscall.Errno(10065)
|
|
)
|
|
|
|
func errUrlNetworkError(err error) bool {
|
|
var urlErr *url.Error
|
|
if errors.As(err, &urlErr) {
|
|
return errNetworkError(urlErr.Err)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func errNetworkError(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
if opErr.Temporary() {
|
|
return true
|
|
}
|
|
switch {
|
|
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
|
errors.Is(opErr.Err, syscall.EINVAL),
|
|
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
|
errors.Is(opErr.Err, windowsENETUNREACH),
|
|
errors.Is(opErr.Err, windowsEINVAL),
|
|
errors.Is(opErr.Err, windowsECONNREFUSED),
|
|
errors.Is(opErr.Err, windowsEHOSTUNREACH):
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// errConnectionRefused reports whether err is connection refused.
|
|
func errConnectionRefused(err error) bool {
|
|
var opErr *net.OpError
|
|
if !errors.As(err, &opErr) {
|
|
return false
|
|
}
|
|
return errors.Is(opErr.Err, syscall.ECONNREFUSED) || errors.Is(opErr.Err, windowsECONNREFUSED)
|
|
}
|
|
|
|
func ifaceFirstPrivateIP(iface *net.Interface) string {
|
|
if iface == nil {
|
|
return ""
|
|
}
|
|
do := func(addrs []net.Addr, v4 bool) net.IP {
|
|
for _, addr := range addrs {
|
|
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
|
if v4 {
|
|
return netIP.IP.To4()
|
|
}
|
|
return netIP.IP
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
addrs, _ := iface.Addrs()
|
|
if ip := do(addrs, true); ip != nil {
|
|
return ip.String()
|
|
}
|
|
if ip := do(addrs, false); ip != nil {
|
|
return ip.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
|
|
func (p *prog) defaultRouteIP() string {
|
|
dr, err := netmon.DefaultRoute()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
drNetIface, err := netInterface(dr.InterfaceName)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
p.Debug().Str("iface", drNetIface.Name).Msg("Checking default route interface")
|
|
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
|
|
p.Debug().Str("ip", ip).Msg("Found ip with default route interface")
|
|
return ip
|
|
}
|
|
|
|
// If we reach here, it means the default route interface is connected directly to ISP.
|
|
// We need to find the LAN interface with the same Mac address with drNetIface.
|
|
//
|
|
// There could be multiple LAN interfaces with the same Mac address, so we find all private
|
|
// IPs then using the smallest one.
|
|
var addrs []netip.Addr
|
|
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
|
if i.Name == drNetIface.Name {
|
|
return
|
|
}
|
|
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
|
|
for _, pfx := range prefixes {
|
|
addr := pfx.Addr()
|
|
if addr.IsPrivate() {
|
|
addrs = append(addrs, addr)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
if len(addrs) == 0 {
|
|
p.Warn().Msg("No default route IP found")
|
|
return ""
|
|
}
|
|
sort.Slice(addrs, func(i, j int) bool {
|
|
return addrs[i].Less(addrs[j])
|
|
})
|
|
|
|
ip := addrs[0].String()
|
|
p.Debug().Str("ip", ip).Msg("Found LAN interface IP")
|
|
return ip
|
|
}
|
|
|
|
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
|
|
func canBeLocalUpstream(addr string) bool {
|
|
if ip, err := netip.ParseAddr(addr); err == nil {
|
|
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding
|
|
// the interface that matches excludeIfaceName. The context is used to clarify the
|
|
// log message when error happens.
|
|
func withEachPhysicalInterfaces(excludeIfaceName, contextStr string, f func(i *net.Interface) error) {
|
|
validIfacesMap := validInterfacesMap(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
|
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
|
// Skip loopback/virtual/down interface.
|
|
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
|
return
|
|
}
|
|
// Skip invalid interface.
|
|
if !validInterface(i.Interface, validIfacesMap) {
|
|
return
|
|
}
|
|
netIface := i.Interface
|
|
if patched, err := patchNetIfaceName(netIface); err != nil {
|
|
mainLog.Load().Debug().Err(err).Msg("Failed to patch net interface name")
|
|
return
|
|
} else if !patched {
|
|
// The interface is not functional, skipping.
|
|
return
|
|
}
|
|
// Skip excluded interface.
|
|
if netIface.Name == excludeIfaceName {
|
|
return
|
|
}
|
|
// TODO: investigate whether we should report this error?
|
|
if err := f(netIface); err == nil {
|
|
if contextStr != "" {
|
|
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", contextStr, i.Name)
|
|
}
|
|
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
|
mainLog.Load().Err(err).Msgf("%s for interface %q failed", contextStr, i.Name)
|
|
}
|
|
})
|
|
}
|
|
|
|
// requiredMultiNicConfig reports whether ctrld needs to set/reset DNS for multiple NICs.
|
|
func requiredMultiNICsConfig() bool {
|
|
switch runtime.GOOS {
|
|
case "windows", "darwin":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not supported on this platform")
|
|
|
|
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
|
|
// Only works on Windows and Mac.
|
|
func saveCurrentStaticDNS(iface *net.Interface) error {
|
|
if iface == nil {
|
|
mainLog.Load().Debug().Msg("Could not save current static DNS settings for nil interface")
|
|
return nil
|
|
}
|
|
switch runtime.GOOS {
|
|
case "windows", "darwin":
|
|
default:
|
|
return errSaveCurrentStaticDNSNotSupported
|
|
}
|
|
file := ctrld.SavedStaticDnsSettingsFilePath(iface)
|
|
ns, err := currentStaticDNS(iface)
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msgf("Could not get current static DNS settings for %q", iface.Name)
|
|
return err
|
|
}
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Debug().Msgf("No static DNS settings for %q, removing old static DNS settings file", iface.Name)
|
|
_ = os.Remove(file) // removing old static DNS settings
|
|
return nil
|
|
}
|
|
//filter out loopback addresses
|
|
ns = slices.DeleteFunc(ns, func(s string) bool {
|
|
return net.ParseIP(s).IsLoopback()
|
|
})
|
|
//if we now have no static DNS settings and the file already exists
|
|
// return and do not save the file
|
|
if len(ns) == 0 {
|
|
mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name)
|
|
return nil
|
|
}
|
|
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
mainLog.Load().Warn().Err(err).Msgf("Could not remove old static DNS settings file: %s", file)
|
|
}
|
|
nss := strings.Join(ns, ",")
|
|
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
|
|
if err := os.WriteFile(file, []byte(nss), 0600); err != nil {
|
|
mainLog.Load().Err(err).Msgf("Could not save DNS settings for iface: %s", iface.Name)
|
|
return err
|
|
}
|
|
mainLog.Load().Debug().Msgf("Save DNS settings for interface %q successfully", iface.Name)
|
|
return nil
|
|
}
|
|
|
|
// dnsChanged reports whether DNS settings for given interface was changed.
|
|
// It returns false for a nil iface.
|
|
//
|
|
// The caller must sort the nameservers before calling this function.
|
|
func (p *prog) dnsChanged(iface *net.Interface, nameservers []string) bool {
|
|
if iface == nil {
|
|
return false
|
|
}
|
|
curNameservers, _ := currentStaticDNS(iface)
|
|
slices.Sort(curNameservers)
|
|
if !slices.Equal(curNameservers, nameservers) {
|
|
p.Debug().Msgf("Interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
|
|
func selfUninstallCheck(uninstallErr error, p *prog, logger *ctrld.Logger) {
|
|
var uer *controld.ErrorResponse
|
|
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
|
p.stopDnsWatchers()
|
|
|
|
// Perform self-uninstall now.
|
|
selfUninstall(p, logger)
|
|
}
|
|
}
|
|
|
|
// shouldUpgrade checks if the version target vt is greater than the current one cv.
|
|
// Major version upgrades are not allowed to prevent breaking changes.
|
|
//
|
|
// The callers must ensure curVer and logger are non-nil.
|
|
// Returns true if upgrade is allowed, false otherwise.
|
|
func shouldUpgrade(vt string, cv *semver.Version, logger *ctrld.Logger) bool {
|
|
if vt == "" {
|
|
logger.Debug().Msg("No version target set, skipped checking self-upgrade")
|
|
return false
|
|
}
|
|
vts := vt
|
|
if !strings.HasPrefix(vts, "v") {
|
|
vts = "v" + vts
|
|
}
|
|
targetVer, err := semver.NewVersion(vts)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("Invalid target version, skipped self-upgrade: %s", vt)
|
|
return false
|
|
}
|
|
|
|
// Prevent major version upgrades to avoid breaking changes
|
|
if targetVer.Major() != cv.Major() {
|
|
logger.Warn().
|
|
Str("target", vt).
|
|
Str("current", cv.String()).
|
|
Msgf("Major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major())
|
|
return false
|
|
}
|
|
|
|
if !targetVer.GreaterThan(cv) {
|
|
logger.Debug().
|
|
Str("target", vt).
|
|
Str("current", cv.String()).
|
|
Msgf("Target version is not greater than current one, skipped self-upgrade")
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// performUpgrade executes the self-upgrade command.
|
|
// Returns true if upgrade was initiated successfully, false otherwise.
|
|
func performUpgrade(vt string, logger *ctrld.Logger) bool {
|
|
exe, err := os.Executable()
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("Failed to get executable path, skipped self-upgrade")
|
|
return false
|
|
}
|
|
cmd := exec.Command(exe, "upgrade", "prod", "-vv")
|
|
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
|
|
if err := cmd.Start(); err != nil {
|
|
logger.Error().Err(err).Msg("Failed to start self-upgrade")
|
|
return false
|
|
}
|
|
logger.Debug().Msgf("Self-upgrade triggered, version target: %s", vt)
|
|
return true
|
|
}
|
|
|
|
// selfUpgradeCheck checks if the version target vt is greater
|
|
// than the current one cv, perform self-upgrade then.
|
|
// Major version upgrades are not allowed to prevent breaking changes.
|
|
//
|
|
// The callers must ensure curVer and logger are non-nil.
|
|
// Returns true if upgrade is allowed and should proceed, false otherwise.
|
|
func selfUpgradeCheck(vt string, cv *semver.Version, logger *ctrld.Logger) bool {
|
|
if shouldUpgrade(vt, cv, logger) {
|
|
return performUpgrade(vt, logger)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
|
|
// when upstream failures occur.
|
|
func (p *prog) leakOnUpstreamFailure() bool {
|
|
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
|
|
return *ptr
|
|
}
|
|
return true
|
|
}
|