diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index fcbf9b4..e3657ec 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -853,10 +853,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo } mainLog.Load().Debug().Msg("ctrld listener is ready") - mainLog.Load().Debug().Msg("performing self-check") lc := cfg.FirstListener() addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) + + mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr) + if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil { return false, status, err } diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 4866267..e1bcd9a 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. if hasLocalDnsServerRunning() { + mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders") + file := absHomeDir(windowsForwardersFilename) - oldForwardersContent, _ := os.ReadFile(file) + mainLog.Load().Debug().Msgf("Using forwarders file: %s", file) + + oldForwardersContent, err := os.ReadFile(file) + if err != nil { + mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file") + } else { + mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent)) + } + hasLocalIPv6Listener := needLocalIPv6Listener() + mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status") + forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool { if !hasLocalIPv6Listener { return false } return s == "::1" }) + mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list") + if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil { mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") + } else { + mainLog.Load().Debug().Msg("Successfully wrote new forwarders file") } + oldForwarders := strings.Split(string(oldForwardersContent), ",") + mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders") + if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil { mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") + } else { + mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders") } } }) @@ -229,7 +250,11 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) { if len(value) > 0 { mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value) parsed := parseDNSServers(value) - ns = append(ns, parsed...) + for _, pns := range parsed { + if !slices.Contains(ns, pns) { + ns = append(ns, pns) + } + } } } }() diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index b119423..48e6708 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -120,6 +120,7 @@ type prog struct { runningIface string requiredMultiNICsConfig bool adDomain string + runningOnDomainController bool selfUninstallMu sync.Mutex refusedQueryCount int @@ -276,6 +277,11 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { + if runtime.GOOS == "windows" { + isDC, roleInt := isRunningOnDomainController() + p.runningOnDomainController = isDC + mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) + } p.resetDNS(false, false) ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) @@ -1410,5 +1416,25 @@ func (p *prog) leakOnUpstreamFailure() bool { if router.Name() != "" { return false } + // if we are running on ADDC, we should not leak on upstream failure + if p.runningOnDomainController { + return false + } return true } + +// Domain controller role values from Win32_ComputerSystem +// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem +const ( + BackupDomainController = 4 + PrimaryDomainController = 5 +) + +// isRunningOnDomainController checks if the current machine is a domain controller +// by querying the DomainRole property from Win32_ComputerSystem via WMI. +func isRunningOnDomainController() (bool, int) { + if runtime.GOOS != "windows" { + return false, 0 + } + return isRunningOnDomainControllerWindows() +} diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index 056903c..954b228 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) { func hasLocalDnsServerRunning() bool { return false } func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } + +func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index c4df5a5..fddb0ef 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,12 +2,18 @@ package cli import ( "os" + "reflect" "runtime" + "strconv" "strings" "syscall" "time" "unsafe" + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc/mgr" ) @@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool { } } } + +func isRunningOnDomainControllerWindows() (bool, int) { + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("Win32_ComputerSystem") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) + if err != nil { + mainLog.Load().Debug().Err(err).Msg("WMI query failed") + return false, 0 + } + if instances == nil { + mainLog.Load().Debug().Msg("WMI query returned nil instances") + return false, 0 + } + defer instances.Close() + + if len(instances) == 0 { + mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem") + return false, 0 + } + + val, err := instances[0].GetProperty("DomainRole") + if err != nil { + mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property") + return false, 0 + } + if val == nil { + mainLog.Load().Debug().Msg("DomainRole property is nil") + return false, 0 + } + + // Safely handle varied types: string or integer + var roleInt int + switch v := val.(type) { + case string: + // "4", "5", etc. + parsed, parseErr := strconv.Atoi(v) + if parseErr != nil { + mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v) + return false, 0 + } + roleInt = parsed + case int8, int16, int32, int64: + roleInt = int(reflect.ValueOf(v).Int()) + case uint8, uint16, uint32, uint64: + roleInt = int(reflect.ValueOf(v).Uint()) + default: + mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v) + return false, 0 + } + + // Check if role indicates a domain controller + isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController + return isDC, roleInt +}