mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
cmd/cli: get static DNS using syscall
This commit is contained in:
committed by
Cuong Manh Le
parent
5e9b4244e7
commit
6837176ec7
@@ -13,14 +13,15 @@ import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -177,25 +178,31 @@ func currentDNS(iface *net.Interface) []string {
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("luid.GUID: %w", err)
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
|
||||
}
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
value, _, err := k.GetStringValue(key)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%s: %w", key, err)
|
||||
}
|
||||
if len(value) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
for _, e := range strings.Split(value, ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
|
||||
68
cmd/cli/os_windows_test.go
Normal file
68
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
Reference in New Issue
Block a user