diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index aa44418..5ff9360 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -13,14 +13,15 @@ import ( "sync" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) const ( - v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` - v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` ) var ( @@ -177,25 +178,31 @@ func currentDNS(iface *net.Interface) []string { func currentStaticDNS(iface *net.Interface) ([]string, error) { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return nil, err + return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err) } guid, err := luid.GUID() if err != nil { - return nil, err + return nil, fmt.Errorf("luid.GUID: %w", err) } var ns []string for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { - interfaceKeyPath := path + guid.String() found := false + interfaceKeyPath := path + guid.String() + k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err) + } for _, key := range []string{"NameServer", "ProfileNameServer"} { if found { continue } - cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) - out, err := powershell(cmd) - if err == nil && len(out) > 0 { + value, _, err := k.GetStringValue(key) + if err != nil && !errors.Is(err, registry.ErrNotExist) { + return nil, fmt.Errorf("%s: %w", key, err) + } + if len(value) > 0 { found = true - for _, e := range strings.Split(string(out), ",") { + for _, e := range strings.Split(value, ",") { ns = append(ns, strings.TrimRight(e, "\x00")) } } diff --git a/cmd/cli/os_windows_test.go b/cmd/cli/os_windows_test.go new file mode 100644 index 0000000..40be5ed --- /dev/null +++ b/cmd/cli/os_windows_test.go @@ -0,0 +1,68 @@ +package cli + +import ( + "fmt" + "net" + "slices" + "strings" + "testing" + "time" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func Test_currentStaticDNS(t *testing.T) { + iface, err := net.InterfaceByName(defaultIfaceName()) + if err != nil { + t.Fatal(err) + } + start := time.Now() + staticDns, err := currentStaticDNS(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + staticDnsPowershell, err := currentStaticDnsPowershell(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + slices.Sort(staticDns) + slices.Sort(staticDnsPowershell) + if !slices.Equal(staticDns, staticDnsPowershell) { + t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns) + } +} + +func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) { + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return nil, err + } + guid, err := luid.GUID() + if err != nil { + return nil, err + } + var ns []string + for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} { + interfaceKeyPath := path + guid.String() + found := false + for _, key := range []string{"NameServer", "ProfileNameServer"} { + if found { + continue + } + cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) + out, err := powershell(cmd) + if err == nil && len(out) > 0 { + found = true + for _, e := range strings.Split(string(out), ",") { + ns = append(ns, strings.TrimRight(e, "\x00")) + } + } + } + } + return ns, nil +}