mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
391 lines
9.2 KiB
Go
391 lines
9.2 KiB
Go
package cli
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/kardianos/service"
|
|
"tailscale.com/net/interfaces"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
"github.com/Control-D-Inc/ctrld/internal/router"
|
|
)
|
|
|
|
const (
|
|
defaultSemaphoreCap = 256
|
|
ctrldLogUnixSock = "ctrld_start.sock"
|
|
ctrldControlUnixSock = "ctrld_control.sock"
|
|
)
|
|
|
|
var logf = func(format string, args ...any) {
|
|
mainLog.Load().Debug().Msgf(format, args...)
|
|
}
|
|
|
|
var svcConfig = &service.Config{
|
|
Name: "ctrld",
|
|
DisplayName: "Control-D Helper Service",
|
|
Option: service.KeyValue{},
|
|
}
|
|
|
|
var useSystemdResolved = false
|
|
|
|
type prog struct {
|
|
mu sync.Mutex
|
|
waitCh chan struct{}
|
|
stopCh chan struct{}
|
|
logConn net.Conn
|
|
cs *controlServer
|
|
|
|
cfg *ctrld.Config
|
|
cache dnscache.Cacher
|
|
sema semaphore
|
|
ciTable *clientinfo.Table
|
|
router router.Router
|
|
|
|
started chan struct{}
|
|
onStartedDone chan struct{}
|
|
onStarted []func()
|
|
onStopped []func()
|
|
}
|
|
|
|
func (p *prog) Start(s service.Service) error {
|
|
p.cfg = &cfg
|
|
go p.run()
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) preRun() {
|
|
if !service.Interactive() {
|
|
p.setDNS()
|
|
}
|
|
if runtime.GOOS == "darwin" {
|
|
p.onStopped = append(p.onStopped, func() {
|
|
if !service.Interactive() {
|
|
p.resetDNS()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (p *prog) run() {
|
|
// Wait the caller to signal that we can do our logic.
|
|
<-p.waitCh
|
|
p.preRun()
|
|
numListeners := len(p.cfg.Listener)
|
|
p.started = make(chan struct{}, numListeners)
|
|
p.onStartedDone = make(chan struct{})
|
|
if p.cfg.Service.CacheEnable {
|
|
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
|
|
} else {
|
|
p.cache = cacher
|
|
}
|
|
}
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
|
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
|
n := *mcr
|
|
if n == 0 {
|
|
p.sema = &noopSemaphore{}
|
|
} else {
|
|
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
|
}
|
|
}
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(p.cfg.Listener))
|
|
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
|
|
continue
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
for n := range p.cfg.Upstream {
|
|
uc := p.cfg.Upstream[n]
|
|
uc.Init()
|
|
if uc.BootstrapIP == "" {
|
|
uc.SetupBootstrapIP()
|
|
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
|
} else {
|
|
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
|
}
|
|
uc.SetCertPool(rootCertPool)
|
|
go uc.Ping()
|
|
}
|
|
|
|
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
|
|
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
|
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
|
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
|
p.ciTable.AddLeaseFile(leaseFile, format)
|
|
}
|
|
|
|
go func() {
|
|
p.ciTable.Init()
|
|
p.ciTable.RefreshLoop(p.stopCh)
|
|
}()
|
|
go p.watchLinkState()
|
|
|
|
for listenerNum := range p.cfg.Listener {
|
|
p.cfg.Listener[listenerNum].Init()
|
|
go func(listenerNum string) {
|
|
defer wg.Done()
|
|
listenerConfig := p.cfg.Listener[listenerNum]
|
|
upstreamConfig := p.cfg.Upstream[listenerNum]
|
|
if upstreamConfig == nil {
|
|
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
|
}
|
|
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
|
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
|
if err := p.serveDNS(listenerNum); err != nil {
|
|
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
|
}
|
|
}(listenerNum)
|
|
}
|
|
|
|
for i := 0; i < numListeners; i++ {
|
|
<-p.started
|
|
}
|
|
for _, f := range p.onStarted {
|
|
f()
|
|
}
|
|
close(p.onStartedDone)
|
|
|
|
// Stop writing log to unix socket.
|
|
consoleWriter.Out = os.Stdout
|
|
initLoggingWithBackup(false)
|
|
if p.logConn != nil {
|
|
_ = p.logConn.Close()
|
|
}
|
|
if p.cs != nil {
|
|
p.registerControlServerHandler()
|
|
if err := p.cs.start(); err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
|
}
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (p *prog) Stop(s service.Service) error {
|
|
mainLog.Load().Info().Msg("Service stopped")
|
|
close(p.stopCh)
|
|
if err := p.deAllocateIP(); err != nil {
|
|
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) allocateIP(ip string) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
return allocateIP(ip)
|
|
}
|
|
|
|
func (p *prog) deAllocateIP() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if !p.cfg.Service.AllocateIP {
|
|
return nil
|
|
}
|
|
for _, lc := range p.cfg.Listener {
|
|
if err := deAllocateIP(lc.IP); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *prog) setDNS() {
|
|
if cfg.Listener == nil {
|
|
return
|
|
}
|
|
if iface == "" {
|
|
return
|
|
}
|
|
if iface == "auto" {
|
|
iface = defaultIfaceName()
|
|
}
|
|
lc := cfg.FirstListener()
|
|
if lc == nil {
|
|
return
|
|
}
|
|
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
|
netIface, err := netInterface(iface)
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("could not get interface")
|
|
return
|
|
}
|
|
if err := setupNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("could not patch NetworkManager")
|
|
return
|
|
}
|
|
|
|
logger.Debug().Msg("setting DNS for interface")
|
|
ns := lc.IP
|
|
switch {
|
|
case lc.IsDirectDnsListener():
|
|
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
|
|
ns = "127.0.0.1"
|
|
case lc.Port != 53:
|
|
ns = "127.0.0.1"
|
|
if resolver := router.LocalResolverIP(); resolver != "" {
|
|
ns = resolver
|
|
}
|
|
default:
|
|
// If we ever reach here, it means ctrld is running on lc.IP port 53,
|
|
// so we could just use lc.IP as nameserver.
|
|
}
|
|
|
|
nameservers := []string{ns}
|
|
if needRFC1918Listeners(lc) {
|
|
nameservers = append(nameservers, rfc1918Addresses()...)
|
|
}
|
|
if err := setDNS(netIface, nameservers); err != nil {
|
|
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
|
return
|
|
}
|
|
logger.Debug().Msg("setting DNS successfully")
|
|
}
|
|
|
|
func (p *prog) resetDNS() {
|
|
if iface == "" {
|
|
return
|
|
}
|
|
if iface == "auto" {
|
|
iface = defaultIfaceName()
|
|
}
|
|
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
|
netIface, err := netInterface(iface)
|
|
if err != nil {
|
|
logger.Error().Err(err).Msg("could not get interface")
|
|
return
|
|
}
|
|
if err := restoreNetworkManager(); err != nil {
|
|
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
|
return
|
|
}
|
|
logger.Debug().Msg("Restoring DNS for interface")
|
|
if err := resetDNS(netIface); err != nil {
|
|
logger.Error().Err(err).Msgf("could not reset DNS")
|
|
return
|
|
}
|
|
logger.Debug().Msg("Restoring DNS successfully")
|
|
}
|
|
|
|
func randomLocalIP() string {
|
|
n := rand.Intn(254-2) + 2
|
|
return fmt.Sprintf("127.0.0.%d", n)
|
|
}
|
|
|
|
func randomPort() int {
|
|
max := 1<<16 - 1
|
|
min := 1025
|
|
n := rand.Intn(max-min) + min
|
|
return n
|
|
}
|
|
|
|
// runLogServer starts a unix listener, use by startCmd to gather log from runCmd.
|
|
func runLogServer(sockPath string) net.Conn {
|
|
addr, err := net.ResolveUnixAddr("unix", sockPath)
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("invalid log sock path")
|
|
return nil
|
|
}
|
|
ln, err := net.ListenUnix("unix", addr)
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not listen log socket")
|
|
return nil
|
|
}
|
|
defer ln.Close()
|
|
|
|
server, err := ln.Accept()
|
|
if err != nil {
|
|
mainLog.Load().Warn().Err(err).Msg("could not accept connection")
|
|
return nil
|
|
}
|
|
return server
|
|
}
|
|
|
|
func errAddrInUse(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
return errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(opErr.Err, windowsEADDRINUSE)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
|
|
var (
|
|
windowsECONNREFUSED = syscall.Errno(10061)
|
|
windowsENETUNREACH = syscall.Errno(10051)
|
|
windowsEINVAL = syscall.Errno(10022)
|
|
windowsEADDRINUSE = syscall.Errno(10048)
|
|
)
|
|
|
|
func errUrlNetworkError(err error) bool {
|
|
var urlErr *url.Error
|
|
if errors.As(err, &urlErr) {
|
|
var opErr *net.OpError
|
|
if errors.As(urlErr.Err, &opErr) {
|
|
if opErr.Temporary() {
|
|
return true
|
|
}
|
|
switch {
|
|
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
|
errors.Is(opErr.Err, syscall.EINVAL),
|
|
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
|
errors.Is(opErr.Err, windowsENETUNREACH),
|
|
errors.Is(opErr.Err, windowsEINVAL),
|
|
errors.Is(opErr.Err, windowsECONNREFUSED):
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// defaultRouteIP returns IP string of the default route if present, prefer IPv4 over IPv6.
|
|
func defaultRouteIP() string {
|
|
if dr, err := interfaces.DefaultRoute(); err == nil {
|
|
if netIface, err := netInterface(dr.InterfaceName); err == nil {
|
|
addrs, _ := netIface.Addrs()
|
|
do := func(v4 bool) net.IP {
|
|
for _, addr := range addrs {
|
|
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
|
if v4 {
|
|
return netIP.IP.To4()
|
|
}
|
|
return netIP.IP
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
if ip := do(true); ip != nil {
|
|
return ip.String()
|
|
}
|
|
if ip := do(false); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|