ignore non physical ifaces in validInterfaces method on Windows

debugging

skip type 24 in nameserver detection

skip type 24 in nameserver detection

remove interface type check from valid interfaces for now

skip non hardware interfaces in DNS nameserver lookup

ignore win api log output

set retries to 5 and 1s backoff

reset DNS when upgrading to make sure we get the proper OS nameservers on start

init running iface for upgrade

update windows service options for auto restarts on failure

make upgrade use the actual stop and start commands

fix the windows service retry logic

fix the windows service retry logic

task debugging

more task debugging

windows service name fix

windows service name fix

fix start command args

fix restart delay

dont recover from non crash failures

fix upgrade flow
This commit is contained in:
Alex
2025-01-29 14:09:53 -05:00
committed by Cuong Manh Le
parent ce3281e70d
commit e573a490c9
11 changed files with 296 additions and 99 deletions

View File

@@ -126,7 +126,7 @@ func initCLI() {
rootCmd.CompletionOptions.HiddenDefaultCmd = true
initRunCmd()
startCmd := initStartCmd()
startCmd, startCmdAlias := initStartCmd()
stopCmd := initStopCmd()
restartCmd := initRestartCmd()
reloadCmd := initReloadCmd(restartCmd)
@@ -135,7 +135,7 @@ func initCLI() {
interfacesCmd := initInterfacesCmd()
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
initClientsCmd()
initUpgradeCmd()
initUpgradeCmd(startCmdAlias)
initLogCmd()
}
@@ -243,6 +243,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
if err := s.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to start service")
}
// Configure Windows service failure actions
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
}
}()
}
writeDefaultConfig := !noConfigStart && configBase64 == ""
@@ -1016,8 +1020,8 @@ func uninstall(p *prog, s service.Service) {
return
}
tasks := []task{
{s.Stop, false},
{s.Uninstall, true},
{s.Stop, false, "Stop"},
{s.Uninstall, true, "Uninstall"},
}
initLogging()
if doTasks(tasks) {
@@ -1688,6 +1692,10 @@ func runInCdMode() bool {
// curCdUID returns the current ControlD UID used by running ctrld process.
func curCdUID() string {
if s, _ := newService(&prog{}, svcConfig); s != nil {
// Configure Windows service failure actions
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
}
if dir, _ := socketDir(); dir != "" {
cc := newSocketControlClient(context.TODO(), s, dir)
if cc != nil {
@@ -1791,7 +1799,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe
}
iface = oldIface
return nil
}, false}
}, false, "Reset DNS"}
}
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
@@ -1840,7 +1848,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
p.resetDNS()
tasks := []task{{s.Uninstall, true}}
tasks := []task{{s.Uninstall, true, "Uninstall"}}
if doTasks(tasks) {
logger.Info().Msg("uninstalled service")
if doStop {

View File

@@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command {
return runCmd
}
func initStartCmd() *cobra.Command {
func initStartCmd() (*cobra.Command, *cobra.Command) {
startCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
@@ -310,7 +310,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
initLogging()
tasks := []task{
{s.Stop, false},
{s.Stop, false, "Stop"},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
@@ -321,9 +321,12 @@ NOTE: running "ctrld start" without any arguments will start already installed c
return nil
})
return nil
}, false},
{s.Start, true},
{noticeWritingControlDConfig, false},
}, false, "Save current DNS"},
{func() error {
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
}, false, "Configure Windows service failure actions"},
{s.Start, true, "Start"},
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
}
mainLog.Load().Notice().Msg("Starting existing ctrld service")
if doTasks(tasks) {
@@ -387,9 +390,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c
}
tasks := []task{
{s.Stop, false},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
{func() error { return ensureUninstall(s) }, false},
{s.Stop, false, "Stop"},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"},
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
@@ -400,12 +403,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
return nil
})
return nil
}, false},
{s.Install, false},
{s.Start, true},
}, false, "Save current DNS"},
{s.Install, false, "Install"},
{func() error {
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
}, false, "Configure Windows service failure actions"},
{s.Start, true, "Start"},
// Note that startCmd do not actually write ControlD config, but the config file was
// generated after s.Start, so we notice users here for consistent with nextdns mode.
{noticeWritingControlDConfig, false},
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
}
mainLog.Load().Notice().Msg("Starting service")
if doTasks(tasks) {
@@ -528,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
rootCmd.AddCommand(startCmdAlias)
return startCmd
return startCmd, startCmdAlias
}
func initStopCmd() *cobra.Command {
@@ -558,7 +564,7 @@ func initStopCmd() *cobra.Command {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
if doTasks([]task{{s.Stop, true}}) {
if doTasks([]task{{s.Stop, true, "Stop"}}) {
p.router.Cleanup()
p.resetDNS()
@@ -651,8 +657,8 @@ func initRestartCmd() *cobra.Command {
iface = ir.Name
}
tasks := []task{
{s.Stop, false},
{s.Start, true},
{s.Stop, false, "Stop"},
{s.Start, true, "Start"},
}
if doTasks(tasks) {
dir, err := socketDir()
@@ -1043,7 +1049,7 @@ func initClientsCmd() *cobra.Command {
return clientsCmd
}
func initUpgradeCmd() *cobra.Command {
func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
const (
upgradeChannelDev = "dev"
upgradeChannelProd = "prod"
@@ -1115,23 +1121,23 @@ func initUpgradeCmd() *cobra.Command {
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
}
// we run the actual commands to make sure all the logic we want is executed
doRestart := func() bool {
if !svcInstalled {
return true
// run the start command so that we reinit the service
// this is to fix the non restarting options on windows for existing clients
// we have to reset os.Args, since other commands use it.
curCdUID := curCdUID()
startArgs := []string{}
os.Args = []string{"ctrld", "start"}
if curCdUID != "" {
startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID))
os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID))
}
tasks := []task{
{s.Stop, false},
{s.Start, false},
}
if doTasks(tasks) {
if dir, err := socketDir(); err == nil {
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
_, _ = cc.post(ifacePath, nil)
return true
}
}
}
return false
startCmd.Run(startCmd, startArgs)
return true
}
if svcInstalled {
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")

View File

@@ -1339,9 +1339,9 @@ func parseInterfaceState(state *netmon.State) map[string]string {
}
result := make(map[string]string)
stateStr := state.String()
// Extract interface information
ifsStart := strings.Index(stateStr, "ifs={")
if ifsStart == -1 {
@@ -1356,26 +1356,26 @@ func parseInterfaceState(state *netmon.State) map[string]string {
// Get the content between ifs={ }
ifsContent := strings.TrimSpace(ifsStr[:ifsEnd])
// Split on "] " to get each interface entry
entries := strings.Split(ifsContent, "] ")
for _, entry := range entries {
if entry == "" {
continue
}
// Split on ":["
parts := strings.Split(entry, ":[")
if len(parts) != 2 {
continue
}
name := strings.TrimSpace(parts[0])
state := "[" + strings.TrimSuffix(parts[1], "]") + "]"
result[strings.ToLower(name)] = state
}
return result
}
}

View File

@@ -52,23 +52,39 @@ func validInterfaces() []string {
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter connector present property")
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
continue
}
if !physical {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
continue
}
name, err := adapter.GetPropertyName()
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
continue
}
if !hardware {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
continue
}
adapters = append(adapters, name)
}
return adapters

View File

@@ -45,6 +45,7 @@ const (
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
dnsWatchdogDefaultInterval = 20 * time.Second
ctrldServiceName = "ctrld"
)
// ControlSocketName returns name for control unix socket.
@@ -61,8 +62,9 @@ var logf = func(format string, args ...any) {
}
var svcConfig = &service.Config{
Name: "ctrld",
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}

View File

@@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) {
type task struct {
f func() error
abortOnError bool
Name string
}
func doTasks(tasks []task) bool {
var prevErr error
for _, task := range tasks {
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
if err := task.f(); err != nil {
if task.abortOnError {
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
return false
}
prevErr = err
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
}
}
return true

View File

@@ -16,3 +16,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool { return false }
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }

View File

@@ -2,11 +2,14 @@ package cli
import (
"os"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/mgr"
)
func hasElevatedPrivilege() (bool, error) {
@@ -30,6 +33,53 @@ func hasElevatedPrivilege() (bool, error) {
return token.IsMember(sid)
}
// ConfigureWindowsServiceFailureActions checks if the given service
// has the correct failure actions configured, and updates them if not.
func ConfigureWindowsServiceFailureActions(serviceName string) error {
if runtime.GOOS != "windows" {
return nil // no-op on non-Windows
}
m, err := mgr.Connect()
if err != nil {
return err
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return err
}
defer s.Close()
// restart 3 times with a delay of 2 seconds
actions := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
}
// Set the recovery actions (3 restarts, reset period = 120).
err = s.SetRecoveryActions(actions, 120)
if err != nil {
return err
}
// Ensure that failure actions are NOT triggered on user-initiated stops.
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
if err := windows.ChangeServiceConfig2(
s.Handle,
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
(*byte)(unsafe.Pointer(&failureActionsFlag)),
); err != nil {
return err
}
return nil
}
func openLogFile(path string, mode int) (*os.File, error) {
if len(path) == 0 {
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}