mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
ignore non physical ifaces in validInterfaces method on Windows
debugging skip type 24 in nameserver detection skip type 24 in nameserver detection remove interface type check from valid interfaces for now skip non hardware interfaces in DNS nameserver lookup ignore win api log output set retries to 5 and 1s backoff reset DNS when upgrading to make sure we get the proper OS nameservers on start init running iface for upgrade update windows service options for auto restarts on failure make upgrade use the actual stop and start commands fix the windows service retry logic fix the windows service retry logic task debugging more task debugging windows service name fix windows service name fix fix start command args fix restart delay dont recover from non crash failures fix upgrade flow
This commit is contained in:
@@ -126,7 +126,7 @@ func initCLI() {
|
|||||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||||
|
|
||||||
initRunCmd()
|
initRunCmd()
|
||||||
startCmd := initStartCmd()
|
startCmd, startCmdAlias := initStartCmd()
|
||||||
stopCmd := initStopCmd()
|
stopCmd := initStopCmd()
|
||||||
restartCmd := initRestartCmd()
|
restartCmd := initRestartCmd()
|
||||||
reloadCmd := initReloadCmd(restartCmd)
|
reloadCmd := initReloadCmd(restartCmd)
|
||||||
@@ -135,7 +135,7 @@ func initCLI() {
|
|||||||
interfacesCmd := initInterfacesCmd()
|
interfacesCmd := initInterfacesCmd()
|
||||||
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
|
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
|
||||||
initClientsCmd()
|
initClientsCmd()
|
||||||
initUpgradeCmd()
|
initUpgradeCmd(startCmdAlias)
|
||||||
initLogCmd()
|
initLogCmd()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +243,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
|||||||
if err := s.Run(); err != nil {
|
if err := s.Run(); err != nil {
|
||||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||||
}
|
}
|
||||||
|
// Configure Windows service failure actions
|
||||||
|
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
||||||
|
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||||
@@ -1016,8 +1020,8 @@ func uninstall(p *prog, s service.Service) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
tasks := []task{
|
tasks := []task{
|
||||||
{s.Stop, false},
|
{s.Stop, false, "Stop"},
|
||||||
{s.Uninstall, true},
|
{s.Uninstall, true, "Uninstall"},
|
||||||
}
|
}
|
||||||
initLogging()
|
initLogging()
|
||||||
if doTasks(tasks) {
|
if doTasks(tasks) {
|
||||||
@@ -1688,6 +1692,10 @@ func runInCdMode() bool {
|
|||||||
// curCdUID returns the current ControlD UID used by running ctrld process.
|
// curCdUID returns the current ControlD UID used by running ctrld process.
|
||||||
func curCdUID() string {
|
func curCdUID() string {
|
||||||
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
||||||
|
// Configure Windows service failure actions
|
||||||
|
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
||||||
|
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||||
|
}
|
||||||
if dir, _ := socketDir(); dir != "" {
|
if dir, _ := socketDir(); dir != "" {
|
||||||
cc := newSocketControlClient(context.TODO(), s, dir)
|
cc := newSocketControlClient(context.TODO(), s, dir)
|
||||||
if cc != nil {
|
if cc != nil {
|
||||||
@@ -1791,7 +1799,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe
|
|||||||
}
|
}
|
||||||
iface = oldIface
|
iface = oldIface
|
||||||
return nil
|
return nil
|
||||||
}, false}
|
}, false, "Reset DNS"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
|
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
|
||||||
@@ -1840,7 +1848,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
|
|||||||
|
|
||||||
p.resetDNS()
|
p.resetDNS()
|
||||||
|
|
||||||
tasks := []task{{s.Uninstall, true}}
|
tasks := []task{{s.Uninstall, true, "Uninstall"}}
|
||||||
if doTasks(tasks) {
|
if doTasks(tasks) {
|
||||||
logger.Info().Msg("uninstalled service")
|
logger.Info().Msg("uninstalled service")
|
||||||
if doStop {
|
if doStop {
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command {
|
|||||||
return runCmd
|
return runCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func initStartCmd() *cobra.Command {
|
func initStartCmd() (*cobra.Command, *cobra.Command) {
|
||||||
startCmd := &cobra.Command{
|
startCmd := &cobra.Command{
|
||||||
PreRun: func(cmd *cobra.Command, args []string) {
|
PreRun: func(cmd *cobra.Command, args []string) {
|
||||||
checkHasElevatedPrivilege()
|
checkHasElevatedPrivilege()
|
||||||
@@ -310,7 +310,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
|
|
||||||
initLogging()
|
initLogging()
|
||||||
tasks := []task{
|
tasks := []task{
|
||||||
{s.Stop, false},
|
{s.Stop, false, "Stop"},
|
||||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||||
{func() error {
|
{func() error {
|
||||||
// Save current DNS so we can restore later.
|
// Save current DNS so we can restore later.
|
||||||
@@ -321,9 +321,12 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}, false},
|
}, false, "Save current DNS"},
|
||||||
{s.Start, true},
|
{func() error {
|
||||||
{noticeWritingControlDConfig, false},
|
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||||
|
}, false, "Configure Windows service failure actions"},
|
||||||
|
{s.Start, true, "Start"},
|
||||||
|
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||||
}
|
}
|
||||||
mainLog.Load().Notice().Msg("Starting existing ctrld service")
|
mainLog.Load().Notice().Msg("Starting existing ctrld service")
|
||||||
if doTasks(tasks) {
|
if doTasks(tasks) {
|
||||||
@@ -387,9 +390,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
}
|
}
|
||||||
|
|
||||||
tasks := []task{
|
tasks := []task{
|
||||||
{s.Stop, false},
|
{s.Stop, false, "Stop"},
|
||||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
|
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"},
|
||||||
{func() error { return ensureUninstall(s) }, false},
|
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||||
{func() error {
|
{func() error {
|
||||||
// Save current DNS so we can restore later.
|
// Save current DNS so we can restore later.
|
||||||
@@ -400,12 +403,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}, false},
|
}, false, "Save current DNS"},
|
||||||
{s.Install, false},
|
{s.Install, false, "Install"},
|
||||||
{s.Start, true},
|
{func() error {
|
||||||
|
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||||
|
}, false, "Configure Windows service failure actions"},
|
||||||
|
{s.Start, true, "Start"},
|
||||||
// Note that startCmd do not actually write ControlD config, but the config file was
|
// Note that startCmd do not actually write ControlD config, but the config file was
|
||||||
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
||||||
{noticeWritingControlDConfig, false},
|
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||||
}
|
}
|
||||||
mainLog.Load().Notice().Msg("Starting service")
|
mainLog.Load().Notice().Msg("Starting service")
|
||||||
if doTasks(tasks) {
|
if doTasks(tasks) {
|
||||||
@@ -528,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||||
rootCmd.AddCommand(startCmdAlias)
|
rootCmd.AddCommand(startCmdAlias)
|
||||||
|
|
||||||
return startCmd
|
return startCmd, startCmdAlias
|
||||||
}
|
}
|
||||||
|
|
||||||
func initStopCmd() *cobra.Command {
|
func initStopCmd() *cobra.Command {
|
||||||
@@ -558,7 +564,7 @@ func initStopCmd() *cobra.Command {
|
|||||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||||
os.Exit(deactivationPinInvalidExitCode)
|
os.Exit(deactivationPinInvalidExitCode)
|
||||||
}
|
}
|
||||||
if doTasks([]task{{s.Stop, true}}) {
|
if doTasks([]task{{s.Stop, true, "Stop"}}) {
|
||||||
p.router.Cleanup()
|
p.router.Cleanup()
|
||||||
p.resetDNS()
|
p.resetDNS()
|
||||||
|
|
||||||
@@ -651,8 +657,8 @@ func initRestartCmd() *cobra.Command {
|
|||||||
iface = ir.Name
|
iface = ir.Name
|
||||||
}
|
}
|
||||||
tasks := []task{
|
tasks := []task{
|
||||||
{s.Stop, false},
|
{s.Stop, false, "Stop"},
|
||||||
{s.Start, true},
|
{s.Start, true, "Start"},
|
||||||
}
|
}
|
||||||
if doTasks(tasks) {
|
if doTasks(tasks) {
|
||||||
dir, err := socketDir()
|
dir, err := socketDir()
|
||||||
@@ -1043,7 +1049,7 @@ func initClientsCmd() *cobra.Command {
|
|||||||
return clientsCmd
|
return clientsCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func initUpgradeCmd() *cobra.Command {
|
func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
||||||
const (
|
const (
|
||||||
upgradeChannelDev = "dev"
|
upgradeChannelDev = "dev"
|
||||||
upgradeChannelProd = "prod"
|
upgradeChannelProd = "prod"
|
||||||
@@ -1115,23 +1121,23 @@ func initUpgradeCmd() *cobra.Command {
|
|||||||
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
|
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we run the actual commands to make sure all the logic we want is executed
|
||||||
doRestart := func() bool {
|
doRestart := func() bool {
|
||||||
if !svcInstalled {
|
|
||||||
return true
|
// run the start command so that we reinit the service
|
||||||
|
// this is to fix the non restarting options on windows for existing clients
|
||||||
|
// we have to reset os.Args, since other commands use it.
|
||||||
|
curCdUID := curCdUID()
|
||||||
|
startArgs := []string{}
|
||||||
|
os.Args = []string{"ctrld", "start"}
|
||||||
|
if curCdUID != "" {
|
||||||
|
startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID))
|
||||||
|
os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID))
|
||||||
}
|
}
|
||||||
tasks := []task{
|
startCmd.Run(startCmd, startArgs)
|
||||||
{s.Stop, false},
|
|
||||||
{s.Start, false},
|
return true
|
||||||
}
|
|
||||||
if doTasks(tasks) {
|
|
||||||
if dir, err := socketDir(); err == nil {
|
|
||||||
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
|
||||||
_, _ = cc.post(ifacePath, nil)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
if svcInstalled {
|
if svcInstalled {
|
||||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||||
|
|||||||
@@ -1339,9 +1339,9 @@ func parseInterfaceState(state *netmon.State) map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result := make(map[string]string)
|
result := make(map[string]string)
|
||||||
|
|
||||||
stateStr := state.String()
|
stateStr := state.String()
|
||||||
|
|
||||||
// Extract interface information
|
// Extract interface information
|
||||||
ifsStart := strings.Index(stateStr, "ifs={")
|
ifsStart := strings.Index(stateStr, "ifs={")
|
||||||
if ifsStart == -1 {
|
if ifsStart == -1 {
|
||||||
@@ -1356,26 +1356,26 @@ func parseInterfaceState(state *netmon.State) map[string]string {
|
|||||||
|
|
||||||
// Get the content between ifs={ }
|
// Get the content between ifs={ }
|
||||||
ifsContent := strings.TrimSpace(ifsStr[:ifsEnd])
|
ifsContent := strings.TrimSpace(ifsStr[:ifsEnd])
|
||||||
|
|
||||||
// Split on "] " to get each interface entry
|
// Split on "] " to get each interface entry
|
||||||
entries := strings.Split(ifsContent, "] ")
|
entries := strings.Split(ifsContent, "] ")
|
||||||
|
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if entry == "" {
|
if entry == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split on ":["
|
// Split on ":["
|
||||||
parts := strings.Split(entry, ":[")
|
parts := strings.Split(entry, ":[")
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
name := strings.TrimSpace(parts[0])
|
name := strings.TrimSpace(parts[0])
|
||||||
state := "[" + strings.TrimSuffix(parts[1], "]") + "]"
|
state := "[" + strings.TrimSuffix(parts[1], "]") + "]"
|
||||||
|
|
||||||
result[strings.ToLower(name)] = state
|
result[strings.ToLower(name)] = state
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,23 +52,39 @@ func validInterfaces() []string {
|
|||||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
name, err := adapter.GetPropertyName()
|
||||||
|
if err != nil {
|
||||||
|
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||||
//
|
//
|
||||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||||
physical, err := adapter.GetPropertyConnectorPresent()
|
physical, err := adapter.GetPropertyConnectorPresent()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter connector present property")
|
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !physical {
|
if !physical {
|
||||||
|
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
name, err := adapter.GetPropertyName()
|
|
||||||
|
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||||
|
// because some interfaces are not physical but have a connector.
|
||||||
|
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !hardware {
|
||||||
|
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
adapters = append(adapters, name)
|
adapters = append(adapters, name)
|
||||||
}
|
}
|
||||||
return adapters
|
return adapters
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ const (
|
|||||||
upstreamOS = upstreamPrefix + "os"
|
upstreamOS = upstreamPrefix + "os"
|
||||||
upstreamPrivate = upstreamPrefix + "private"
|
upstreamPrivate = upstreamPrefix + "private"
|
||||||
dnsWatchdogDefaultInterval = 20 * time.Second
|
dnsWatchdogDefaultInterval = 20 * time.Second
|
||||||
|
ctrldServiceName = "ctrld"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ControlSocketName returns name for control unix socket.
|
// ControlSocketName returns name for control unix socket.
|
||||||
@@ -61,8 +62,9 @@ var logf = func(format string, args ...any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var svcConfig = &service.Config{
|
var svcConfig = &service.Config{
|
||||||
Name: "ctrld",
|
Name: ctrldServiceName,
|
||||||
DisplayName: "Control-D Helper Service",
|
DisplayName: "Control-D Helper Service",
|
||||||
|
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||||
Option: service.KeyValue{},
|
Option: service.KeyValue{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) {
|
|||||||
type task struct {
|
type task struct {
|
||||||
f func() error
|
f func() error
|
||||||
abortOnError bool
|
abortOnError bool
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
func doTasks(tasks []task) bool {
|
func doTasks(tasks []task) bool {
|
||||||
var prevErr error
|
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
|
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||||
if err := task.f(); err != nil {
|
if err := task.f(); err != nil {
|
||||||
if task.abortOnError {
|
if task.abortOnError {
|
||||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
prevErr = err
|
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -16,3 +16,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
|
|||||||
|
|
||||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||||
func hasLocalDnsServerRunning() bool { return false }
|
func hasLocalDnsServerRunning() bool { return false }
|
||||||
|
|
||||||
|
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func hasElevatedPrivilege() (bool, error) {
|
func hasElevatedPrivilege() (bool, error) {
|
||||||
@@ -30,6 +33,53 @@ func hasElevatedPrivilege() (bool, error) {
|
|||||||
return token.IsMember(sid)
|
return token.IsMember(sid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||||
|
// has the correct failure actions configured, and updates them if not.
|
||||||
|
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return nil // no-op on non-Windows
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
|
||||||
|
s, err := m.OpenService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// restart 3 times with a delay of 2 seconds
|
||||||
|
actions := []mgr.RecoveryAction{
|
||||||
|
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||||
|
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||||
|
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the recovery actions (3 restarts, reset period = 120).
|
||||||
|
err = s.SetRecoveryActions(actions, 120)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||||
|
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||||
|
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||||
|
|
||||||
|
if err := windows.ChangeServiceConfig2(
|
||||||
|
s.Handle,
|
||||||
|
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||||
|
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func openLogFile(path string, mode int) (*os.File, error) {
|
func openLogFile(path string, mode int) (*os.File, error) {
|
||||||
if len(path) == 0 {
|
if len(path) == 0 {
|
||||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -38,7 +38,7 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||||
golang.org/x/net v0.33.0
|
golang.org/x/net v0.33.0
|
||||||
golang.org/x/sync v0.10.0
|
golang.org/x/sync v0.10.0
|
||||||
golang.org/x/sys v0.28.0
|
golang.org/x/sys v0.29.0
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
tailscale.com v1.74.0
|
tailscale.com v1.74.0
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -494,6 +494,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|||||||
@@ -3,35 +3,41 @@ package ctrld
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
|
"github.com/StackExchange/wmi"
|
||||||
|
"github.com/microsoft/wmi/pkg/base/host"
|
||||||
|
"github.com/microsoft/wmi/pkg/base/instance"
|
||||||
|
"github.com/microsoft/wmi/pkg/base/query"
|
||||||
|
"github.com/microsoft/wmi/pkg/constant"
|
||||||
|
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
"github.com/StackExchange/wmi"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxRetries = 3
|
maxRetries = 5
|
||||||
retryDelay = 500 * time.Millisecond
|
retryDelay = 1 * time.Second
|
||||||
defaultTimeout = 5 * time.Second
|
defaultTimeout = 5 * time.Second
|
||||||
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
||||||
NetSetupUnknown uint32 = 0
|
NetSetupUnknown uint32 = 0
|
||||||
NetSetupWorkgroup uint32 = 1
|
NetSetupWorkgroup uint32 = 1
|
||||||
NetSetupDomain uint32 = 2
|
NetSetupDomain uint32 = 2
|
||||||
NetSetupCloudDomain uint32 = 3
|
NetSetupCloudDomain uint32 = 3
|
||||||
DS_FORCE_REDISCOVERY = 0x00000001
|
DS_FORCE_REDISCOVERY = 0x00000001
|
||||||
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
||||||
DS_BACKGROUND_ONLY = 0x00000100
|
DS_BACKGROUND_ONLY = 0x00000100
|
||||||
DS_IP_REQUIRED = 0x00000200
|
DS_IP_REQUIRED = 0x00000200
|
||||||
DS_IS_DNS_NAME = 0x00020000
|
DS_IS_DNS_NAME = 0x00020000
|
||||||
DS_RETURN_DNS_NAME = 0x40000000
|
DS_RETURN_DNS_NAME = 0x40000000
|
||||||
)
|
)
|
||||||
|
|
||||||
type DomainControllerInfo struct {
|
type DomainControllerInfo struct {
|
||||||
@@ -132,7 +138,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Failed to get local AD domain: %v", err)
|
"Failed to get local AD domain: %v", err)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
// Load netapi32.dll
|
// Load netapi32.dll
|
||||||
@@ -141,9 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
|
|
||||||
var info *DomainControllerInfo
|
var info *DomainControllerInfo
|
||||||
|
|
||||||
flags := uint32(DS_RETURN_DNS_NAME |
|
flags := uint32(DS_RETURN_DNS_NAME |
|
||||||
DS_IP_REQUIRED |
|
DS_IP_REQUIRED |
|
||||||
DS_IS_DNS_NAME)
|
DS_IS_DNS_NAME)
|
||||||
|
|
||||||
// Convert domain name to UTF16 pointer
|
// Convert domain name to UTF16 pointer
|
||||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||||
@@ -221,6 +227,14 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip if software loopback or other non-physical types
|
||||||
|
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||||
|
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Processing adapter %s", aa.FriendlyName())
|
"Processing adapter %s", aa.FriendlyName())
|
||||||
|
|
||||||
@@ -232,12 +246,29 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validInterfacesMap := validInterfaces()
|
||||||
|
|
||||||
// Collect DNS servers
|
// Collect DNS servers
|
||||||
for _, aa := range aas {
|
for _, aa := range aas {
|
||||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip if software loopback or other non-physical types
|
||||||
|
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||||
|
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if not in the validInterfacesMap, skip
|
||||||
|
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||||
ip := dns.Address.IP()
|
ip := dns.Address.IP()
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
@@ -322,8 +353,8 @@ func checkDomainJoined() bool {
|
|||||||
// Consider both traditional and cloud domains as valid domain joins
|
// Consider both traditional and cloud domains as valid domain joins
|
||||||
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
||||||
status,
|
status,
|
||||||
status == NetSetupDomain,
|
status == NetSetupDomain,
|
||||||
status == NetSetupCloudDomain,
|
status == NetSetupCloudDomain,
|
||||||
isDomain)
|
isDomain)
|
||||||
@@ -333,32 +364,111 @@ func checkDomainJoined() bool {
|
|||||||
|
|
||||||
// Win32_ComputerSystem is the minimal struct for WMI query
|
// Win32_ComputerSystem is the minimal struct for WMI query
|
||||||
type Win32_ComputerSystem struct {
|
type Win32_ComputerSystem struct {
|
||||||
Domain string
|
Domain string
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLocalADDomain tries to detect the AD domain in two ways:
|
// getLocalADDomain tries to detect the AD domain in two ways:
|
||||||
// 1) USERDNSDOMAIN env var (often set in AD logon sessions)
|
// 1. USERDNSDOMAIN env var (often set in AD logon sessions)
|
||||||
// 2) WMI Win32_ComputerSystem.Domain
|
// 2. WMI Win32_ComputerSystem.Domain
|
||||||
func getLocalADDomain() (string, error) {
|
func getLocalADDomain() (string, error) {
|
||||||
// 1) Check environment variable
|
// 1) Check environment variable
|
||||||
envDomain := os.Getenv("USERDNSDOMAIN")
|
envDomain := os.Getenv("USERDNSDOMAIN")
|
||||||
if envDomain != "" {
|
if envDomain != "" {
|
||||||
return strings.TrimSpace(envDomain), nil
|
return strings.TrimSpace(envDomain), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
|
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
|
||||||
var result []Win32_ComputerSystem
|
var result []Win32_ComputerSystem
|
||||||
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
|
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("WMI query failed: %v", err)
|
return "", fmt.Errorf("WMI query failed: %v", err)
|
||||||
}
|
}
|
||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
domain := strings.TrimSpace(result[0].Domain)
|
||||||
|
if domain == "" {
|
||||||
|
return "", fmt.Errorf("machine does not appear to have a domain set")
|
||||||
|
}
|
||||||
|
return domain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validInterfaces returns a list of all physical interfaces.
|
||||||
|
// this is a duplicate of what is in net_windows.go, we should
|
||||||
|
// clean this up so there is only one version
|
||||||
|
func validInterfaces() map[string]struct{} {
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
//load the logger
|
||||||
|
logger := zerolog.New(io.Discard)
|
||||||
|
if ProxyLogger.Load() != nil {
|
||||||
|
logger = *ProxyLogger.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
whost := host.NewWmiLocalHost()
|
||||||
|
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||||
|
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||||
|
if err != nil {
|
||||||
|
Log(context.Background(), logger.Warn(),
|
||||||
|
"failed to get wmi network adapter: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer instances.Close()
|
||||||
|
var adapters []string
|
||||||
|
for _, i := range instances {
|
||||||
|
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||||
|
if err != nil {
|
||||||
|
Log(context.Background(), logger.Warn(),
|
||||||
|
"failed to get network adapter: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := adapter.GetPropertyName()
|
||||||
|
if err != nil {
|
||||||
|
Log(context.Background(), logger.Warn(),
|
||||||
|
"failed to get interface name: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||||
|
//
|
||||||
|
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||||
|
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||||
|
physical, err := adapter.GetPropertyConnectorPresent()
|
||||||
|
if err != nil {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"failed to get network adapter connector present property: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !physical {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"skipping non-physical adapter: %s", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||||
|
// because some interfaces are not physical but have a connector.
|
||||||
|
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||||
|
if err != nil {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"failed to get network adapter hardware interface property: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !hardware {
|
||||||
|
Log(context.Background(), logger.Debug(),
|
||||||
|
"skipping non-hardware interface: %s", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
adapters = append(adapters, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[string]struct{})
|
||||||
|
for _, ifaceName := range adapters {
|
||||||
|
m[ifaceName] = struct{}{}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
|
||||||
domain := strings.TrimSpace(result[0].Domain)
|
|
||||||
if domain == "" {
|
|
||||||
return "", fmt.Errorf("machine does not appear to have a domain set")
|
|
||||||
}
|
|
||||||
return domain, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user