mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
The "ctrld start" command is running slow, and using much CPU than necessary. The problem was made because of several things: 1. ctrld process is waiting for 5 seconds before marking listeners up. That ends up adding those seconds to the self-check process, even though the listeners may have been already available. 2. While creating socket control client, "s.Status()" is called to obtain ctrld service status, so we could terminate early if the service failed to run. However, that would make a lot of syscall in a hot loop, eating the CPU constantly while the command is running. On Windows, that call would become slower after each calls. The same effect could be seen using Windows services manager GUI, by pressing start/stop/restart button fast enough, we could see a timeout raised. 3. The socket control server is started lately, after all the listeners up. That would make the loop for creating socket control client run longer and use much resources than necessary. Fixes for these problems are quite obvious: 1. Removing hard code 5 seconds waiting. NotifyStartedFunc is enough to ensure that listeners are ready for accepting requests. 2. Check "s.Status()" only once before the loop. There has been already 30 seconds timeout, so if anything went wrong, the self-check process could be terminated, and won't hang forever. 3. Starting socket control server earlier, so newSocketControlClient can connect to server with fewest attempts, then querying "/started" endpoint to ensure the listeners have been ready. With these fixes, "ctrld start" now run much faster on modern machines, taking ~1-2 seconds (previously ~5-8 seconds) to finish. On dual cores VM, it takes ~5-8 seconds (previously a few dozen seconds or timeout). --- While at it, there are two refactoring for making the code easier to read/maintain: - PersistentPreRun is now used in root command to init console logging, so we don't have to initialize them in sub-commands. - NotifyStartedFunc now use channel for synchronization, instead of a mutex, making the ugly asymetric calls to lock goes away, making the code more idiom, and theoretically have better performance.
789 lines
20 KiB
Go
789 lines
20 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"math/rand"
|
|
"net"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"runtime"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/kardianos/service"
|
|
"github.com/spf13/viper"
|
|
"tailscale.com/net/interfaces"
|
|
"tailscale.com/net/tsaddr"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
"github.com/Control-D-Inc/ctrld/internal/router"
|
|
)
|
|
|
|
const (
|
|
defaultSemaphoreCap = 256
|
|
ctrldLogUnixSock = "ctrld_start.sock"
|
|
ctrldControlUnixSock = "ctrld_control.sock"
|
|
// iOS unix socket name max length is 11.
|
|
ctrldControlUnixSockMobile = "cd.sock"
|
|
upstreamPrefix = "upstream."
|
|
upstreamOS = upstreamPrefix + "os"
|
|
upstreamPrivate = upstreamPrefix + "private"
|
|
)
|
|
|
|
// ControlSocketName returns name for control unix socket.
|
|
func ControlSocketName() string {
|
|
if isMobile() {
|
|
return ctrldControlUnixSockMobile
|
|
} else {
|
|
return ctrldControlUnixSock
|
|
}
|
|
}
|
|
|
|
var logf = func(format string, args ...any) {
|
|
mainLog.Load().Debug().Msgf(format, args...)
|
|
}
|
|
|
|
var svcConfig = &service.Config{
|
|
Name: "ctrld",
|
|
DisplayName: "Control-D Helper Service",
|
|
Option: service.KeyValue{},
|
|
}
|
|
|
|
var useSystemdResolved = false
|
|
|
|
type prog struct {
|
|
mu sync.Mutex
|
|
waitCh chan struct{}
|
|
stopCh chan struct{}
|
|
reloadCh chan struct{} // For Windows.
|
|
reloadDoneCh chan struct{}
|
|
logConn net.Conn
|
|
cs *controlServer
|
|
|
|
cfg *ctrld.Config
|
|
localUpstreams []string
|
|
ptrNameservers []string
|
|
appCallback *AppCallback
|
|
cache dnscache.Cacher
|
|
cacheFlushDomainsMap map[string]struct{}
|
|
sema semaphore
|
|
ciTable *clientinfo.Table
|
|
um *upstreamMonitor
|
|
router router.Router
|
|
ptrLoopGuard *loopGuard
|
|
lanLoopGuard *loopGuard
|
|
|
|
loopMu sync.Mutex
|
|
loop map[string]bool
|
|
|
|
started chan struct{}
|
|
onStartedDone chan struct{}
|
|
onStarted []func()
|
|
onStopped []func()
|
|
}
|
|
|
|
func (p *prog) Start(s service.Service) error {
|
|
go p.runWait()
|
|
return nil
|
|
}
|
|
|
|
// runWait runs ctrld components, waiting for signal to reload.
|
|
func (p *prog) runWait() {
|
|
p.mu.Lock()
|
|
p.cfg = &cfg
|
|
p.mu.Unlock()
|
|
reloadSigCh := make(chan os.Signal, 1)
|
|
notifyReloadSigCh(reloadSigCh)
|
|
|
|
reload := false
|
|
logger := mainLog.Load()
|
|
for {
|
|
reloadCh := make(chan struct{})
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
p.run(reload, reloadCh)
|
|
reload = true
|
|
}()
|
|
select {
|
|
case sig := <-reloadSigCh:
|
|
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
|
case <-p.reloadCh:
|
|
logger.Notice().Msg("reloading...")
|
|
case <-p.stopCh:
|
|
close(reloadCh)
|
|
return
|
|
}
|
|
|
|
waitOldRunDone := func() {
|
|
close(reloadCh)
|
|
<-done
|
|
}
|
|
newCfg := &ctrld.Config{}
|
|
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
|
ctrld.InitConfig(v, "ctrld")
|
|
if configPath != "" {
|
|
v.SetConfigFile(configPath)
|
|
}
|
|
if err := v.ReadInConfig(); err != nil {
|
|
logger.Err(err).Msg("could not read new config")
|
|
waitOldRunDone()
|
|
continue
|
|
}
|
|
if err := v.Unmarshal(&newCfg); err != nil {
|
|
logger.Err(err).Msg("could not unmarshal new config")
|
|
waitOldRunDone()
|
|
continue
|
|
}
|
|
if cdUID != "" {
|
|
if err := processCDFlags(newCfg); err != nil {
|
|
logger.Err(err).Msg("could not fetch ControlD config")
|
|
waitOldRunDone()
|
|
continue
|
|
}
|
|
}
|
|
|
|
waitOldRunDone()
|
|
|
|
p.mu.Lock()
|
|
curListener := p.cfg.Listener
|
|
p.mu.Unlock()
|
|
|
|
for n, lc := range newCfg.Listener {
|
|
curLc := curListener[n]
|
|
if curLc == nil {
|
|
continue
|
|
}
|
|
if lc.IP == "" {
|
|
lc.IP = curLc.IP
|
|
}
|
|
if lc.Port == 0 {
|
|
lc.Port = curLc.Port
|
|
}
|
|
}
|
|
if err := validateConfig(newCfg); err != nil {
|
|
logger.Err(err).Msg("invalid config")
|
|
continue
|
|
}
|
|
|
|
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
|
// upstream config because its initialization function have not been called yet.
|
|
mainLog.Load().Debug().Msg("setup upstream with new config")
|
|
p.setupUpstream(newCfg)
|
|
|
|
p.mu.Lock()
|
|
*p.cfg = *newCfg
|
|
p.mu.Unlock()
|
|
|
|
logger.Notice().Msg("reloading config successfully")
|
|
select {
|
|
case p.reloadDoneCh <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *prog) preRun() {
|
|
if !service.Interactive() {
|
|
p.setDNS()
|
|
}
|
|
if runtime.GOOS == "darwin" {
|
|
p.onStopped = append(p.onStopped, func() {
|
|
if !service.Interactive() {
|
|
p.resetDNS()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
|
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
|
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
|
for n := range cfg.Upstream {
|
|
uc := cfg.Upstream[n]
|
|
uc.Init()
|
|
if uc.BootstrapIP == "" {
|
|
uc.SetupBootstrapIP()
|
|
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
|
} else {
|
|
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
|
}
|
|
uc.SetCertPool(rootCertPool)
|
|
go uc.Ping()
|
|
|
|
if canBeLocalUpstream(uc.Domain) {
|
|
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
|
}
|
|
if uc.IsDiscoverable() {
|
|
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
|
}
|
|
}
|
|
p.localUpstreams = localUpstreams
|
|
p.ptrNameservers = ptrNameservers
|
|
}
|
|
|
|
// run runs the ctrld main components.
|
|
//
|
|
// The reload boolean indicates that the function is run when ctrld first start
|
|
// or when ctrld receive reloading signal. Platform specifics setup is only done
|
|
// on started, mean reload is "false".
|
|
//
|
|
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
|
|
// so all listeners could be terminated and re-spawned again.
|
|
func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
|
// Wait the caller to signal that we can do our logic.
|
|
<-p.waitCh
|
|
if !reload {
|
|
p.preRun()
|
|
}
|
|
numListeners := len(p.cfg.Listener)
|
|
if !reload {
|
|
p.started = make(chan struct{}, numListeners)
|
|
if p.cs != nil {
|
|
p.registerControlServerHandler()
|
|
if err := p.cs.start(); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
|
}
|
|
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
|
|
}
|
|
}
|
|
p.onStartedDone = make(chan struct{})
|
|
p.loop = make(map[string]bool)
|
|
p.lanLoopGuard = newLoopGuard()
|
|
p.ptrLoopGuard = newLoopGuard()
|
|
p.cacheFlushDomainsMap = nil
|
|
if p.cfg.Service.CacheEnable {
|
|
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
|
|
} else {
|
|
p.cache = cacher
|
|
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
|
|
for _, domain := range p.cfg.Service.CacheFlushDomains {
|
|
p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(p.cfg.Listener))
|
|
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
|
|
continue
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
p.um = newUpstreamMonitor(p.cfg)
|
|
|
|
if !reload {
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
|
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
|
n := *mcr
|
|
if n == 0 {
|
|
p.sema = &noopSemaphore{}
|
|
} else {
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
|
}
|
|
}
|
|
p.setupUpstream(p.cfg)
|
|
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
|
|
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
|
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
|
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
|
p.ciTable.AddLeaseFile(leaseFile, format)
|
|
}
|
|
}
|
|
|
|
// context for managing spawn goroutines.
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
defer cancelFunc()
|
|
|
|
// Newer versions of android and iOS denies permission which breaks connectivity.
|
|
if !isMobile() && !reload {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
p.ciTable.Init()
|
|
p.ciTable.RefreshLoop(ctx)
|
|
}()
|
|
go p.watchLinkState(ctx)
|
|
}
|
|
|
|
for listenerNum := range p.cfg.Listener {
|
|
p.cfg.Listener[listenerNum].Init()
|
|
if !reload {
|
|
go func(listenerNum string) {
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
upstreamConfig := p.cfg.Upstream[listenerNum]
|
|
if upstreamConfig == nil {
|
|
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
|
}
|
|
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
|
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
|
if err := p.serveDNS(listenerNum); err != nil {
|
|
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
|
}
|
|
}(listenerNum)
|
|
}
|
|
go func() {
|
|
defer func() {
|
|
cancelFunc()
|
|
wg.Done()
|
|
}()
|
|
select {
|
|
case <-p.stopCh:
|
|
case <-ctx.Done():
|
|
case <-reloadCh:
|
|
}
|
|
}()
|
|
}
|
|
|
|
if !reload {
|
|
for i := 0; i < numListeners; i++ {
|
|
<-p.started
|
|
}
|
|
for _, f := range p.onStarted {
|
|
f()
|
|
}
|
|
}
|
|
|
|
close(p.onStartedDone)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
// Check for possible DNS loop.
|
|
p.checkDnsLoop()
|
|
// Start check DNS loop ticker.
|
|
p.checkDnsLoopTicker(ctx)
|
|
}()
|
|
|
|
wg.Add(1)
|
|
// Prometheus exporter goroutine.
|
|
go func() {
|
|
defer wg.Done()
|
|
p.runMetricsServer(ctx, reloadCh)
|
|
}()
|
|
|
|
if !reload {
|
|
// Stop writing log to unix socket.
|
|
consoleWriter.Out = os.Stdout
|
|
initLoggingWithBackup(false)
|
|
if p.logConn != nil {
|
|
_ = p.logConn.Close()
|
|
}
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
|
|
func (p *prog) metricsEnabled() bool {
|
|
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
|
|
}
|
|
|
|
func (p *prog) Stop(s service.Service) error {
|
|
mainLog.Load().Info().Msg("Service stopped")
|
|
close(p.stopCh)
|
|
if err := p.deAllocateIP(); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) allocateIP(ip string) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
return allocateIP(ip)
|
|
}
|
|
|
|
func (p *prog) deAllocateIP() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
for _, lc := range p.cfg.Listener {
|
|
if err := deAllocateIP(lc.IP); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) setDNS() {
|
|
if cfg.Listener == nil {
|
|
return
|
|
}
|
|
if iface == "" {
|
|
return
|
|
}
|
|
// allIfaces tracks whether we should set DNS for all physical interfaces.
|
|
allIfaces := false
|
|
if iface == "auto" {
|
|
iface = defaultIfaceName()
|
|
// If iface is "auto", it means user does not specify "--iface" flag.
|
|
// In this case, ctrld has to set DNS for all physical interfaces, so
|
|
// thing will still work when user switch from one to the other.
|
|
allIfaces = requiredMultiNICsConfig()
|
|
}
|
|
lc := cfg.FirstListener()
|
|
if lc == nil {
|
|
return
|
|
}
|
|
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
|
netIface, err := netInterface(iface)
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("could not get interface")
|
|
return
|
|
}
|
|
if err := setupNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("could not patch NetworkManager")
|
|
return
|
|
}
|
|
|
|
logger.Debug().Msg("setting DNS for interface")
|
|
ns := lc.IP
|
|
switch {
|
|
case lc.IsDirectDnsListener():
|
|
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
|
|
ns = "127.0.0.1"
|
|
case lc.Port != 53:
|
|
ns = "127.0.0.1"
|
|
if resolver := router.LocalResolverIP(); resolver != "" {
|
|
ns = resolver
|
|
}
|
|
default:
|
|
// If we ever reach here, it means ctrld is running on lc.IP port 53,
|
|
// so we could just use lc.IP as nameserver.
|
|
}
|
|
|
|
nameservers := []string{ns}
|
|
if needRFC1918Listeners(lc) {
|
|
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
|
}
|
|
if err := setDNS(netIface, nameservers); err != nil {
|
|
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
|
return
|
|
}
|
|
logger.Debug().Msg("setting DNS successfully")
|
|
if shouldWatchResolvconf() {
|
|
servers := make([]netip.Addr, len(nameservers))
|
|
for i := range nameservers {
|
|
servers[i] = netip.MustParseAddr(nameservers[i])
|
|
}
|
|
go watchResolvConf(netIface, servers, setResolvConf)
|
|
}
|
|
if allIfaces {
|
|
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
|
return setDnsIgnoreUnusableInterface(i, nameservers)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (p *prog) resetDNS() {
|
|
if iface == "" {
|
|
return
|
|
}
|
|
allIfaces := false
|
|
if iface == "auto" {
|
|
iface = defaultIfaceName()
|
|
// See corresponding comments in (*prog).setDNS function.
|
|
allIfaces = requiredMultiNICsConfig()
|
|
}
|
|
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
|
netIface, err := netInterface(iface)
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("could not get interface")
|
|
return
|
|
}
|
|
if err := restoreNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
|
return
|
|
}
|
|
logger.Debug().Msg("Restoring DNS for interface")
|
|
if err := resetDNS(netIface); err != nil {
|
|
logger.Error().Err(err).Msgf("could not reset DNS")
|
|
return
|
|
}
|
|
logger.Debug().Msg("Restoring DNS successfully")
|
|
if allIfaces {
|
|
withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface)
|
|
}
|
|
}
|
|
|
|
func randomLocalIP() string {
|
|
n := rand.Intn(254-2) + 2
|
|
return fmt.Sprintf("127.0.0.%d", n)
|
|
}
|
|
|
|
func randomPort() int {
|
|
max := 1<<16 - 1
|
|
min := 1025
|
|
n := rand.Intn(max-min) + min
|
|
return n
|
|
}
|
|
|
|
// runLogServer starts a unix listener, use by startCmd to gather log from runCmd.
|
|
func runLogServer(sockPath string) net.Conn {
|
|
addr, err := net.ResolveUnixAddr("unix", sockPath)
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("invalid log sock path")
|
|
return nil
|
|
}
|
|
ln, err := net.ListenUnix("unix", addr)
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not listen log socket")
|
|
return nil
|
|
}
|
|
defer ln.Close()
|
|
|
|
server, err := ln.Accept()
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not accept connection")
|
|
return nil
|
|
}
|
|
return server
|
|
}
|
|
|
|
func errAddrInUse(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
return errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(opErr.Err, windowsEADDRINUSE)
|
|
}
|
|
return false
|
|
}
|
|
|
|
var _ = errAddrInUse
|
|
|
|
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
|
|
var (
|
|
windowsECONNREFUSED = syscall.Errno(10061)
|
|
windowsENETUNREACH = syscall.Errno(10051)
|
|
windowsEINVAL = syscall.Errno(10022)
|
|
windowsEADDRINUSE = syscall.Errno(10048)
|
|
windowsEHOSTUNREACH = syscall.Errno(10065)
|
|
)
|
|
|
|
func errUrlNetworkError(err error) bool {
|
|
var urlErr *url.Error
|
|
if errors.As(err, &urlErr) {
|
|
return errNetworkError(urlErr.Err)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func errNetworkError(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
if opErr.Temporary() {
|
|
return true
|
|
}
|
|
switch {
|
|
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
|
errors.Is(opErr.Err, syscall.EINVAL),
|
|
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
|
errors.Is(opErr.Err, windowsENETUNREACH),
|
|
errors.Is(opErr.Err, windowsEINVAL),
|
|
errors.Is(opErr.Err, windowsECONNREFUSED),
|
|
errors.Is(opErr.Err, windowsEHOSTUNREACH):
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// errConnectionRefused reports whether err is connection refused.
|
|
func errConnectionRefused(err error) bool {
|
|
var opErr *net.OpError
|
|
if !errors.As(err, &opErr) {
|
|
return false
|
|
}
|
|
return errors.Is(opErr.Err, syscall.ECONNREFUSED) || errors.Is(opErr.Err, windowsECONNREFUSED)
|
|
}
|
|
|
|
func ifaceFirstPrivateIP(iface *net.Interface) string {
|
|
if iface == nil {
|
|
return ""
|
|
}
|
|
do := func(addrs []net.Addr, v4 bool) net.IP {
|
|
for _, addr := range addrs {
|
|
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
|
if v4 {
|
|
return netIP.IP.To4()
|
|
}
|
|
return netIP.IP
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
addrs, _ := iface.Addrs()
|
|
if ip := do(addrs, true); ip != nil {
|
|
return ip.String()
|
|
}
|
|
if ip := do(addrs, false); ip != nil {
|
|
return ip.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
|
|
func defaultRouteIP() string {
|
|
dr, err := interfaces.DefaultRoute()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
drNetIface, err := netInterface(dr.InterfaceName)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
|
|
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
|
|
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
|
|
return ip
|
|
}
|
|
|
|
// If we reach here, it means the default route interface is connected directly to ISP.
|
|
// We need to find the LAN interface with the same Mac address with drNetIface.
|
|
//
|
|
// There could be multiple LAN interfaces with the same Mac address, so we find all private
|
|
// IPs then using the smallest one.
|
|
var addrs []netip.Addr
|
|
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
|
if i.Name == drNetIface.Name {
|
|
return
|
|
}
|
|
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
|
|
for _, pfx := range prefixes {
|
|
addr := pfx.Addr()
|
|
if addr.IsPrivate() {
|
|
addrs = append(addrs, addr)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
if len(addrs) == 0 {
|
|
mainLog.Load().Warn().Msg("no default route IP found")
|
|
return ""
|
|
}
|
|
sort.Slice(addrs, func(i, j int) bool {
|
|
return addrs[i].Less(addrs[j])
|
|
})
|
|
|
|
ip := addrs[0].String()
|
|
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
|
return ip
|
|
}
|
|
|
|
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
|
|
func canBeLocalUpstream(addr string) bool {
|
|
if ip, err := netip.ParseAddr(addr); err == nil {
|
|
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding
|
|
// the interface that matches excludeIfaceName. The context is used to clarify the
|
|
// log message when error happens.
|
|
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
|
|
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
|
// Skip loopback/virtual interface.
|
|
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
|
return
|
|
}
|
|
// Skip invalid interface.
|
|
if !validInterface(i.Interface) {
|
|
return
|
|
}
|
|
netIface := i.Interface
|
|
if err := patchNetIfaceName(netIface); err != nil {
|
|
mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name")
|
|
return
|
|
}
|
|
// Skip excluded interface.
|
|
if netIface.Name == excludeIfaceName {
|
|
return
|
|
}
|
|
// TODO: investigate whether we should report this error?
|
|
if err := f(netIface); err == nil {
|
|
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
|
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
|
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
|
|
}
|
|
})
|
|
}
|
|
|
|
// requiredMultiNicConfig reports whether ctrld needs to set/reset DNS for multiple NICs.
|
|
func requiredMultiNICsConfig() bool {
|
|
switch runtime.GOOS {
|
|
case "windows", "darwin":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not supported on this platform")
|
|
|
|
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
|
|
// Only works on Windows and Mac.
|
|
func saveCurrentStaticDNS(iface *net.Interface) error {
|
|
switch runtime.GOOS {
|
|
case "windows", "darwin":
|
|
default:
|
|
return errSaveCurrentStaticDNSNotSupported
|
|
}
|
|
file := savedStaticDnsSettingsFilePath(iface)
|
|
ns, _ := currentStaticDNS(iface)
|
|
if len(ns) == 0 {
|
|
_ = os.Remove(file) // removing old static DNS settings
|
|
return nil
|
|
}
|
|
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file")
|
|
}
|
|
mainLog.Load().Debug().Msgf("DNS settings for %s is static, saving ...", iface.Name)
|
|
if err := os.WriteFile(file, []byte(strings.Join(ns, ",")), 0600); err != nil {
|
|
mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
|
|
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
|
return absHomeDir(".dns_" + iface.Name)
|
|
}
|
|
|
|
// savedStaticNameservers returns the static DNS nameservers of the given interface.
|
|
//
|
|
//lint:ignore U1000 use in os_windows.go and os_darwin.go
|
|
func savedStaticNameservers(iface *net.Interface) []string {
|
|
file := savedStaticDnsSettingsFilePath(iface)
|
|
if data, _ := os.ReadFile(file); len(data) > 0 {
|
|
return strings.Split(string(data), ",")
|
|
}
|
|
return nil
|
|
}
|