cmd/cli: get static DNS using syscall

This commit is contained in:
Cuong Manh Le
2024-11-25 20:06:42 +07:00
committed by Cuong Manh Le
parent 5e9b4244e7
commit 6837176ec7
2 changed files with 84 additions and 9 deletions

View File

@@ -13,14 +13,15 @@ import (
"sync"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
const (
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
)
var (
@@ -177,25 +178,31 @@ func currentDNS(iface *net.Interface) []string {
func currentStaticDNS(iface *net.Interface) ([]string, error) {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return nil, err
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
}
guid, err := luid.GUID()
if err != nil {
return nil, err
return nil, fmt.Errorf("luid.GUID: %w", err)
}
var ns []string
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
interfaceKeyPath := path + guid.String()
found := false
interfaceKeyPath := path + guid.String()
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
}
for _, key := range []string{"NameServer", "ProfileNameServer"} {
if found {
continue
}
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
out, err := powershell(cmd)
if err == nil && len(out) > 0 {
value, _, err := k.GetStringValue(key)
if err != nil && !errors.Is(err, registry.ErrNotExist) {
return nil, fmt.Errorf("%s: %w", key, err)
}
if len(value) > 0 {
found = true
for _, e := range strings.Split(string(out), ",") {
for _, e := range strings.Split(value, ",") {
ns = append(ns, strings.TrimRight(e, "\x00"))
}
}

View File

@@ -0,0 +1,68 @@
package cli
import (
"fmt"
"net"
"slices"
"strings"
"testing"
"time"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func Test_currentStaticDNS(t *testing.T) {
iface, err := net.InterfaceByName(defaultIfaceName())
if err != nil {
t.Fatal(err)
}
start := time.Now()
staticDns, err := currentStaticDNS(iface)
if err != nil {
t.Fatal(err)
}
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
if err != nil {
t.Fatal(err)
}
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
slices.Sort(staticDns)
slices.Sort(staticDnsPowershell)
if !slices.Equal(staticDns, staticDnsPowershell) {
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
}
}
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return nil, err
}
guid, err := luid.GUID()
if err != nil {
return nil, err
}
var ns []string
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
interfaceKeyPath := path + guid.String()
found := false
for _, key := range []string{"NameServer", "ProfileNameServer"} {
if found {
continue
}
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
out, err := powershell(cmd)
if err == nil && len(out) > 0 {
found = true
for _, e := range strings.Split(string(out), ",") {
ns = append(ns, strings.TrimRight(e, "\x00"))
}
}
}
}
return ns, nil
}