mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
fix logging fix logging try to enable nameserver logs try to enable nameserver logs handle flags in interface state changes debugging debugging debugging fix state detection, AD status fix fix debugging line more dc info always log state changes remove unused method windows AD IP discovery windows AD IP discovery windows AD IP discovery
365 lines
10 KiB
Go
365 lines
10 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
"io"
|
|
"os"
|
|
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/sys/windows"
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
"github.com/StackExchange/wmi"
|
|
)
|
|
|
|
const (
|
|
maxRetries = 3
|
|
retryDelay = 500 * time.Millisecond
|
|
defaultTimeout = 5 * time.Second
|
|
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
|
NetSetupUnknown uint32 = 0
|
|
NetSetupWorkgroup uint32 = 1
|
|
NetSetupDomain uint32 = 2
|
|
NetSetupCloudDomain uint32 = 3
|
|
DS_FORCE_REDISCOVERY = 0x00000001
|
|
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
|
DS_BACKGROUND_ONLY = 0x00000100
|
|
DS_IP_REQUIRED = 0x00000200
|
|
DS_IS_DNS_NAME = 0x00020000
|
|
DS_RETURN_DNS_NAME = 0x40000000
|
|
)
|
|
|
|
type DomainControllerInfo struct {
|
|
DomainControllerName *uint16
|
|
DomainControllerAddress *uint16
|
|
DomainControllerAddressType uint32
|
|
DomainGuid windows.GUID
|
|
DomainName *uint16
|
|
DnsForestName *uint16
|
|
Flags uint32
|
|
DcSiteName *uint16
|
|
ClientSiteName *uint16
|
|
}
|
|
|
|
func dnsFns() []dnsFn {
|
|
return []dnsFn{dnsFromAdapter}
|
|
}
|
|
|
|
func dnsFromAdapter() []string {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
defer cancel()
|
|
|
|
var ns []string
|
|
var err error
|
|
|
|
//load the logger
|
|
logger := zerolog.New(io.Discard)
|
|
if ProxyLogger.Load() != nil {
|
|
logger = *ProxyLogger.Load()
|
|
}
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
if ctx.Err() != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
|
return nil
|
|
}
|
|
|
|
ns, err = getDNSServers(ctx)
|
|
if err == nil && len(ns) >= minDNSServers {
|
|
if i > 0 {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
|
|
}
|
|
return ns
|
|
}
|
|
|
|
// Log the specific failure reason
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-time.After(retryDelay):
|
|
}
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries)
|
|
return ns // Return whatever we got, even if insufficient
|
|
}
|
|
|
|
func getDNSServers(ctx context.Context) ([]string, error) {
|
|
//load the logger
|
|
logger := zerolog.New(io.Discard)
|
|
if ProxyLogger.Load() != nil {
|
|
logger = *ProxyLogger.Load()
|
|
}
|
|
// Check context before making the call
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
// Get DNS servers from adapters (existing method)
|
|
flags := winipcfg.GAAFlagIncludeGateways |
|
|
winipcfg.GAAFlagIncludePrefix
|
|
|
|
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting adapters: %w", err)
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Found network adapters, count=%d", len(aas))
|
|
|
|
// Try to get domain controller info if domain-joined
|
|
var dcServers []string
|
|
isDomain := checkDomainJoined()
|
|
if isDomain {
|
|
|
|
domainName, err := getLocalADDomain()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get local AD domain: %v", err)
|
|
|
|
} else {
|
|
|
|
// Load netapi32.dll
|
|
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
|
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
|
|
|
var info *DomainControllerInfo
|
|
|
|
flags := uint32(DS_RETURN_DNS_NAME |
|
|
DS_IP_REQUIRED |
|
|
DS_IS_DNS_NAME)
|
|
|
|
// Convert domain name to UTF16 pointer
|
|
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to convert domain name to UTF16: %v", err)
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
|
|
|
// Call DsGetDcNameW with domain name
|
|
ret, _, err := dsDcName.Call(
|
|
0, // ComputerName - can be NULL
|
|
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
|
|
0, // DomainGuid - not needed
|
|
0, // SiteName - not needed
|
|
uintptr(flags), // Flags
|
|
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
|
|
|
|
if ret != 0 {
|
|
switch ret {
|
|
case 1355: // ERROR_NO_SUCH_DOMAIN
|
|
Log(context.Background(), logger.Debug(),
|
|
"Domain not found: %s (%d)", domainName, ret)
|
|
case 1311: // ERROR_NO_LOGON_SERVERS
|
|
Log(context.Background(), logger.Debug(),
|
|
"No logon servers available for domain: %s (%d)", domainName, ret)
|
|
case 1004: // ERROR_DC_NOT_FOUND
|
|
Log(context.Background(), logger.Debug(),
|
|
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
|
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
|
Log(context.Background(), logger.Debug(),
|
|
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
|
default:
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
|
}
|
|
} else if info != nil {
|
|
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
|
|
|
// Get DC address
|
|
if info.DomainControllerAddress != nil {
|
|
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
|
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Found domain controller address: %s", dcAddr)
|
|
|
|
// Try to resolve DC
|
|
if ip := net.ParseIP(dcAddr); ip != nil {
|
|
dcServers = append(dcServers, ip.String())
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added domain controller DNS servers: %v", dcServers)
|
|
}
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"No domain controller address found")
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
// Continue with existing adapter DNS collection
|
|
ns := make([]string, 0, len(aas)*2)
|
|
seen := make(map[string]bool)
|
|
addressMap := make(map[string]struct{})
|
|
|
|
// Collect all local IPs
|
|
for _, aa := range aas {
|
|
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
|
continue
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Processing adapter %s", aa.FriendlyName())
|
|
|
|
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
|
ip := a.Address.IP().String()
|
|
addressMap[ip] = struct{}{}
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
|
}
|
|
}
|
|
|
|
// Collect DNS servers
|
|
for _, aa := range aas {
|
|
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
|
continue
|
|
}
|
|
|
|
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
|
ip := dns.Address.IP()
|
|
if ip == nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
|
continue
|
|
}
|
|
|
|
ipStr := ip.String()
|
|
logger := logger.Debug().
|
|
Str("ip", ipStr).
|
|
Str("adapter", aa.FriendlyName())
|
|
|
|
if ip.IsLoopback() {
|
|
logger.Msg("Skipping loopback IP")
|
|
continue
|
|
}
|
|
|
|
if seen[ipStr] {
|
|
logger.Msg("Skipping duplicate IP")
|
|
continue
|
|
}
|
|
|
|
if _, ok := addressMap[ipStr]; ok {
|
|
logger.Msg("Skipping local interface IP")
|
|
continue
|
|
}
|
|
|
|
seen[ipStr] = true
|
|
ns = append(ns, ipStr)
|
|
logger.Msg("Added DNS server")
|
|
}
|
|
}
|
|
|
|
// Add DC servers if they're not already in the list
|
|
for _, dcServer := range dcServers {
|
|
if !seen[dcServer] {
|
|
seen[dcServer] = true
|
|
ns = append(ns, dcServer)
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added additional domain controller DNS server: %s", dcServer)
|
|
}
|
|
}
|
|
|
|
if len(ns) == 0 {
|
|
return nil, fmt.Errorf("no valid DNS servers found")
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
|
len(ns), ns, len(dcServers))
|
|
return ns, nil
|
|
}
|
|
|
|
func nameserversFromResolvconf() []string {
|
|
return nil
|
|
}
|
|
|
|
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
|
// Returns whether it's domain joined and the domain name if available
|
|
func checkDomainJoined() bool {
|
|
//load the logger
|
|
logger := zerolog.New(io.Discard)
|
|
if ProxyLogger.Load() != nil {
|
|
logger = *ProxyLogger.Load()
|
|
}
|
|
var domain *uint16
|
|
var status uint32
|
|
|
|
err := windows.NetGetJoinInformation(nil, &domain, &status)
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get domain join status: %v", err)
|
|
return false
|
|
}
|
|
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
|
|
|
|
domainName := windows.UTF16PtrToString(domain)
|
|
Log(context.Background(), logger.Debug(),
|
|
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status)
|
|
|
|
// Consider both traditional and cloud domains as valid domain joins
|
|
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
|
Log(context.Background(), logger.Debug(),
|
|
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
|
status,
|
|
status == NetSetupDomain,
|
|
status == NetSetupCloudDomain,
|
|
isDomain)
|
|
|
|
return isDomain
|
|
}
|
|
|
|
// Win32_ComputerSystem is the minimal struct for WMI query
|
|
type Win32_ComputerSystem struct {
|
|
Domain string
|
|
}
|
|
|
|
// getLocalADDomain tries to detect the AD domain in two ways:
|
|
// 1) USERDNSDOMAIN env var (often set in AD logon sessions)
|
|
// 2) WMI Win32_ComputerSystem.Domain
|
|
func getLocalADDomain() (string, error) {
|
|
// 1) Check environment variable
|
|
envDomain := os.Getenv("USERDNSDOMAIN")
|
|
if envDomain != "" {
|
|
return strings.TrimSpace(envDomain), nil
|
|
}
|
|
|
|
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
|
|
var result []Win32_ComputerSystem
|
|
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
|
|
if err != nil {
|
|
return "", fmt.Errorf("WMI query failed: %v", err)
|
|
}
|
|
if len(result) == 0 {
|
|
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
|
}
|
|
|
|
domain := strings.TrimSpace(result[0].Domain)
|
|
if domain == "" {
|
|
return "", fmt.Errorf("machine does not appear to have a domain set")
|
|
}
|
|
return domain, nil
|
|
}
|