Resolve "OS upstream failure / wrong default route"

This commit is contained in:
Alex Paguis
2025-02-22 11:51:07 +00:00
committed by Cuong Manh Le
parent 49eb152d02
commit a0c5062e3a
5 changed files with 118 additions and 3 deletions

View File

@@ -853,10 +853,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
}
mainLog.Load().Debug().Msg("ctrld listener is ready")
mainLog.Load().Debug().Msg("performing self-check")
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr)
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil {
return false, status, err
}

View File

@@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
file := absHomeDir(windowsForwardersFilename)
oldForwardersContent, _ := os.ReadFile(file)
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
oldForwardersContent, err := os.ReadFile(file)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
} else {
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
}
hasLocalIPv6Listener := needLocalIPv6Listener()
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
if !hasLocalIPv6Listener {
return false
}
return s == "::1"
})
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
}
oldForwarders := strings.Split(string(oldForwardersContent), ",")
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
}
}
})
@@ -229,7 +250,11 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
if len(value) > 0 {
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
parsed := parseDNSServers(value)
ns = append(ns, parsed...)
for _, pns := range parsed {
if !slices.Contains(ns, pns) {
ns = append(ns, pns)
}
}
}
}
}()

View File

@@ -120,6 +120,7 @@ type prog struct {
runningIface string
requiredMultiNICsConfig bool
adDomain string
runningOnDomainController bool
selfUninstallMu sync.Mutex
refusedQueryCount int
@@ -276,6 +277,11 @@ func (p *prog) preRun() {
func (p *prog) postRun() {
if !service.Interactive() {
if runtime.GOOS == "windows" {
isDC, roleInt := isRunningOnDomainController()
p.runningOnDomainController = isDC
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
}
p.resetDNS(false, false)
ns := ctrld.InitializeOsResolver(false)
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
@@ -1410,5 +1416,25 @@ func (p *prog) leakOnUpstreamFailure() bool {
if router.Name() != "" {
return false
}
// if we are running on ADDC, we should not leak on upstream failure
if p.runningOnDomainController {
return false
}
return true
}
// Domain controller role values from Win32_ComputerSystem
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
const (
BackupDomainController = 4
PrimaryDomainController = 5
)
// isRunningOnDomainController checks if the current machine is a domain controller
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
func isRunningOnDomainController() (bool, int) {
if runtime.GOOS != "windows" {
return false, 0
}
return isRunningOnDomainControllerWindows()
}

View File

@@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
func hasLocalDnsServerRunning() bool { return false }
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }

View File

@@ -2,12 +2,18 @@ package cli
import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/mgr"
)
@@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool {
}
}
}
func isRunningOnDomainControllerWindows() (bool, int) {
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("Win32_ComputerSystem")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
return false, 0
}
if instances == nil {
mainLog.Load().Debug().Msg("WMI query returned nil instances")
return false, 0
}
defer instances.Close()
if len(instances) == 0 {
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
return false, 0
}
val, err := instances[0].GetProperty("DomainRole")
if err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
return false, 0
}
if val == nil {
mainLog.Load().Debug().Msg("DomainRole property is nil")
return false, 0
}
// Safely handle varied types: string or integer
var roleInt int
switch v := val.(type) {
case string:
// "4", "5", etc.
parsed, parseErr := strconv.Atoi(v)
if parseErr != nil {
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
return false, 0
}
roleInt = parsed
case int8, int16, int32, int64:
roleInt = int(reflect.ValueOf(v).Int())
case uint8, uint16, uint32, uint64:
roleInt = int(reflect.ValueOf(v).Uint())
default:
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
return false, 0
}
// Check if role indicates a domain controller
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
return isDC, roleInt
}