From 59ece456b1cf7980afdcc8b64405e98950573d19 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Jun 2025 16:38:03 +0700 Subject: [PATCH] refactor: improve network interface validation Add context parameter to validInterfacesMap for better error handling and logging. Move Windows-specific network adapter validation logic to the ctrld package. Key changes include: - Add context parameter to validInterfacesMap across all platforms - Move Windows validInterfaces to ctrld.ValidInterfaces - Improve error handling for virtual interface detection on Linux - Update all callers to pass appropriate context This change improves error reporting and makes the interface validation code more maintainable across different platforms. --- cmd/cli/dns_proxy.go | 2 +- cmd/cli/net_darwin.go | 3 +- cmd/cli/net_linux.go | 20 +++++++--- cmd/cli/net_others.go | 3 +- cmd/cli/net_windows.go | 73 ++----------------------------------- cmd/cli/net_windows_test.go | 7 +++- cmd/cli/prog.go | 10 ++--- nameservers_windows.go | 9 ++--- 8 files changed, 38 insertions(+), 89 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index c09e11d..4491160 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1201,7 +1201,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { // Get map of valid interfaces - validIfaces := validInterfacesMap() + validIfaces := validInterfacesMap(ctrld.LoggerCtx(ctx, p.logger.Load())) isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index 6233161..7dac51d 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -3,6 +3,7 @@ package cli import ( "bufio" "bytes" + "context" "io" "net" "os/exec" @@ -51,7 +52,7 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo } // validInterfacesMap returns a set of all valid hardware ports. -func validInterfacesMap() map[string]struct{} { +func validInterfacesMap(ctx context.Context) map[string]struct{} { b, err := exec.Command("networksetup", "-listallhardwareports").Output() if err != nil { return nil diff --git a/cmd/cli/net_linux.go b/cmd/cli/net_linux.go index ea17d3d..c6b30d7 100644 --- a/cmd/cli/net_linux.go +++ b/cmd/cli/net_linux.go @@ -1,12 +1,15 @@ package cli import ( + "context" "net" "net/netip" "os" "strings" "tailscale.com/net/netmon" + + "github.com/Control-D-Inc/ctrld" ) func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } @@ -19,16 +22,16 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo } // validInterfacesMap returns a set containing non virtual interfaces. -func validInterfacesMap() map[string]struct{} { +func validInterfacesMap(ctx context.Context) map[string]struct{} { m := make(map[string]struct{}) - vis := virtualInterfaces() + vis := virtualInterfaces(ctx) netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { if _, existed := vis[i.Name]; existed { return } m[i.Name] = struct{}{} }) - // Fallback to default route interface if found nothing. + // Fallback to the default route interface if found nothing. if len(m) == 0 { defaultRoute, err := netmon.DefaultRoute() if err != nil { @@ -39,10 +42,15 @@ func validInterfacesMap() map[string]struct{} { return m } -// virtualInterfaces returns a map of virtual interfaces on current machine. -func virtualInterfaces() map[string]struct{} { +// virtualInterfaces returns a map of virtual interfaces on the current machine. +func virtualInterfaces(ctx context.Context) map[string]struct{} { + logger := ctrld.LoggerFromCtx(ctx) s := make(map[string]struct{}) - entries, _ := os.ReadDir("/sys/devices/virtual/net") + entries, err := os.ReadDir("/sys/devices/virtual/net") + if err != nil { + logger.Error().Err(err).Msg("failed to read /sys/devices/virtual/net") + return nil + } for _, entry := range entries { if entry.IsDir() { s[strings.TrimSpace(entry.Name())] = struct{}{} diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index f347278..2015d06 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -3,6 +3,7 @@ package cli import ( + "context" "net" "tailscale.com/net/netmon" @@ -13,7 +14,7 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } // validInterfacesMap returns a set containing only default route interfaces. -func validInterfacesMap() map[string]struct{} { +func validInterfacesMap(ctx context.Context) map[string]struct{} { defaultRoute, err := netmon.DefaultRoute() if err != nil { return nil diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index bed06b5..7b00a17 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -1,16 +1,10 @@ package cli import ( - "io" - "log" + "context" "net" - "os" - "github.com/microsoft/wmi/pkg/base/host" - "github.com/microsoft/wmi/pkg/base/instance" - "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/constant" - "github.com/microsoft/wmi/pkg/hardware/network/netadapter" + "github.com/Control-D-Inc/ctrld" ) func patchNetIfaceName(iface *net.Interface) (bool, error) { @@ -25,69 +19,10 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo } // validInterfacesMap returns a set of all physical interfaces. -func validInterfacesMap() map[string]struct{} { +func validInterfacesMap(ctx context.Context) map[string]struct{} { m := make(map[string]struct{}) - for _, ifaceName := range validInterfaces() { + for ifaceName := range ctrld.ValidInterfaces(ctx) { m[ifaceName] = struct{}{} } return m } - -// validInterfaces returns a list of all physical interfaces. -func validInterfaces() []string { - log.SetOutput(io.Discard) - defer log.SetOutput(os.Stderr) - whost := host.NewWmiLocalHost() - q := query.NewWmiQuery("MSFT_NetAdapter") - instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) - if instances != nil { - defer instances.Close() - } - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") - return nil - } - var adapters []string - for _, i := range instances { - adapter, err := netadapter.NewNetworkAdapter(i) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") - continue - } - - name, err := adapter.GetPropertyName() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get interface name") - continue - } - - // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) - // - // "Indicates if a connector is present on the network adapter. This value is set to TRUE - // if this is a physical adapter or FALSE if this is not a physical adapter." - physical, err := adapter.GetPropertyConnectorPresent() - if err != nil { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property") - continue - } - if !physical { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter") - continue - } - - // Check if it's a hardware interface. Checking only for connector present is not enough - // because some interfaces are not physical but have a connector. - hardware, err := adapter.GetPropertyHardwareInterface() - if err != nil { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property") - continue - } - if !hardware { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface") - continue - } - - adapters = append(adapters, name) - } - return adapters -} diff --git a/cmd/cli/net_windows_test.go b/cmd/cli/net_windows_test.go index a15f119..551fe78 100644 --- a/cmd/cli/net_windows_test.go +++ b/cmd/cli/net_windows_test.go @@ -3,18 +3,23 @@ package cli import ( "bufio" "bytes" + "context" + "maps" "slices" "strings" "testing" "time" + + "github.com/Control-D-Inc/ctrld" ) func Test_validInterfaces(t *testing.T) { verbose = 3 initConsoleLogging() start := time.Now() - ifaces := validInterfaces() + im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load())) t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + ifaces := slices.Collect(maps.Keys(im)) start = time.Now() ifacesPowershell := validInterfacesPowershell() diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 0cfd3b9..89cdab7 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -1320,8 +1320,8 @@ func canBeLocalUpstream(addr string) bool { // withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding // the interface that matches excludeIfaceName. The context is used to clarify the // log message when error happens. -func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) { - validIfacesMap := validInterfacesMap() +func withEachPhysicalInterfaces(excludeIfaceName, contextStr string, f func(i *net.Interface) error) { + validIfacesMap := validInterfacesMap(ctrld.LoggerCtx(context.Background(), mainLog.Load())) netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { // Skip loopback/virtual/down interface. if i.IsLoopback() || len(i.HardwareAddr) == 0 { @@ -1345,11 +1345,11 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. } // TODO: investigate whether we should report this error? if err := f(netIface); err == nil { - if context != "" { - mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name) + if contextStr != "" { + mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", contextStr, i.Name) } } else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) { - mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name) + mainLog.Load().Err(err).Msgf("%s for interface %q failed", contextStr, i.Name) } }) } diff --git a/nameservers_windows.go b/nameservers_windows.go index 596fb5f..ecffc89 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -210,7 +210,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } - validInterfacesMap := validInterfaces(ctx) + validInterfacesMap := ValidInterfaces(ctx) // Collect DNS servers for _, aa := range aas { @@ -377,10 +377,9 @@ func getLocalADDomain() (string, error) { return domainName, nil } -// validInterfaces returns a list of all physical interfaces. -// this is a duplicate of what is in net_windows.go, we should -// clean this up so there is only one version -func validInterfaces(ctx context.Context) map[string]struct{} { +// ValidInterfaces returns a map of valid network interface names as keys with empty struct values. +// It filters interfaces to include only physical, hardware-based adapters using WMI queries. +func ValidInterfaces(ctx context.Context) map[string]struct{} { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr)