diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 7565517..223f14e 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -126,7 +126,7 @@ func initCLI() { rootCmd.CompletionOptions.HiddenDefaultCmd = true initRunCmd() - startCmd := initStartCmd() + startCmd, startCmdAlias := initStartCmd() stopCmd := initStopCmd() restartCmd := initRestartCmd() reloadCmd := initReloadCmd(restartCmd) @@ -135,7 +135,7 @@ func initCLI() { interfacesCmd := initInterfacesCmd() initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) initClientsCmd() - initUpgradeCmd() + initUpgradeCmd(startCmdAlias) initLogCmd() } @@ -243,6 +243,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := s.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start service") } + // Configure Windows service failure actions + if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { + mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + } }() } writeDefaultConfig := !noConfigStart && configBase64 == "" @@ -1016,8 +1020,8 @@ func uninstall(p *prog, s service.Service) { return } tasks := []task{ - {s.Stop, false}, - {s.Uninstall, true}, + {s.Stop, false, "Stop"}, + {s.Uninstall, true, "Uninstall"}, } initLogging() if doTasks(tasks) { @@ -1688,6 +1692,10 @@ func runInCdMode() bool { // curCdUID returns the current ControlD UID used by running ctrld process. func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { + // Configure Windows service failure actions + if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { + mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) if cc != nil { @@ -1791,7 +1799,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe } iface = oldIface return nil - }, false} + }, false, "Reset DNS"} } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. @@ -1840,7 +1848,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { p.resetDNS() - tasks := []task{{s.Uninstall, true}} + tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { logger.Info().Msg("uninstalled service") if doStop { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 8713ea5..f3555e5 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command { return runCmd } -func initStartCmd() *cobra.Command { +func initStartCmd() (*cobra.Command, *cobra.Command) { startCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { checkHasElevatedPrivilege() @@ -310,7 +310,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c initLogging() tasks := []task{ - {s.Stop, false}, + {s.Stop, false, "Stop"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. @@ -321,9 +321,12 @@ NOTE: running "ctrld start" without any arguments will start already installed c return nil }) return nil - }, false}, - {s.Start, true}, - {noticeWritingControlDConfig, false}, + }, false, "Save current DNS"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } mainLog.Load().Notice().Msg("Starting existing ctrld service") if doTasks(tasks) { @@ -387,9 +390,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c } tasks := []task{ - {s.Stop, false}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, - {func() error { return ensureUninstall(s) }, false}, + {s.Stop, false, "Stop"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"}, + {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. @@ -400,12 +403,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c return nil }) return nil - }, false}, - {s.Install, false}, - {s.Start, true}, + }, false, "Save current DNS"}, + {s.Install, false, "Install"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, // Note that startCmd do not actually write ControlD config, but the config file was // generated after s.Start, so we notice users here for consistent with nextdns mode. - {noticeWritingControlDConfig, false}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } mainLog.Load().Notice().Msg("Starting service") if doTasks(tasks) { @@ -528,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) - return startCmd + return startCmd, startCmdAlias } func initStopCmd() *cobra.Command { @@ -558,7 +564,7 @@ func initStopCmd() *cobra.Command { if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } - if doTasks([]task{{s.Stop, true}}) { + if doTasks([]task{{s.Stop, true, "Stop"}}) { p.router.Cleanup() p.resetDNS() @@ -651,8 +657,8 @@ func initRestartCmd() *cobra.Command { iface = ir.Name } tasks := []task{ - {s.Stop, false}, - {s.Start, true}, + {s.Stop, false, "Stop"}, + {s.Start, true, "Start"}, } if doTasks(tasks) { dir, err := socketDir() @@ -1043,7 +1049,7 @@ func initClientsCmd() *cobra.Command { return clientsCmd } -func initUpgradeCmd() *cobra.Command { +func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { const ( upgradeChannelDev = "dev" upgradeChannelProd = "prod" @@ -1115,23 +1121,23 @@ func initUpgradeCmd() *cobra.Command { mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") } + // we run the actual commands to make sure all the logic we want is executed doRestart := func() bool { - if !svcInstalled { - return true + + // run the start command so that we reinit the service + // this is to fix the non restarting options on windows for existing clients + // we have to reset os.Args, since other commands use it. + curCdUID := curCdUID() + startArgs := []string{} + os.Args = []string{"ctrld", "start"} + if curCdUID != "" { + startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID)) + os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID)) } - tasks := []task{ - {s.Stop, false}, - {s.Start, false}, - } - if doTasks(tasks) { - if dir, err := socketDir(); err == nil { - if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { - _, _ = cc.post(ifacePath, nil) - return true - } - } - } - return false + startCmd.Run(startCmd, startArgs) + + return true + } if svcInstalled { mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0d67e88..ac808db 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1339,9 +1339,9 @@ func parseInterfaceState(state *netmon.State) map[string]string { } result := make(map[string]string) - + stateStr := state.String() - + // Extract interface information ifsStart := strings.Index(stateStr, "ifs={") if ifsStart == -1 { @@ -1356,26 +1356,26 @@ func parseInterfaceState(state *netmon.State) map[string]string { // Get the content between ifs={ } ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) - + // Split on "] " to get each interface entry entries := strings.Split(ifsContent, "] ") - + for _, entry := range entries { if entry == "" { continue } - + // Split on ":[" parts := strings.Split(entry, ":[") if len(parts) != 2 { continue } - + name := strings.TrimSpace(parts[0]) state := "[" + strings.TrimSuffix(parts[1], "]") + "]" - + result[strings.ToLower(name)] = state } return result -} \ No newline at end of file +} diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index fe075a3..6290a1c 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -52,23 +52,39 @@ func validInterfaces() []string { mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") continue } + + name, err := adapter.GetPropertyName() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get interface name") + continue + } + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) // // "Indicates if a connector is present on the network adapter. This value is set to TRUE // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get network adapter connector present property") + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property") continue } if !physical { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter") continue } - name, err := adapter.GetPropertyName() + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get interface name") + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property") continue } + if !hardware { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface") + continue + } + adapters = append(adapters, name) } return adapters diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 331f42a..f8147eb 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -45,6 +45,7 @@ const ( upstreamOS = upstreamPrefix + "os" upstreamPrivate = upstreamPrefix + "private" dnsWatchdogDefaultInterval = 20 * time.Second + ctrldServiceName = "ctrld" ) // ControlSocketName returns name for control unix socket. @@ -61,8 +62,9 @@ var logf = func(format string, args ...any) { } var svcConfig = &service.Config{ - Name: "ctrld", + Name: ctrldServiceName, DisplayName: "Control-D Helper Service", + Description: "A highly configurable, multi-protocol DNS forwarding proxy", Option: service.KeyValue{}, } diff --git a/cmd/cli/service.go b/cmd/cli/service.go index e4edfaf..82f144c 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) { type task struct { f func() error abortOnError bool + Name string } func doTasks(tasks []task) bool { - var prevErr error for _, task := range tasks { + mainLog.Load().Debug().Msgf("Running task %s", task.Name) if err := task.f(); err != nil { if task.abortOnError { - mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error()) + mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err) return false } - prevErr = err + mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) } } return true diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index 2303e30..056903c 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -16,3 +16,5 @@ func openLogFile(path string, flags int) (*os.File, error) { // hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. func hasLocalDnsServerRunning() bool { return false } + +func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index af4f317..4d3d281 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,11 +2,14 @@ package cli import ( "os" + "runtime" "strings" "syscall" + "time" "unsafe" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" ) func hasElevatedPrivilege() (bool, error) { @@ -30,6 +33,53 @@ func hasElevatedPrivilege() (bool, error) { return token.IsMember(sid) } +// ConfigureWindowsServiceFailureActions checks if the given service +// has the correct failure actions configured, and updates them if not. +func ConfigureWindowsServiceFailureActions(serviceName string) error { + if runtime.GOOS != "windows" { + return nil // no-op on non-Windows + } + + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return err + } + defer s.Close() + + // restart 3 times with a delay of 2 seconds + actions := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + } + + // Set the recovery actions (3 restarts, reset period = 120). + err = s.SetRecoveryActions(actions, 120) + if err != nil { + return err + } + + // Ensure that failure actions are NOT triggered on user-initiated stops. + var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG + failureActionsFlag.FailureActionsOnNonCrashFailures = 0 + + if err := windows.ChangeServiceConfig2( + s.Handle, + windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, + (*byte)(unsafe.Pointer(&failureActionsFlag)), + ); err != nil { + return err + } + + return nil +} + func openLogFile(path string, mode int) (*os.File, error) { if len(path) == 0 { return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} diff --git a/go.mod b/go.mod index 8e9a8f7..e570bae 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/vishvananda/netlink v1.2.1-beta.2 golang.org/x/net v0.33.0 golang.org/x/sync v0.10.0 - golang.org/x/sys v0.28.0 + golang.org/x/sys v0.29.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) diff --git a/go.sum b/go.sum index fcf2ac7..f2d5ff9 100644 --- a/go.sum +++ b/go.sum @@ -494,6 +494,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/nameservers_windows.go b/nameservers_windows.go index a8c5191..c71e065 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -3,35 +3,41 @@ package ctrld import ( "context" "fmt" + "io" + "log" "net" + "os" "strings" "syscall" "time" "unsafe" - "io" - "os" + "github.com/StackExchange/wmi" + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" + "github.com/microsoft/wmi/pkg/hardware/network/netadapter" "github.com/rs/zerolog" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/StackExchange/wmi" ) const ( - maxRetries = 3 - retryDelay = 500 * time.Millisecond - defaultTimeout = 5 * time.Second - minDNSServers = 1 // Minimum number of DNS servers we want to find - NetSetupUnknown uint32 = 0 - NetSetupWorkgroup uint32 = 1 - NetSetupDomain uint32 = 2 - NetSetupCloudDomain uint32 = 3 - DS_FORCE_REDISCOVERY = 0x00000001 - DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 - DS_BACKGROUND_ONLY = 0x00000100 - DS_IP_REQUIRED = 0x00000200 - DS_IS_DNS_NAME = 0x00020000 - DS_RETURN_DNS_NAME = 0x40000000 + maxRetries = 5 + retryDelay = 1 * time.Second + defaultTimeout = 5 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + NetSetupUnknown uint32 = 0 + NetSetupWorkgroup uint32 = 1 + NetSetupDomain uint32 = 2 + NetSetupCloudDomain uint32 = 3 + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 ) type DomainControllerInfo struct { @@ -132,7 +138,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { if err != nil { Log(context.Background(), logger.Debug(), "Failed to get local AD domain: %v", err) - + } else { // Load netapi32.dll @@ -141,9 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { var info *DomainControllerInfo - flags := uint32(DS_RETURN_DNS_NAME | - DS_IP_REQUIRED | - DS_IS_DNS_NAME) + flags := uint32(DS_RETURN_DNS_NAME | + DS_IP_REQUIRED | + DS_IS_DNS_NAME) // Convert domain name to UTF16 pointer domainUTF16, err := windows.UTF16PtrFromString(domainName) @@ -221,6 +227,14 @@ func getDNSServers(ctx context.Context) ([]string, error) { continue } + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + Log(context.Background(), logger.Debug(), "Processing adapter %s", aa.FriendlyName()) @@ -232,12 +246,29 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } + validInterfacesMap := validInterfaces() + // Collect DNS servers for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { continue } + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + + // if not in the validInterfacesMap, skip + if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { + Log(context.Background(), logger.Debug(), + "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + continue + } + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { @@ -322,8 +353,8 @@ func checkDomainJoined() bool { // Consider both traditional and cloud domains as valid domain joins isDomain := status == NetSetupDomain || status == NetSetupCloudDomain Log(context.Background(), logger.Debug(), - "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", - status, + "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", + status, status == NetSetupDomain, status == NetSetupCloudDomain, isDomain) @@ -333,32 +364,111 @@ func checkDomainJoined() bool { // Win32_ComputerSystem is the minimal struct for WMI query type Win32_ComputerSystem struct { - Domain string + Domain string } // getLocalADDomain tries to detect the AD domain in two ways: -// 1) USERDNSDOMAIN env var (often set in AD logon sessions) -// 2) WMI Win32_ComputerSystem.Domain +// 1. USERDNSDOMAIN env var (often set in AD logon sessions) +// 2. WMI Win32_ComputerSystem.Domain func getLocalADDomain() (string, error) { - // 1) Check environment variable - envDomain := os.Getenv("USERDNSDOMAIN") - if envDomain != "" { - return strings.TrimSpace(envDomain), nil - } + // 1) Check environment variable + envDomain := os.Getenv("USERDNSDOMAIN") + if envDomain != "" { + return strings.TrimSpace(envDomain), nil + } - // 2) Check WMI (requires Windows + admin privileges or sufficient access) - var result []Win32_ComputerSystem - err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) - if err != nil { - return "", fmt.Errorf("WMI query failed: %v", err) - } - if len(result) == 0 { - return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") - } + // 2) Check WMI (requires Windows + admin privileges or sufficient access) + var result []Win32_ComputerSystem + err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + if err != nil { + return "", fmt.Errorf("WMI query failed: %v", err) + } + if len(result) == 0 { + return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") + } + + domain := strings.TrimSpace(result[0].Domain) + if domain == "" { + return "", fmt.Errorf("machine does not appear to have a domain set") + } + return domain, nil +} + +// validInterfaces returns a list of all physical interfaces. +// this is a duplicate of what is in net_windows.go, we should +// clean this up so there is only one version +func validInterfaces() map[string]struct{} { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("MSFT_NetAdapter") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get wmi network adapter: %v", err) + return nil + } + defer instances.Close() + var adapters []string + for _, i := range instances { + adapter, err := netadapter.NewNetworkAdapter(i) + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get network adapter: %v", err) + continue + } + + name, err := adapter.GetPropertyName() + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get interface name: %v", err) + continue + } + + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) + // + // "Indicates if a connector is present on the network adapter. This value is set to TRUE + // if this is a physical adapter or FALSE if this is not a physical adapter." + physical, err := adapter.GetPropertyConnectorPresent() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter connector present property: %v", err) + continue + } + if !physical { + Log(context.Background(), logger.Debug(), + "skipping non-physical adapter: %s", name) + continue + } + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter hardware interface property: %v", err) + continue + } + if !hardware { + Log(context.Background(), logger.Debug(), + "skipping non-hardware interface: %s", name) + continue + } + + adapters = append(adapters, name) + } + + m := make(map[string]struct{}) + for _, ifaceName := range adapters { + m[ifaceName] = struct{}{} + } + return m - domain := strings.TrimSpace(result[0].Domain) - if domain == "" { - return "", fmt.Errorf("machine does not appear to have a domain set") - } - return domain, nil }