From a56711796fb1f67904436f1f3b38772d79dde666 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 20 Nov 2024 21:39:01 +0700 Subject: [PATCH] cmd/cli: set DNS using Windows API --- cmd/cli/os_windows.go | 55 +++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index b9412b6..1a22b0f 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -4,12 +4,13 @@ import ( "errors" "fmt" "net" + "net/netip" "os" "slices" - "strconv" "strings" "sync" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -30,14 +31,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e return setDNS(iface, nameservers) } -func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string { - nss := make([]string, 0, len(nameservers)) - for _, ns := range nameservers { - nss = append(nss, strconv.Quote(ns)) - } - return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ",")) -} - // setDNS sets the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { @@ -65,9 +58,36 @@ func setDNS(iface *net.Interface, nameservers []string) error { } } }) - out, err := powershell(setDnsPowershellCmd(iface, nameservers)) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("setDNS: %w", err) + } + var ( + serversV4 []netip.Addr + serversV6 []netip.Addr + ) + for _, ns := range nameservers { + if addr, err := netip.ParseAddr(ns); err == nil { + if addr.Is4() { + serversV4 = append(serversV4, addr) + } else { + serversV6 = append(serversV6, addr) + } + } + } + + if len(serversV4) == 0 && len(serversV6) == 0 { + return errors.New("invalid DNS nameservers") + } + if len(serversV4) > 0 { + if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil { + return fmt.Errorf("could not set DNS ipv4: %w", err) + } + } + if len(serversV6) > 0 { + if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil { + return fmt.Errorf("could not set DNS ipv6: %w", err) + } } return nil } @@ -96,11 +116,16 @@ func resetDNS(iface *net.Interface) error { } }) - // Restoring DHCP settings. - cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index) - out, err := powershell(cmd) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("resetDNS: %w", err) + } + // Restoring DHCP settings. + if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv4: %w", err) + } + if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv6: %w", err) } // If there's static DNS saved, restoring it.