mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
fix upgrade flow
set service on new run, fix duplicate args set service on new run, fix duplicate args revert startCmd in upgrade flow due to pin compat issues make restart reset DNS like upgrade, add debugging to uninstall method debugging debugging debugging debugging debugging WMI remove stackexchange lib, use ms wmi pkg debugging debugging set correct class fix os reolver init issues fix netadapter class use os resolver instead of fetching default nameservers while already running remove debug lines fix lookup IP fix lookup IP fix lookup IP fix lookup IP fix dns namserver retries when not needed
This commit is contained in:
@@ -56,10 +56,12 @@ func getActiveDirectoryDomain() (string, error) {
|
|||||||
defer log.SetOutput(os.Stderr)
|
defer log.SetOutput(os.Stderr)
|
||||||
whost := host.NewWmiLocalHost()
|
whost := host.NewWmiLocalHost()
|
||||||
cs, err := hh.GetComputerSystem(whost)
|
cs, err := hh.GetComputerSystem(whost)
|
||||||
|
if cs != nil {
|
||||||
|
defer cs.Close()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer cs.Close()
|
|
||||||
pod, err := cs.GetPropertyPartOfDomain()
|
pod, err := cs.GetPropertyPartOfDomain()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func initCLI() {
|
|||||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||||
|
|
||||||
initRunCmd()
|
initRunCmd()
|
||||||
startCmd, startCmdAlias := initStartCmd()
|
startCmd := initStartCmd()
|
||||||
stopCmd := initStopCmd()
|
stopCmd := initStopCmd()
|
||||||
restartCmd := initRestartCmd()
|
restartCmd := initRestartCmd()
|
||||||
reloadCmd := initReloadCmd(restartCmd)
|
reloadCmd := initReloadCmd(restartCmd)
|
||||||
@@ -135,7 +135,7 @@ func initCLI() {
|
|||||||
interfacesCmd := initInterfacesCmd()
|
interfacesCmd := initInterfacesCmd()
|
||||||
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
|
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
|
||||||
initClientsCmd()
|
initClientsCmd()
|
||||||
initUpgradeCmd(startCmdAlias)
|
initUpgradeCmd()
|
||||||
initLogCmd()
|
initLogCmd()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,10 +243,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
|||||||
if err := s.Run(); err != nil {
|
if err := s.Run(); err != nil {
|
||||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||||
}
|
}
|
||||||
// Configure Windows service failure actions
|
|
||||||
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
|
||||||
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||||
@@ -394,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Configure Windows service failure actions
|
||||||
|
_ = ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||||
})
|
})
|
||||||
p.onStopped = append(p.onStopped, func() {
|
p.onStopped = append(p.onStopped, func() {
|
||||||
for _, lc := range p.cfg.Listener {
|
for _, lc := range p.cfg.Listener {
|
||||||
@@ -1615,22 +1613,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto
|
|||||||
|
|
||||||
// checkDeactivationPin validates if the deactivation pin matches one in ControlD config.
|
// checkDeactivationPin validates if the deactivation pin matches one in ControlD config.
|
||||||
func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
|
func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
|
||||||
|
mainLog.Load().Debug().Msg("Checking deactivation pin")
|
||||||
dir, err := socketDir()
|
dir, err := socketDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Err(err).Msg("could not check deactivation pin")
|
mainLog.Load().Err(err).Msg("could not check deactivation pin")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
mainLog.Load().Debug().Msg("Creating control client")
|
||||||
var cc *controlClient
|
var cc *controlClient
|
||||||
if s == nil {
|
if s == nil {
|
||||||
cc = newSocketControlClientMobile(dir, stopCh)
|
cc = newSocketControlClientMobile(dir, stopCh)
|
||||||
} else {
|
} else {
|
||||||
cc = newSocketControlClient(context.TODO(), s, dir)
|
cc = newSocketControlClient(context.TODO(), s, dir)
|
||||||
}
|
}
|
||||||
|
mainLog.Load().Debug().Msg("Control client done")
|
||||||
if cc == nil {
|
if cc == nil {
|
||||||
return nil // ctrld is not running.
|
return nil // ctrld is not running.
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin})
|
data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin})
|
||||||
resp, _ := cc.post(deactivationPath, bytes.NewReader(data))
|
mainLog.Load().Debug().Msg("Posting deactivation request")
|
||||||
|
resp, err := cc.post(deactivationPath, bytes.NewReader(data))
|
||||||
|
mainLog.Load().Debug().Msg("Posting deactivation request done")
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case http.StatusBadRequest:
|
case http.StatusBadRequest:
|
||||||
@@ -1694,7 +1697,7 @@ func curCdUID() string {
|
|||||||
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
||||||
// Configure Windows service failure actions
|
// Configure Windows service failure actions
|
||||||
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
||||||
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||||
}
|
}
|
||||||
if dir, _ := socketDir(); dir != "" {
|
if dir, _ := socketDir(); dir != "" {
|
||||||
cc := newSocketControlClient(context.TODO(), s, dir)
|
cc := newSocketControlClient(context.TODO(), s, dir)
|
||||||
@@ -1777,6 +1780,7 @@ func resetDnsNoLog(p *prog) {
|
|||||||
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
|
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
|
||||||
return task{func() error {
|
return task{func() error {
|
||||||
if iface == "" {
|
if iface == "" {
|
||||||
|
mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Always reset DNS first, ensuring DNS setting is in a good state.
|
// Always reset DNS first, ensuring DNS setting is in a good state.
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command {
|
|||||||
return runCmd
|
return runCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func initStartCmd() (*cobra.Command, *cobra.Command) {
|
func initStartCmd() *cobra.Command {
|
||||||
startCmd := &cobra.Command{
|
startCmd := &cobra.Command{
|
||||||
PreRun: func(cmd *cobra.Command, args []string) {
|
PreRun: func(cmd *cobra.Command, args []string) {
|
||||||
checkHasElevatedPrivilege()
|
checkHasElevatedPrivilege()
|
||||||
@@ -391,7 +391,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
|
|
||||||
tasks := []task{
|
tasks := []task{
|
||||||
{s.Stop, false, "Stop"},
|
{s.Stop, false, "Stop"},
|
||||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"},
|
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||||
{func() error {
|
{func() error {
|
||||||
@@ -534,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
|||||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||||
rootCmd.AddCommand(startCmdAlias)
|
rootCmd.AddCommand(startCmdAlias)
|
||||||
|
|
||||||
return startCmd, startCmdAlias
|
return startCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func initStopCmd() *cobra.Command {
|
func initStopCmd() *cobra.Command {
|
||||||
@@ -647,6 +647,15 @@ func initRestartCmd() *cobra.Command {
|
|||||||
mainLog.Load().Warn().Msg("service not installed")
|
mainLog.Load().Warn().Msg("service not installed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if iface == "" {
|
||||||
|
iface = "auto"
|
||||||
|
}
|
||||||
|
p.preRun()
|
||||||
|
if ir := runningIface(s); ir != nil {
|
||||||
|
p.runningIface = ir.Name
|
||||||
|
p.requiredMultiNICsConfig = ir.All
|
||||||
|
}
|
||||||
|
|
||||||
initLogging()
|
initLogging()
|
||||||
|
|
||||||
if cdMode {
|
if cdMode {
|
||||||
@@ -656,11 +665,53 @@ func initRestartCmd() *cobra.Command {
|
|||||||
if ir := runningIface(s); ir != nil {
|
if ir := runningIface(s); ir != nil {
|
||||||
iface = ir.Name
|
iface = ir.Name
|
||||||
}
|
}
|
||||||
tasks := []task{
|
|
||||||
{s.Stop, false, "Stop"},
|
doRestart := func() bool {
|
||||||
{s.Start, true, "Start"},
|
tasks := []task{
|
||||||
|
{s.Stop, true, "Stop"},
|
||||||
|
{func() error {
|
||||||
|
p.router.Cleanup()
|
||||||
|
p.resetDNS()
|
||||||
|
return nil
|
||||||
|
}, false, "Cleanup"},
|
||||||
|
{func() error {
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
return nil
|
||||||
|
}, false, "Waiting for service to stop"},
|
||||||
|
}
|
||||||
|
if doTasks(tasks) {
|
||||||
|
|
||||||
|
if router.WaitProcessExited() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
|
||||||
|
break loop
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
if status, _ := s.Status(); status == service.StatusStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = []task{
|
||||||
|
{s.Start, true, "Start"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return doTasks(tasks)
|
||||||
|
|
||||||
}
|
}
|
||||||
if doTasks(tasks) {
|
|
||||||
|
if doRestart() {
|
||||||
dir, err := socketDir()
|
dir, err := socketDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
||||||
@@ -668,11 +719,13 @@ func initRestartCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
cc := newSocketControlClient(context.TODO(), s, dir)
|
cc := newSocketControlClient(context.TODO(), s, dir)
|
||||||
if cc == nil {
|
if cc == nil {
|
||||||
mainLog.Load().Notice().Msg("Service was not restarted")
|
mainLog.Load().Error().Msg("Could not complete service restart")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
_, _ = cc.post(ifacePath, nil)
|
_, _ = cc.post(ifacePath, nil)
|
||||||
mainLog.Load().Notice().Msg("Service restarted")
|
mainLog.Load().Notice().Msg("Service restarted")
|
||||||
|
} else {
|
||||||
|
mainLog.Load().Error().Msg("Service restart failed")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1049,7 +1102,7 @@ func initClientsCmd() *cobra.Command {
|
|||||||
return clientsCmd
|
return clientsCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
func initUpgradeCmd() *cobra.Command {
|
||||||
const (
|
const (
|
||||||
upgradeChannelDev = "dev"
|
upgradeChannelDev = "dev"
|
||||||
upgradeChannelProd = "prod"
|
upgradeChannelProd = "prod"
|
||||||
@@ -1087,6 +1140,14 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
|||||||
mainLog.Load().Error().Msg(err.Error())
|
mainLog.Load().Error().Msg(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if iface == "" {
|
||||||
|
iface = "auto"
|
||||||
|
}
|
||||||
|
p.preRun()
|
||||||
|
if ir := runningIface(s); ir != nil {
|
||||||
|
p.runningIface = ir.Name
|
||||||
|
p.requiredMultiNICsConfig = ir.All
|
||||||
|
}
|
||||||
|
|
||||||
svcInstalled := true
|
svcInstalled := true
|
||||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||||
@@ -1121,23 +1182,56 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
|||||||
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
|
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
|
||||||
}
|
}
|
||||||
|
|
||||||
// we run the actual commands to make sure all the logic we want is executed
|
|
||||||
doRestart := func() bool {
|
doRestart := func() bool {
|
||||||
|
if !svcInstalled {
|
||||||
// run the start command so that we reinit the service
|
return true
|
||||||
// this is to fix the non restarting options on windows for existing clients
|
|
||||||
// we have to reset os.Args, since other commands use it.
|
|
||||||
curCdUID := curCdUID()
|
|
||||||
startArgs := []string{}
|
|
||||||
os.Args = []string{"ctrld", "start"}
|
|
||||||
if curCdUID != "" {
|
|
||||||
startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID))
|
|
||||||
os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID))
|
|
||||||
}
|
}
|
||||||
startCmd.Run(startCmd, startArgs)
|
tasks := []task{
|
||||||
|
{s.Stop, true, "Stop"},
|
||||||
|
{func() error {
|
||||||
|
p.router.Cleanup()
|
||||||
|
p.resetDNS()
|
||||||
|
return nil
|
||||||
|
}, false, "Cleanup"},
|
||||||
|
{func() error {
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
return nil
|
||||||
|
}, false, "Waiting for service to stop"},
|
||||||
|
}
|
||||||
|
if doTasks(tasks) {
|
||||||
|
|
||||||
return true
|
if router.WaitProcessExited() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
|
||||||
|
break loop
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
if status, _ := s.Status(); status == service.StatusStopped {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = []task{
|
||||||
|
{s.Start, true, "Start"},
|
||||||
|
}
|
||||||
|
if doTasks(tasks) {
|
||||||
|
if dir, err := socketDir(); err == nil {
|
||||||
|
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
||||||
|
_, _ = cc.post(ifacePath, nil)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
if svcInstalled {
|
if svcInstalled {
|
||||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||||
|
|||||||
@@ -40,11 +40,13 @@ func validInterfaces() []string {
|
|||||||
whost := host.NewWmiLocalHost()
|
whost := host.NewWmiLocalHost()
|
||||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||||
|
if instances != nil {
|
||||||
|
defer instances.Close()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer instances.Close()
|
|
||||||
var adapters []string
|
var adapters []string
|
||||||
for _, i := range instances {
|
for _, i := range instances {
|
||||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||||
|
|||||||
@@ -792,6 +792,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
|||||||
|
|
||||||
func (p *prog) resetDNS() {
|
func (p *prog) resetDNS() {
|
||||||
if p.runningIface == "" {
|
if p.runningIface == "" {
|
||||||
|
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// See corresponding comments in (*prog).setDNS function.
|
// See corresponding comments in (*prog).setDNS function.
|
||||||
|
|||||||
@@ -52,7 +52,21 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
|||||||
}
|
}
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
// restart 3 times with a delay of 2 seconds
|
// 1. Retrieve the current config
|
||||||
|
cfg, err := s.Config()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Update the Description
|
||||||
|
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||||
|
|
||||||
|
// 3. Apply the updated config
|
||||||
|
if err := s.UpdateConfig(cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then proceed with existing actions, e.g. setting failure actions
|
||||||
actions := []mgr.RecoveryAction{
|
actions := []mgr.RecoveryAction{
|
||||||
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||||
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -45,7 +45,6 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
aead.dev/minisign v0.2.0 // indirect
|
aead.dev/minisign v0.2.0 // indirect
|
||||||
github.com/StackExchange/wmi v1.2.1 // indirect
|
|
||||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
|
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
|
|||||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||||
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
|
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
|
||||||
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||||
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
|
|
||||||
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
|
|
||||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
|
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
|
||||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||||
@@ -95,7 +93,6 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
|
|||||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
|
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
|
||||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||||
@@ -452,7 +449,6 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -492,8 +488,6 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/StackExchange/wmi"
|
|
||||||
"github.com/microsoft/wmi/pkg/base/host"
|
"github.com/microsoft/wmi/pkg/base/host"
|
||||||
"github.com/microsoft/wmi/pkg/base/instance"
|
"github.com/microsoft/wmi/pkg/base/instance"
|
||||||
"github.com/microsoft/wmi/pkg/base/query"
|
"github.com/microsoft/wmi/pkg/base/query"
|
||||||
@@ -24,9 +23,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxRetries = 5
|
maxDNSAdapterRetries = 5
|
||||||
retryDelay = 1 * time.Second
|
retryDelayDNSAdapter = 1 * time.Second
|
||||||
defaultTimeout = 5 * time.Second
|
defaultDNSAdapterTimeout = 10 * time.Second
|
||||||
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
||||||
NetSetupUnknown uint32 = 0
|
NetSetupUnknown uint32 = 0
|
||||||
NetSetupWorkgroup uint32 = 1
|
NetSetupWorkgroup uint32 = 1
|
||||||
@@ -57,19 +56,18 @@ func dnsFns() []dnsFn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func dnsFromAdapter() []string {
|
func dnsFromAdapter() []string {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var ns []string
|
var ns []string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
//load the logger
|
|
||||||
logger := zerolog.New(io.Discard)
|
logger := zerolog.New(io.Discard)
|
||||||
if ProxyLogger.Load() != nil {
|
if ProxyLogger.Load() != nil {
|
||||||
logger = *ProxyLogger.Load()
|
logger = *ProxyLogger.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < maxRetries; i++ {
|
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||||
@@ -80,12 +78,18 @@ func dnsFromAdapter() []string {
|
|||||||
if err == nil && len(ns) >= minDNSServers {
|
if err == nil && len(ns) >= minDNSServers {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
|
"Successfully got DNS servers after %d attempts, found %d servers",
|
||||||
|
i+1, len(ns))
|
||||||
}
|
}
|
||||||
return ns
|
return ns
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log the specific failure reason
|
// if osResolver is not initialized, this is likely a command line run
|
||||||
|
// and ctrld is already on the interface, abort retries
|
||||||
|
if or == nil {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||||
@@ -97,17 +101,16 @@ func dnsFromAdapter() []string {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(retryDelay):
|
case <-time.After(retryDelayDNSAdapter):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries)
|
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||||
return ns // Return whatever we got, even if insufficient
|
return ns
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||||
//load the logger
|
|
||||||
logger := zerolog.New(io.Discard)
|
logger := zerolog.New(io.Discard)
|
||||||
if ProxyLogger.Load() != nil {
|
if ProxyLogger.Load() != nil {
|
||||||
logger = *ProxyLogger.Load()
|
logger = *ProxyLogger.Load()
|
||||||
@@ -133,25 +136,18 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
var dcServers []string
|
var dcServers []string
|
||||||
isDomain := checkDomainJoined()
|
isDomain := checkDomainJoined()
|
||||||
if isDomain {
|
if isDomain {
|
||||||
|
|
||||||
domainName, err := getLocalADDomain()
|
domainName, err := getLocalADDomain()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Failed to get local AD domain: %v", err)
|
"Failed to get local AD domain: %v", err)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
// Load netapi32.dll
|
// Load netapi32.dll
|
||||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||||
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
||||||
|
|
||||||
var info *DomainControllerInfo
|
var info *DomainControllerInfo
|
||||||
|
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
|
||||||
|
|
||||||
flags := uint32(DS_RETURN_DNS_NAME |
|
|
||||||
DS_IP_REQUIRED |
|
|
||||||
DS_IS_DNS_NAME)
|
|
||||||
|
|
||||||
// Convert domain name to UTF16 pointer
|
|
||||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
@@ -190,15 +186,12 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
} else if info != nil {
|
} else if info != nil {
|
||||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
||||||
|
|
||||||
// Get DC address
|
|
||||||
if info.DomainControllerAddress != nil {
|
if info.DomainControllerAddress != nil {
|
||||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||||
|
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Found domain controller address: %s", dcAddr)
|
"Found domain controller address: %s", dcAddr)
|
||||||
|
|
||||||
// Try to resolve DC
|
|
||||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||||
dcServers = append(dcServers, ip.String())
|
dcServers = append(dcServers, ip.String())
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
@@ -210,7 +203,6 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,28 +270,26 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ipStr := ip.String()
|
ipStr := ip.String()
|
||||||
logger := logger.Debug().
|
l := logger.Debug().
|
||||||
Str("ip", ipStr).
|
Str("ip", ipStr).
|
||||||
Str("adapter", aa.FriendlyName())
|
Str("adapter", aa.FriendlyName())
|
||||||
|
|
||||||
if ip.IsLoopback() {
|
if ip.IsLoopback() {
|
||||||
logger.Msg("Skipping loopback IP")
|
l.Msg("Skipping loopback IP")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if seen[ipStr] {
|
if seen[ipStr] {
|
||||||
logger.Msg("Skipping duplicate IP")
|
l.Msg("Skipping duplicate IP")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := addressMap[ipStr]; ok {
|
if _, ok := addressMap[ipStr]; ok {
|
||||||
logger.Msg("Skipping local interface IP")
|
l.Msg("Skipping local interface IP")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
seen[ipStr] = true
|
seen[ipStr] = true
|
||||||
ns = append(ns, ipStr)
|
ns = append(ns, ipStr)
|
||||||
logger.Msg("Added DNS server")
|
l.Msg("Added DNS server")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +320,6 @@ func nameserversFromResolvconf() []string {
|
|||||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||||
// Returns whether it's domain joined and the domain name if available
|
// Returns whether it's domain joined and the domain name if available
|
||||||
func checkDomainJoined() bool {
|
func checkDomainJoined() bool {
|
||||||
//load the logger
|
|
||||||
logger := zerolog.New(io.Discard)
|
logger := zerolog.New(io.Discard)
|
||||||
if ProxyLogger.Load() != nil {
|
if ProxyLogger.Load() != nil {
|
||||||
logger = *ProxyLogger.Load()
|
logger = *ProxyLogger.Load()
|
||||||
@@ -348,9 +337,10 @@ func checkDomainJoined() bool {
|
|||||||
|
|
||||||
domainName := windows.UTF16PtrToString(domain)
|
domainName := windows.UTF16PtrToString(domain)
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status)
|
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
|
||||||
|
domainName, status)
|
||||||
|
|
||||||
// Consider both traditional and cloud domains as valid domain joins
|
// Consider domain or cloud domain as domain-joined
|
||||||
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
||||||
Log(context.Background(), logger.Debug(),
|
Log(context.Background(), logger.Debug(),
|
||||||
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
||||||
@@ -362,36 +352,44 @@ func checkDomainJoined() bool {
|
|||||||
return isDomain
|
return isDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
// Win32_ComputerSystem is the minimal struct for WMI query
|
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
|
||||||
type Win32_ComputerSystem struct {
|
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
|
||||||
Domain string
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLocalADDomain tries to detect the AD domain in two ways:
|
|
||||||
// 1. USERDNSDOMAIN env var (often set in AD logon sessions)
|
|
||||||
// 2. WMI Win32_ComputerSystem.Domain
|
|
||||||
func getLocalADDomain() (string, error) {
|
func getLocalADDomain() (string, error) {
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
// 1) Check environment variable
|
// 1) Check environment variable
|
||||||
envDomain := os.Getenv("USERDNSDOMAIN")
|
envDomain := os.Getenv("USERDNSDOMAIN")
|
||||||
if envDomain != "" {
|
if envDomain != "" {
|
||||||
return strings.TrimSpace(envDomain), nil
|
return strings.TrimSpace(envDomain), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
|
// 2) Query WMI via the microsoft/wmi library
|
||||||
var result []Win32_ComputerSystem
|
whost := host.NewWmiLocalHost()
|
||||||
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
|
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||||
|
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||||
|
if instances != nil {
|
||||||
|
defer instances.Close()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("WMI query failed: %v", err)
|
return "", fmt.Errorf("WMI query failed: %v", err)
|
||||||
}
|
}
|
||||||
if len(result) == 0 {
|
|
||||||
|
// If no results, return an error
|
||||||
|
if len(instances) == 0 {
|
||||||
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
||||||
}
|
}
|
||||||
|
|
||||||
domain := strings.TrimSpace(result[0].Domain)
|
// We only care about the first row
|
||||||
if domain == "" {
|
domainVal, err := instances[0].GetProperty("Domain")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
|
||||||
|
if domainName == "" {
|
||||||
return "", fmt.Errorf("machine does not appear to have a domain set")
|
return "", fmt.Errorf("machine does not appear to have a domain set")
|
||||||
}
|
}
|
||||||
return domain, nil
|
return domainName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validInterfaces returns a list of all physical interfaces.
|
// validInterfaces returns a list of all physical interfaces.
|
||||||
@@ -410,12 +408,14 @@ func validInterfaces() map[string]struct{} {
|
|||||||
whost := host.NewWmiLocalHost()
|
whost := host.NewWmiLocalHost()
|
||||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||||
|
if instances != nil {
|
||||||
|
defer instances.Close()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log(context.Background(), logger.Warn(),
|
Log(context.Background(), logger.Warn(),
|
||||||
"failed to get wmi network adapter: %v", err)
|
"failed to get wmi network adapter: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer instances.Close()
|
|
||||||
var adapters []string
|
var adapters []string
|
||||||
for _, i := range instances {
|
for _, i := range instances {
|
||||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||||
@@ -470,5 +470,4 @@ func validInterfaces() map[string]struct{} {
|
|||||||
m[ifaceName] = struct{}{}
|
m[ifaceName] = struct{}{}
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
188
resolver.go
188
resolver.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -11,10 +12,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
)
|
)
|
||||||
@@ -48,11 +48,13 @@ const (
|
|||||||
|
|
||||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||||
|
|
||||||
// or is the Resolver used for ResolverTypeOS.
|
|
||||||
var or = newResolverWithNameserver(defaultNameservers())
|
|
||||||
|
|
||||||
var localResolver = newLocalResolver()
|
var localResolver = newLocalResolver()
|
||||||
|
|
||||||
|
var (
|
||||||
|
resolverMutex sync.Mutex
|
||||||
|
or *osResolver
|
||||||
|
)
|
||||||
|
|
||||||
func newLocalResolver() Resolver {
|
func newLocalResolver() Resolver {
|
||||||
var nss []string
|
var nss []string
|
||||||
for _, addr := range Rfc1918Addresses() {
|
for _, addr := range Rfc1918Addresses() {
|
||||||
@@ -86,7 +88,6 @@ func availableNameservers() []string {
|
|||||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||||
|
|
||||||
|
|
||||||
//load the logger
|
//load the logger
|
||||||
logger := zerolog.New(io.Discard)
|
logger := zerolog.New(io.Discard)
|
||||||
if ProxyLogger.Load() != nil {
|
if ProxyLogger.Load() != nil {
|
||||||
@@ -129,6 +130,9 @@ func availableNameservers() []string {
|
|||||||
// calling this function.
|
// calling this function.
|
||||||
func InitializeOsResolver() []string {
|
func InitializeOsResolver() []string {
|
||||||
ns := initializeOsResolver(availableNameservers())
|
ns := initializeOsResolver(availableNameservers())
|
||||||
|
resolverMutex.Lock()
|
||||||
|
defer resolverMutex.Unlock()
|
||||||
|
or = newResolverWithNameserver(ns)
|
||||||
return ns
|
return ns
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,6 +142,7 @@ func InitializeOsResolver() []string {
|
|||||||
// - First available LAN servers are saved and store.
|
// - First available LAN servers are saved and store.
|
||||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||||
func initializeOsResolver(servers []string) []string {
|
func initializeOsResolver(servers []string) []string {
|
||||||
|
|
||||||
var lanNss, publicNss []string
|
var lanNss, publicNss []string
|
||||||
|
|
||||||
// First categorize servers
|
// First categorize servers
|
||||||
@@ -154,171 +159,13 @@ func initializeOsResolver(servers []string) []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store initial servers immediately
|
|
||||||
if len(lanNss) > 0 {
|
|
||||||
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
|
|
||||||
or.lanServers.Store(&lanNss)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(publicNss) == 0 {
|
if len(publicNss) == 0 {
|
||||||
publicNss = []string{controldPublicDnsWithPort}
|
publicNss = []string{controldPublicDnsWithPort}
|
||||||
}
|
}
|
||||||
or.publicServers.Store(&publicNss)
|
|
||||||
|
|
||||||
// no longer testing servers in the background
|
|
||||||
// if DCHP nameservers are not working, this is outside of our control
|
|
||||||
|
|
||||||
// // Test servers in background and remove failures
|
|
||||||
// go func() {
|
|
||||||
// // Test servers in parallel but maintain order
|
|
||||||
// type result struct {
|
|
||||||
// index int
|
|
||||||
// server string
|
|
||||||
// valid bool
|
|
||||||
// }
|
|
||||||
|
|
||||||
// testServers := func(servers []string) []string {
|
|
||||||
// if len(servers) == 0 {
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// results := make(chan result, len(servers))
|
|
||||||
// var wg sync.WaitGroup
|
|
||||||
|
|
||||||
// for i, server := range servers {
|
|
||||||
// wg.Add(1)
|
|
||||||
// go func(idx int, s string) {
|
|
||||||
// defer wg.Done()
|
|
||||||
// results <- result{
|
|
||||||
// index: idx,
|
|
||||||
// server: s,
|
|
||||||
// valid: testNameServerFn(s),
|
|
||||||
// }
|
|
||||||
// }(i, server)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// go func() {
|
|
||||||
// wg.Wait()
|
|
||||||
// close(results)
|
|
||||||
// }()
|
|
||||||
|
|
||||||
// // Collect results maintaining original order
|
|
||||||
// validServers := make([]string, 0, len(servers))
|
|
||||||
// ordered := make([]result, 0, len(servers))
|
|
||||||
// for r := range results {
|
|
||||||
// ordered = append(ordered, r)
|
|
||||||
// }
|
|
||||||
// slices.SortFunc(ordered, func(a, b result) int {
|
|
||||||
// return a.index - b.index
|
|
||||||
// })
|
|
||||||
// for _, r := range ordered {
|
|
||||||
// if r.valid {
|
|
||||||
// validServers = append(validServers, r.server)
|
|
||||||
// } else {
|
|
||||||
// ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// return validServers
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Test and update LAN servers
|
|
||||||
// if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
|
|
||||||
// or.lanServers.Store(&validLanNss)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Test and update public servers
|
|
||||||
// validPublicNss := testServers(publicNss)
|
|
||||||
// if len(validPublicNss) == 0 {
|
|
||||||
// validPublicNss = []string{controldPublicDnsWithPort}
|
|
||||||
// }
|
|
||||||
// or.publicServers.Store(&validPublicNss)
|
|
||||||
// }()
|
|
||||||
|
|
||||||
return slices.Concat(lanNss, publicNss)
|
return slices.Concat(lanNss, publicNss)
|
||||||
}
|
}
|
||||||
|
|
||||||
// // testNameserverFn sends a test query to DNS nameserver to check if the server is available.
|
|
||||||
// var testNameServerFn = testNameserver
|
|
||||||
|
|
||||||
// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
|
||||||
// func testNameserver(addr string) bool {
|
|
||||||
// // Skip link-local addresses without scope IDs and deprecated site-local addresses
|
|
||||||
// if ip, err := netip.ParseAddr(addr); err == nil {
|
|
||||||
// if ip.Is6() {
|
|
||||||
// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
|
|
||||||
// ProxyLogger.Load().Debug().
|
|
||||||
// Str("nameserver", addr).
|
|
||||||
// Msg("skipping link-local IPv6 address without scope ID")
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
// // Skip deprecated site-local addresses (fec0::/10)
|
|
||||||
// if strings.HasPrefix(ip.String(), "fec0:") {
|
|
||||||
// ProxyLogger.Load().Debug().
|
|
||||||
// Str("nameserver", addr).
|
|
||||||
// Msg("skipping deprecated site-local IPv6 address")
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// ProxyLogger.Load().Debug().
|
|
||||||
// Str("input_addr", addr).
|
|
||||||
// Msg("testing nameserver")
|
|
||||||
|
|
||||||
// // Handle both IPv4 and IPv6 addresses
|
|
||||||
// serverAddr := addr
|
|
||||||
// host, port, err := net.SplitHostPort(addr)
|
|
||||||
// if err != nil {
|
|
||||||
// // No port in address, add default port 53
|
|
||||||
// serverAddr = net.JoinHostPort(addr, "53")
|
|
||||||
// } else if port == "" {
|
|
||||||
// // Has split markers but empty port
|
|
||||||
// serverAddr = net.JoinHostPort(host, "53")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// ProxyLogger.Load().Debug().
|
|
||||||
// Str("server_addr", serverAddr).
|
|
||||||
// Msg("using server address")
|
|
||||||
|
|
||||||
// // Test domains that are likely to exist and respond quickly
|
|
||||||
// testDomains := []struct {
|
|
||||||
// name string
|
|
||||||
// qtype uint16
|
|
||||||
// }{
|
|
||||||
// {".", dns.TypeNS}, // Root NS query - should always work
|
|
||||||
// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain
|
|
||||||
// }
|
|
||||||
|
|
||||||
// client := &dns.Client{
|
|
||||||
// Timeout: 2 * time.Second,
|
|
||||||
// Net: "udp",
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Try each test query until one succeeds
|
|
||||||
// for _, test := range testDomains {
|
|
||||||
// msg := new(dns.Msg)
|
|
||||||
// msg.SetQuestion(test.name, test.qtype)
|
|
||||||
// msg.RecursionDesired = true
|
|
||||||
|
|
||||||
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
|
|
||||||
// cancel()
|
|
||||||
|
|
||||||
// if err == nil && resp != nil {
|
|
||||||
// return true
|
|
||||||
// }
|
|
||||||
|
|
||||||
// ProxyLogger.Load().Error().
|
|
||||||
// Err(err).
|
|
||||||
// Str("nameserver", serverAddr).
|
|
||||||
// Str("test_domain", test.name).
|
|
||||||
// Str("query_type", dns.TypeToString[test.qtype]).
|
|
||||||
// Msg("DNS availability test failed")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Resolver is the interface that wraps the basic DNS operations.
|
// Resolver is the interface that wraps the basic DNS operations.
|
||||||
//
|
//
|
||||||
// Resolve resolves the DNS query, return the result and the corresponding error.
|
// Resolve resolves the DNS query, return the result and the corresponding error.
|
||||||
@@ -339,6 +186,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
|||||||
case ResolverTypeDOQ:
|
case ResolverTypeDOQ:
|
||||||
return &doqResolver{uc: uc}, nil
|
return &doqResolver{uc: uc}, nil
|
||||||
case ResolverTypeOS:
|
case ResolverTypeOS:
|
||||||
|
if or == nil {
|
||||||
|
or = newResolverWithNameserver(defaultNameservers())
|
||||||
|
}
|
||||||
return or, nil
|
return or, nil
|
||||||
case ResolverTypeLegacy:
|
case ResolverTypeLegacy:
|
||||||
return &legacyResolver{uc: uc}, nil
|
return &legacyResolver{uc: uc}, nil
|
||||||
@@ -351,9 +201,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type osResolver struct {
|
type osResolver struct {
|
||||||
initializedLanServers atomic.Pointer[[]string]
|
lanServers atomic.Pointer[[]string]
|
||||||
lanServers atomic.Pointer[[]string]
|
publicServers atomic.Pointer[[]string]
|
||||||
publicServers atomic.Pointer[[]string]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type osResolverResult struct {
|
type osResolverResult struct {
|
||||||
@@ -504,7 +353,10 @@ func LookupIP(domain string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||||
nss := defaultNameservers()
|
if or == nil {
|
||||||
|
or = newResolverWithNameserver(defaultNameservers())
|
||||||
|
}
|
||||||
|
nss := *or.lanServers.Load()
|
||||||
if withBootstrapDNS {
|
if withBootstrapDNS {
|
||||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user