refactor: improve network interface validation

Add context parameter to validInterfacesMap for better error handling and
logging. Move Windows-specific network adapter validation logic to the
ctrld package. Key changes include:

- Add context parameter to validInterfacesMap across all platforms
- Move Windows validInterfaces to ctrld.ValidInterfaces
- Improve error handling for virtual interface detection on Linux
- Update all callers to pass appropriate context

This change improves error reporting and makes the interface validation
code more maintainable across different platforms.
This commit is contained in:
Cuong Manh Le
2025-06-19 16:38:03 +07:00
committed by Cuong Manh Le
parent d5cb327620
commit 59ece456b1
8 changed files with 38 additions and 89 deletions

View File

@@ -1201,7 +1201,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
// Get map of valid interfaces
validIfaces := validInterfacesMap()
validIfaces := validInterfacesMap(ctrld.LoggerCtx(ctx, p.logger.Load()))
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)

View File

@@ -3,6 +3,7 @@ package cli
import (
"bufio"
"bytes"
"context"
"io"
"net"
"os/exec"
@@ -51,7 +52,7 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
}
// validInterfacesMap returns a set of all valid hardware ports.
func validInterfacesMap() map[string]struct{} {
func validInterfacesMap(ctx context.Context) map[string]struct{} {
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
if err != nil {
return nil

View File

@@ -1,12 +1,15 @@
package cli
import (
"context"
"net"
"net/netip"
"os"
"strings"
"tailscale.com/net/netmon"
"github.com/Control-D-Inc/ctrld"
)
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
@@ -19,16 +22,16 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
}
// validInterfacesMap returns a set containing non virtual interfaces.
func validInterfacesMap() map[string]struct{} {
func validInterfacesMap(ctx context.Context) map[string]struct{} {
m := make(map[string]struct{})
vis := virtualInterfaces()
vis := virtualInterfaces(ctx)
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
if _, existed := vis[i.Name]; existed {
return
}
m[i.Name] = struct{}{}
})
// Fallback to default route interface if found nothing.
// Fallback to the default route interface if found nothing.
if len(m) == 0 {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
@@ -39,10 +42,15 @@ func validInterfacesMap() map[string]struct{} {
return m
}
// virtualInterfaces returns a map of virtual interfaces on current machine.
func virtualInterfaces() map[string]struct{} {
// virtualInterfaces returns a map of virtual interfaces on the current machine.
func virtualInterfaces(ctx context.Context) map[string]struct{} {
logger := ctrld.LoggerFromCtx(ctx)
s := make(map[string]struct{})
entries, _ := os.ReadDir("/sys/devices/virtual/net")
entries, err := os.ReadDir("/sys/devices/virtual/net")
if err != nil {
logger.Error().Err(err).Msg("failed to read /sys/devices/virtual/net")
return nil
}
for _, entry := range entries {
if entry.IsDir() {
s[strings.TrimSpace(entry.Name())] = struct{}{}

View File

@@ -3,6 +3,7 @@
package cli
import (
"context"
"net"
"tailscale.com/net/netmon"
@@ -13,7 +14,7 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
// validInterfacesMap returns a set containing only default route interfaces.
func validInterfacesMap() map[string]struct{} {
func validInterfacesMap(ctx context.Context) map[string]struct{} {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
return nil

View File

@@ -1,16 +1,10 @@
package cli
import (
"io"
"log"
"context"
"net"
"os"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
"github.com/Control-D-Inc/ctrld"
)
func patchNetIfaceName(iface *net.Interface) (bool, error) {
@@ -25,69 +19,10 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
}
// validInterfacesMap returns a set of all physical interfaces.
func validInterfacesMap() map[string]struct{} {
func validInterfacesMap(ctx context.Context) map[string]struct{} {
m := make(map[string]struct{})
for _, ifaceName := range validInterfaces() {
for ifaceName := range ctrld.ValidInterfaces(ctx) {
m[ifaceName] = struct{}{}
}
return m
}
// validInterfaces returns a list of all physical interfaces.
func validInterfaces() []string {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
return nil
}
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
continue
}
if !physical {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
continue
}
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
continue
}
if !hardware {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
continue
}
adapters = append(adapters, name)
}
return adapters
}

View File

@@ -3,18 +3,23 @@ package cli
import (
"bufio"
"bytes"
"context"
"maps"
"slices"
"strings"
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
)
func Test_validInterfaces(t *testing.T) {
verbose = 3
initConsoleLogging()
start := time.Now()
ifaces := validInterfaces()
im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
ifaces := slices.Collect(maps.Keys(im))
start = time.Now()
ifacesPowershell := validInterfacesPowershell()

View File

@@ -1320,8 +1320,8 @@ func canBeLocalUpstream(addr string) bool {
// withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding
// the interface that matches excludeIfaceName. The context is used to clarify the
// log message when error happens.
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap()
func withEachPhysicalInterfaces(excludeIfaceName, contextStr string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
// Skip loopback/virtual/down interface.
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
@@ -1345,11 +1345,11 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
}
// TODO: investigate whether we should report this error?
if err := f(netIface); err == nil {
if context != "" {
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name)
if contextStr != "" {
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", contextStr, i.Name)
}
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
mainLog.Load().Err(err).Msgf("%s for interface %q failed", contextStr, i.Name)
}
})
}

View File

@@ -210,7 +210,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
}
}
validInterfacesMap := validInterfaces(ctx)
validInterfacesMap := ValidInterfaces(ctx)
// Collect DNS servers
for _, aa := range aas {
@@ -377,10 +377,9 @@ func getLocalADDomain() (string, error) {
return domainName, nil
}
// validInterfaces returns a list of all physical interfaces.
// this is a duplicate of what is in net_windows.go, we should
// clean this up so there is only one version
func validInterfaces(ctx context.Context) map[string]struct{} {
// ValidInterfaces returns a map of valid network interface names as keys with empty struct values.
// It filters interfaces to include only physical, hardware-based adapters using WMI queries.
func ValidInterfaces(ctx context.Context) map[string]struct{} {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)