fix upgrade flow

set service on new run, fix duplicate args

set service on new run, fix duplicate args

revert startCmd in upgrade flow due to pin compat issues

make restart reset DNS like upgrade, add debugging to uninstall method

debugging

debugging

debugging

debugging

debugging WMI

remove stackexchange lib, use ms wmi pkg

debugging

debugging

set correct class

fix os reolver init issues

fix netadapter class

use os resolver instead of fetching default nameservers while already running

remove debug lines

fix lookup IP

fix lookup IP

fix lookup IP

fix lookup IP

fix dns namserver retries when not needed
This commit is contained in:
Alex
2025-01-30 05:09:51 -05:00
committed by Cuong Manh Le
parent e573a490c9
commit f7a6dbe39b
10 changed files with 221 additions and 260 deletions

View File

@@ -56,10 +56,12 @@ func getActiveDirectoryDomain() (string, error) {
defer log.SetOutput(os.Stderr)
whost := host.NewWmiLocalHost()
cs, err := hh.GetComputerSystem(whost)
if cs != nil {
defer cs.Close()
}
if err != nil {
return "", err
}
defer cs.Close()
pod, err := cs.GetPropertyPartOfDomain()
if err != nil {
return "", err

View File

@@ -126,7 +126,7 @@ func initCLI() {
rootCmd.CompletionOptions.HiddenDefaultCmd = true
initRunCmd()
startCmd, startCmdAlias := initStartCmd()
startCmd := initStartCmd()
stopCmd := initStopCmd()
restartCmd := initRestartCmd()
reloadCmd := initReloadCmd(restartCmd)
@@ -135,7 +135,7 @@ func initCLI() {
interfacesCmd := initInterfacesCmd()
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
initClientsCmd()
initUpgradeCmd(startCmdAlias)
initUpgradeCmd()
initLogCmd()
}
@@ -243,10 +243,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
if err := s.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to start service")
}
// Configure Windows service failure actions
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
}
}()
}
writeDefaultConfig := !noConfigStart && configBase64 == ""
@@ -394,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
}
}
// Configure Windows service failure actions
_ = ConfigureWindowsServiceFailureActions(ctrldServiceName)
})
p.onStopped = append(p.onStopped, func() {
for _, lc := range p.cfg.Listener {
@@ -1615,22 +1613,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto
// checkDeactivationPin validates if the deactivation pin matches one in ControlD config.
func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
mainLog.Load().Debug().Msg("Checking deactivation pin")
dir, err := socketDir()
if err != nil {
mainLog.Load().Err(err).Msg("could not check deactivation pin")
return err
}
mainLog.Load().Debug().Msg("Creating control client")
var cc *controlClient
if s == nil {
cc = newSocketControlClientMobile(dir, stopCh)
} else {
cc = newSocketControlClient(context.TODO(), s, dir)
}
mainLog.Load().Debug().Msg("Control client done")
if cc == nil {
return nil // ctrld is not running.
}
data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin})
resp, _ := cc.post(deactivationPath, bytes.NewReader(data))
mainLog.Load().Debug().Msg("Posting deactivation request")
resp, err := cc.post(deactivationPath, bytes.NewReader(data))
mainLog.Load().Debug().Msg("Posting deactivation request done")
if resp != nil {
switch resp.StatusCode {
case http.StatusBadRequest:
@@ -1694,7 +1697,7 @@ func curCdUID() string {
if s, _ := newService(&prog{}, svcConfig); s != nil {
// Configure Windows service failure actions
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
}
if dir, _ := socketDir(); dir != "" {
cc := newSocketControlClient(context.TODO(), s, dir)
@@ -1777,6 +1780,7 @@ func resetDnsNoLog(p *prog) {
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
return task{func() error {
if iface == "" {
mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask")
return nil
}
// Always reset DNS first, ensuring DNS setting is in a good state.

View File

@@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command {
return runCmd
}
func initStartCmd() (*cobra.Command, *cobra.Command) {
func initStartCmd() *cobra.Command {
startCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
@@ -391,7 +391,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
tasks := []task{
{s.Stop, false, "Stop"},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
@@ -534,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
rootCmd.AddCommand(startCmdAlias)
return startCmd, startCmdAlias
return startCmd
}
func initStopCmd() *cobra.Command {
@@ -647,6 +647,15 @@ func initRestartCmd() *cobra.Command {
mainLog.Load().Warn().Msg("service not installed")
return
}
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
initLogging()
if cdMode {
@@ -656,11 +665,53 @@ func initRestartCmd() *cobra.Command {
if ir := runningIface(s); ir != nil {
iface = ir.Name
}
tasks := []task{
{s.Stop, false, "Stop"},
{s.Start, true, "Start"},
doRestart := func() bool {
tasks := []task{
{s.Stop, true, "Stop"},
{func() error {
p.router.Cleanup()
p.resetDNS()
return nil
}, false, "Cleanup"},
{func() error {
time.Sleep(time.Second * 1)
return nil
}, false, "Waiting for service to stop"},
}
if doTasks(tasks) {
if router.WaitProcessExited() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
loop:
for {
select {
case <-ctx.Done():
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
break loop
default:
}
time.Sleep(time.Second)
if status, _ := s.Status(); status == service.StatusStopped {
break
}
}
}
} else {
return false
}
tasks = []task{
{s.Start, true, "Start"},
}
return doTasks(tasks)
}
if doTasks(tasks) {
if doRestart() {
dir, err := socketDir()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
@@ -668,11 +719,13 @@ func initRestartCmd() *cobra.Command {
}
cc := newSocketControlClient(context.TODO(), s, dir)
if cc == nil {
mainLog.Load().Notice().Msg("Service was not restarted")
mainLog.Load().Error().Msg("Could not complete service restart")
os.Exit(1)
}
_, _ = cc.post(ifacePath, nil)
mainLog.Load().Notice().Msg("Service restarted")
} else {
mainLog.Load().Error().Msg("Service restart failed")
}
},
}
@@ -1049,7 +1102,7 @@ func initClientsCmd() *cobra.Command {
return clientsCmd
}
func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
func initUpgradeCmd() *cobra.Command {
const (
upgradeChannelDev = "dev"
upgradeChannelProd = "prod"
@@ -1087,6 +1140,14 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
mainLog.Load().Error().Msg(err.Error())
return
}
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
svcInstalled := true
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
@@ -1121,23 +1182,56 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
}
// we run the actual commands to make sure all the logic we want is executed
doRestart := func() bool {
// run the start command so that we reinit the service
// this is to fix the non restarting options on windows for existing clients
// we have to reset os.Args, since other commands use it.
curCdUID := curCdUID()
startArgs := []string{}
os.Args = []string{"ctrld", "start"}
if curCdUID != "" {
startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID))
os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID))
if !svcInstalled {
return true
}
startCmd.Run(startCmd, startArgs)
tasks := []task{
{s.Stop, true, "Stop"},
{func() error {
p.router.Cleanup()
p.resetDNS()
return nil
}, false, "Cleanup"},
{func() error {
time.Sleep(time.Second * 1)
return nil
}, false, "Waiting for service to stop"},
}
if doTasks(tasks) {
return true
if router.WaitProcessExited() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
loop:
for {
select {
case <-ctx.Done():
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
break loop
default:
}
time.Sleep(time.Second)
if status, _ := s.Status(); status == service.StatusStopped {
break
}
}
}
}
tasks = []task{
{s.Start, true, "Start"},
}
if doTasks(tasks) {
if dir, err := socketDir(); err == nil {
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
_, _ = cc.post(ifacePath, nil)
return true
}
}
}
return false
}
if svcInstalled {
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")

View File

@@ -40,11 +40,13 @@ func validInterfaces() []string {
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
return nil
}
defer instances.Close()
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)

View File

@@ -792,6 +792,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
func (p *prog) resetDNS() {
if p.runningIface == "" {
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
return
}
// See corresponding comments in (*prog).setDNS function.

View File

@@ -52,7 +52,21 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error {
}
defer s.Close()
// restart 3 times with a delay of 2 seconds
// 1. Retrieve the current config
cfg, err := s.Config()
if err != nil {
return err
}
// 2. Update the Description
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
// 3. Apply the updated config
if err := s.UpdateConfig(cfg); err != nil {
return err
}
// Then proceed with existing actions, e.g. setting failure actions
actions := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds

1
go.mod
View File

@@ -45,7 +45,6 @@ require (
require (
aead.dev/minisign v0.2.0 // indirect
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect

6
go.sum
View File

@@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
@@ -95,7 +93,6 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
@@ -452,7 +449,6 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -492,8 +488,6 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=

View File

@@ -12,7 +12,6 @@ import (
"time"
"unsafe"
"github.com/StackExchange/wmi"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
@@ -24,9 +23,9 @@ import (
)
const (
maxRetries = 5
retryDelay = 1 * time.Second
defaultTimeout = 5 * time.Second
maxDNSAdapterRetries = 5
retryDelayDNSAdapter = 1 * time.Second
defaultDNSAdapterTimeout = 10 * time.Second
minDNSServers = 1 // Minimum number of DNS servers we want to find
NetSetupUnknown uint32 = 0
NetSetupWorkgroup uint32 = 1
@@ -57,19 +56,18 @@ func dnsFns() []dnsFn {
}
func dnsFromAdapter() []string {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
defer cancel()
var ns []string
var err error
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
for i := 0; i < maxRetries; i++ {
for i := 0; i < maxDNSAdapterRetries; i++ {
if ctx.Err() != nil {
Log(context.Background(), logger.Debug(),
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
@@ -80,12 +78,18 @@ func dnsFromAdapter() []string {
if err == nil && len(ns) >= minDNSServers {
if i > 0 {
Log(context.Background(), logger.Debug(),
"Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
"Successfully got DNS servers after %d attempts, found %d servers",
i+1, len(ns))
}
return ns
}
// Log the specific failure reason
// if osResolver is not initialized, this is likely a command line run
// and ctrld is already on the interface, abort retries
if or == nil {
return ns
}
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get DNS servers, attempt %d: %v", i+1, err)
@@ -97,17 +101,16 @@ func dnsFromAdapter() []string {
select {
case <-ctx.Done():
return nil
case <-time.After(retryDelay):
case <-time.After(retryDelayDNSAdapter):
}
}
Log(context.Background(), logger.Debug(),
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries)
return ns // Return whatever we got, even if insufficient
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
return ns
}
func getDNSServers(ctx context.Context) ([]string, error) {
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
@@ -133,25 +136,18 @@ func getDNSServers(ctx context.Context) ([]string, error) {
var dcServers []string
isDomain := checkDomainJoined()
if isDomain {
domainName, err := getLocalADDomain()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get local AD domain: %v", err)
} else {
// Load netapi32.dll
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
dsDcName := netapi32.NewProc("DsGetDcNameW")
var info *DomainControllerInfo
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
flags := uint32(DS_RETURN_DNS_NAME |
DS_IP_REQUIRED |
DS_IS_DNS_NAME)
// Convert domain name to UTF16 pointer
domainUTF16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
Log(context.Background(), logger.Debug(),
@@ -190,15 +186,12 @@ func getDNSServers(ctx context.Context) ([]string, error) {
} else if info != nil {
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
// Get DC address
if info.DomainControllerAddress != nil {
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
Log(context.Background(), logger.Debug(),
"Found domain controller address: %s", dcAddr)
// Try to resolve DC
if ip := net.ParseIP(dcAddr); ip != nil {
dcServers = append(dcServers, ip.String())
Log(context.Background(), logger.Debug(),
@@ -210,7 +203,6 @@ func getDNSServers(ctx context.Context) ([]string, error) {
}
}
}
}
}
@@ -278,28 +270,26 @@ func getDNSServers(ctx context.Context) ([]string, error) {
}
ipStr := ip.String()
logger := logger.Debug().
l := logger.Debug().
Str("ip", ipStr).
Str("adapter", aa.FriendlyName())
if ip.IsLoopback() {
logger.Msg("Skipping loopback IP")
l.Msg("Skipping loopback IP")
continue
}
if seen[ipStr] {
logger.Msg("Skipping duplicate IP")
l.Msg("Skipping duplicate IP")
continue
}
if _, ok := addressMap[ipStr]; ok {
logger.Msg("Skipping local interface IP")
l.Msg("Skipping local interface IP")
continue
}
seen[ipStr] = true
ns = append(ns, ipStr)
logger.Msg("Added DNS server")
l.Msg("Added DNS server")
}
}
@@ -330,7 +320,6 @@ func nameserversFromResolvconf() []string {
// checkDomainJoined checks if the machine is joined to an Active Directory domain
// Returns whether it's domain joined and the domain name if available
func checkDomainJoined() bool {
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
@@ -348,9 +337,10 @@ func checkDomainJoined() bool {
domainName := windows.UTF16PtrToString(domain)
Log(context.Background(), logger.Debug(),
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status)
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
domainName, status)
// Consider both traditional and cloud domains as valid domain joins
// Consider domain or cloud domain as domain-joined
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
Log(context.Background(), logger.Debug(),
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
@@ -362,36 +352,44 @@ func checkDomainJoined() bool {
return isDomain
}
// Win32_ComputerSystem is the minimal struct for WMI query
type Win32_ComputerSystem struct {
Domain string
}
// getLocalADDomain tries to detect the AD domain in two ways:
// 1. USERDNSDOMAIN env var (often set in AD logon sessions)
// 2. WMI Win32_ComputerSystem.Domain
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
func getLocalADDomain() (string, error) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
// 1) Check environment variable
envDomain := os.Getenv("USERDNSDOMAIN")
if envDomain != "" {
return strings.TrimSpace(envDomain), nil
}
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
var result []Win32_ComputerSystem
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
// 2) Query WMI via the microsoft/wmi library
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("Win32_ComputerSystem")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
return "", fmt.Errorf("WMI query failed: %v", err)
}
if len(result) == 0 {
// If no results, return an error
if len(instances) == 0 {
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
}
domain := strings.TrimSpace(result[0].Domain)
if domain == "" {
// We only care about the first row
domainVal, err := instances[0].GetProperty("Domain")
if err != nil {
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
}
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
if domainName == "" {
return "", fmt.Errorf("machine does not appear to have a domain set")
}
return domain, nil
return domainName, nil
}
// validInterfaces returns a list of all physical interfaces.
@@ -410,12 +408,14 @@ func validInterfaces() map[string]struct{} {
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get wmi network adapter: %v", err)
return nil
}
defer instances.Close()
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
@@ -470,5 +470,4 @@ func validInterfaces() map[string]struct{} {
m[ifaceName] = struct{}{}
}
return m
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"slices"
@@ -11,10 +12,9 @@ import (
"sync"
"sync/atomic"
"time"
"io"
"github.com/rs/zerolog"
"github.com/miekg/dns"
"github.com/rs/zerolog"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
)
@@ -48,11 +48,13 @@ const (
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS.
var or = newResolverWithNameserver(defaultNameservers())
var localResolver = newLocalResolver()
var (
resolverMutex sync.Mutex
or *osResolver
)
func newLocalResolver() Resolver {
var nss []string
for _, addr := range Rfc1918Addresses() {
@@ -86,7 +88,6 @@ func availableNameservers() []string {
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
machineIPsMap := make(map[string]struct{}, len(regularIPs))
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
@@ -129,6 +130,9 @@ func availableNameservers() []string {
// calling this function.
func InitializeOsResolver() []string {
ns := initializeOsResolver(availableNameservers())
resolverMutex.Lock()
defer resolverMutex.Unlock()
or = newResolverWithNameserver(ns)
return ns
}
@@ -138,6 +142,7 @@ func InitializeOsResolver() []string {
// - First available LAN servers are saved and store.
// - Later calls, if no LAN servers available, the saved servers above will be used.
func initializeOsResolver(servers []string) []string {
var lanNss, publicNss []string
// First categorize servers
@@ -154,171 +159,13 @@ func initializeOsResolver(servers []string) []string {
}
}
// Store initial servers immediately
if len(lanNss) > 0 {
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
or.lanServers.Store(&lanNss)
}
if len(publicNss) == 0 {
publicNss = []string{controldPublicDnsWithPort}
}
or.publicServers.Store(&publicNss)
// no longer testing servers in the background
// if DCHP nameservers are not working, this is outside of our control
// // Test servers in background and remove failures
// go func() {
// // Test servers in parallel but maintain order
// type result struct {
// index int
// server string
// valid bool
// }
// testServers := func(servers []string) []string {
// if len(servers) == 0 {
// return nil
// }
// results := make(chan result, len(servers))
// var wg sync.WaitGroup
// for i, server := range servers {
// wg.Add(1)
// go func(idx int, s string) {
// defer wg.Done()
// results <- result{
// index: idx,
// server: s,
// valid: testNameServerFn(s),
// }
// }(i, server)
// }
// go func() {
// wg.Wait()
// close(results)
// }()
// // Collect results maintaining original order
// validServers := make([]string, 0, len(servers))
// ordered := make([]result, 0, len(servers))
// for r := range results {
// ordered = append(ordered, r)
// }
// slices.SortFunc(ordered, func(a, b result) int {
// return a.index - b.index
// })
// for _, r := range ordered {
// if r.valid {
// validServers = append(validServers, r.server)
// } else {
// ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
// }
// }
// return validServers
// }
// // Test and update LAN servers
// if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
// or.lanServers.Store(&validLanNss)
// }
// // Test and update public servers
// validPublicNss := testServers(publicNss)
// if len(validPublicNss) == 0 {
// validPublicNss = []string{controldPublicDnsWithPort}
// }
// or.publicServers.Store(&validPublicNss)
// }()
return slices.Concat(lanNss, publicNss)
}
// // testNameserverFn sends a test query to DNS nameserver to check if the server is available.
// var testNameServerFn = testNameserver
// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
// func testNameserver(addr string) bool {
// // Skip link-local addresses without scope IDs and deprecated site-local addresses
// if ip, err := netip.ParseAddr(addr); err == nil {
// if ip.Is6() {
// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
// ProxyLogger.Load().Debug().
// Str("nameserver", addr).
// Msg("skipping link-local IPv6 address without scope ID")
// return false
// }
// // Skip deprecated site-local addresses (fec0::/10)
// if strings.HasPrefix(ip.String(), "fec0:") {
// ProxyLogger.Load().Debug().
// Str("nameserver", addr).
// Msg("skipping deprecated site-local IPv6 address")
// return false
// }
// }
// }
// ProxyLogger.Load().Debug().
// Str("input_addr", addr).
// Msg("testing nameserver")
// // Handle both IPv4 and IPv6 addresses
// serverAddr := addr
// host, port, err := net.SplitHostPort(addr)
// if err != nil {
// // No port in address, add default port 53
// serverAddr = net.JoinHostPort(addr, "53")
// } else if port == "" {
// // Has split markers but empty port
// serverAddr = net.JoinHostPort(host, "53")
// }
// ProxyLogger.Load().Debug().
// Str("server_addr", serverAddr).
// Msg("using server address")
// // Test domains that are likely to exist and respond quickly
// testDomains := []struct {
// name string
// qtype uint16
// }{
// {".", dns.TypeNS}, // Root NS query - should always work
// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain
// }
// client := &dns.Client{
// Timeout: 2 * time.Second,
// Net: "udp",
// }
// // Try each test query until one succeeds
// for _, test := range testDomains {
// msg := new(dns.Msg)
// msg.SetQuestion(test.name, test.qtype)
// msg.RecursionDesired = true
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
// cancel()
// if err == nil && resp != nil {
// return true
// }
// ProxyLogger.Load().Error().
// Err(err).
// Str("nameserver", serverAddr).
// Str("test_domain", test.name).
// Str("query_type", dns.TypeToString[test.qtype]).
// Msg("DNS availability test failed")
// }
// return false
// }
// Resolver is the interface that wraps the basic DNS operations.
//
// Resolve resolves the DNS query, return the result and the corresponding error.
@@ -339,6 +186,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeDOQ:
return &doqResolver{uc: uc}, nil
case ResolverTypeOS:
if or == nil {
or = newResolverWithNameserver(defaultNameservers())
}
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{uc: uc}, nil
@@ -351,9 +201,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
}
type osResolver struct {
initializedLanServers atomic.Pointer[[]string]
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
}
type osResolverResult struct {
@@ -504,7 +353,10 @@ func LookupIP(domain string) []string {
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
nss := defaultNameservers()
if or == nil {
or = newResolverWithNameserver(defaultNameservers())
}
nss := *or.lanServers.Load()
if withBootstrapDNS {
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
}