mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
430 lines
13 KiB
Go
430 lines
13 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/microsoft/wmi/pkg/base/host"
|
|
"github.com/microsoft/wmi/pkg/base/instance"
|
|
"github.com/microsoft/wmi/pkg/base/query"
|
|
"github.com/microsoft/wmi/pkg/constant"
|
|
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
|
"golang.org/x/sys/windows"
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
"tailscale.com/net/netmon"
|
|
|
|
"github.com/Control-D-Inc/ctrld/internal/system"
|
|
)
|
|
|
|
const (
|
|
maxDNSAdapterRetries = 5
|
|
retryDelayDNSAdapter = 1 * time.Second
|
|
defaultDNSAdapterTimeout = 10 * time.Second
|
|
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
|
|
|
DS_FORCE_REDISCOVERY = 0x00000001
|
|
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
|
DS_BACKGROUND_ONLY = 0x00000100
|
|
DS_IP_REQUIRED = 0x00000200
|
|
DS_IS_DNS_NAME = 0x00020000
|
|
DS_RETURN_DNS_NAME = 0x40000000
|
|
)
|
|
|
|
type DomainControllerInfo struct {
|
|
DomainControllerName *uint16
|
|
DomainControllerAddress *uint16
|
|
DomainControllerAddressType uint32
|
|
DomainGuid windows.GUID
|
|
DomainName *uint16
|
|
DnsForestName *uint16
|
|
Flags uint32
|
|
DcSiteName *uint16
|
|
ClientSiteName *uint16
|
|
}
|
|
|
|
func dnsFns() []dnsFn {
|
|
return []dnsFn{dnsFromAdapter}
|
|
}
|
|
|
|
func dnsFromAdapter() []string {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
|
defer cancel()
|
|
|
|
var ns []string
|
|
var err error
|
|
|
|
logger := *ProxyLogger.Load()
|
|
|
|
for i := 0; i < maxDNSAdapterRetries; i++ {
|
|
if ctx.Err() != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
|
return nil
|
|
}
|
|
|
|
ns, err = getDNSServers(ctx)
|
|
if err == nil && len(ns) >= minDNSServers {
|
|
if i > 0 {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Successfully got DNS servers after %d attempts, found %d servers",
|
|
i+1, len(ns))
|
|
}
|
|
return ns
|
|
}
|
|
|
|
// if osResolver is not initialized, this is likely a command line run
|
|
// and ctrld is already on the interface, abort retries
|
|
if or == nil {
|
|
return ns
|
|
}
|
|
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-time.After(retryDelayDNSAdapter):
|
|
}
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
|
return ns
|
|
}
|
|
|
|
func getDNSServers(ctx context.Context) ([]string, error) {
|
|
logger := *ProxyLogger.Load()
|
|
|
|
// Check context before making the call
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
// Get DNS servers from adapters (existing method)
|
|
flags := winipcfg.GAAFlagIncludeGateways |
|
|
winipcfg.GAAFlagIncludePrefix
|
|
|
|
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting adapters: %w", err)
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Found network adapters, count=%d", len(aas))
|
|
|
|
// Try to get domain controller info if domain-joined
|
|
var dcServers []string
|
|
isDomain := checkDomainJoined()
|
|
if isDomain {
|
|
domainName, err := system.GetActiveDirectoryDomain()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get local AD domain: %v", err)
|
|
} else {
|
|
// Load netapi32.dll
|
|
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
|
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
|
|
|
var info *DomainControllerInfo
|
|
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
|
|
|
|
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to convert domain name to UTF16: %v", err)
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
|
|
|
// Call DsGetDcNameW with domain name
|
|
ret, _, err := dsDcName.Call(
|
|
0, // ComputerName - can be NULL
|
|
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
|
|
0, // DomainGuid - not needed
|
|
0, // SiteName - not needed
|
|
uintptr(flags), // Flags
|
|
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
|
|
|
|
if ret != 0 {
|
|
switch ret {
|
|
case 1355: // ERROR_NO_SUCH_DOMAIN
|
|
Log(context.Background(), logger.Debug(),
|
|
"Domain not found: %s (%d)", domainName, ret)
|
|
case 1311: // ERROR_NO_LOGON_SERVERS
|
|
Log(context.Background(), logger.Debug(),
|
|
"No logon servers available for domain: %s (%d)", domainName, ret)
|
|
case 1004: // ERROR_DC_NOT_FOUND
|
|
Log(context.Background(), logger.Debug(),
|
|
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
|
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
|
Log(context.Background(), logger.Debug(),
|
|
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
|
default:
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
|
}
|
|
} else if info != nil {
|
|
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
|
|
|
if info.DomainControllerAddress != nil {
|
|
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
|
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
|
Log(context.Background(), logger.Debug(),
|
|
"Found domain controller address: %s", dcAddr)
|
|
|
|
if ip := net.ParseIP(dcAddr); ip != nil {
|
|
dcServers = append(dcServers, ip.String())
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added domain controller DNS servers: %v", dcServers)
|
|
}
|
|
} else {
|
|
Log(context.Background(), logger.Debug(),
|
|
"No domain controller address found")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Continue with existing adapter DNS collection
|
|
ns := make([]string, 0, len(aas)*2)
|
|
seen := make(map[string]bool)
|
|
addressMap := make(map[string]struct{})
|
|
|
|
// Collect all local IPs
|
|
for _, aa := range aas {
|
|
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
|
continue
|
|
}
|
|
|
|
// Skip if software loopback or other non-physical types
|
|
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
|
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping %s (software loopback)", aa.FriendlyName())
|
|
continue
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"Processing adapter %s", aa.FriendlyName())
|
|
|
|
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
|
ip := a.Address.IP().String()
|
|
addressMap[ip] = struct{}{}
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
|
}
|
|
}
|
|
|
|
validInterfacesMap := validInterfaces()
|
|
|
|
// Collect DNS servers
|
|
for _, aa := range aas {
|
|
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
|
continue
|
|
}
|
|
|
|
// Skip if software loopback or other non-physical types
|
|
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
|
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping %s (software loopback)", aa.FriendlyName())
|
|
continue
|
|
}
|
|
|
|
// if not in the validInterfacesMap, skip
|
|
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
|
continue
|
|
}
|
|
|
|
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
|
ip := dns.Address.IP()
|
|
if ip == nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
|
continue
|
|
}
|
|
|
|
ipStr := ip.String()
|
|
l := logger.Debug().
|
|
Str("ip", ipStr).
|
|
Str("adapter", aa.FriendlyName())
|
|
|
|
if ip.IsLoopback() {
|
|
l.Msg("Skipping loopback IP")
|
|
continue
|
|
}
|
|
if seen[ipStr] {
|
|
l.Msg("Skipping duplicate IP")
|
|
continue
|
|
}
|
|
if _, ok := addressMap[ipStr]; ok {
|
|
l.Msg("Skipping local interface IP")
|
|
continue
|
|
}
|
|
|
|
seen[ipStr] = true
|
|
ns = append(ns, ipStr)
|
|
l.Msg("Added DNS server")
|
|
}
|
|
}
|
|
|
|
// Add DC servers if they're not already in the list
|
|
for _, dcServer := range dcServers {
|
|
if !seen[dcServer] {
|
|
seen[dcServer] = true
|
|
ns = append(ns, dcServer)
|
|
Log(context.Background(), logger.Debug(),
|
|
"Added additional domain controller DNS server: %s", dcServer)
|
|
}
|
|
}
|
|
|
|
// if we have static DNS servers saved for the current default route, we should add them to the list
|
|
drIfaceName, err := netmon.DefaultRouteInterface()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get default route interface: %v", err)
|
|
} else {
|
|
drIface, err := net.InterfaceByName(drIfaceName)
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Failed to get interface by name %s: %v", drIfaceName, err)
|
|
} else {
|
|
staticNs, file := SavedStaticNameservers(drIface)
|
|
Log(context.Background(), logger.Debug(),
|
|
"static dns servers from %s: %v", file, staticNs)
|
|
if len(staticNs) > 0 {
|
|
Log(context.Background(), logger.Debug(),
|
|
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
|
ns = append(ns, staticNs...)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(ns) == 0 {
|
|
return nil, fmt.Errorf("no valid DNS servers found")
|
|
}
|
|
|
|
Log(context.Background(), logger.Debug(),
|
|
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
|
len(ns), ns, len(dcServers))
|
|
return ns, nil
|
|
}
|
|
|
|
// currentNameserversFromResolvconf returns a nil slice of strings.
|
|
func currentNameserversFromResolvconf() []string {
|
|
return nil
|
|
}
|
|
|
|
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
|
// Returns whether it's domain joined and the domain name if available
|
|
func checkDomainJoined() bool {
|
|
logger := *ProxyLogger.Load()
|
|
|
|
status, err := system.DomainJoinedStatus()
|
|
if err != nil {
|
|
logger.Debug().Msgf("Failed to get domain joined status: %v", err)
|
|
return false
|
|
}
|
|
isDomain := status == syscall.NetSetupDomainName
|
|
logger.Debug().Msg("Domain join status: (UnknownStatus=0, Unjoined=1, WorkgroupName=2, DomainName=3)")
|
|
logger.Debug().Msgf("Is domain joined? status=%d, result=%v", status, isDomain)
|
|
|
|
return isDomain
|
|
}
|
|
|
|
// validInterfaces returns a list of all physical interfaces.
|
|
// this is a duplicate of what is in net_windows.go, we should
|
|
// clean this up so there is only one version
|
|
func validInterfaces() map[string]struct{} {
|
|
log.SetOutput(io.Discard)
|
|
defer log.SetOutput(os.Stderr)
|
|
|
|
//load the logger
|
|
logger := *ProxyLogger.Load()
|
|
|
|
whost := host.NewWmiLocalHost()
|
|
q := query.NewWmiQuery("MSFT_NetAdapter")
|
|
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
|
if instances != nil {
|
|
defer instances.Close()
|
|
}
|
|
if err != nil {
|
|
Log(context.Background(), logger.Warn(),
|
|
"failed to get wmi network adapter: %v", err)
|
|
return nil
|
|
}
|
|
var adapters []string
|
|
for _, i := range instances {
|
|
adapter, err := netadapter.NewNetworkAdapter(i)
|
|
if err != nil {
|
|
Log(context.Background(), logger.Warn(),
|
|
"failed to get network adapter: %v", err)
|
|
continue
|
|
}
|
|
|
|
name, err := adapter.GetPropertyName()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Warn(),
|
|
"failed to get interface name: %v", err)
|
|
continue
|
|
}
|
|
|
|
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
|
//
|
|
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
|
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
|
physical, err := adapter.GetPropertyConnectorPresent()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"failed to get network adapter connector present property: %v", err)
|
|
continue
|
|
}
|
|
if !physical {
|
|
Log(context.Background(), logger.Debug(),
|
|
"skipping non-physical adapter: %s", name)
|
|
continue
|
|
}
|
|
|
|
// Check if it's a hardware interface. Checking only for connector present is not enough
|
|
// because some interfaces are not physical but have a connector.
|
|
hardware, err := adapter.GetPropertyHardwareInterface()
|
|
if err != nil {
|
|
Log(context.Background(), logger.Debug(),
|
|
"failed to get network adapter hardware interface property: %v", err)
|
|
continue
|
|
}
|
|
if !hardware {
|
|
Log(context.Background(), logger.Debug(),
|
|
"skipping non-hardware interface: %s", name)
|
|
continue
|
|
}
|
|
|
|
adapters = append(adapters, name)
|
|
}
|
|
|
|
m := make(map[string]struct{})
|
|
for _, ifaceName := range adapters {
|
|
m[ifaceName] = struct{}{}
|
|
}
|
|
return m
|
|
}
|