mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-13 10:26:06 +00:00
feat: introduce DNS intercept mode infrastructure
This commit is contained in:
committed by
Cuong Manh Le
parent
490ebbba88
commit
f76a332329
@@ -342,6 +342,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
// Persist intercept_mode to config when provided via CLI flag on full install.
|
||||
// This ensures the config file reflects the actual running mode for RMM/MDM visibility.
|
||||
if interceptMode == "dns" || interceptMode == "hard" {
|
||||
if cfg.Service.InterceptMode != interceptMode {
|
||||
cfg.Service.InterceptMode = interceptMode
|
||||
updated = true
|
||||
p.Info().Msgf("writing intercept_mode = %q to config", interceptMode)
|
||||
}
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
notifyExitToLogServer()
|
||||
|
||||
@@ -51,6 +51,7 @@ func InitRunCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
runCmd.Flags().StringVarP(&interceptMode, "intercept-mode", "", "", "OS-level DNS interception mode: 'dns' (with VPN split routing) or 'hard' (all DNS through ctrld, no VPN split routing)")
|
||||
|
||||
runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true}
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -254,3 +255,53 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
|
||||
return serviceCmd
|
||||
}
|
||||
|
||||
// validInterceptMode reports whether the given value is a recognized --intercept-mode.
|
||||
// This is the single source of truth for mode validation — used by the early start
|
||||
// command check, the runtime validation in prog.go, and onlyInterceptFlags below.
|
||||
// Add new modes here to have them recognized everywhere.
|
||||
func validInterceptMode(mode string) bool {
|
||||
switch mode {
|
||||
case "off", "dns", "hard":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// onlyInterceptFlags reports whether args contain only intercept mode
|
||||
// flags (--intercept-mode <value>) and flags that are auto-added by the
|
||||
// start command alias (--iface). This is used to detect "ctrld start --intercept-mode dns"
|
||||
// (or "off" to disable) on an existing installation, where the intent is to modify the
|
||||
// intercept flag on the existing service without replacing other arguments.
|
||||
//
|
||||
// Note: the startCmdAlias appends "--iface=auto" to os.Args when --iface isn't
|
||||
// explicitly provided, so we must allow it here.
|
||||
func onlyInterceptFlags(args []string) bool {
|
||||
hasIntercept := false
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case arg == "--intercept-mode":
|
||||
// Next arg must be a valid mode value.
|
||||
if i+1 < len(args) && validInterceptMode(args[i+1]) {
|
||||
hasIntercept = true
|
||||
i++ // skip the value
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
case strings.HasPrefix(arg, "--intercept-mode="):
|
||||
val := strings.TrimPrefix(arg, "--intercept-mode=")
|
||||
if validInterceptMode(val) {
|
||||
hasIntercept = true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
case arg == "--iface=auto" || arg == "--iface" || arg == "auto":
|
||||
// Auto-added by startCmdAlias or its value; safe to ignore.
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return hasIntercept
|
||||
}
|
||||
|
||||
@@ -36,6 +36,14 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
setDependencies(svcConfig)
|
||||
svcConfig.Arguments = append([]string{"run"}, osArgs...)
|
||||
|
||||
// Validate --intercept-mode early, before installing the service.
|
||||
// Without this, a typo like "--intercept-mode fds" would install the service,
|
||||
// the child process would Fatal() on the invalid value, and the parent would
|
||||
// then uninstall — confusing and destructive.
|
||||
if interceptMode != "" && !validInterceptMode(interceptMode) {
|
||||
logger.Fatal().Msgf("invalid --intercept-mode value %q: must be 'off', 'dns', or 'hard'", interceptMode)
|
||||
}
|
||||
|
||||
// Initialize service manager with proper configuration
|
||||
s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
if err != nil {
|
||||
@@ -53,6 +61,49 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
// Get current running iface, if any.
|
||||
var currentIface *ifaceResponse
|
||||
|
||||
// Handle "ctrld start --intercept-mode dns|hard" on an existing
|
||||
// service BEFORE the pin check. Adding intercept mode is an enhancement, not
|
||||
// deactivation, so it doesn't require the deactivation pin. We modify the
|
||||
// plist/registry directly and restart the service via the OS service manager.
|
||||
osArgsEarly := os.Args[2:]
|
||||
if os.Args[1] == "service" {
|
||||
osArgsEarly = os.Args[3:]
|
||||
}
|
||||
osArgsEarly = filterEmptyStrings(osArgsEarly)
|
||||
interceptOnly := onlyInterceptFlags(osArgsEarly)
|
||||
svcExists := serviceConfigFileExists()
|
||||
logger.Debug().Msgf("intercept upgrade check: args=%v interceptOnly=%v svcConfigExists=%v interceptMode=%q", osArgsEarly, interceptOnly, svcExists, interceptMode)
|
||||
if interceptOnly && svcExists {
|
||||
// Remove any existing intercept flags before applying the new value.
|
||||
_ = removeServiceFlag("--intercept-mode")
|
||||
|
||||
if interceptMode == "off" {
|
||||
// "off" = remove intercept mode entirely (just the removal above).
|
||||
logger.Notice().Msg("Existing service detected — removing --intercept-mode from service arguments")
|
||||
} else {
|
||||
// Add the new mode value.
|
||||
logger.Notice().Msgf("Existing service detected — appending --intercept-mode %s to service arguments", interceptMode)
|
||||
if err := appendServiceFlag("--intercept-mode"); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to append intercept flag to service arguments")
|
||||
}
|
||||
if err := appendServiceFlag(interceptMode); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to append intercept mode value to service arguments")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the service if running (bypasses ctrld pin — this is an
|
||||
// enhancement, not deactivation). Then fall through to the normal
|
||||
// startOnly path which handles start, self-check, and reporting.
|
||||
if isCtrldRunning {
|
||||
logger.Notice().Msg("Stopping service for intercept mode upgrade")
|
||||
_ = s.Stop()
|
||||
isCtrldRunning = false
|
||||
}
|
||||
startOnly = true
|
||||
isCtrldInstalled = true
|
||||
// Fall through to startOnly path below.
|
||||
}
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
@@ -78,20 +129,31 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
return
|
||||
}
|
||||
if res.OK {
|
||||
name := res.Name
|
||||
if iff, err := net.InterfaceByName(name); err == nil {
|
||||
_, _ = patchNetIfaceName(iff)
|
||||
name = iff.Name
|
||||
}
|
||||
logger := logger.With().Str("iface", name)
|
||||
logger.Debug().Msg("Setting DNS successfully")
|
||||
if res.All {
|
||||
// Log that DNS is set for other interfaces.
|
||||
withEachPhysicalInterfaces(
|
||||
name,
|
||||
"set DNS",
|
||||
func(i *net.Interface) error { return nil },
|
||||
)
|
||||
// In intercept mode, show intercept-specific status instead of
|
||||
// per-interface DNS messages (which are irrelevant).
|
||||
if res.InterceptMode != "" {
|
||||
switch res.InterceptMode {
|
||||
case "hard":
|
||||
logger.Notice().Msg("DNS hard intercept mode active — all DNS traffic intercepted, no VPN split routing")
|
||||
default:
|
||||
logger.Notice().Msg("DNS intercept mode active — all DNS traffic intercepted via OS packet filter")
|
||||
}
|
||||
} else {
|
||||
name := res.Name
|
||||
if iff, err := net.InterfaceByName(name); err == nil {
|
||||
_, _ = patchNetIfaceName(iff)
|
||||
name = iff.Name
|
||||
}
|
||||
ifaceLogger := logger.With().Str("iface", name)
|
||||
ifaceLogger.Debug().Msg("Setting DNS successfully")
|
||||
if res.All {
|
||||
// Log that DNS is set for other interfaces.
|
||||
withEachPhysicalInterfaces(
|
||||
name,
|
||||
"set DNS",
|
||||
func(i *net.Interface) error { return nil },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,6 +241,10 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
// Verify service registration after successful start.
|
||||
if err := verifyServiceRegistration(); err != nil {
|
||||
logger.Warn().Err(err).Msg("Service registry verification failed")
|
||||
}
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("Failed to start existing ctrld service")
|
||||
os.Exit(1)
|
||||
@@ -301,6 +367,10 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
// Verify service registration after successful start.
|
||||
if err := verifyServiceRegistration(); err != nil {
|
||||
logger.Warn().Err(err).Msg("Service registry verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service start command completed")
|
||||
@@ -350,6 +420,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
|
||||
_ = startCmd.Flags().MarkHidden("start_only")
|
||||
startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
startCmd.Flags().StringVarP(&interceptMode, "intercept-mode", "", "", "OS-level DNS interception mode: 'dns' (with VPN split routing) or 'hard' (all DNS through ctrld, no VPN split routing)")
|
||||
|
||||
// Start command alias
|
||||
startCmdAlias := &cobra.Command{
|
||||
|
||||
@@ -32,9 +32,10 @@ const (
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
InterceptMode string `json:"intercept_mode,omitempty"` // "dns", "hard", or "" (not intercepting)
|
||||
}
|
||||
|
||||
// controlServer represents an HTTP server for handling control requests
|
||||
@@ -279,6 +280,10 @@ func (p *prog) registerControlServerHandler() {
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
// Report intercept mode to the start command for proper log output.
|
||||
if interceptMode == "dns" || interceptMode == "hard" {
|
||||
res.InterceptMode = interceptMode
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
|
||||
39
cmd/cli/dns_intercept_others.go
Normal file
39
cmd/cli/dns_intercept_others.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// startDNSIntercept is not supported on this platform.
|
||||
// DNS intercept mode is only available on Windows (via WFP) and macOS (via pf).
|
||||
func (p *prog) startDNSIntercept() error {
|
||||
return fmt.Errorf("dns intercept: not supported on this platform (only Windows and macOS)")
|
||||
}
|
||||
|
||||
// stopDNSIntercept is a no-op on unsupported platforms.
|
||||
func (p *prog) stopDNSIntercept() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// exemptVPNDNSServers is a no-op on unsupported platforms.
|
||||
func (p *prog) exemptVPNDNSServers(exemptions []vpnDNSExemption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensurePFAnchorActive is a no-op on unsupported platforms.
|
||||
func (p *prog) ensurePFAnchorActive() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// checkTunnelInterfaceChanges is a no-op on unsupported platforms.
|
||||
func (p *prog) checkTunnelInterfaceChanges() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// scheduleDelayedRechecks is a no-op on unsupported platforms.
|
||||
func (p *prog) scheduleDelayedRechecks() {}
|
||||
|
||||
// pfInterceptMonitor is a no-op on unsupported platforms.
|
||||
func (p *prog) pfInterceptMonitor() {}
|
||||
@@ -244,6 +244,21 @@ func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m
|
||||
return true
|
||||
}
|
||||
|
||||
// Interception probe: if we're expecting a probe query and this matches,
|
||||
// signal the prober and respond NXDOMAIN. Used by both macOS pf probes
|
||||
// (_pf-probe-*) and Windows NRPT probes (_nrpt-probe-*) to verify that
|
||||
// DNS interception is actually routing queries to ctrld's listener.
|
||||
if probeID, ok := p.pfProbeExpected.Load().(string); ok && probeID != "" && domain == probeID {
|
||||
if chPtr, ok := p.pfProbeCh.Load().(*chan struct{}); ok && chPtr != nil {
|
||||
select {
|
||||
case *chPtr <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
sendDNSResponse(w, m, dns.RcodeNameError) // NXDOMAIN
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, p.Debug(), "Received query %q, local cache is purged", domain)
|
||||
@@ -592,6 +607,19 @@ func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest,
|
||||
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.Log(ctx, p.Debug(), "Proxy query processing started")
|
||||
|
||||
// DNS intercept recovery bypass: forward all queries to OS/DHCP resolver.
|
||||
// This runs when upstreams are unreachable (e.g., captive portal network)
|
||||
// and allows the network's DNS to handle authentication pages.
|
||||
if dnsIntercept && p.recoveryBypass.Load() {
|
||||
ctrld.Log(ctx, p.Debug(), "Recovery bypass active: forwarding to OS resolver")
|
||||
answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig)
|
||||
if answer != nil {
|
||||
return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint}
|
||||
}
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver failed during recovery bypass")
|
||||
// Fall through to normal flow as last resort
|
||||
}
|
||||
|
||||
upstreams, upstreamConfigs := p.initializeUpstreams(req)
|
||||
ctrld.Log(ctx, p.Debug(), "Initialized upstreams: %v", upstreams)
|
||||
|
||||
@@ -605,6 +633,36 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
return cachedRes
|
||||
}
|
||||
|
||||
// VPN DNS split routing (only in dns-intercept mode)
|
||||
if dnsIntercept && p.vpnDNS != nil && len(req.msg.Question) > 0 {
|
||||
domain := req.msg.Question[0].Name
|
||||
if vpnServers := p.vpnDNS.UpstreamForDomain(domain); len(vpnServers) > 0 {
|
||||
ctrld.Log(ctx, p.Debug(), "VPN DNS route matched for domain %s, using servers: %v", domain, vpnServers)
|
||||
|
||||
// Try each VPN DNS server
|
||||
for _, server := range vpnServers {
|
||||
upstreamConfig := p.vpnDNS.upstreamConfigFor(server)
|
||||
ctrld.Log(ctx, p.Debug(), "Querying VPN DNS server: %s", server)
|
||||
|
||||
answer := p.queryUpstream(ctx, req, "vpn-dns", upstreamConfig)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "VPN DNS query successful")
|
||||
|
||||
// Update cache if enabled
|
||||
if p.cache != nil {
|
||||
p.updateCache(ctx, req, answer, "vpn-dns")
|
||||
}
|
||||
|
||||
return &proxyResponse{answer: answer, cached: false}
|
||||
} else {
|
||||
ctrld.Log(ctx, p.Debug(), "VPN DNS server %s failed", server)
|
||||
}
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Debug(), "All VPN DNS servers failed, falling back to normal upstreams")
|
||||
}
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Debug(), "No cache hit, trying upstreams")
|
||||
if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "Upstream query successful")
|
||||
@@ -1164,12 +1222,30 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
||||
} else {
|
||||
ci.Self = p.queryFromSelf(ci.IP)
|
||||
}
|
||||
|
||||
// In DNS intercept mode, ALL queries are from the local machine — pf/WFP
|
||||
// intercepts outbound DNS and redirects to ctrld. The source IP may be a
|
||||
// virtual interface (Tailscale, VPN) that has no ARP/MAC entry, causing
|
||||
// missing x-cd-mac, x-cd-host, and x-cd-os headers. Force Self=true and
|
||||
// populate from the primary physical interface info.
|
||||
if dnsIntercept && !ci.Self {
|
||||
ci.Self = true
|
||||
}
|
||||
|
||||
// If this is a query from self, but ci.IP is not loopback IP,
|
||||
// try using hostname mapping for lookback IP if presents.
|
||||
if ci.Self {
|
||||
if name := p.ciTable.LocalHostname(); name != "" {
|
||||
ci.Hostname = name
|
||||
}
|
||||
// If MAC is still empty (e.g., query arrived via virtual interface IP
|
||||
// like Tailscale), fall back to the loopback MAC mapping which addSelf()
|
||||
// populates from the primary physical interface.
|
||||
if ci.Mac == "" {
|
||||
if mac := p.ciTable.LookupMac("127.0.0.1"); mac != "" {
|
||||
ci.Mac = mac
|
||||
}
|
||||
}
|
||||
}
|
||||
p.spoofLoopbackIpInClientInfo(ci)
|
||||
return ci
|
||||
@@ -1532,6 +1608,62 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
p.Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
// check if the default IPs are still on an interface that is up
|
||||
ValidateDefaultLocalIPsFromDelta(delta.New)
|
||||
// Even minor interface changes can trigger macOS pf reloads — verify anchor.
|
||||
// We check immediately AND schedule delayed re-checks (2s + 4s) to catch
|
||||
// programs like Windscribe that modify pf rules and DNS settings
|
||||
// asynchronously after the network change event fires.
|
||||
if dnsIntercept && p.dnsInterceptState != nil {
|
||||
if !p.pfStabilizing.Load() {
|
||||
p.ensurePFAnchorActive()
|
||||
}
|
||||
// Check tunnel interfaces unconditionally — it decides internally
|
||||
// whether to enter stabilization or rebuild immediately.
|
||||
p.checkTunnelInterfaceChanges()
|
||||
// Schedule delayed re-checks to catch async VPN teardown changes.
|
||||
// These also refresh the OS resolver and VPN DNS routes.
|
||||
p.scheduleDelayedRechecks()
|
||||
|
||||
// Detect interface appearance/disappearance — hypervisors (Parallels,
|
||||
// VMware, VirtualBox) reload pf when creating/destroying virtual network
|
||||
// interfaces, which can corrupt pf's internal translation state.
|
||||
if delta.Old != nil {
|
||||
interfaceChanged := false
|
||||
var changedIface string
|
||||
for ifaceName := range delta.Old.Interface {
|
||||
if ifaceName == "lo0" {
|
||||
continue
|
||||
}
|
||||
if _, exists := delta.New.Interface[ifaceName]; !exists {
|
||||
interfaceChanged = true
|
||||
changedIface = ifaceName
|
||||
break
|
||||
}
|
||||
}
|
||||
if !interfaceChanged {
|
||||
for ifaceName := range delta.New.Interface {
|
||||
if ifaceName == "lo0" {
|
||||
continue
|
||||
}
|
||||
if _, exists := delta.Old.Interface[ifaceName]; !exists {
|
||||
interfaceChanged = true
|
||||
changedIface = ifaceName
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if interfaceChanged {
|
||||
p.Info().Str("interface", changedIface).
|
||||
Msg("DNS intercept: interface appeared/disappeared — starting interception probe monitor")
|
||||
go p.pfInterceptMonitor()
|
||||
}
|
||||
}
|
||||
}
|
||||
// Refresh VPN DNS on tunnel interface changes (e.g., Tailscale connect/disconnect)
|
||||
// even though the physical interface didn't change. Runs after tunnel checks
|
||||
// so the pf anchor rebuild includes current VPN DNS exemptions.
|
||||
if dnsIntercept && p.vpnDNS != nil {
|
||||
p.vpnDNS.Refresh(ctx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1602,6 +1734,26 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
p.handleRecovery(RecoveryReasonNetworkChange)
|
||||
|
||||
// After network changes, verify our pf anchor is still active and
|
||||
// refresh VPN DNS state. Order matters: tunnel checks first (may rebuild
|
||||
// anchor), then VPN DNS refresh (updates exemptions in anchor), then
|
||||
// delayed re-checks for async VPN teardown.
|
||||
if dnsIntercept && p.dnsInterceptState != nil {
|
||||
if !p.pfStabilizing.Load() {
|
||||
p.ensurePFAnchorActive()
|
||||
}
|
||||
// Check tunnel interfaces unconditionally — it decides internally
|
||||
// whether to enter stabilization or rebuild immediately.
|
||||
p.checkTunnelInterfaceChanges()
|
||||
// Refresh VPN DNS routes — runs after tunnel checks so the anchor
|
||||
// rebuild includes current VPN DNS exemptions.
|
||||
if p.vpnDNS != nil {
|
||||
p.vpnDNS.Refresh(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
}
|
||||
// Schedule delayed re-checks to catch async VPN teardown changes.
|
||||
p.scheduleDelayedRechecks()
|
||||
}
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
@@ -1781,7 +1933,50 @@ func (p *prog) prepareForRecovery(reason RecoveryReason) error {
|
||||
// Set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
||||
p.recoveryRunning.Store(true)
|
||||
|
||||
// Remove DNS settings - we do not want to restore any static DNS settings
|
||||
// In DNS intercept mode, don't tear down WFP/pf filters.
|
||||
// Instead, enable recovery bypass so proxy() forwards queries to
|
||||
// the OS/DHCP resolver. This handles captive portal authentication
|
||||
// without the overhead of filter teardown/rebuild.
|
||||
if dnsIntercept && p.dnsInterceptState != nil {
|
||||
p.recoveryBypass.Store(true)
|
||||
p.Info().Msg("DNS intercept recovery: enabling DHCP bypass (filters stay active)")
|
||||
|
||||
// Reinitialize OS resolver to discover DHCP servers on the new network.
|
||||
// This is critical for captive portals — we need the network's DNS servers
|
||||
// to resolve the auth page.
|
||||
p.Debug().Msg("DNS intercept recovery: discovering DHCP nameservers")
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
dhcpServers := ctrld.InitializeOsResolver(loggerCtx, true)
|
||||
if len(dhcpServers) == 0 {
|
||||
p.Warn().Msg("DNS intercept recovery: no DHCP nameservers found")
|
||||
} else {
|
||||
p.Info().Msgf("DNS intercept recovery: found DHCP nameservers: %v", dhcpServers)
|
||||
}
|
||||
|
||||
// Exempt DHCP nameservers from intercept filters so the OS resolver
|
||||
// can actually reach them on port 53. Without this, the WFP block
|
||||
// or pf redirect would intercept ctrld's own recovery queries.
|
||||
if len(dhcpServers) > 0 {
|
||||
// Strip :53 port suffix if present (exemptVPNDNSServers expects bare IPs).
|
||||
var dhcpExemptions []vpnDNSExemption
|
||||
for _, s := range dhcpServers {
|
||||
host := s
|
||||
if h, _, err := net.SplitHostPort(s); err == nil {
|
||||
host = h
|
||||
}
|
||||
dhcpExemptions = append(dhcpExemptions, vpnDNSExemption{Server: host})
|
||||
}
|
||||
p.Info().Msgf("DNS intercept recovery: exempting DHCP nameservers from filters: %v", dhcpServers)
|
||||
if err := p.exemptVPNDNSServers(dhcpExemptions); err != nil {
|
||||
p.Warn().Err(err).Msg("DNS intercept recovery: failed to exempt DHCP nameservers — recovery queries may fail")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Traditional flow: remove DNS settings to expose DHCP nameservers
|
||||
// we do not want to restore any static DNS settings
|
||||
// we must try to get the DHCP values, any static DNS settings
|
||||
// will be appended to nameservers from the saved interface values
|
||||
p.resetDNS(false, false)
|
||||
@@ -1814,6 +2009,33 @@ func (p *prog) completeRecovery(reason RecoveryReason, recovered string) error {
|
||||
// Reset the upstream failure count and down state
|
||||
p.um.reset(recovered)
|
||||
|
||||
// In DNS intercept mode, just disable the bypass — filters are still active.
|
||||
if dnsIntercept && p.dnsInterceptState != nil {
|
||||
// Always reset recoveryRunning, even on error paths below.
|
||||
defer p.recoveryRunning.Store(false)
|
||||
|
||||
p.recoveryBypass.Store(false)
|
||||
p.Info().Msg("DNS intercept recovery complete: disabling DHCP bypass, resuming normal flow")
|
||||
|
||||
// Refresh VPN DNS routes in case VPN state changed during recovery.
|
||||
// This also re-exempts VPN DNS servers (which may have changed) and
|
||||
// removes any DHCP exemptions that were added during recovery.
|
||||
if p.vpnDNS != nil {
|
||||
p.vpnDNS.Refresh(ctrld.LoggerCtx(context.Background(), p.logger.Load()))
|
||||
}
|
||||
|
||||
// Reinitialize OS resolver for the recovered state.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil {
|
||||
return fmt.Errorf("failed to reinitialize OS resolver during network change: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Traditional flow: reapply DNS settings.
|
||||
|
||||
// For network changes we also reinitialize the OS resolver.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"go.uber.org/zap"
|
||||
@@ -42,6 +45,9 @@ var (
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
interceptMode string // "", "dns", or "hard" — set via --intercept-mode flag or config
|
||||
dnsIntercept bool // derived: interceptMode == "dns" || interceptMode == "hard"
|
||||
hardIntercept bool // derived: interceptMode == "hard"
|
||||
|
||||
mainLog atomic.Pointer[ctrld.Logger]
|
||||
consoleWriter zapcore.Core
|
||||
@@ -68,6 +74,12 @@ func init() {
|
||||
// Main is the entry point for the CLI application
|
||||
// It initializes configuration, sets up the CLI structure, and executes the root command
|
||||
func Main() {
|
||||
// Fast path for pf interception probe subprocess.
|
||||
if len(os.Args) >= 4 && os.Args[1] == "pf-probe-send" {
|
||||
pfProbeSend(os.Args[2], os.Args[3])
|
||||
return
|
||||
}
|
||||
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
rootCmd := initCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
@@ -229,3 +241,21 @@ func initCache() {
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
}
|
||||
|
||||
// pfProbeSend is a minimal subprocess that sends a pre-built DNS query packet
|
||||
// to the specified host on port 53.
|
||||
func pfProbeSend(host, hexPacket string) {
|
||||
packet, err := hex.DecodeString(hexPacket)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
conn, err := net.DialTimeout("udp", net.JoinHostPort(host, "53"), time.Second)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(time.Second))
|
||||
_, _ = conn.Write(packet)
|
||||
buf := make([]byte, 512)
|
||||
_, _ = conn.Read(buf)
|
||||
}
|
||||
|
||||
104
cmd/cli/prog.go
104
cmd/cli/prog.go
@@ -133,6 +133,51 @@ type prog struct {
|
||||
recoveryCancelMu sync.Mutex
|
||||
recoveryCancel context.CancelFunc
|
||||
recoveryRunning atomic.Bool
|
||||
// recoveryBypass is set when dns-intercept mode enters recovery.
|
||||
// While true, proxy() forwards all queries to the OS/DHCP resolver
|
||||
// instead of the configured upstreams. This allows captive portal
|
||||
// authentication without tearing down WFP/pf filters.
|
||||
recoveryBypass atomic.Bool
|
||||
|
||||
// DNS intercept mode state (platform-specific).
|
||||
// On Windows: *wfpState, on macOS: *pfState, nil on other platforms.
|
||||
dnsInterceptState any
|
||||
|
||||
// lastTunnelIfaces tracks the set of active VPN/tunnel interfaces (utun*, ipsec*, etc.)
|
||||
// discovered during the last pf anchor rule build. When the set changes (e.g., a VPN
|
||||
// connects and creates utun420), we rebuild the pf anchor to add interface-specific
|
||||
// intercept rules for the new interface. Protected by mu.
|
||||
lastTunnelIfaces []string //lint:ignore U1000 used on darwin
|
||||
|
||||
// pfStabilizing is true while we're waiting for a VPN's pf ruleset to settle.
|
||||
// While true, the watchdog and network change callbacks do NOT restore our rules.
|
||||
pfStabilizing atomic.Bool
|
||||
|
||||
// pfStabilizeCancel cancels the active stabilization goroutine, if any.
|
||||
// Protected by mu.
|
||||
pfStabilizeCancel context.CancelFunc //lint:ignore U1000 used on darwin
|
||||
|
||||
// pfLastRestoreTime records when we last restored our anchor (unix millis).
|
||||
// Used to detect immediate re-wipes (VPN reconnect cycle).
|
||||
pfLastRestoreTime atomic.Int64 //lint:ignore U1000 used on darwin
|
||||
|
||||
// pfBackoffMultiplier tracks exponential backoff for stabilization.
|
||||
// Resets to 0 when rules survive for >60s.
|
||||
pfBackoffMultiplier atomic.Int32 //lint:ignore U1000 used on darwin
|
||||
|
||||
// pfMonitorRunning ensures only one pfInterceptMonitor goroutine runs at a time.
|
||||
// When an interface appears/disappears, we spawn a monitor that probes pf
|
||||
// interception with exponential backoff and auto-heals if broken.
|
||||
pfMonitorRunning atomic.Bool //lint:ignore U1000 used on darwin
|
||||
|
||||
// pfProbeExpected holds the domain name of a pending pf interception probe.
|
||||
pfProbeExpected atomic.Value // string
|
||||
|
||||
// pfProbeCh is signaled when the DNS handler receives the expected probe query.
|
||||
pfProbeCh atomic.Value // *chan struct{}
|
||||
|
||||
// VPN DNS manager for split DNS routing when intercept mode is active.
|
||||
vpnDNS *vpnDNSManager
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
@@ -700,6 +745,54 @@ func (p *prog) setDNS() {
|
||||
p.csSetDnsOk = setDnsOK
|
||||
}()
|
||||
|
||||
// Validate and resolve intercept mode.
|
||||
// CLI flag (--intercept-mode) takes priority over config file.
|
||||
// Valid values: "" (off), "dns" (with VPN split routing), "hard" (all DNS through ctrld).
|
||||
if interceptMode != "" && !validInterceptMode(interceptMode) {
|
||||
p.Fatal().Msgf("invalid --intercept-mode value %q: must be 'off', 'dns', or 'hard'", interceptMode)
|
||||
}
|
||||
if interceptMode == "" || interceptMode == "off" {
|
||||
interceptMode = cfg.Service.InterceptMode
|
||||
if interceptMode != "" && interceptMode != "off" {
|
||||
p.Info().Msgf("Intercept mode enabled via config (intercept_mode = %q)", interceptMode)
|
||||
}
|
||||
}
|
||||
|
||||
// Derive convenience bools from interceptMode.
|
||||
switch interceptMode {
|
||||
case "dns":
|
||||
dnsIntercept = true
|
||||
case "hard":
|
||||
dnsIntercept = true
|
||||
hardIntercept = true
|
||||
}
|
||||
|
||||
// DNS intercept mode: use OS-level packet interception (WFP/pf) instead of
|
||||
// modifying interface DNS settings. This eliminates race conditions with VPN
|
||||
// software that also manages DNS. See issue #489.
|
||||
if dnsIntercept {
|
||||
if err := p.startDNSIntercept(); err != nil {
|
||||
p.Error().Err(err).Msg("DNS intercept mode failed — falling back to interface DNS settings")
|
||||
// Fall through to traditional setDNS behavior.
|
||||
} else {
|
||||
if hardIntercept {
|
||||
p.Info().Msg("Hard intercept mode active — all DNS through ctrld, no VPN split routing")
|
||||
} else {
|
||||
p.Info().Msg("DNS intercept mode active — skipping interface DNS configuration and watchdog")
|
||||
|
||||
// Initialize VPN DNS manager for split DNS routing.
|
||||
// Discovers search domains from virtual/VPN interfaces and forwards
|
||||
// matching queries to the DNS server on that interface.
|
||||
// Skipped in --intercept-mode hard where all DNS goes through ctrld.
|
||||
p.vpnDNS = newVPNDNSManager(&p.logger, p.exemptVPNDNSServers)
|
||||
p.vpnDNS.Refresh(ctrld.LoggerCtx(context.Background(), p.logger.Load()))
|
||||
}
|
||||
|
||||
setDnsOK = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Listener == nil {
|
||||
return
|
||||
}
|
||||
@@ -918,7 +1011,18 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
}
|
||||
|
||||
// resetDNS performs a DNS reset for all interfaces.
|
||||
// In DNS intercept mode, this tears down the WFP/pf filters instead.
|
||||
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
|
||||
if dnsIntercept && p.dnsInterceptState != nil {
|
||||
if err := p.stopDNSIntercept(); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to stop DNS intercept mode during reset")
|
||||
}
|
||||
|
||||
// Clean up VPN DNS manager
|
||||
p.vpnDNS = nil
|
||||
|
||||
return
|
||||
}
|
||||
netIfaceName := ""
|
||||
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
|
||||
netIfaceName = netIface.Name
|
||||
|
||||
134
cmd/cli/service_args_darwin.go
Normal file
134
cmd/cli/service_args_darwin.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const launchdPlistPath = "/Library/LaunchDaemons/ctrld.plist"
|
||||
|
||||
// serviceConfigFileExists returns true if the launchd plist for ctrld exists on disk.
|
||||
// This is more reliable than checking launchctl status, which may report "not found"
|
||||
// if the service was unloaded but the plist file still exists.
|
||||
func serviceConfigFileExists() bool {
|
||||
_, err := os.Stat(launchdPlistPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// appendServiceFlag appends a CLI flag (e.g., "--intercept-mode") to the installed
|
||||
// service's launch arguments. This is used when upgrading an existing installation
|
||||
// to intercept mode without losing the existing --cd flag and other arguments.
|
||||
//
|
||||
// On macOS, this modifies the launchd plist at /Library/LaunchDaemons/ctrld.plist
|
||||
// using the "defaults" command, which is the standard way to edit plists.
|
||||
//
|
||||
// The function is idempotent: if the flag already exists, it's a no-op.
|
||||
func appendServiceFlag(flag string) error {
|
||||
// Read current ProgramArguments from plist.
|
||||
out, err := exec.Command("defaults", "read", launchdPlistPath, "ProgramArguments").CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read plist ProgramArguments: %w (output: %s)", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
// Check if the flag is already present (idempotent).
|
||||
args := string(out)
|
||||
if strings.Contains(args, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q already present in plist, skipping", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use PlistBuddy to append the flag to ProgramArguments array.
|
||||
// PlistBuddy is more reliable than "defaults" for array manipulation.
|
||||
addCmd := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Add :ProgramArguments: string %s", flag),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to append %q to plist ProgramArguments: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Appended %q to service launch arguments", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyServiceRegistration is a no-op on macOS (launchd plist verification not needed).
|
||||
func verifyServiceRegistration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag removes a CLI flag (and its value, if the next argument is not
|
||||
// a flag) from the installed service's launch arguments. For example, removing
|
||||
// "--intercept-mode" also removes the following "dns" or "hard" value argument.
|
||||
//
|
||||
// The function is idempotent: if the flag doesn't exist, it's a no-op.
|
||||
func removeServiceFlag(flag string) error {
|
||||
// Read current ProgramArguments to find the index.
|
||||
out, err := exec.Command("/usr/libexec/PlistBuddy", "-c", "Print :ProgramArguments", launchdPlistPath).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read plist ProgramArguments: %w (output: %s)", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
// Parse the PlistBuddy output to find the flag's index.
|
||||
// PlistBuddy prints arrays as:
|
||||
// Array {
|
||||
// /path/to/ctrld
|
||||
// run
|
||||
// --cd=xxx
|
||||
// --intercept-mode
|
||||
// dns
|
||||
// }
|
||||
lines := strings.Split(string(out), "\n")
|
||||
var entries []string
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "Array {" || trimmed == "}" || trimmed == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, trimmed)
|
||||
}
|
||||
|
||||
index := -1
|
||||
for i, entry := range entries {
|
||||
if entry == flag {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if index < 0 {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q not present in plist, skipping removal", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the next entry is a value (not a flag). If so, delete it first
|
||||
// (deleting by index shifts subsequent entries down, so delete value before flag).
|
||||
hasValue := index+1 < len(entries) && !strings.HasPrefix(entries[index+1], "-")
|
||||
if hasValue {
|
||||
delVal := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Delete :ProgramArguments:%d", index+1),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := delVal.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove value for %q from plist: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the flag itself.
|
||||
delCmd := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Delete :ProgramArguments:%d", index),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := delCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove %q from plist ProgramArguments: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Removed %q from service launch arguments", flag)
|
||||
return nil
|
||||
}
|
||||
38
cmd/cli/service_args_others.go
Normal file
38
cmd/cli/service_args_others.go
Normal file
@@ -0,0 +1,38 @@
|
||||
//go:build !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// serviceConfigFileExists checks common service config file locations on Linux.
|
||||
func serviceConfigFileExists() bool {
|
||||
// systemd unit file
|
||||
if _, err := os.Stat("/etc/systemd/system/ctrld.service"); err == nil {
|
||||
return true
|
||||
}
|
||||
// SysV init script
|
||||
if _, err := os.Stat("/etc/init.d/ctrld"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// appendServiceFlag is not yet implemented on this platform.
|
||||
// Linux services (systemd) store args in unit files; intercept mode
|
||||
// should be set via the config file (intercept_mode) on these platforms.
|
||||
func appendServiceFlag(flag string) error {
|
||||
return fmt.Errorf("appending service flags is not supported on this platform; use intercept_mode in config instead")
|
||||
}
|
||||
|
||||
// verifyServiceRegistration is a no-op on this platform.
|
||||
func verifyServiceRegistration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag is not yet implemented on this platform.
|
||||
func removeServiceFlag(flag string) error {
|
||||
return fmt.Errorf("removing service flags is not supported on this platform; use intercept_mode in config instead")
|
||||
}
|
||||
153
cmd/cli/service_args_windows.go
Normal file
153
cmd/cli/service_args_windows.go
Normal file
@@ -0,0 +1,153 @@
|
||||
//go:build windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// serviceConfigFileExists returns true if the ctrld Windows service is registered.
|
||||
func serviceConfigFileExists() bool {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// appendServiceFlag appends a CLI flag (e.g., "--intercept-mode") to the installed
|
||||
// Windows service's BinPath arguments. This is used when upgrading an existing
|
||||
// installation to intercept mode without losing the existing --cd flag.
|
||||
//
|
||||
// The function is idempotent: if the flag already exists, it's a no-op.
|
||||
func appendServiceFlag(flag string) error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
// Check if flag already present (idempotent).
|
||||
if strings.Contains(config.BinaryPathName, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q already present in BinPath, skipping", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append the flag to BinPath.
|
||||
config.BinaryPathName = strings.TrimSpace(config.BinaryPathName) + " " + flag
|
||||
|
||||
if err := s.UpdateConfig(config); err != nil {
|
||||
return fmt.Errorf("failed to update service config with %q: %w", flag, err)
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Appended %q to service BinPath", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyServiceRegistration opens the Windows Service Control Manager and verifies
|
||||
// that the ctrld service is correctly registered: logs the BinaryPathName, checks
|
||||
// that --intercept-mode is present if expected, and verifies SERVICE_AUTO_START.
|
||||
func verifyServiceRegistration() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msgf("Service registry: BinaryPathName = %q", config.BinaryPathName)
|
||||
|
||||
// If intercept mode is set, verify the flag is present in BinPath.
|
||||
if interceptMode == "dns" || interceptMode == "hard" {
|
||||
if !strings.Contains(config.BinaryPathName, "--intercept-mode") {
|
||||
return fmt.Errorf("service registry: --intercept-mode flag missing from BinaryPathName (expected mode %q)", interceptMode)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Service registry: --intercept-mode flag present in BinaryPathName")
|
||||
}
|
||||
|
||||
// Verify auto-start. mgr.StartAutomatic == 2 == SERVICE_AUTO_START.
|
||||
if config.StartType != mgr.StartAutomatic {
|
||||
return fmt.Errorf("service registry: StartType is %d, expected SERVICE_AUTO_START (%d)", config.StartType, mgr.StartAutomatic)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag removes a CLI flag (and its value, if present) from the installed
|
||||
// Windows service's BinPath. For example, removing "--intercept-mode" also removes
|
||||
// the following "dns" or "hard" value. The function is idempotent.
|
||||
func removeServiceFlag(flag string) error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(config.BinaryPathName, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q not present in BinPath, skipping removal", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split BinPath into parts, find and remove the flag + its value (if any).
|
||||
parts := strings.Fields(config.BinaryPathName)
|
||||
var newParts []string
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] == flag {
|
||||
// Skip the flag. Also skip the next part if it's a value (not a flag).
|
||||
if i+1 < len(parts) && !strings.HasPrefix(parts[i+1], "-") {
|
||||
i++ // skip value too
|
||||
}
|
||||
continue
|
||||
}
|
||||
newParts = append(newParts, parts[i])
|
||||
}
|
||||
config.BinaryPathName = strings.Join(newParts, " ")
|
||||
|
||||
if err := s.UpdateConfig(config); err != nil {
|
||||
return fmt.Errorf("failed to update service config: %w", err)
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Removed %q from service BinPath", flag)
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user