From 9b6a308958d4d7420c066e3f69a0afecbfdac10c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 21 Nov 2024 20:24:46 +0700 Subject: [PATCH] cmd/cli: get AD domain using Windows API --- cmd/cli/ad_windows.go | 18 +++++++++++++----- cmd/cli/ad_windows_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 cmd/cli/ad_windows_test.go diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index d7374d0..316414d 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -1,8 +1,11 @@ package cli import ( - "fmt" "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" "github.com/Control-D-Inc/ctrld" ) @@ -40,10 +43,15 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // getActiveDirectoryDomain returns AD domain name of this computer. func getActiveDirectoryDomain() (string, error) { - cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" - output, err := powershell(cmd) + var domain *uint16 + var status uint32 + err := syscall.NetGetJoinInformation(nil, &domain, &status) if err != nil { - return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + return "", err } - return string(output), nil + defer syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + if status == syscall.NetSetupDomainName { + return windows.UTF16PtrToString(domain), nil + } + return "", nil } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go new file mode 100644 index 0000000..392abbd --- /dev/null +++ b/cmd/cli/ad_windows_test.go @@ -0,0 +1,36 @@ +package cli + +import ( + "fmt" + "testing" + "time" +) + +func Test_getActiveDirectoryDomain(t *testing.T) { + start := time.Now() + domain, err := getActiveDirectoryDomain() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + domainPowershell, err := getActiveDirectoryDomainPowershell() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if domain != domainPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain) + } +} + +func getActiveDirectoryDomainPowershell() (string, error) { + cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" + output, err := powershell(cmd) + if err != nil { + return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + } + return string(output), nil +}