diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 3f9fa17..66180a9 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -56,10 +56,12 @@ func getActiveDirectoryDomain() (string, error) { defer log.SetOutput(os.Stderr) whost := host.NewWmiLocalHost() cs, err := hh.GetComputerSystem(whost) + if cs != nil { + defer cs.Close() + } if err != nil { return "", err } - defer cs.Close() pod, err := cs.GetPropertyPartOfDomain() if err != nil { return "", err diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 223f14e..af5bb75 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -126,7 +126,7 @@ func initCLI() { rootCmd.CompletionOptions.HiddenDefaultCmd = true initRunCmd() - startCmd, startCmdAlias := initStartCmd() + startCmd := initStartCmd() stopCmd := initStopCmd() restartCmd := initRestartCmd() reloadCmd := initReloadCmd(restartCmd) @@ -135,7 +135,7 @@ func initCLI() { interfacesCmd := initInterfacesCmd() initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) initClientsCmd() - initUpgradeCmd(startCmdAlias) + initUpgradeCmd() initLogCmd() } @@ -243,10 +243,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := s.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start service") } - // Configure Windows service failure actions - if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { - mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) - } }() } writeDefaultConfig := !noConfigStart && configBase64 == "" @@ -394,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } } } + // Configure Windows service failure actions + _ = ConfigureWindowsServiceFailureActions(ctrldServiceName) }) p.onStopped = append(p.onStopped, func() { for _, lc := range p.cfg.Listener { @@ -1615,22 +1613,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto // checkDeactivationPin validates if the deactivation pin matches one in ControlD config. func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { + mainLog.Load().Debug().Msg("Checking deactivation pin") dir, err := socketDir() if err != nil { mainLog.Load().Err(err).Msg("could not check deactivation pin") return err } + mainLog.Load().Debug().Msg("Creating control client") var cc *controlClient if s == nil { cc = newSocketControlClientMobile(dir, stopCh) } else { cc = newSocketControlClient(context.TODO(), s, dir) } + mainLog.Load().Debug().Msg("Control client done") if cc == nil { return nil // ctrld is not running. } data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin}) - resp, _ := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request") + resp, err := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request done") if resp != nil { switch resp.StatusCode { case http.StatusBadRequest: @@ -1694,7 +1697,7 @@ func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { // Configure Windows service failure actions if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { - mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) @@ -1777,6 +1780,7 @@ func resetDnsNoLog(p *prog) { func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task { return task{func() error { if iface == "" { + mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask") return nil } // Always reset DNS first, ensuring DNS setting is in a good state. diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index f3555e5..70d4467 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command { return runCmd } -func initStartCmd() (*cobra.Command, *cobra.Command) { +func initStartCmd() *cobra.Command { startCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { checkHasElevatedPrivilege() @@ -391,7 +391,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c tasks := []task{ {s.Stop, false, "Stop"}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { @@ -534,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) - return startCmd, startCmdAlias + return startCmd } func initStopCmd() *cobra.Command { @@ -647,6 +647,15 @@ func initRestartCmd() *cobra.Command { mainLog.Load().Warn().Msg("service not installed") return } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + initLogging() if cdMode { @@ -656,11 +665,53 @@ func initRestartCmd() *cobra.Command { if ir := runningIface(s); ir != nil { iface = ir.Name } - tasks := []task{ - {s.Stop, false, "Stop"}, - {s.Start, true, "Start"}, + + doRestart := func() bool { + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { + + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } else { + return false + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + + return doTasks(tasks) + } - if doTasks(tasks) { + + if doRestart() { dir, err := socketDir() if err != nil { mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") @@ -668,11 +719,13 @@ func initRestartCmd() *cobra.Command { } cc := newSocketControlClient(context.TODO(), s, dir) if cc == nil { - mainLog.Load().Notice().Msg("Service was not restarted") + mainLog.Load().Error().Msg("Could not complete service restart") os.Exit(1) } _, _ = cc.post(ifacePath, nil) mainLog.Load().Notice().Msg("Service restarted") + } else { + mainLog.Load().Error().Msg("Service restart failed") } }, } @@ -1049,7 +1102,7 @@ func initClientsCmd() *cobra.Command { return clientsCmd } -func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { +func initUpgradeCmd() *cobra.Command { const ( upgradeChannelDev = "dev" upgradeChannelProd = "prod" @@ -1087,6 +1140,14 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { mainLog.Load().Error().Msg(err.Error()) return } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } svcInstalled := true if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { @@ -1121,23 +1182,56 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") } - // we run the actual commands to make sure all the logic we want is executed doRestart := func() bool { - - // run the start command so that we reinit the service - // this is to fix the non restarting options on windows for existing clients - // we have to reset os.Args, since other commands use it. - curCdUID := curCdUID() - startArgs := []string{} - os.Args = []string{"ctrld", "start"} - if curCdUID != "" { - startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID)) - os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID)) + if !svcInstalled { + return true } - startCmd.Run(startCmd, startArgs) + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { - return true + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } + } + } + return false } if svcInstalled { mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index 6290a1c..bed06b5 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -40,11 +40,13 @@ func validInterfaces() []string { whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") return nil } - defer instances.Close() var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index f8147eb..8390680 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -792,6 +792,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces func (p *prog) resetDNS() { if p.runningIface == "" { + mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") return } // See corresponding comments in (*prog).setDNS function. diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index 4d3d281..6e3bd82 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -52,7 +52,21 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error { } defer s.Close() - // restart 3 times with a delay of 2 seconds + // 1. Retrieve the current config + cfg, err := s.Config() + if err != nil { + return err + } + + // 2. Update the Description + cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy" + + // 3. Apply the updated config + if err := s.UpdateConfig(cfg); err != nil { + return err + } + + // Then proceed with existing actions, e.g. setting failure actions actions := []mgr.RecoveryAction{ {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds diff --git a/go.mod b/go.mod index e570bae..635261f 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,6 @@ require ( require ( aead.dev/minisign v0.2.0 // indirect - github.com/StackExchange/wmi v1.2.1 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index f2d5ff9..2ac97af 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= -github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= @@ -95,7 +93,6 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= @@ -452,7 +449,6 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -492,8 +488,6 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= diff --git a/nameservers_windows.go b/nameservers_windows.go index c71e065..54fb8b6 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -12,7 +12,6 @@ import ( "time" "unsafe" - "github.com/StackExchange/wmi" "github.com/microsoft/wmi/pkg/base/host" "github.com/microsoft/wmi/pkg/base/instance" "github.com/microsoft/wmi/pkg/base/query" @@ -24,9 +23,9 @@ import ( ) const ( - maxRetries = 5 - retryDelay = 1 * time.Second - defaultTimeout = 5 * time.Second + maxDNSAdapterRetries = 5 + retryDelayDNSAdapter = 1 * time.Second + defaultDNSAdapterTimeout = 10 * time.Second minDNSServers = 1 // Minimum number of DNS servers we want to find NetSetupUnknown uint32 = 0 NetSetupWorkgroup uint32 = 1 @@ -57,19 +56,18 @@ func dnsFns() []dnsFn { } func dnsFromAdapter() []string { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) defer cancel() var ns []string var err error - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() } - for i := 0; i < maxRetries; i++ { + for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { Log(context.Background(), logger.Debug(), "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) @@ -80,12 +78,18 @@ func dnsFromAdapter() []string { if err == nil && len(ns) >= minDNSServers { if i > 0 { Log(context.Background(), logger.Debug(), - "Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) + "Successfully got DNS servers after %d attempts, found %d servers", + i+1, len(ns)) } return ns } - // Log the specific failure reason + // if osResolver is not initialized, this is likely a command line run + // and ctrld is already on the interface, abort retries + if or == nil { + return ns + } + if err != nil { Log(context.Background(), logger.Debug(), "Failed to get DNS servers, attempt %d: %v", i+1, err) @@ -97,17 +101,16 @@ func dnsFromAdapter() []string { select { case <-ctx.Done(): return nil - case <-time.After(retryDelay): + case <-time.After(retryDelayDNSAdapter): } } Log(context.Background(), logger.Debug(), - "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries) - return ns // Return whatever we got, even if insufficient + "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns } func getDNSServers(ctx context.Context) ([]string, error) { - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() @@ -133,25 +136,18 @@ func getDNSServers(ctx context.Context) ([]string, error) { var dcServers []string isDomain := checkDomainJoined() if isDomain { - domainName, err := getLocalADDomain() if err != nil { Log(context.Background(), logger.Debug(), "Failed to get local AD domain: %v", err) - } else { - // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") dsDcName := netapi32.NewProc("DsGetDcNameW") var info *DomainControllerInfo + flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME) - flags := uint32(DS_RETURN_DNS_NAME | - DS_IP_REQUIRED | - DS_IS_DNS_NAME) - - // Convert domain name to UTF16 pointer domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { Log(context.Background(), logger.Debug(), @@ -190,15 +186,12 @@ func getDNSServers(ctx context.Context) ([]string, error) { } else if info != nil { defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) - // Get DC address if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), "Found domain controller address: %s", dcAddr) - // Try to resolve DC if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) Log(context.Background(), logger.Debug(), @@ -210,7 +203,6 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } } - } } @@ -278,28 +270,26 @@ func getDNSServers(ctx context.Context) ([]string, error) { } ipStr := ip.String() - logger := logger.Debug(). + l := logger.Debug(). Str("ip", ipStr). Str("adapter", aa.FriendlyName()) if ip.IsLoopback() { - logger.Msg("Skipping loopback IP") + l.Msg("Skipping loopback IP") continue } - if seen[ipStr] { - logger.Msg("Skipping duplicate IP") + l.Msg("Skipping duplicate IP") continue } - if _, ok := addressMap[ipStr]; ok { - logger.Msg("Skipping local interface IP") + l.Msg("Skipping local interface IP") continue } seen[ipStr] = true ns = append(ns, ipStr) - logger.Msg("Added DNS server") + l.Msg("Added DNS server") } } @@ -330,7 +320,6 @@ func nameserversFromResolvconf() []string { // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available func checkDomainJoined() bool { - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() @@ -348,9 +337,10 @@ func checkDomainJoined() bool { domainName := windows.UTF16PtrToString(domain) Log(context.Background(), logger.Debug(), - "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status) + "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", + domainName, status) - // Consider both traditional and cloud domains as valid domain joins + // Consider domain or cloud domain as domain-joined isDomain := status == NetSetupDomain || status == NetSetupCloudDomain Log(context.Background(), logger.Debug(), "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", @@ -362,36 +352,44 @@ func checkDomainJoined() bool { return isDomain } -// Win32_ComputerSystem is the minimal struct for WMI query -type Win32_ComputerSystem struct { - Domain string -} - -// getLocalADDomain tries to detect the AD domain in two ways: -// 1. USERDNSDOMAIN env var (often set in AD logon sessions) -// 2. WMI Win32_ComputerSystem.Domain +// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*) +// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call. func getLocalADDomain() (string, error) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) // 1) Check environment variable envDomain := os.Getenv("USERDNSDOMAIN") if envDomain != "" { return strings.TrimSpace(envDomain), nil } - // 2) Check WMI (requires Windows + admin privileges or sufficient access) - var result []Win32_ComputerSystem - err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + // 2) Query WMI via the microsoft/wmi library + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("Win32_ComputerSystem") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { return "", fmt.Errorf("WMI query failed: %v", err) } - if len(result) == 0 { + + // If no results, return an error + if len(instances) == 0 { return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") } - domain := strings.TrimSpace(result[0].Domain) - if domain == "" { + // We only care about the first row + domainVal, err := instances[0].GetProperty("Domain") + if err != nil { + return "", fmt.Errorf("machine does not appear to have a domain set: %v", err) + } + + domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal)) + if domainName == "" { return "", fmt.Errorf("machine does not appear to have a domain set") } - return domain, nil + return domainName, nil } // validInterfaces returns a list of all physical interfaces. @@ -410,12 +408,14 @@ func validInterfaces() map[string]struct{} { whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { Log(context.Background(), logger.Warn(), "failed to get wmi network adapter: %v", err) return nil } - defer instances.Close() var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) @@ -470,5 +470,4 @@ func validInterfaces() map[string]struct{} { m[ifaceName] = struct{}{} } return m - } diff --git a/resolver.go b/resolver.go index e82b763..49b81af 100644 --- a/resolver.go +++ b/resolver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "net/netip" "slices" @@ -11,10 +12,9 @@ import ( "sync" "sync/atomic" "time" - "io" - "github.com/rs/zerolog" "github.com/miekg/dns" + "github.com/rs/zerolog" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -48,11 +48,13 @@ const ( var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") -// or is the Resolver used for ResolverTypeOS. -var or = newResolverWithNameserver(defaultNameservers()) - var localResolver = newLocalResolver() +var ( + resolverMutex sync.Mutex + or *osResolver +) + func newLocalResolver() Resolver { var nss []string for _, addr := range Rfc1918Addresses() { @@ -86,7 +88,6 @@ func availableNameservers() []string { regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { @@ -129,6 +130,9 @@ func availableNameservers() []string { // calling this function. func InitializeOsResolver() []string { ns := initializeOsResolver(availableNameservers()) + resolverMutex.Lock() + defer resolverMutex.Unlock() + or = newResolverWithNameserver(ns) return ns } @@ -138,6 +142,7 @@ func InitializeOsResolver() []string { // - First available LAN servers are saved and store. // - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { + var lanNss, publicNss []string // First categorize servers @@ -154,171 +159,13 @@ func initializeOsResolver(servers []string) []string { } } - // Store initial servers immediately - if len(lanNss) > 0 { - or.initializedLanServers.CompareAndSwap(nil, &lanNss) - or.lanServers.Store(&lanNss) - } - if len(publicNss) == 0 { publicNss = []string{controldPublicDnsWithPort} } - or.publicServers.Store(&publicNss) - - // no longer testing servers in the background - // if DCHP nameservers are not working, this is outside of our control - - // // Test servers in background and remove failures - // go func() { - // // Test servers in parallel but maintain order - // type result struct { - // index int - // server string - // valid bool - // } - - // testServers := func(servers []string) []string { - // if len(servers) == 0 { - // return nil - // } - - // results := make(chan result, len(servers)) - // var wg sync.WaitGroup - - // for i, server := range servers { - // wg.Add(1) - // go func(idx int, s string) { - // defer wg.Done() - // results <- result{ - // index: idx, - // server: s, - // valid: testNameServerFn(s), - // } - // }(i, server) - // } - - // go func() { - // wg.Wait() - // close(results) - // }() - - // // Collect results maintaining original order - // validServers := make([]string, 0, len(servers)) - // ordered := make([]result, 0, len(servers)) - // for r := range results { - // ordered = append(ordered, r) - // } - // slices.SortFunc(ordered, func(a, b result) int { - // return a.index - b.index - // }) - // for _, r := range ordered { - // if r.valid { - // validServers = append(validServers, r.server) - // } else { - // ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") - // } - // } - // return validServers - // } - - // // Test and update LAN servers - // if validLanNss := testServers(lanNss); len(validLanNss) > 0 { - // or.lanServers.Store(&validLanNss) - // } - - // // Test and update public servers - // validPublicNss := testServers(publicNss) - // if len(validPublicNss) == 0 { - // validPublicNss = []string{controldPublicDnsWithPort} - // } - // or.publicServers.Store(&validPublicNss) - // }() return slices.Concat(lanNss, publicNss) } -// // testNameserverFn sends a test query to DNS nameserver to check if the server is available. -// var testNameServerFn = testNameserver - -// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. -// func testNameserver(addr string) bool { -// // Skip link-local addresses without scope IDs and deprecated site-local addresses -// if ip, err := netip.ParseAddr(addr); err == nil { -// if ip.Is6() { -// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { -// ProxyLogger.Load().Debug(). -// Str("nameserver", addr). -// Msg("skipping link-local IPv6 address without scope ID") -// return false -// } -// // Skip deprecated site-local addresses (fec0::/10) -// if strings.HasPrefix(ip.String(), "fec0:") { -// ProxyLogger.Load().Debug(). -// Str("nameserver", addr). -// Msg("skipping deprecated site-local IPv6 address") -// return false -// } -// } -// } - -// ProxyLogger.Load().Debug(). -// Str("input_addr", addr). -// Msg("testing nameserver") - -// // Handle both IPv4 and IPv6 addresses -// serverAddr := addr -// host, port, err := net.SplitHostPort(addr) -// if err != nil { -// // No port in address, add default port 53 -// serverAddr = net.JoinHostPort(addr, "53") -// } else if port == "" { -// // Has split markers but empty port -// serverAddr = net.JoinHostPort(host, "53") -// } - -// ProxyLogger.Load().Debug(). -// Str("server_addr", serverAddr). -// Msg("using server address") - -// // Test domains that are likely to exist and respond quickly -// testDomains := []struct { -// name string -// qtype uint16 -// }{ -// {".", dns.TypeNS}, // Root NS query - should always work -// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain -// } - -// client := &dns.Client{ -// Timeout: 2 * time.Second, -// Net: "udp", -// } - -// // Try each test query until one succeeds -// for _, test := range testDomains { -// msg := new(dns.Msg) -// msg.SetQuestion(test.name, test.qtype) -// msg.RecursionDesired = true - -// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) -// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) -// cancel() - -// if err == nil && resp != nil { -// return true -// } - -// ProxyLogger.Load().Error(). -// Err(err). -// Str("nameserver", serverAddr). -// Str("test_domain", test.name). -// Str("query_type", dns.TypeToString[test.qtype]). -// Msg("DNS availability test failed") -// } - -// return false -// } - // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error. @@ -339,6 +186,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeDOQ: return &doqResolver{uc: uc}, nil case ResolverTypeOS: + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil @@ -351,9 +201,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - initializedLanServers atomic.Pointer[[]string] - lanServers atomic.Pointer[[]string] - publicServers atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -504,7 +353,10 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - nss := defaultNameservers() + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } + nss := *or.lanServers.Load() if withBootstrapDNS { nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) }