mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Resolve "OS upstream failure / wrong default route"
This commit is contained in:
committed by
Cuong Manh Le
parent
49eb152d02
commit
a0c5062e3a
@@ -853,10 +853,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("ctrld listener is ready")
|
||||
mainLog.Load().Debug().Msg("performing self-check")
|
||||
|
||||
lc := cfg.FirstListener()
|
||||
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
|
||||
mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr)
|
||||
|
||||
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil {
|
||||
return false, status, err
|
||||
}
|
||||
|
||||
@@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
|
||||
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
|
||||
|
||||
oldForwardersContent, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
|
||||
}
|
||||
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
|
||||
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
|
||||
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
|
||||
}
|
||||
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
|
||||
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -229,7 +250,11 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
ns = append(ns, parsed...)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -120,6 +120,7 @@ type prog struct {
|
||||
runningIface string
|
||||
requiredMultiNICsConfig bool
|
||||
adDomain string
|
||||
runningOnDomainController bool
|
||||
|
||||
selfUninstallMu sync.Mutex
|
||||
refusedQueryCount int
|
||||
@@ -276,6 +277,11 @@ func (p *prog) preRun() {
|
||||
|
||||
func (p *prog) postRun() {
|
||||
if !service.Interactive() {
|
||||
if runtime.GOOS == "windows" {
|
||||
isDC, roleInt := isRunningOnDomainController()
|
||||
p.runningOnDomainController = isDC
|
||||
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
|
||||
}
|
||||
p.resetDNS(false, false)
|
||||
ns := ctrld.InitializeOsResolver(false)
|
||||
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
@@ -1410,5 +1416,25 @@ func (p *prog) leakOnUpstreamFailure() bool {
|
||||
if router.Name() != "" {
|
||||
return false
|
||||
}
|
||||
// if we are running on ADDC, we should not leak on upstream failure
|
||||
if p.runningOnDomainController {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Domain controller role values from Win32_ComputerSystem
|
||||
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
|
||||
const (
|
||||
BackupDomainController = 4
|
||||
PrimaryDomainController = 5
|
||||
)
|
||||
|
||||
// isRunningOnDomainController checks if the current machine is a domain controller
|
||||
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
|
||||
func isRunningOnDomainController() (bool, int) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false, 0
|
||||
}
|
||||
return isRunningOnDomainControllerWindows()
|
||||
}
|
||||
|
||||
@@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }
|
||||
|
||||
@@ -2,12 +2,18 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
@@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
|
||||
return false, 0
|
||||
}
|
||||
if instances == nil {
|
||||
mainLog.Load().Debug().Msg("WMI query returned nil instances")
|
||||
return false, 0
|
||||
}
|
||||
defer instances.Close()
|
||||
|
||||
if len(instances) == 0 {
|
||||
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
val, err := instances[0].GetProperty("DomainRole")
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
|
||||
return false, 0
|
||||
}
|
||||
if val == nil {
|
||||
mainLog.Load().Debug().Msg("DomainRole property is nil")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Safely handle varied types: string or integer
|
||||
var roleInt int
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// "4", "5", etc.
|
||||
parsed, parseErr := strconv.Atoi(v)
|
||||
if parseErr != nil {
|
||||
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
|
||||
return false, 0
|
||||
}
|
||||
roleInt = parsed
|
||||
case int8, int16, int32, int64:
|
||||
roleInt = int(reflect.ValueOf(v).Int())
|
||||
case uint8, uint16, uint32, uint64:
|
||||
roleInt = int(reflect.ValueOf(v).Uint())
|
||||
default:
|
||||
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Check if role indicates a domain controller
|
||||
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
|
||||
return isDC, roleInt
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user