mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
fix upgrade flow
set service on new run, fix duplicate args set service on new run, fix duplicate args revert startCmd in upgrade flow due to pin compat issues make restart reset DNS like upgrade, add debugging to uninstall method debugging debugging debugging debugging debugging WMI remove stackexchange lib, use ms wmi pkg debugging debugging set correct class fix os reolver init issues fix netadapter class use os resolver instead of fetching default nameservers while already running remove debug lines fix lookup IP fix lookup IP fix lookup IP fix lookup IP fix dns namserver retries when not needed
This commit is contained in:
@@ -56,10 +56,12 @@ func getActiveDirectoryDomain() (string, error) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer cs.Close()
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -126,7 +126,7 @@ func initCLI() {
|
||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||
|
||||
initRunCmd()
|
||||
startCmd, startCmdAlias := initStartCmd()
|
||||
startCmd := initStartCmd()
|
||||
stopCmd := initStopCmd()
|
||||
restartCmd := initRestartCmd()
|
||||
reloadCmd := initReloadCmd(restartCmd)
|
||||
@@ -135,7 +135,7 @@ func initCLI() {
|
||||
interfacesCmd := initInterfacesCmd()
|
||||
initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd)
|
||||
initClientsCmd()
|
||||
initUpgradeCmd(startCmdAlias)
|
||||
initUpgradeCmd()
|
||||
initLogCmd()
|
||||
}
|
||||
|
||||
@@ -243,10 +243,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
// Configure Windows service failure actions
|
||||
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||
}
|
||||
}()
|
||||
}
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
@@ -394,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Configure Windows service failure actions
|
||||
_ = ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
@@ -1615,22 +1613,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto
|
||||
|
||||
// checkDeactivationPin validates if the deactivation pin matches one in ControlD config.
|
||||
func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
|
||||
mainLog.Load().Debug().Msg("Checking deactivation pin")
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not check deactivation pin")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Msg("Creating control client")
|
||||
var cc *controlClient
|
||||
if s == nil {
|
||||
cc = newSocketControlClientMobile(dir, stopCh)
|
||||
} else {
|
||||
cc = newSocketControlClient(context.TODO(), s, dir)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("Control client done")
|
||||
if cc == nil {
|
||||
return nil // ctrld is not running.
|
||||
}
|
||||
data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin})
|
||||
resp, _ := cc.post(deactivationPath, bytes.NewReader(data))
|
||||
mainLog.Load().Debug().Msg("Posting deactivation request")
|
||||
resp, err := cc.post(deactivationPath, bytes.NewReader(data))
|
||||
mainLog.Load().Debug().Msg("Posting deactivation request done")
|
||||
if resp != nil {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusBadRequest:
|
||||
@@ -1694,7 +1697,7 @@ func curCdUID() string {
|
||||
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
||||
// Configure Windows service failure actions
|
||||
if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName)
|
||||
}
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
cc := newSocketControlClient(context.TODO(), s, dir)
|
||||
@@ -1777,6 +1780,7 @@ func resetDnsNoLog(p *prog) {
|
||||
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
|
||||
return task{func() error {
|
||||
if iface == "" {
|
||||
mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask")
|
||||
return nil
|
||||
}
|
||||
// Always reset DNS first, ensuring DNS setting is in a good state.
|
||||
|
||||
@@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command {
|
||||
return runCmd
|
||||
}
|
||||
|
||||
func initStartCmd() (*cobra.Command, *cobra.Command) {
|
||||
func initStartCmd() *cobra.Command {
|
||||
startCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
@@ -391,7 +391,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
@@ -534,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
|
||||
return startCmd, startCmdAlias
|
||||
return startCmd
|
||||
}
|
||||
|
||||
func initStopCmd() *cobra.Command {
|
||||
@@ -647,6 +647,15 @@ func initRestartCmd() *cobra.Command {
|
||||
mainLog.Load().Warn().Msg("service not installed")
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initLogging()
|
||||
|
||||
if cdMode {
|
||||
@@ -656,11 +665,53 @@ func initRestartCmd() *cobra.Command {
|
||||
if ir := runningIface(s); ir != nil {
|
||||
iface = ir.Name
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
{s.Start, true, "Start"},
|
||||
|
||||
doRestart := func() bool {
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
|
||||
if router.WaitProcessExited() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
|
||||
break loop
|
||||
default:
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
if status, _ := s.Status(); status == service.StatusStopped {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
|
||||
return doTasks(tasks)
|
||||
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
|
||||
if doRestart() {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
||||
@@ -668,11 +719,13 @@ func initRestartCmd() *cobra.Command {
|
||||
}
|
||||
cc := newSocketControlClient(context.TODO(), s, dir)
|
||||
if cc == nil {
|
||||
mainLog.Load().Notice().Msg("Service was not restarted")
|
||||
mainLog.Load().Error().Msg("Could not complete service restart")
|
||||
os.Exit(1)
|
||||
}
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
mainLog.Load().Notice().Msg("Service restarted")
|
||||
} else {
|
||||
mainLog.Load().Error().Msg("Service restart failed")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1049,7 +1102,7 @@ func initClientsCmd() *cobra.Command {
|
||||
return clientsCmd
|
||||
}
|
||||
|
||||
func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
||||
func initUpgradeCmd() *cobra.Command {
|
||||
const (
|
||||
upgradeChannelDev = "dev"
|
||||
upgradeChannelProd = "prod"
|
||||
@@ -1087,6 +1140,14 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
svcInstalled := true
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
@@ -1121,23 +1182,56 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
|
||||
}
|
||||
|
||||
// we run the actual commands to make sure all the logic we want is executed
|
||||
doRestart := func() bool {
|
||||
|
||||
// run the start command so that we reinit the service
|
||||
// this is to fix the non restarting options on windows for existing clients
|
||||
// we have to reset os.Args, since other commands use it.
|
||||
curCdUID := curCdUID()
|
||||
startArgs := []string{}
|
||||
os.Args = []string{"ctrld", "start"}
|
||||
if curCdUID != "" {
|
||||
startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID))
|
||||
os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID))
|
||||
if !svcInstalled {
|
||||
return true
|
||||
}
|
||||
startCmd.Run(startCmd, startArgs)
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
|
||||
return true
|
||||
if router.WaitProcessExited() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
|
||||
break loop
|
||||
default:
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
if status, _ := s.Status(); status == service.StatusStopped {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if svcInstalled {
|
||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||
|
||||
@@ -40,11 +40,13 @@ func validInterfaces() []string {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||
return nil
|
||||
}
|
||||
defer instances.Close()
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
|
||||
@@ -792,6 +792,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
if p.runningIface == "" {
|
||||
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
|
||||
return
|
||||
}
|
||||
// See corresponding comments in (*prog).setDNS function.
|
||||
|
||||
@@ -52,7 +52,21 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// restart 3 times with a delay of 2 seconds
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds
|
||||
|
||||
1
go.mod
1
go.mod
@@ -45,7 +45,6 @@ require (
|
||||
|
||||
require (
|
||||
aead.dev/minisign v0.2.0 // indirect
|
||||
github.com/StackExchange/wmi v1.2.1 // indirect
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
|
||||
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
|
||||
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||
@@ -95,7 +93,6 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||
@@ -452,7 +449,6 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -492,8 +488,6 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/StackExchange/wmi"
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
@@ -24,9 +23,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxRetries = 5
|
||||
retryDelay = 1 * time.Second
|
||||
defaultTimeout = 5 * time.Second
|
||||
maxDNSAdapterRetries = 5
|
||||
retryDelayDNSAdapter = 1 * time.Second
|
||||
defaultDNSAdapterTimeout = 10 * time.Second
|
||||
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
||||
NetSetupUnknown uint32 = 0
|
||||
NetSetupWorkgroup uint32 = 1
|
||||
@@ -57,19 +56,18 @@ func dnsFns() []dnsFn {
|
||||
}
|
||||
|
||||
func dnsFromAdapter() []string {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
||||
defer cancel()
|
||||
|
||||
var ns []string
|
||||
var err error
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||
if ctx.Err() != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||
@@ -80,12 +78,18 @@ func dnsFromAdapter() []string {
|
||||
if err == nil && len(ns) >= minDNSServers {
|
||||
if i > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
|
||||
"Successfully got DNS servers after %d attempts, found %d servers",
|
||||
i+1, len(ns))
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// Log the specific failure reason
|
||||
// if osResolver is not initialized, this is likely a command line run
|
||||
// and ctrld is already on the interface, abort retries
|
||||
if or == nil {
|
||||
return ns
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||
@@ -97,17 +101,16 @@ func dnsFromAdapter() []string {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(retryDelay):
|
||||
case <-time.After(retryDelayDNSAdapter):
|
||||
}
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries)
|
||||
return ns // Return whatever we got, even if insufficient
|
||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||
return ns
|
||||
}
|
||||
|
||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
@@ -133,25 +136,18 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
var dcServers []string
|
||||
isDomain := checkDomainJoined()
|
||||
if isDomain {
|
||||
|
||||
domainName, err := getLocalADDomain()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get local AD domain: %v", err)
|
||||
|
||||
} else {
|
||||
|
||||
// Load netapi32.dll
|
||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
||||
|
||||
var info *DomainControllerInfo
|
||||
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
|
||||
|
||||
flags := uint32(DS_RETURN_DNS_NAME |
|
||||
DS_IP_REQUIRED |
|
||||
DS_IS_DNS_NAME)
|
||||
|
||||
// Convert domain name to UTF16 pointer
|
||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
@@ -190,15 +186,12 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
} else if info != nil {
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
||||
|
||||
// Get DC address
|
||||
if info.DomainControllerAddress != nil {
|
||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found domain controller address: %s", dcAddr)
|
||||
|
||||
// Try to resolve DC
|
||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||
dcServers = append(dcServers, ip.String())
|
||||
Log(context.Background(), logger.Debug(),
|
||||
@@ -210,7 +203,6 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,28 +270,26 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
logger := logger.Debug().
|
||||
l := logger.Debug().
|
||||
Str("ip", ipStr).
|
||||
Str("adapter", aa.FriendlyName())
|
||||
|
||||
if ip.IsLoopback() {
|
||||
logger.Msg("Skipping loopback IP")
|
||||
l.Msg("Skipping loopback IP")
|
||||
continue
|
||||
}
|
||||
|
||||
if seen[ipStr] {
|
||||
logger.Msg("Skipping duplicate IP")
|
||||
l.Msg("Skipping duplicate IP")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := addressMap[ipStr]; ok {
|
||||
logger.Msg("Skipping local interface IP")
|
||||
l.Msg("Skipping local interface IP")
|
||||
continue
|
||||
}
|
||||
|
||||
seen[ipStr] = true
|
||||
ns = append(ns, ipStr)
|
||||
logger.Msg("Added DNS server")
|
||||
l.Msg("Added DNS server")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,7 +320,6 @@ func nameserversFromResolvconf() []string {
|
||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||
// Returns whether it's domain joined and the domain name if available
|
||||
func checkDomainJoined() bool {
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
@@ -348,9 +337,10 @@ func checkDomainJoined() bool {
|
||||
|
||||
domainName := windows.UTF16PtrToString(domain)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status)
|
||||
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
|
||||
domainName, status)
|
||||
|
||||
// Consider both traditional and cloud domains as valid domain joins
|
||||
// Consider domain or cloud domain as domain-joined
|
||||
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
||||
@@ -362,36 +352,44 @@ func checkDomainJoined() bool {
|
||||
return isDomain
|
||||
}
|
||||
|
||||
// Win32_ComputerSystem is the minimal struct for WMI query
|
||||
type Win32_ComputerSystem struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
// getLocalADDomain tries to detect the AD domain in two ways:
|
||||
// 1. USERDNSDOMAIN env var (often set in AD logon sessions)
|
||||
// 2. WMI Win32_ComputerSystem.Domain
|
||||
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
|
||||
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
|
||||
func getLocalADDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
// 1) Check environment variable
|
||||
envDomain := os.Getenv("USERDNSDOMAIN")
|
||||
if envDomain != "" {
|
||||
return strings.TrimSpace(envDomain), nil
|
||||
}
|
||||
|
||||
// 2) Check WMI (requires Windows + admin privileges or sufficient access)
|
||||
var result []Win32_ComputerSystem
|
||||
err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result)
|
||||
// 2) Query WMI via the microsoft/wmi library
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("WMI query failed: %v", err)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
|
||||
// If no results, return an error
|
||||
if len(instances) == 0 {
|
||||
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
||||
}
|
||||
|
||||
domain := strings.TrimSpace(result[0].Domain)
|
||||
if domain == "" {
|
||||
// We only care about the first row
|
||||
domainVal, err := instances[0].GetProperty("Domain")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
|
||||
}
|
||||
|
||||
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
|
||||
if domainName == "" {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set")
|
||||
}
|
||||
return domain, nil
|
||||
return domainName, nil
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
@@ -410,12 +408,14 @@ func validInterfaces() map[string]struct{} {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get wmi network adapter: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer instances.Close()
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
@@ -470,5 +470,4 @@ func validInterfaces() map[string]struct{} {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
|
||||
}
|
||||
|
||||
188
resolver.go
188
resolver.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -11,10 +12,9 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"io"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
@@ -48,11 +48,13 @@ const (
|
||||
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = newResolverWithNameserver(defaultNameservers())
|
||||
|
||||
var localResolver = newLocalResolver()
|
||||
|
||||
var (
|
||||
resolverMutex sync.Mutex
|
||||
or *osResolver
|
||||
)
|
||||
|
||||
func newLocalResolver() Resolver {
|
||||
var nss []string
|
||||
for _, addr := range Rfc1918Addresses() {
|
||||
@@ -86,7 +88,6 @@ func availableNameservers() []string {
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
@@ -129,6 +130,9 @@ func availableNameservers() []string {
|
||||
// calling this function.
|
||||
func InitializeOsResolver() []string {
|
||||
ns := initializeOsResolver(availableNameservers())
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
or = newResolverWithNameserver(ns)
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -138,6 +142,7 @@ func InitializeOsResolver() []string {
|
||||
// - First available LAN servers are saved and store.
|
||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||
func initializeOsResolver(servers []string) []string {
|
||||
|
||||
var lanNss, publicNss []string
|
||||
|
||||
// First categorize servers
|
||||
@@ -154,171 +159,13 @@ func initializeOsResolver(servers []string) []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Store initial servers immediately
|
||||
if len(lanNss) > 0 {
|
||||
or.initializedLanServers.CompareAndSwap(nil, &lanNss)
|
||||
or.lanServers.Store(&lanNss)
|
||||
}
|
||||
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = []string{controldPublicDnsWithPort}
|
||||
}
|
||||
or.publicServers.Store(&publicNss)
|
||||
|
||||
// no longer testing servers in the background
|
||||
// if DCHP nameservers are not working, this is outside of our control
|
||||
|
||||
// // Test servers in background and remove failures
|
||||
// go func() {
|
||||
// // Test servers in parallel but maintain order
|
||||
// type result struct {
|
||||
// index int
|
||||
// server string
|
||||
// valid bool
|
||||
// }
|
||||
|
||||
// testServers := func(servers []string) []string {
|
||||
// if len(servers) == 0 {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// results := make(chan result, len(servers))
|
||||
// var wg sync.WaitGroup
|
||||
|
||||
// for i, server := range servers {
|
||||
// wg.Add(1)
|
||||
// go func(idx int, s string) {
|
||||
// defer wg.Done()
|
||||
// results <- result{
|
||||
// index: idx,
|
||||
// server: s,
|
||||
// valid: testNameServerFn(s),
|
||||
// }
|
||||
// }(i, server)
|
||||
// }
|
||||
|
||||
// go func() {
|
||||
// wg.Wait()
|
||||
// close(results)
|
||||
// }()
|
||||
|
||||
// // Collect results maintaining original order
|
||||
// validServers := make([]string, 0, len(servers))
|
||||
// ordered := make([]result, 0, len(servers))
|
||||
// for r := range results {
|
||||
// ordered = append(ordered, r)
|
||||
// }
|
||||
// slices.SortFunc(ordered, func(a, b result) int {
|
||||
// return a.index - b.index
|
||||
// })
|
||||
// for _, r := range ordered {
|
||||
// if r.valid {
|
||||
// validServers = append(validServers, r.server)
|
||||
// } else {
|
||||
// ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing")
|
||||
// }
|
||||
// }
|
||||
// return validServers
|
||||
// }
|
||||
|
||||
// // Test and update LAN servers
|
||||
// if validLanNss := testServers(lanNss); len(validLanNss) > 0 {
|
||||
// or.lanServers.Store(&validLanNss)
|
||||
// }
|
||||
|
||||
// // Test and update public servers
|
||||
// validPublicNss := testServers(publicNss)
|
||||
// if len(validPublicNss) == 0 {
|
||||
// validPublicNss = []string{controldPublicDnsWithPort}
|
||||
// }
|
||||
// or.publicServers.Store(&validPublicNss)
|
||||
// }()
|
||||
|
||||
return slices.Concat(lanNss, publicNss)
|
||||
}
|
||||
|
||||
// // testNameserverFn sends a test query to DNS nameserver to check if the server is available.
|
||||
// var testNameServerFn = testNameserver
|
||||
|
||||
// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
// func testNameserver(addr string) bool {
|
||||
// // Skip link-local addresses without scope IDs and deprecated site-local addresses
|
||||
// if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
// if ip.Is6() {
|
||||
// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") {
|
||||
// ProxyLogger.Load().Debug().
|
||||
// Str("nameserver", addr).
|
||||
// Msg("skipping link-local IPv6 address without scope ID")
|
||||
// return false
|
||||
// }
|
||||
// // Skip deprecated site-local addresses (fec0::/10)
|
||||
// if strings.HasPrefix(ip.String(), "fec0:") {
|
||||
// ProxyLogger.Load().Debug().
|
||||
// Str("nameserver", addr).
|
||||
// Msg("skipping deprecated site-local IPv6 address")
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// ProxyLogger.Load().Debug().
|
||||
// Str("input_addr", addr).
|
||||
// Msg("testing nameserver")
|
||||
|
||||
// // Handle both IPv4 and IPv6 addresses
|
||||
// serverAddr := addr
|
||||
// host, port, err := net.SplitHostPort(addr)
|
||||
// if err != nil {
|
||||
// // No port in address, add default port 53
|
||||
// serverAddr = net.JoinHostPort(addr, "53")
|
||||
// } else if port == "" {
|
||||
// // Has split markers but empty port
|
||||
// serverAddr = net.JoinHostPort(host, "53")
|
||||
// }
|
||||
|
||||
// ProxyLogger.Load().Debug().
|
||||
// Str("server_addr", serverAddr).
|
||||
// Msg("using server address")
|
||||
|
||||
// // Test domains that are likely to exist and respond quickly
|
||||
// testDomains := []struct {
|
||||
// name string
|
||||
// qtype uint16
|
||||
// }{
|
||||
// {".", dns.TypeNS}, // Root NS query - should always work
|
||||
// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain
|
||||
// }
|
||||
|
||||
// client := &dns.Client{
|
||||
// Timeout: 2 * time.Second,
|
||||
// Net: "udp",
|
||||
// }
|
||||
|
||||
// // Try each test query until one succeeds
|
||||
// for _, test := range testDomains {
|
||||
// msg := new(dns.Msg)
|
||||
// msg.SetQuestion(test.name, test.qtype)
|
||||
// msg.RecursionDesired = true
|
||||
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr)
|
||||
// cancel()
|
||||
|
||||
// if err == nil && resp != nil {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// ProxyLogger.Load().Error().
|
||||
// Err(err).
|
||||
// Str("nameserver", serverAddr).
|
||||
// Str("test_domain", test.name).
|
||||
// Str("query_type", dns.TypeToString[test.qtype]).
|
||||
// Msg("DNS availability test failed")
|
||||
// }
|
||||
|
||||
// return false
|
||||
// }
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
//
|
||||
// Resolve resolves the DNS query, return the result and the corresponding error.
|
||||
@@ -339,6 +186,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeDOQ:
|
||||
return &doqResolver{uc: uc}, nil
|
||||
case ResolverTypeOS:
|
||||
if or == nil {
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
@@ -351,9 +201,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
}
|
||||
|
||||
type osResolver struct {
|
||||
initializedLanServers atomic.Pointer[[]string]
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -504,7 +353,10 @@ func LookupIP(domain string) []string {
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
nss := defaultNameservers()
|
||||
if or == nil {
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
nss := *or.lanServers.Load()
|
||||
if withBootstrapDNS {
|
||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user