diff --git a/cmd/cli/ad_others.go b/cmd/cli/ad_others.go index 6a7417f..b23476f 100644 --- a/cmd/cli/ad_others.go +++ b/cmd/cli/ad_others.go @@ -8,3 +8,8 @@ import ( // addExtraSplitDnsRule adds split DNS rule if present. func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false } + +// getActiveDirectoryDomain returns AD domain name of this computer. +func getActiveDirectoryDomain() (string, error) { + return "", nil +} diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index d7374d0..66180a9 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -1,9 +1,14 @@ package cli import ( - "fmt" + "io" + "log" + "os" "strings" + "github.com/microsoft/wmi/pkg/base/host" + hh "github.com/microsoft/wmi/pkg/hardware/host" + "github.com/Control-D-Inc/ctrld" ) @@ -21,29 +26,48 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // Network rules are lowercase during toml config marshaling, // lowercase the domain here too for consistency. domain = strings.ToLower(domain) + domainRuleAdded := addSplitDnsRule(cfg, domain) + wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, ".")) + return domainRuleAdded || wildcardDomainRuleRuleAdded +} + +// addSplitDnsRule adds split-rule for given domain if there's no existed rule. +// The return value indicates whether the split-rule was added or not. +func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { for n, lc := range cfg.Listener { if lc.Policy == nil { lc.Policy = &ctrld.ListenerPolicyConfig{} } - domainRule := "*." + strings.TrimPrefix(domain, ".") for _, rule := range lc.Policy.Rules { - if _, ok := rule[domainRule]; ok { - mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n) + if _, ok := rule[domain]; ok { + mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n) return false } } - mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n) - lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}}) + mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n) + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) } return true } // getActiveDirectoryDomain returns AD domain name of this computer. func getActiveDirectoryDomain() (string, error) { - cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" - output, err := powershell(cmd) - if err != nil { - return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + whost := host.NewWmiLocalHost() + cs, err := hh.GetComputerSystem(whost) + if cs != nil { + defer cs.Close() } - return string(output), nil + if err != nil { + return "", err + } + pod, err := cs.GetPropertyPartOfDomain() + if err != nil { + return "", err + } + if pod { + return cs.GetPropertyDomain() + } + return "", nil } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go new file mode 100644 index 0000000..6abd25f --- /dev/null +++ b/cmd/cli/ad_windows_test.go @@ -0,0 +1,71 @@ +package cli + +import ( + "fmt" + "testing" + "time" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" + "github.com/stretchr/testify/assert" +) + +func Test_getActiveDirectoryDomain(t *testing.T) { + start := time.Now() + domain, err := getActiveDirectoryDomain() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + domainPowershell, err := getActiveDirectoryDomainPowershell() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if domain != domainPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain) + } +} + +func getActiveDirectoryDomainPowershell() (string, error) { + cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" + output, err := powershell(cmd) + if err != nil { + return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + } + return string(output), nil +} + +func Test_addSplitDnsRule(t *testing.T) { + newCfg := func(domains ...string) *ctrld.Config { + cfg := testhelper.SampleConfig(t) + lc := cfg.Listener["0"] + for _, domain := range domains { + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) + } + return cfg + } + tests := []struct { + name string + cfg *ctrld.Config + domain string + added bool + }{ + {"added", newCfg(), "example.com", true}, + {"TLD existed", newCfg("example.com"), "*.example.com", true}, + {"wildcard existed", newCfg("*.example.com"), "example.com", true}, + {"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false}, + {"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + added := addSplitDnsRule(tc.cfg, tc.domain) + assert.Equal(t, tc.added, added) + }) + } +} diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 502014e..9d01206 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -30,18 +30,14 @@ import ( "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/miekg/dns" - "github.com/minio/selfupdate" - "github.com/olekukonko/tablewriter" "github.com/pelletier/go-toml/v2" "github.com/rs/zerolog" "github.com/spf13/cobra" - "github.com/spf13/pflag" "github.com/spf13/viper" "tailscale.com/logtail/backoff" "tailscale.com/net/netmon" "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/clientinfo" "github.com/Control-D-Inc/ctrld/internal/controld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" "github.com/Control-D-Inc/ctrld/internal/router" @@ -129,977 +125,18 @@ func initCLI() { rootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) rootCmd.CompletionOptions.HiddenDefaultCmd = true - runCmd := &cobra.Command{ - Use: "run", - Short: "Run the DNS proxy server", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - RunCobraCommand(cmd) - }, - } - runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") - runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = runCmd.Flags().MarkHidden("dev") - runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") - _ = runCmd.Flags().MarkHidden("homedir") - runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - _ = runCmd.Flags().MarkHidden("iface") - runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - - runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} - rootCmd.AddCommand(runCmd) - - startCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Install and start the ctrld service", - Long: `Install and start the ctrld service - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - checkStrFlagEmpty(cmd, cdUidFlagName) - checkStrFlagEmpty(cmd, cdOrgFlagName) - validateCdAndNextDNSFlags() - sc := &service.Config{} - *sc = *svcConfig - osArgs := os.Args[2:] - if os.Args[1] == "service" { - osArgs = os.Args[3:] - } - setDependencies(sc) - sc.Arguments = append([]string{"run"}, osArgs...) - - p := &prog{ - router: router.New(&cfg, cdUID != ""), - cfg: &cfg, - } - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - - status, err := s.Status() - isCtrldRunning := status == service.StatusRunning - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) - - // Get current running iface, if any. - var currentIface string - - // If pin code was set, do not allow running start command. - if isCtrldRunning { - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - currentIface = runningIface(s) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - reportSetDnsOk := func(sockDir string) { - if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { - if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { - if iface == "auto" { - iface = defaultIfaceName() - } - logger := mainLog.Load().With().Str("iface", iface).Logger() - logger.Debug().Msg("setting DNS successfully") - } - } - } - - // No config path, generating config in HOME directory. - noConfigStart := isNoConfigStart(cmd) - writeDefaultConfig := !noConfigStart && configBase64 == "" - - logServerStarted := make(chan struct{}) - // A buffer channel to gather log output from runCmd and report - // to user in case self-check process failed. - runCmdLogCh := make(chan string, 256) - ud, err := userHomeDir() - sockDir := ud - if err != nil { - mainLog.Load().Warn().Msg("log server did not start") - close(logServerStarted) - } else { - setWorkingDirectory(sc, ud) - if configPath == "" && writeDefaultConfig { - defaultConfigFile = filepath.Join(ud, defaultConfigFile) - } - sc.Arguments = append(sc.Arguments, "--homedir="+ud) - if d, err := socketDir(); err == nil { - sockDir = d - } - sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - _ = os.Remove(sockPath) - go func() { - defer func() { - close(runCmdLogCh) - _ = os.Remove(sockPath) - }() - close(logServerStarted) - if conn := runLogServer(sockPath); conn != nil { - // Enough buffer for log message, we don't produce - // such long log message, but just in case. - buf := make([]byte, 1024) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - msg := string(buf[:n]) - if _, _, found := strings.Cut(msg, msgExit); found { - cancel() - } - runCmdLogCh <- msg - } - } - }() - } - <-logServerStarted - - if !startOnly { - startOnly = len(osArgs) == 0 - } - // If user run "ctrld start" and ctrld is already installed, starting existing service. - if startOnly && isCtrldInstalled { - tryReadingConfigWithNotice(false, true) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - initLogging() - tasks := []task{ - {s.Stop, false}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false}, - {s.Start, true}, - {noticeWritingControlDConfig, false}, - } - mainLog.Load().Notice().Msg("Starting existing ctrld service") - if doTasks(tasks) { - mainLog.Load().Notice().Msg("Service started") - sockDir, err := socketDir() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") - os.Exit(1) - } - reportSetDnsOk(sockDir) - } else { - mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") - os.Exit(1) - } - return - } - - if cdUID != "" { - doValidateCdRemoteConfig(cdUID) - } else if uid := cdUIDFromProvToken(); uid != "" { - cdUID = uid - mainLog.Load().Debug().Msg("using uid from provision token") - removeOrgFlagsFromArgs(sc) - // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. - sc.Arguments = append(sc.Arguments, "--cd="+cdUID) - } - if cdUID != "" { - validateCdUpstreamProtocol() - } - - if err := p.router.ConfigureService(sc); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") - } - - if configPath != "" { - v.SetConfigFile(configPath) - } - - tryReadingConfigWithNotice(writeDefaultConfig, true) - - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - initLogging() - - if nextdns != "" { - removeNextDNSFromArgs(sc) - } - - // Explicitly passing config, so on system where home directory could not be obtained, - // or sub-process env is different with the parent, we still behave correctly and use - // the expected config file. - if configPath == "" { - sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) - } - - if router.Name() != "" && iface != "" { - mainLog.Load().Debug().Msg("cleaning up router before installing") - _ = p.router.Cleanup() - } - - tasks := []task{ - {s.Stop, false}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, - {func() error { return ensureUninstall(s) }, false}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false}, - {s.Install, false}, - {s.Start, true}, - // Note that startCmd do not actually write ControlD config, but the config file was - // generated after s.Start, so we notice users here for consistent with nextdns mode. - {noticeWritingControlDConfig, false}, - } - mainLog.Load().Notice().Msg("Starting service") - if doTasks(tasks) { - if err := p.router.Install(sc); err != nil { - mainLog.Load().Warn().Err(err).Msg("post installation failed, please check system/service log for details error") - return - } - - ok, status, err := selfCheckStatus(ctx, s, sockDir) - switch { - case ok && status == service.StatusRunning: - mainLog.Load().Notice().Msg("Service started") - default: - marker := bytes.Repeat([]byte("="), 32) - // If ctrld service is not running, emitting log obtained from ctrld process. - if status != service.StatusRunning || ctx.Err() != nil { - mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") - _, _ = mainLog.Load().Write(marker) - haveLog := false - for msg := range runCmdLogCh { - _, _ = mainLog.Load().Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) - haveLog = true - } - // If we're unable to get log from "ctrld run", notice users about it. - if !haveLog { - mainLog.Load().Write([]byte(`"`)) - } - } - // Report any error if occurred. - if err != nil { - _, _ = mainLog.Load().Write(marker) - msg := fmt.Sprintf("An error occurred while performing test query: %s", err) - mainLog.Load().Write([]byte(msg)) - } - // If ctrld service is running but selfCheckStatus failed, it could be related - // to user's system firewall configuration, notice users about it. - if status == service.StatusRunning && err == nil { - _, _ = mainLog.Load().Write(marker) - mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) - mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) - } - - _, _ = mainLog.Load().Write(marker) - uninstall(p, s) - os.Exit(1) - } - reportSetDnsOk(sockDir) - } - }, - } - // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". - startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = startCmd.Flags().MarkHidden("dev") - startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") - startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) - startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") - _ = startCmd.Flags().MarkHidden("start_only") - - routerCmd := &cobra.Command{ - Use: "setup", - Run: func(cmd *cobra.Command, _ []string) { - exe, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Msgf("could not find executable path: %v", err) - os.Exit(1) - } - flags := make([]string, 0) - cmd.Flags().Visit(func(flag *pflag.Flag) { - flags = append(flags, fmt.Sprintf("--%s=%s", flag.Name, flag.Value)) - }) - cmdArgs := []string{"start"} - cmdArgs = append(cmdArgs, flags...) - command := exec.Command(exe, cmdArgs...) - command.Stdout = os.Stdout - command.Stderr = os.Stderr - command.Stdin = os.Stdin - if err := command.Run(); err != nil { - mainLog.Load().Fatal().Msg(err.Error()) - } - }, - } - routerCmd.Flags().AddFlagSet(startCmd.Flags()) - routerCmd.Hidden = true - rootCmd.AddCommand(routerCmd) - - stopCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Stop the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - initLogging() - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - if doTasks([]task{{s.Stop, true}}) { - p.router.Cleanup() - p.resetDNS() - if router.WaitProcessExited() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - for { - select { - case <-ctx.Done(): - mainLog.Load().Error().Msg("timeout while waiting for service to stop") - return - default: - } - time.Sleep(time.Second) - if status, _ := s.Status(); status == service.StatusStopped { - break - } - } - } - mainLog.Load().Notice().Msg("Service stopped") - } - }, - } - stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) - _ = stopCmd.Flags().MarkHidden("pin") - - restartCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - cdUID = curCdUID() - cdMode := cdUID != "" - - p := &prog{router: router.New(&cfg, cdMode)} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - initLogging() - - if cdMode { - doValidateCdRemoteConfig(cdUID) - } - - iface = runningIface(s) - tasks := []task{ - {s.Stop, false}, - {s.Start, true}, - } - if doTasks(tasks) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") - return - } - cc := newSocketControlClient(context.TODO(), s, dir) - if cc == nil { - mainLog.Load().Notice().Msg("Service was not restarted") - os.Exit(1) - } - _, _ = cc.post(ifacePath, nil) - mainLog.Load().Notice().Msg("Service restarted") - } - }, - } - - reloadCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(reloadPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") - } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - mainLog.Load().Notice().Msg("Service reloaded") - case http.StatusCreated: - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") - mainLog.Load().Warn().Msg("Restarting service") - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("Service not installed") - return - } - restartCmd.Run(cmd, args) - default: - buf, err := io.ReadAll(resp.Body) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") - } - mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) - } - }, - } - statusCmd := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - status, err := s.Status() - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - os.Exit(1) - } - switch status { - case service.StatusUnknown: - mainLog.Load().Notice().Msg("Unknown status") - os.Exit(2) - case service.StatusRunning: - mainLog.Load().Notice().Msg("Service is running") - os.Exit(0) - case service.StatusStopped: - mainLog.Load().Notice().Msg("Service is stopped") - os.Exit(1) - } - }, - } - if runtime.GOOS == "darwin" { - // On darwin, running status command without privileges may return wrong information. - statusCmd.PreRun = func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - } - } - - uninstallCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if iface == "" { - iface = "auto" - } - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - uninstall(p, s) - if cleanup { - var files []string - // Config file. - files = append(files, v.ConfigFileUsed()) - // Log file and backup log file. - // For safety, only process if log file path is absolute. - if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { - files = append(files, logFile) - oldLogFile := logFile + oldLogSuffix - if _, err := os.Stat(oldLogFile); err == nil { - files = append(files, oldLogFile) - } - } - // Socket files. - if dir, _ := socketDir(); dir != "" { - files = append(files, filepath.Join(dir, ctrldControlUnixSock)) - files = append(files, filepath.Join(dir, ctrldLogUnixSock)) - } - // Static DNS settings files. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) - if _, err := os.Stat(file); err == nil { - files = append(files, file) - } - return nil - }) - // Windows forwarders file. - if windowsHasLocalDnsServerRunning() { - files = append(files, absHomeDir(windowsForwardersFilename)) - } - // Binary itself. - bin, _ := os.Executable() - if bin != "" && supportedSelfDelete { - files = append(files, bin) - } - // Backup file after upgrading. - oldBin := bin + oldBinSuffix - if _, err := os.Stat(oldBin); err == nil { - files = append(files, oldBin) - } - for _, file := range files { - if file == "" { - continue - } - if err := os.Remove(file); err != nil { - if os.IsNotExist(err) { - continue - } - mainLog.Load().Warn().Err(err).Msg("failed to remove file") - } else { - mainLog.Load().Debug().Msgf("file removed: %s", file) - } - } - if err := selfDeleteExe(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to remove file") - } else { - if !supportedSelfDelete { - mainLog.Load().Debug().Msgf("file removed: %s", bin) - } - } - } - }, - } - uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) - uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) - _ = uninstallCmd.Flags().MarkHidden("pin") - uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) - - listIfacesCmd := &cobra.Command{ - Use: "list", - Short: "List network interfaces of the host", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - err := netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { - fmt.Printf("Index : %d\n", i.Index) - fmt.Printf("Name : %s\n", i.Name) - addrs, _ := i.Addrs() - for i, ipaddr := range addrs { - if i == 0 { - fmt.Printf("Addrs : %v\n", ipaddr) - continue - } - fmt.Printf(" %v\n", ipaddr) - } - for i, dns := range currentDNS(i.Interface) { - if i == 0 { - fmt.Printf("DNS : %s\n", dns) - continue - } - fmt.Printf(" : %s\n", dns) - } - println() - }) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - } - }, - } - interfacesCmd := &cobra.Command{ - Use: "interfaces", - Short: "Manage network interfaces", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listIfacesCmd.Use, - }, - } - interfacesCmd.AddCommand(listIfacesCmd) - - serviceCmd := &cobra.Command{ - Use: "service", - Short: "Manage ctrld service", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - startCmd.Use, - stopCmd.Use, - restartCmd.Use, - reloadCmd.Use, - statusCmd.Use, - uninstallCmd.Use, - interfacesCmd.Use, - }, - } - serviceCmd.AddCommand(startCmd) - serviceCmd.AddCommand(stopCmd) - serviceCmd.AddCommand(restartCmd) - serviceCmd.AddCommand(reloadCmd) - serviceCmd.AddCommand(statusCmd) - serviceCmd.AddCommand(uninstallCmd) - serviceCmd.AddCommand(interfacesCmd) - rootCmd.AddCommand(serviceCmd) - startCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Quick start service and configure DNS on interface", - Long: `Quick start service and configure DNS on interface - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Run: func(cmd *cobra.Command, args []string) { - if len(os.Args) == 2 { - startOnly = true - } - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - startCmd.Run(cmd, args) - }, - } - startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) - rootCmd.AddCommand(startCmdAlias) - stopCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Quick stop service and remove DNS from interface", - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - stopCmd.Run(cmd, args) - }, - } - stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) - rootCmd.AddCommand(stopCmdAlias) - - restartCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - restartCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(restartCmdAlias) - - reloadCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - reloadCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(reloadCmdAlias) - - statusCmdAlias := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: statusCmd.Run, - } - rootCmd.AddCommand(statusCmdAlias) - - uninstallCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - uninstallCmd.Run(cmd, args) - }, - } - uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) - rootCmd.AddCommand(uninstallCmdAlias) - - listClientsCmd := &cobra.Command{ - Use: "list", - Short: "List clients that ctrld discovered", - Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(listClientsPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get clients list") - } - defer resp.Body.Close() - - var clients []*clientinfo.Client - if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to decode clients list result") - } - map2Slice := func(m map[string]struct{}) []string { - s := make([]string, 0, len(m)) - for k := range m { - if k == "" { // skip empty source from output. - continue - } - s = append(s, k) - } - sort.Strings(s) - return s - } - // If metrics is enabled, server set this for all clients, so we can check only the first one. - // Ideally, we may have a field in response to indicate that query count should be shown, but - // it would break earlier version of ctrld, which only look list of clients in response. - withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount - data := make([][]string, len(clients)) - for i, c := range clients { - row := []string{ - c.IP.String(), - c.Hostname, - c.Mac, - strings.Join(map2Slice(c.Source), ","), - } - if withQueryCount { - row = append(row, strconv.FormatInt(c.QueryCount, 10)) - } - data[i] = row - } - table := tablewriter.NewWriter(os.Stdout) - headers := []string{"IP", "Hostname", "Mac", "Discovered"} - if withQueryCount { - headers = append(headers, "Queries") - } - table.SetHeader(headers) - table.SetAutoFormatHeaders(false) - table.AppendBulk(data) - table.Render() - }, - } - clientsCmd := &cobra.Command{ - Use: "clients", - Short: "Manage clients", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listClientsCmd.Use, - }, - } - clientsCmd.AddCommand(listClientsCmd) - rootCmd.AddCommand(clientsCmd) - - const ( - upgradeChannelDev = "dev" - upgradeChannelProd = "prod" - upgradeChannelDefault = "default" - ) - upgradeChannel := map[string]string{ - upgradeChannelDefault: "https://dl.controld.dev", - upgradeChannelDev: "https://dl.controld.dev", - upgradeChannelProd: "https://dl.controld.com", - } - if isStableVersion(curVersion()) { - upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] - } - upgradeCmd := &cobra.Command{ - Use: "upgrade", - Short: "Upgrading ctrld to latest version", - ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, - Args: cobra.MaximumNArgs(1), - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - bin, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") - } - sc := &service.Config{} - *sc = *svcConfig - sc.Executable = bin - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - - svcInstalled := true - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - svcInstalled = false - } - oldBin := bin + oldBinSuffix - baseUrl := upgradeChannel[upgradeChannelDefault] - if len(args) > 0 { - channel := args[0] - switch channel { - case upgradeChannelProd, upgradeChannelDev: // ok - default: - mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) - } - baseUrl = upgradeChannel[channel] - } - dlUrl := upgradeUrl(baseUrl) - mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) - resp, err := http.Get(dlUrl) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to download binary") - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) - } - mainLog.Load().Debug().Msg("Updating current binary") - if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { - if rerr := selfupdate.RollbackError(err); rerr != nil { - mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") - } - mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") - } - - doRestart := func() bool { - if !svcInstalled { - return true - } - tasks := []task{ - {s.Stop, false}, - {s.Start, false}, - } - if doTasks(tasks) { - if dir, err := socketDir(); err == nil { - if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { - _, _ = cc.post(ifacePath, nil) - return true - } - } - } - return false - } - if svcInstalled { - mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") - } - if doRestart() { - _ = os.Remove(oldBin) - _ = os.Chmod(bin, 0755) - ver := "unknown version" - out, err := exec.Command(bin, "--version").CombinedOutput() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") - } - if after, found := strings.CutPrefix(string(out), "ctrld version "); found { - ver = after - } - mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) - return - } - - mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) - if err := os.Remove(bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") - } - if err := os.Rename(oldBin, bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") - } - if doRestart() { - mainLog.Load().Notice().Msg("Restored previous binary successfully") - return - } - }, - } - rootCmd.AddCommand(upgradeCmd) + initRunCmd() + startCmd := initStartCmd() + stopCmd := initStopCmd() + restartCmd := initRestartCmd() + reloadCmd := initReloadCmd(restartCmd) + statusCmd := initStatusCmd() + uninstallCmd := initUninstallCmd() + interfacesCmd := initInterfacesCmd() + initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) + initClientsCmd() + initUpgradeCmd() + initLogCmd() } // isMobile reports whether the current OS is a mobile platform. @@ -1229,7 +266,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. - initLogging() + logWriters := initLogging() + + // Initializing internal logging after global logging. + p.initInternalLogging(logWriters) mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) mainLog.Load().Info().Msgf("os: %s", osVersion()) @@ -1261,7 +301,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if cdUID != "" { validateCdUpstreamProtocol() - if err := processCDFlags(&cfg); err != nil { + if rc, err := processCDFlags(&cfg); err != nil { if isMobile() { appCallback.Exit(err.Error()) return @@ -1269,12 +309,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() // Performs self-uninstallation if the ControlD device does not exist. - var uer *controld.UtilityErrorResponse + var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { _ = uninstallInvalidCdUID(p, cdLogger, false) } notifyExitToLogServer() cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") + } else { + p.mu.Lock() + p.rc = rc + p.mu.Unlock() } } @@ -1346,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } } } + // Configure Windows service failure actions + _ = ConfigureWindowsServiceFailureActions(ctrldServiceName) }) p.onStopped = append(p.onStopped, func() { for _, lc := range p.cfg.Listener { @@ -1564,7 +610,7 @@ func deactivationPinNotSet() bool { return cdDeactivationPin.Load() == defaultDeactivationPin } -func processCDFlags(cfg *ctrld.Config) error { +func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -1582,10 +628,10 @@ func processCDFlags(cfg *ctrld.Config) error { } if err != nil { if isMobile() { - return err + return nil, err } logger.Warn().Err(err).Msg("could not fetch resolver config") - return err + return nil, err } if resolverConfig.DeactivationPin != nil { @@ -1601,7 +647,7 @@ func processCDFlags(cfg *ctrld.Config) error { logger.Info().Msg("using defined custom config of Control-D resolver") if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil { setListenerDefaultValue(cfg) - return nil + return resolverConfig, nil } mainLog.Load().Err(err).Msg("disregarding invalid custom config") } @@ -1648,7 +694,7 @@ func processCDFlags(cfg *ctrld.Config) error { // Set default value. setListenerDefaultValue(cfg) - return nil + return resolverConfig, nil } // setListenerDefaultValue sets the default value for cfg.Listener if none existed. @@ -1720,7 +766,7 @@ func netInterface(ifaceName string) (*net.Interface, error) { if iface == nil { return nil, errors.New("interface not found") } - if err := patchNetIfaceName(iface); err != nil { + if _, err := patchNetIfaceName(iface); err != nil { return nil, err } return iface, err @@ -1972,10 +1018,10 @@ func uninstall(p *prog, s service.Service) { return } tasks := []task{ - {s.Stop, false}, - {s.Uninstall, true}, + {s.Stop, false, "Stop"}, + {s.Uninstall, true, "Uninstall"}, } - initLogging() + initInteractiveLogging() if doTasks(tasks) { if err := p.router.ConfigureService(svcConfig); err != nil { mainLog.Load().Fatal().Err(err).Msg("could not configure service") @@ -1985,6 +1031,16 @@ func uninstall(p *prog, s service.Service) { return } p.resetDNS() + + // if present restore the original DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") } @@ -2107,7 +1163,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata cdMode := cdUID != "" nextdnsMode := nextdns != "" // For Windows server with local Dns server running, we can only try on random local IP. - hasLocalDnsServer := windowsHasLocalDnsServerRunning() + hasLocalDnsServer := hasLocalDnsServerRunning() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -2422,6 +1478,12 @@ func removeOrgFlagsFromArgs(sc *service.Config) { // newSocketControlClient returns new control client after control server was started. func newSocketControlClient(ctx context.Context, s service.Service, dir string) *controlClient { + return newSocketControlClientWithTimeout(ctx, s, dir, dialSocketControlServerTimeout) +} + +// newSocketControlClientWithTimeout returns new control client after control server was started. +// The timeoutDuration controls how long to wait for the server. +func newSocketControlClientWithTimeout(ctx context.Context, s service.Service, dir string, timeoutDuration time.Duration) *controlClient { // Return early if service is not running. if status, err := s.Status(); err != nil || status != service.StatusRunning { return nil @@ -2430,7 +1492,7 @@ func newSocketControlClient(ctx context.Context, s service.Service, dir string) bo.LogLongerThan = 10 * time.Second cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - timeout := time.NewTimer(30 * time.Second) + timeout := time.NewTimer(timeoutDuration) defer timeout.Stop() // The socket control server may not start yet, so attempt to ping @@ -2557,22 +1619,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto // checkDeactivationPin validates if the deactivation pin matches one in ControlD config. func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { + mainLog.Load().Debug().Msg("Checking deactivation pin") dir, err := socketDir() if err != nil { mainLog.Load().Err(err).Msg("could not check deactivation pin") return err } + mainLog.Load().Debug().Msg("Creating control client") var cc *controlClient if s == nil { cc = newSocketControlClientMobile(dir, stopCh) } else { cc = newSocketControlClient(context.TODO(), s, dir) } + mainLog.Load().Debug().Msg("Control client done") if cc == nil { return nil // ctrld is not running. } data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin}) - resp, _ := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request") + resp, err := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request done") if resp != nil { switch resp.StatusCode { case http.StatusBadRequest: @@ -2614,21 +1681,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M return c.ExchangeContext(ctx, msg, addr) } -// powershell runs the given powershell command. -func powershell(cmd string) ([]byte, error) { - out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() - return bytes.TrimSpace(out), err -} - -// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. -func windowsHasLocalDnsServerRunning() bool { - if runtime.GOOS == "windows" { - _, err := powershell("Get-Process -Name DNS") - return err == nil - } - return false -} - // absHomeDir returns the absolute path to given filename using home directory as root dir. func absHomeDir(filename string) string { if homedir != "" { @@ -2649,6 +1701,10 @@ func runInCdMode() bool { // curCdUID returns the current ControlD UID used by running ctrld process. func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { + // Configure Windows service failure actions + if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) if cc != nil { @@ -2694,20 +1750,22 @@ func upgradeUrl(baseUrl string) string { } // runningIface returns the value of the iface variable used by ctrld process which is running. -func runningIface(s service.Service) string { +func runningIface(s service.Service) *ifaceResponse { if sockDir, err := socketDir(); err == nil { if cc := newSocketControlClient(context.TODO(), s, sockDir); cc != nil { resp, err := cc.post(ifacePath, nil) if err != nil { - return "" + return nil } defer resp.Body.Close() - if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 { - return string(buf) + res := &ifaceResponse{} + if err := json.NewDecoder(resp.Body).Decode(res); err != nil { + return nil } + return res } } - return "" + return nil } // resetDnsNoLog performs resetting DNS with logging disable. @@ -2725,9 +1783,10 @@ func resetDnsNoLog(p *prog) { } // resetDnsTask returns a task which perform reset DNS operation. -func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunningIface string) task { +func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task { return task{func() error { if iface == "" { + mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask") return nil } // Always reset DNS first, ensuring DNS setting is in a good state. @@ -2735,9 +1794,12 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunn // process to reset what setDNS has done properly. oldIface := iface iface = "auto" - if currentRunningIface != "" { - iface = currentRunningIface + p.requiredMultiNICsConfig = requiredMultiNICsConfig() + if ir != nil { + iface = ir.Name + p.requiredMultiNICsConfig = ir.All } + p.runningIface = iface if isCtrldInstalled { mainLog.Load().Debug().Msg("restore system DNS settings") if status, _ := s.Status(); status == service.StatusRunning { @@ -2747,14 +1809,21 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunn } iface = oldIface return nil - }, false} + }, false, "Reset DNS"} } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. -func doValidateCdRemoteConfig(cdUID string) { +func doValidateCdRemoteConfig(cdUID string, fatal bool) error { rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) if err != nil { - mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + logger := mainLog.Load().Fatal() + if !fatal { + logger = mainLog.Load().Warn() + } + logger.Err(err).Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + if !fatal { + return err + } } // validateCdRemoteConfig clobbers v, saving it here to restore later. oldV := v @@ -2784,6 +1853,7 @@ func doValidateCdRemoteConfig(cdUID string) { mainLog.Load().Warn().Msg("disregarding invalid custom config") } v = oldV + return nil } // uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. @@ -2796,7 +1866,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { p.resetDNS() - tasks := []task{{s.Uninstall, true}} + tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { logger.Info().Msg("uninstalled service") if doStop { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go new file mode 100644 index 0000000..49dfb8f --- /dev/null +++ b/cmd/cli/commands.go @@ -0,0 +1,1362 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "time" + + "github.com/docker/go-units" + "github.com/kardianos/service" + "github.com/minio/selfupdate" + "github.com/olekukonko/tablewriter" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/clientinfo" + "github.com/Control-D-Inc/ctrld/internal/router" +) + +// dialSocketControlServerTimeout is the default timeout to wait when ping control server. +const dialSocketControlServerTimeout = 30 * time.Second + +func initLogCmd() *cobra.Command { + warnRuntimeLoggingNotEnabled := func() { + mainLog.Load().Warn().Msg("runtime debug logging is not enabled") + mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`) + } + logSendCmd := &cobra.Command{ + Use: "send", + Short: "Send runtime debug logs to ControlD", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(sendLogsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send logs") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusServiceUnavailable: + mainLog.Load().Warn().Msg("runtime logs could only be sent once per minute") + return + case http.StatusMovedPermanently: + warnRuntimeLoggingNotEnabled() + return + } + var logs logSentResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode sent logs result") + } + size := units.BytesSize(float64(logs.Size)) + if logs.Error == "" { + mainLog.Load().Notice().Msgf("runtime logs sent successfully (%s)", size) + } else { + mainLog.Load().Error().Msgf("failed to send logs (%s)", size) + mainLog.Load().Error().Msg(logs.Error) + } + }, + } + logViewCmd := &cobra.Command{ + Use: "view", + Short: "View current runtime debug logs", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(viewLogsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get logs") + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusMovedPermanently: + warnRuntimeLoggingNotEnabled() + return + case http.StatusBadRequest: + mainLog.Load().Warn().Msg("runtime debugs log is not available") + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to read response body") + } + mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf)) + return + case http.StatusOK: + } + var logs logViewResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode view logs result") + } + fmt.Println(logs.Data) + }, + } + logCmd := &cobra.Command{ + Use: "log", + Short: "Manage runtime debug logs", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + logSendCmd.Use, + }, + } + logCmd.AddCommand(logSendCmd) + logCmd.AddCommand(logViewCmd) + rootCmd.AddCommand(logCmd) + + return logCmd +} + +func initRunCmd() *cobra.Command { + runCmd := &cobra.Command{ + Use: "run", + Short: "Run the DNS proxy server", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + RunCobraCommand(cmd) + }, + } + runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") + runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = runCmd.Flags().MarkHidden("dev") + runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") + _ = runCmd.Flags().MarkHidden("homedir") + runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + + runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} + rootCmd.AddCommand(runCmd) + + return runCmd +} + +func initStartCmd() *cobra.Command { + startCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "start", + Short: "Install and start the ctrld service", + Long: `Install and start the ctrld service + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + checkStrFlagEmpty(cmd, cdUidFlagName) + checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() + sc := &service.Config{} + *sc = *svcConfig + osArgs := os.Args[2:] + if os.Args[1] == "service" { + osArgs = os.Args[3:] + } + setDependencies(sc) + sc.Arguments = append([]string{"run"}, osArgs...) + + p := &prog{ + router: router.New(&cfg, cdUID != ""), + cfg: &cfg, + } + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + + status, err := s.Status() + isCtrldRunning := status == service.StatusRunning + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + + // Get current running iface, if any. + var currentIface *ifaceResponse + + // If pin code was set, do not allow running start command. + if isCtrldRunning { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + currentIface = runningIface(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reportSetDnsOk := func(sockDir string) { + if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + res := &ifaceResponse{} + if err := json.NewDecoder(resp.Body).Decode(res); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get iface info") + return + } + if res.OK { + name := res.Name + if iff, err := net.InterfaceByName(name); err == nil { + _, _ = patchNetIfaceName(iff) + name = iff.Name + } + logger := mainLog.Load().With().Str("iface", name).Logger() + logger.Debug().Msg("setting DNS successfully") + if res.All { + // Log that DNS is set for other interfaces. + withEachPhysicalInterfaces( + name, + "set DNS", + func(i *net.Interface) error { return nil }, + ) + } + } + } + } + } + + // No config path, generating config in HOME directory. + noConfigStart := isNoConfigStart(cmd) + writeDefaultConfig := !noConfigStart && configBase64 == "" + + logServerStarted := make(chan struct{}) + // A buffer channel to gather log output from runCmd and report + // to user in case self-check process failed. + runCmdLogCh := make(chan string, 256) + ud, err := userHomeDir() + sockDir := ud + if err != nil { + mainLog.Load().Warn().Msg("log server did not start") + close(logServerStarted) + } else { + setWorkingDirectory(sc, ud) + if configPath == "" && writeDefaultConfig { + defaultConfigFile = filepath.Join(ud, defaultConfigFile) + } + sc.Arguments = append(sc.Arguments, "--homedir="+ud) + if d, err := socketDir(); err == nil { + sockDir = d + } + sockPath := filepath.Join(sockDir, ctrldLogUnixSock) + _ = os.Remove(sockPath) + go func() { + defer func() { + close(runCmdLogCh) + _ = os.Remove(sockPath) + }() + close(logServerStarted) + if conn := runLogServer(sockPath); conn != nil { + // Enough buffer for log message, we don't produce + // such long log message, but just in case. + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + msg := string(buf[:n]) + if _, _, found := strings.Cut(msg, msgExit); found { + cancel() + } + runCmdLogCh <- msg + } + } + }() + } + <-logServerStarted + + if !startOnly { + startOnly = len(osArgs) == 0 + } + // If user run "ctrld start" and ctrld is already installed, starting existing service. + if startOnly && isCtrldInstalled { + tryReadingConfigWithNotice(false, true) + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + } + + initInteractiveLogging() + tasks := []task{ + {s.Stop, false, "Stop"}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false, "Save current DNS"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, + } + mainLog.Load().Notice().Msg("Starting existing ctrld service") + if doTasks(tasks) { + mainLog.Load().Notice().Msg("Service started") + sockDir, err := socketDir() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") + os.Exit(1) + } + reportSetDnsOk(sockDir) + } else { + mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") + os.Exit(1) + } + return + } + + if cdUID != "" { + _ = doValidateCdRemoteConfig(cdUID, true) + } else if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + mainLog.Load().Debug().Msg("using uid from provision token") + removeOrgFlagsFromArgs(sc) + // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. + sc.Arguments = append(sc.Arguments, "--cd="+cdUID) + } + if cdUID != "" { + validateCdUpstreamProtocol() + } + + if err := p.router.ConfigureService(sc); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") + } + + if configPath != "" { + v.SetConfigFile(configPath) + } + + tryReadingConfigWithNotice(writeDefaultConfig, true) + + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + } + + initInteractiveLogging() + + if nextdns != "" { + removeNextDNSFromArgs(sc) + } + + // Explicitly passing config, so on system where home directory could not be obtained, + // or sub-process env is different with the parent, we still behave correctly and use + // the expected config file. + if configPath == "" { + sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) + } + + if router.Name() != "" && iface != "" { + mainLog.Load().Debug().Msg("cleaning up router before installing") + _ = p.router.Cleanup() + } + + tasks := []task{ + {s.Stop, false, "Stop"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, + {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false, "Save current DNS"}, + {s.Install, false, "Install"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, + // Note that startCmd do not actually write ControlD config, but the config file was + // generated after s.Start, so we notice users here for consistent with nextdns mode. + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, + } + mainLog.Load().Notice().Msg("Starting service") + if doTasks(tasks) { + if err := p.router.Install(sc); err != nil { + mainLog.Load().Warn().Err(err).Msg("post installation failed, please check system/service log for details error") + return + } + + ok, status, err := selfCheckStatus(ctx, s, sockDir) + switch { + case ok && status == service.StatusRunning: + mainLog.Load().Notice().Msg("Service started") + default: + marker := bytes.Repeat([]byte("="), 32) + // If ctrld service is not running, emitting log obtained from ctrld process. + if status != service.StatusRunning || ctx.Err() != nil { + mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") + _, _ = mainLog.Load().Write(marker) + haveLog := false + for msg := range runCmdLogCh { + _, _ = mainLog.Load().Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) + haveLog = true + } + // If we're unable to get log from "ctrld run", notice users about it. + if !haveLog { + mainLog.Load().Write([]byte(`"`)) + } + } + // Report any error if occurred. + if err != nil { + _, _ = mainLog.Load().Write(marker) + msg := fmt.Sprintf("An error occurred while performing test query: %s", err) + mainLog.Load().Write([]byte(msg)) + } + // If ctrld service is running but selfCheckStatus failed, it could be related + // to user's system firewall configuration, notice users about it. + if status == service.StatusRunning && err == nil { + _, _ = mainLog.Load().Write(marker) + mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) + mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) + } + + _, _ = mainLog.Load().Write(marker) + uninstall(p, s) + os.Exit(1) + } + reportSetDnsOk(sockDir) + } + }, + } + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". + startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = startCmd.Flags().MarkHidden("dev") + startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) + startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") + _ = startCmd.Flags().MarkHidden("start_only") + + routerCmd := &cobra.Command{ + Use: "setup", + Run: func(cmd *cobra.Command, _ []string) { + exe, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Msgf("could not find executable path: %v", err) + os.Exit(1) + } + flags := make([]string, 0) + cmd.Flags().Visit(func(flag *pflag.Flag) { + flags = append(flags, fmt.Sprintf("--%s=%s", flag.Name, flag.Value)) + }) + cmdArgs := []string{"start"} + cmdArgs = append(cmdArgs, flags...) + command := exec.Command(exe, cmdArgs...) + command.Stdout = os.Stdout + command.Stderr = os.Stderr + command.Stdin = os.Stdin + if err := command.Run(); err != nil { + mainLog.Load().Fatal().Msg(err.Error()) + } + }, + } + routerCmd.Flags().AddFlagSet(startCmd.Flags()) + routerCmd.Hidden = true + rootCmd.AddCommand(routerCmd) + + startCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "start", + Short: "Quick start service and configure DNS on interface", + Long: `Quick start service and configure DNS on interface + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Run: func(cmd *cobra.Command, args []string) { + if len(os.Args) == 2 { + startOnly = true + } + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + startCmd.Run(cmd, args) + }, + } + startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) + rootCmd.AddCommand(startCmdAlias) + + return startCmd +} + +func initStopCmd() *cobra.Command { + stopCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "stop", + Short: "Stop the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + initInteractiveLogging() + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is already stopped") + return + } + + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + if doTasks([]task{{s.Stop, true, "Stop"}}) { + p.router.Cleanup() + p.resetDNS() + + // restore DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + return + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + mainLog.Load().Notice().Msg("Service stopped") + } + }, + } + stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) + _ = stopCmd.Flags().MarkHidden("pin") + + stopCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "stop", + Short: "Quick stop service and remove DNS from interface", + Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + stopCmd.Run(cmd, args) + }, + } + stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) + rootCmd.AddCommand(stopCmdAlias) + + return stopCmd +} + +func initRestartCmd() *cobra.Command { + restartCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "restart", + Short: "Restart the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + cdUID = curCdUID() + cdMode := cdUID != "" + + p := &prog{router: router.New(&cfg, cdMode)} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + initInteractiveLogging() + + var validateConfigErr error + if cdMode { + validateConfigErr = doValidateCdRemoteConfig(cdUID, false) + } + + if ir := runningIface(s); ir != nil { + iface = ir.Name + } + + doRestart := func() bool { + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { + + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } else { + return false + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + + return doTasks(tasks) + + } + + if doRestart() { + if dir, err := socketDir(); err == nil { + timeout := dialSocketControlServerTimeout + // If we failed to validate remote config above, it's likely that + // we are having problem with network connection. So using a shorter + // timeout than default one for better UX. + if validateConfigErr != nil { + timeout = 5 * time.Second + } + if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil { + _, _ = cc.post(ifacePath, nil) + } else { + mainLog.Load().Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet") + } + } else { + mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") + } + mainLog.Load().Notice().Msg("Service restarted") + } else { + mainLog.Load().Error().Msg("Service restart failed") + } + }, + } + + restartCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "restart", + Short: "Restart the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + restartCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(restartCmdAlias) + + return restartCmd +} + +func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { + reloadCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("Service reloaded") + case http.StatusCreated: + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") + mainLog.Load().Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return + } + restartCmd.Run(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") + } + mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) + } + }, + } + + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + reloadCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + + return reloadCmd +} + +func initStatusCmd() *cobra.Command { + statusCmd := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + status, err := s.Status() + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + os.Exit(1) + } + switch status { + case service.StatusUnknown: + mainLog.Load().Notice().Msg("Unknown status") + os.Exit(2) + case service.StatusRunning: + mainLog.Load().Notice().Msg("Service is running") + os.Exit(0) + case service.StatusStopped: + mainLog.Load().Notice().Msg("Service is stopped") + os.Exit(1) + } + }, + } + if runtime.GOOS == "darwin" { + // On darwin, running status command without privileges may return wrong information. + statusCmd.PreRun = func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + } + } + + statusCmdAlias := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + Run: statusCmd.Run, + } + rootCmd.AddCommand(statusCmdAlias) + + return statusCmd +} + +func initUninstallCmd() *cobra.Command { + uninstallCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + uninstall(p, s) + if cleanup { + var files []string + // Config file. + files = append(files, v.ConfigFileUsed()) + // Log file and backup log file. + // For safety, only process if log file path is absolute. + if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { + files = append(files, logFile) + oldLogFile := logFile + oldLogSuffix + if _, err := os.Stat(oldLogFile); err == nil { + files = append(files, oldLogFile) + } + } + // Socket files. + if dir, _ := socketDir(); dir != "" { + files = append(files, filepath.Join(dir, ctrldControlUnixSock)) + files = append(files, filepath.Join(dir, ctrldLogUnixSock)) + } + // Static DNS settings files. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + files = append(files, file) + } + return nil + }) + // Windows forwarders file. + if hasLocalDnsServerRunning() { + files = append(files, absHomeDir(windowsForwardersFilename)) + } + // Binary itself. + bin, _ := os.Executable() + if bin != "" && supportedSelfDelete { + files = append(files, bin) + } + // Backup file after upgrading. + oldBin := bin + oldBinSuffix + if _, err := os.Stat(oldBin); err == nil { + files = append(files, oldBin) + } + for _, file := range files { + if file == "" { + continue + } + if err := os.Remove(file); err != nil { + if os.IsNotExist(err) { + continue + } + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + mainLog.Load().Debug().Msgf("file removed: %s", file) + } + } + if err := selfDeleteExe(); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + if !supportedSelfDelete { + mainLog.Load().Debug().Msgf("file removed: %s", bin) + } + } + } + }, + } + uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) + uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) + _ = uninstallCmd.Flags().MarkHidden("pin") + uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) + + uninstallCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + uninstallCmd.Run(cmd, args) + }, + } + uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) + rootCmd.AddCommand(uninstallCmdAlias) + + return uninstallCmd +} + +func initInterfacesCmd() *cobra.Command { + listIfacesCmd := &cobra.Command{ + Use: "list", + Short: "List network interfaces of the host", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + fmt.Printf("Index : %d\n", i.Index) + fmt.Printf("Name : %s\n", i.Name) + addrs, _ := i.Addrs() + for i, ipaddr := range addrs { + if i == 0 { + fmt.Printf("Addrs : %v\n", ipaddr) + continue + } + fmt.Printf(" %v\n", ipaddr) + } + nss, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get DNS") + } + if len(nss) == 0 { + nss = currentDNS(i) + } + for i, dns := range nss { + if i == 0 { + fmt.Printf("DNS : %s\n", dns) + continue + } + fmt.Printf(" : %s\n", dns) + } + println() + return nil + }) + }, + } + interfacesCmd := &cobra.Command{ + Use: "interfaces", + Short: "Manage network interfaces", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listIfacesCmd.Use, + }, + } + interfacesCmd.AddCommand(listIfacesCmd) + + return interfacesCmd +} + +func initClientsCmd() *cobra.Command { + listClientsCmd := &cobra.Command{ + Use: "list", + Short: "List clients that ctrld discovered", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(listClientsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get clients list") + } + defer resp.Body.Close() + + var clients []*clientinfo.Client + if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode clients list result") + } + map2Slice := func(m map[string]struct{}) []string { + s := make([]string, 0, len(m)) + for k := range m { + if k == "" { // skip empty source from output. + continue + } + s = append(s, k) + } + sort.Strings(s) + return s + } + // If metrics is enabled, server set this for all clients, so we can check only the first one. + // Ideally, we may have a field in response to indicate that query count should be shown, but + // it would break earlier version of ctrld, which only look list of clients in response. + withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount + data := make([][]string, len(clients)) + for i, c := range clients { + row := []string{ + c.IP.String(), + c.Hostname, + c.Mac, + strings.Join(map2Slice(c.Source), ","), + } + if withQueryCount { + row = append(row, strconv.FormatInt(c.QueryCount, 10)) + } + data[i] = row + } + table := tablewriter.NewWriter(os.Stdout) + headers := []string{"IP", "Hostname", "Mac", "Discovered"} + if withQueryCount { + headers = append(headers, "Queries") + } + table.SetHeader(headers) + table.SetAutoFormatHeaders(false) + table.AppendBulk(data) + table.Render() + }, + } + clientsCmd := &cobra.Command{ + Use: "clients", + Short: "Manage clients", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listClientsCmd.Use, + }, + } + clientsCmd.AddCommand(listClientsCmd) + rootCmd.AddCommand(clientsCmd) + + return clientsCmd +} + +func initUpgradeCmd() *cobra.Command { + const ( + upgradeChannelDev = "dev" + upgradeChannelProd = "prod" + upgradeChannelDefault = "default" + ) + upgradeChannel := map[string]string{ + upgradeChannelDefault: "https://dl.controld.dev", + upgradeChannelDev: "https://dl.controld.dev", + upgradeChannelProd: "https://dl.controld.com", + } + if isStableVersion(curVersion()) { + upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] + } + upgradeCmd := &cobra.Command{ + Use: "upgrade", + Short: "Upgrading ctrld to latest version", + ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, + Args: cobra.MaximumNArgs(1), + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + bin, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") + } + sc := &service.Config{} + *sc = *svcConfig + sc.Executable = bin + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + svcInstalled := true + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + svcInstalled = false + } + oldBin := bin + oldBinSuffix + baseUrl := upgradeChannel[upgradeChannelDefault] + if len(args) > 0 { + channel := args[0] + switch channel { + case upgradeChannelProd, upgradeChannelDev: // ok + default: + mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) + } + baseUrl = upgradeChannel[channel] + } + dlUrl := upgradeUrl(baseUrl) + mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) + resp, err := http.Get(dlUrl) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to download binary") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) + } + mainLog.Load().Debug().Msg("Updating current binary") + if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { + if rerr := selfupdate.RollbackError(err); rerr != nil { + mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") + } + mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") + } + + doRestart := func() bool { + if !svcInstalled { + return true + } + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { + + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } + } + } + return false + } + if svcInstalled { + mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + } + if doRestart() { + _ = os.Remove(oldBin) + _ = os.Chmod(bin, 0755) + ver := "unknown version" + out, err := exec.Command(bin, "--version").CombinedOutput() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") + } + if after, found := strings.CutPrefix(string(out), "ctrld version "); found { + ver = after + } + mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) + return + } + + mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) + if err := os.Remove(bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") + } + if err := os.Rename(oldBin, bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") + } + if doRestart() { + mainLog.Load().Notice().Msg("Restored previous binary successfully") + return + } + }, + } + rootCmd.AddCommand(upgradeCmd) + + return upgradeCmd +} + +func initServicesCmd(commands ...*cobra.Command) *cobra.Command { + serviceCmd := &cobra.Command{ + Use: "service", + Short: "Manage ctrld service", + Args: cobra.OnlyValidArgs, + } + serviceCmd.ValidArgs = make([]string, len(commands)) + for i, cmd := range commands { + serviceCmd.ValidArgs[i] = cmd.Use + serviceCmd.AddCommand(cmd) + } + rootCmd.AddCommand(serviceCmd) + + return serviceCmd +} diff --git a/cmd/cli/control_client.go b/cmd/cli/control_client.go index 73002e8..7382d4e 100644 --- a/cmd/cli/control_client.go +++ b/cmd/cli/control_client.go @@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient { } func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) { + // for log/send, set the timeout to 5 minutes + if path == sendLogsPath { + c.c.Timeout = time.Minute * 5 + } return c.c.Post("http://unix"+path, contentTypeJson, data) } diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index c31fd13..1ea1693 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -3,6 +3,8 @@ package cli import ( "context" "encoding/json" + "fmt" + "io" "net" "net/http" "os" @@ -25,8 +27,16 @@ const ( deactivationPath = "/deactivation" cdPath = "/cd" ifacePath = "/iface" + viewLogsPath = "/log/view" + sendLogsPath = "/log/send" ) +type ifaceResponse struct { + Name string `json:"name"` + All bool `json:"all"` + OK bool `json:"ok"` +} + type controlServer struct { server *http.Server mux *http.ServeMux @@ -201,15 +211,76 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusBadRequest) })) p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + res := &ifaceResponse{Name: iface} // p.setDNS is only called when running as a service if !service.Interactive() { <-p.csSetDnsDone if p.csSetDnsOk { - w.Write([]byte(iface)) - return + res.Name = p.runningIface + res.All = p.requiredMultiNICsConfig + res.OK = true } } - w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(res); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError) + return + } + })) + p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + lr, err := p.logReader() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer lr.r.Close() + if lr.size == 0 { + w.WriteHeader(http.StatusMovedPermanently) + return + } + data, err := io.ReadAll(lr.r) + if err != nil { + http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError) + return + } + })) + p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + if time.Since(p.internalLogSent) < logSentInterval { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r, err := p.logReader() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.size == 0 { + w.WriteHeader(http.StatusMovedPermanently) + return + } + req := &controld.LogsRequest{ + UID: cdUID, + Data: r.r, + } + mainLog.Load().Debug().Msg("sending log file to ControlD server") + resp := logSentResponse{Size: r.size} + if err := controld.SendLogs(req, cdDev); err != nil { + mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) + resp.Error = err.Error() + w.WriteHeader(http.StatusInternalServerError) + } else { + mainLog.Load().Debug().Msg("sending log file successfully") + w.WriteHeader(http.StatusOK) + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + p.internalLogSent = time.Now() })) } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 6611975..0bc042e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "os/exec" "runtime" "slices" "strconv" @@ -19,6 +20,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" @@ -41,7 +43,7 @@ const ( var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: ctrld.ResolverTypeOS, - Timeout: 2000, + Timeout: 3000, } var privateUpstreamConfig = &ctrld.UpstreamConfig{ @@ -50,6 +52,12 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } +var localUpstreamConfig = &ctrld.UpstreamConfig{ + Name: "Local resolver", + Type: ctrld.ResolverTypeLocal, + Timeout: 2000, +} + // proxyRequest contains data for proxying a DNS query to upstream. type proxyRequest struct { msg *dns.Msg @@ -76,7 +84,13 @@ type upstreamForResult struct { srcAddr string } -func (p *prog) serveDNS(listenerNum string) error { +func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { + // Start network monitoring + if err := p.monitorNetworkChanges(mainCtx); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + // Don't return here as we still want DNS service to run + } + listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -106,11 +120,18 @@ func (p *prog) serveDNS(listenerNum string) error { go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) - if domain == selfCheckInternalTestDomain { + switch { + case domain == "": + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeFormatError) + _ = w.WriteMsg(answer) + return + case domain == selfCheckInternalTestDomain: answer := resolveInternalDomainTestQuery(ctx, domain, m) _ = w.WriteMsg(answer) return } + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { p.cache.Purge() ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) @@ -411,23 +432,19 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - leaked := false - // If ctrld is going to leak query to OS resolver, check remote upstream in background, - // so ctrld could be back to normal operation as long as the network is back online. - if len(upstreamConfigs) > 0 && p.leakingQuery.Load() { - for n, uc := range upstreamConfigs { - go p.checkUpstream(upstreams[n], uc) - } - upstreamConfigs = nil - leaked = true - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) - } - if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} } + if p.isAdDomainQuery(req.msg) { + ctrld.Log(ctx, mainLog.Load().Debug(), + "AD domain query detected for %s in domain %s", + req.msg.Question[0].Name, p.adDomain) + upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} + upstreams = []string{upstreamOS} + } + res := &proxyResponse{} // LAN/PTR lookup flow: @@ -438,13 +455,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // 4. Try remote upstream. isLanOrPtrQuery := false if req.ufr.matched { - if leaked { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } else { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } + ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { + case isSrvLookup(req.msg): + upstreams = []string{upstreamOS} + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) + ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { @@ -452,7 +470,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { res.clientInfo = true return res } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true @@ -461,7 +480,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { res.clientInfo = true return res } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + upstreams = []string{upstreamOS} + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) @@ -488,8 +509,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { staleAnswer = answer } } - resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) + resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { + ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") @@ -504,43 +525,53 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } return dnsResolver.Resolve(resolveCtx, msg) } - resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { + resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } - answer, err := resolve1(n, upstreamConfig, msg) + answer, err := resolve1(upstream, upstreamConfig, msg) + // if we have an answer, we should reset the failure count + // we dont use reset here since we dont want to prevent failure counts from being incremented + if answer != nil { + p.um.mu.Lock() + p.um.failureReq[upstream] = 0 + p.um.down[upstream] = false + p.um.mu.Unlock() + return answer + } + + ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + + // increase failure count when there is no answer + // rehardless of what kind of error we get + p.um.increaseFailureCount(upstream) + if err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") - isNetworkErr := errNetworkError(err) - if isNetworkErr { - p.um.increaseFailureCount(upstreams[n]) - if p.um.isDown(upstreams[n]) { - go p.checkUpstream(upstreams[n], upstreamConfig) - } - } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error if errors.As(err, &e) && e.Timeout() { upstreamConfig.ReBootstrap() } - return nil } - return answer + + return nil } for n, upstreamConfig := range upstreamConfigs { if upstreamConfig == nil { continue } + logger := mainLog.Load().Debug(). + Str("upstream", upstreamConfig.String()). + Str("query", req.msg.Question[0].Name). + Bool("is_ad_query", p.isAdDomainQuery(req.msg)). + Bool("is_lan_query", isLanOrPtrQuery) + if p.isLoop(upstreamConfig) { - mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint) + ctrld.Log(ctx, logger, "DNS loop detected") continue } - if p.um.isDown(upstreams[n]) { - ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) - continue - } - answer := resolve(n, upstreamConfig, req.msg) + answer := resolve(upstreams[n], upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") @@ -587,21 +618,49 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) - if cdUID != "" && p.leakOnUpstreamFailure() { - p.leakingQueryMu.Lock() - if !p.leakingQueryWasRun { - p.leakingQueryWasRun = true - go p.performLeakingQuery() + + // if we have no healthy upstreams, trigger recovery flow + if p.recoverOnUpstreamFailure() { + if p.um.countHealthy(upstreams) == 0 { + p.recoveryCancelMu.Lock() + if p.recoveryCancel == nil { + var reason RecoveryReason + if upstreams[0] == upstreamOS { + reason = RecoveryReasonOSFailure + } else { + reason = RecoveryReasonRegularFailure + } + mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + go p.handleRecovery(reason) + } else { + mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") + } + p.recoveryCancelMu.Unlock() + } else { + mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } - p.leakingQueryMu.Unlock() } + + // attempt query to OS resolver while as a retry catch all + if upstreams[0] != upstreamOS { + ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") + answer := resolve(upstreamOS, osUpstreamConfig, req.msg) + if answer != nil { + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") + res.answer = answer + res.upstream = osUpstreamConfig.Endpoint + return res + } + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") + } + answer := new(dns.Msg) answer.SetRcode(req.msg, dns.RcodeServerFailure) res.answer = answer return res } -func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { +func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { if len(p.localUpstreams) > 0 { tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) tmp = append(tmp, p.localUpstreams...) @@ -620,6 +679,14 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U return upstreamConfigs } +func (p *prog) isAdDomainQuery(msg *dns.Msg) bool { + if p.adDomain == "" { + return false + } + cDomainName := canonicalName(msg.Question[0].Name) + return dns.IsSubDomain(p.adDomain, cDomainName) +} + // canonicalName returns canonical name from FQDN with "." trimmed. func canonicalName(fqdn string) string { q := strings.TrimSpace(fqdn) @@ -916,18 +983,6 @@ func (p *prog) selfUninstallCoolOfPeriod() { p.selfUninstallMu.Unlock() } -// performLeakingQuery performs necessary works to leak queries to OS resolver. -func (p *prog) performLeakingQuery() { - mainLog.Load().Warn().Msg("leaking query to OS resolver") - // Signal dns watchers to stop, so changes made below won't be reverted. - p.leakingQuery.Store(true) - p.resetDNS() - ns := ctrld.InitializeOsResolver() - mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) - p.dnsWg.Wait() - p.setDNS() -} - // forceFetchingAPI sends signal to force syncing API config if run in cd mode, // and the domain == "cdUID.verify.controld.com" func (p *prog) forceFetchingAPI(domain string) { @@ -1056,7 +1111,16 @@ func isLanHostnameQuery(m *dns.Msg) bool { name := strings.TrimSuffix(q.Name, ".") return !strings.Contains(name, ".") || strings.HasSuffix(name, ".domain") || - strings.HasSuffix(name, ".lan") + strings.HasSuffix(name, ".lan") || + strings.HasSuffix(name, ".local") +} + +// isSrvLookup reports whether DNS message is a SRV query. +func isSrvLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + return m.Question[0].Qtype == dns.TypeSRV } // isWanClient reports whether the input is a WAN address. @@ -1089,3 +1153,406 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M answer.SetReply(m) return answer } + +// FlushDNSCache flushes the DNS cache on macOS. +func FlushDNSCache() error { + // if not macOS, return + if runtime.GOOS != "darwin" { + return nil + } + + // Flush the DNS cache via mDNSResponder. + // This is typically needed on modern macOS systems. + if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil { + return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out)) + } + + // Optionally, flush the directory services cache. + if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil { + return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out)) + } + + return nil +} + +// monitorNetworkChanges starts monitoring for network interface changes +func (p *prog) monitorNetworkChanges(ctx context.Context) error { + mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating network monitor: %w", err) + } + + mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + // Get map of valid interfaces + validIfaces := validInterfacesMap() + + isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) + + mainLog.Load().Debug(). + Interface("old_state", delta.Old). + Interface("new_state", delta.New). + Bool("is_major_change", isMajorChange). + Msg("Network change detected") + + changed := false + activeInterfaceExists := false + var changeIPs []netip.Prefix + // Check each valid interface for changes + for ifaceName := range validIfaces { + oldIface, oldExists := delta.Old.Interface[ifaceName] + newIface, newExists := delta.New.Interface[ifaceName] + if !newExists { + continue + } + + oldIPs := delta.Old.InterfaceIPs[ifaceName] + newIPs := delta.New.InterfaceIPs[ifaceName] + + // if a valid interface did not exist in old + // check that its up and has usable IPs + if !oldExists { + // The interface is new (was not present in the old state). + usableNewIPs := filterUsableIPs(newIPs) + if newIface.IsUp() && len(usableNewIPs) > 0 { + changed = true + changeIPs = usableNewIPs + mainLog.Load().Debug(). + Str("interface", ifaceName). + Interface("new_ips", usableNewIPs). + Msg("Interface newly appeared (was not present in old state)") + break + } + continue + } + + // Filter new IPs to only those that are usable. + usableNewIPs := filterUsableIPs(newIPs) + + // Check if interface is up and has usable IPs. + if newIface.IsUp() && len(usableNewIPs) > 0 { + activeInterfaceExists = true + } + + // Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs). + if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) { + if newIface.IsUp() && len(usableNewIPs) > 0 { + changed = true + changeIPs = usableNewIPs + mainLog.Load().Debug(). + Str("interface", ifaceName). + Interface("old_ips", oldIPs). + Interface("new_ips", usableNewIPs). + Msg("Interface state or IPs changed") + break + } + } + } + + if !changed { + mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + return + } + + if !activeInterfaceExists { + mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + return + } + + // Get IPs from default route interface in new state + selfIP := defaultRouteIP() + var ipv6 string + + if delta.New.DefaultRouteInterface != "" { + mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) + for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { + ipAddr, _ := netip.ParsePrefix(ip.String()) + addr := ipAddr.Addr() + if selfIP == "" && addr.Is4() { + mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } + } + if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + ipv6 = addr.String() + } + } + } else { + // If no default route interface is set yet, use the changed IPs + mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) + for _, ip := range changeIPs { + ipAddr, _ := netip.ParsePrefix(ip.String()) + addr := ipAddr.Addr() + if selfIP == "" && addr.Is4() { + mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } + } + if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + ipv6 = addr.String() + } + } + } + + if ip := net.ParseIP(selfIP); ip != nil { + ctrld.SetDefaultLocalIPv4(ip) + if !isMobile() && p.ciTable != nil { + p.ciTable.SetSelfIP(selfIP) + } + } + if ip := net.ParseIP(ipv6); ip != nil { + ctrld.SetDefaultLocalIPv6(ip) + } + mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + + if p.recoverOnUpstreamFailure() { + p.handleRecovery(RecoveryReasonNetworkChange) + } + }) + + mon.Start() + mainLog.Load().Debug().Msg("Network monitor started") + return nil +} + +// interfaceStatesEqual compares two interface states +func interfaceStatesEqual(a, b *netmon.Interface) bool { + if a == nil || b == nil { + return a == b + } + return a.IsUp() == b.IsUp() +} + +// filterUsableIPs is a helper that returns only "usable" IP prefixes, +// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses. +func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix { + var usable []netip.Prefix + for _, p := range prefixes { + addr := p.Addr() + if addr.IsLinkLocalUnicast() || + addr.IsLoopback() || + addr.IsMulticast() || + addr.IsUnspecified() || + addr.IsLinkLocalMulticast() || + (addr.Is4() && addr.String() == "255.255.255.255") || + tsaddr.CGNATRange().Contains(addr) { + continue + } + usable = append(usable, p) + } + return usable +} + +// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses. +func interfaceIPsEqual(a, b []netip.Prefix) bool { + aUsable := filterUsableIPs(a) + bUsable := filterUsableIPs(b) + if len(aUsable) != len(bUsable) { + return false + } + + aMap := make(map[string]bool) + for _, ip := range aUsable { + aMap[ip.String()] = true + } + for _, ip := range bUsable { + if !aMap[ip.String()] { + return false + } + } + return true +} + +// checkUpstreamOnce sends a test query to the specified upstream. +// Returns nil if the upstream responds successfully. +func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { + mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + return err + } + + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + + timeout := 1000 * time.Millisecond + if uc.Timeout > 0 { + timeout = time.Millisecond * time.Duration(uc.Timeout) + } + mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + uc.ReBootstrap() + mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + + start := time.Now() + _, err = resolver.Resolve(ctx, msg) + duration := time.Since(start) + + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + } else { + mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + } + return err +} + +// handleRecovery performs a unified recovery by removing DNS settings, +// canceling existing recovery checks for network changes, but coalescing duplicate +// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), +// and then re-applying the DNS settings. +func (p *prog) handleRecovery(reason RecoveryReason) { + mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + + // For network changes, cancel any existing recovery check because the network state has changed. + if reason == RecoveryReasonNetworkChange { + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.recoveryCancel() + p.recoveryCancel = nil + } + p.recoveryCancelMu.Unlock() + } else { + // For upstream failures, if a recovery is already in progress, do nothing new. + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + p.recoveryCancelMu.Unlock() + return + } + p.recoveryCancelMu.Unlock() + } + + // Create a new recovery context without a fixed timeout. + p.recoveryCancelMu.Lock() + recoveryCtx, cancel := context.WithCancel(context.Background()) + p.recoveryCancel = cancel + p.recoveryCancelMu.Unlock() + + // Immediately remove our DNS settings from the interface. + // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface + p.recoveryRunning.Store(true) + p.resetDNS() + + // For an OS failure, reinitialize OS resolver nameservers immediately. + if reason == RecoveryReasonOSFailure { + mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Build upstream map based on the recovery reason. + upstreams := p.buildRecoveryUpstreams(reason) + + // Wait indefinitely until one of the upstreams recovers. + recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) + if err != nil { + mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + return + } + mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + + // reset the upstream failure count and down state + p.um.reset(recovered) + + // For network changes we also reinitialize the OS resolver. + if reason == RecoveryReasonNetworkChange { + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Apply our DNS settings back and log the interface state. + p.setDNS() + p.logInterfacesState() + + // allow watchdogs to put the listener back on the interface if its changed for any reason + p.recoveryRunning.Store(false) + + // Clear the recovery cancellation for a clean slate. + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() +} + +// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers. +// It returns the name of the recovered upstream or an error if the check times out. +func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) { + recoveredCh := make(chan string, 1) + var wg sync.WaitGroup + + mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) + + for name, uc := range upstreams { + wg.Add(1) + go func(name string, uc *ctrld.UpstreamConfig) { + defer wg.Done() + mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) + for { + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) + return + default: + // checkUpstreamOnce will reset any failure counters on success. + if err := p.checkUpstreamOnce(name, uc); err == nil { + mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) + select { + case recoveredCh <- name: + mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) + default: + mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") + } + return + } + mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) + time.Sleep(checkUpstreamBackoffSleep) + } + } + }(name, uc) + } + + var recovered string + select { + case recovered = <-recoveredCh: + case <-ctx.Done(): + return "", ctx.Err() + } + wg.Wait() + return recovered, nil +} + +// buildRecoveryUpstreams constructs the map of upstream configurations to test. +// For OS failures we supply the manual OS resolver upstream configuration. +// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS). +func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig { + upstreams := make(map[string]*ctrld.UpstreamConfig) + switch reason { + case RecoveryReasonOSFailure: + upstreams[upstreamOS] = osUpstreamConfig + case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure: + // Use all configured upstreams except any OS type. + for k, uc := range p.cfg.Upstream { + if uc.Type != ctrld.ResolverTypeOS { + upstreams[upstreamPrefix+k] = uc + } + } + } + return upstreams +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 877fb71..eae3dfa 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -75,6 +75,7 @@ func Test_canonicalName(t *testing.T) { func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) + cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false) p := &prog{cfg: cfg} p.um = newUpstreamMonitor(p.cfg) p.lanLoopGuard = newLoopGuard() @@ -365,6 +366,9 @@ func Test_isLanHostnameQuery(t *testing.T) { {"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false}, {"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false}, {"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false}, + {".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true}, + {".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true}, + {".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true}, } for _, tc := range tests { tc := tc @@ -414,6 +418,26 @@ func Test_isPrivatePtrLookup(t *testing.T) { } } +func Test_isSrvLookup(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isSrvLookup bool + }{ + {"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, + {"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isSrvLookup(tc.msg); tc.isSrvLookup != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got) + } + }) + } +} + func Test_isWanClient(t *testing.T) { tests := []struct { name string diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go new file mode 100644 index 0000000..339d984 --- /dev/null +++ b/cmd/cli/log_writer.go @@ -0,0 +1,186 @@ +package cli + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + logWriterSize = 1024 * 1024 * 5 // 5 MB + logWriterSmallSize = 1024 * 1024 * 1 // 1 MB + logWriterInitialSize = 32 * 1024 // 32 KB + logSentInterval = time.Minute + logStartEndMarker = "\n\n=== INIT_END ===\n\n" + logLogEndMarker = "\n\n=== LOG_END ===\n\n" + logWarnEndMarker = "\n\n=== WARN_END ===\n\n" +) + +type logViewResponse struct { + Data string `json:"data"` +} + +type logSentResponse struct { + Size int64 `json:"size"` + Error string `json:"error"` +} + +type logReader struct { + r io.ReadCloser + size int64 +} + +// logWriter is an internal buffer to keep track of runtime log when no logging is enabled. +type logWriter struct { + mu sync.Mutex + buf bytes.Buffer + size int +} + +// newLogWriter creates an internal log writer. +func newLogWriter() *logWriter { + return newLogWriterWithSize(logWriterSize) +} + +// newSmallLogWriter creates an internal log writer with small buffer size. +func newSmallLogWriter() *logWriter { + return newLogWriterWithSize(logWriterSmallSize) +} + +// newLogWriterWithSize creates an internal log writer with a given buffer size. +func newLogWriterWithSize(size int) *logWriter { + lw := &logWriter{size: size} + return lw +} + +func (lw *logWriter) Write(p []byte) (int, error) { + lw.mu.Lock() + defer lw.mu.Unlock() + + // If writing p causes overflows, discard old data. + if lw.buf.Len()+len(p) > lw.size { + buf := lw.buf.Bytes() + buf = buf[:logWriterInitialSize] + if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 { + buf = buf[:idx] + } + lw.buf.Reset() + lw.buf.Write(buf) + lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated. + } + // If p is bigger than buffer size, truncate p by half until its size is smaller. + for len(p)+lw.buf.Len() > lw.size { + p = p[len(p)/2:] + } + return lw.buf.Write(p) +} + +// initInternalLogging performs internal logging if there's no log enabled. +func (p *prog) initInternalLogging(writers []io.Writer) { + if !p.needInternalLogging() { + return + } + p.initInternalLogWriterOnce.Do(func() { + mainLog.Load().Notice().Msg("internal logging enabled") + p.internalLogWriter = newLogWriter() + p.internalLogSent = time.Now().Add(-logSentInterval) + p.internalWarnLogWriter = newSmallLogWriter() + }) + p.mu.Lock() + lw := p.internalLogWriter + wlw := p.internalWarnLogWriter + p.mu.Unlock() + // If ctrld was run without explicit verbose level, + // run the internal logging at debug level, so we could + // have enough information for troubleshooting. + if verbose == 0 { + for i := range writers { + w := &zerolog.FilteredLevelWriter{ + Writer: zerolog.LevelWriterAdapter{Writer: writers[i]}, + Level: zerolog.NoticeLevel, + } + writers[i] = w + } + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + writers = append(writers, lw) + writers = append(writers, &zerolog.FilteredLevelWriter{ + Writer: zerolog.LevelWriterAdapter{Writer: wlw}, + Level: zerolog.WarnLevel, + }) + multi := zerolog.MultiLevelWriter(writers...) + l := mainLog.Load().Output(multi).With().Logger() + mainLog.Store(&l) + ctrld.ProxyLogger.Store(&l) +} + +// needInternalLogging reports whether prog needs to run internal logging. +func (p *prog) needInternalLogging() bool { + // Do not run in non-cd mode. + if cdUID == "" { + return false + } + // Do not run if there's already log file. + if p.cfg.Service.LogPath != "" { + return false + } + return true +} + +func (p *prog) logReader() (*logReader, error) { + if p.needInternalLogging() { + p.mu.Lock() + lw := p.internalLogWriter + wlw := p.internalWarnLogWriter + p.mu.Unlock() + if lw == nil { + return nil, errors.New("nil internal log writer") + } + if wlw == nil { + return nil, errors.New("nil internal warn log writer") + } + // Normal log content. + lw.mu.Lock() + lwReader := bytes.NewReader(lw.buf.Bytes()) + lwSize := lw.buf.Len() + lw.mu.Unlock() + // Warn log content. + wlw.mu.Lock() + wlwReader := bytes.NewReader(wlw.buf.Bytes()) + wlwSize := wlw.buf.Len() + wlw.mu.Unlock() + reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader) + lr := &logReader{r: io.NopCloser(reader)} + lr.size = int64(lwSize + wlwSize) + if lr.size == 0 { + return nil, errors.New("internal log is empty") + } + return lr, nil + } + if p.cfg.Service.LogPath == "" { + return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil + } + f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath)) + if err != nil { + return nil, err + } + lr := &logReader{r: f} + if st, err := f.Stat(); err == nil { + lr.size = st.Size() + } else { + return nil, fmt.Errorf("f.Stat: %w", err) + } + if lr.size == 0 { + return nil, errors.New("log file is empty") + } + return lr, nil +} diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go new file mode 100644 index 0000000..bd48785 --- /dev/null +++ b/cmd/cli/log_writer_test.go @@ -0,0 +1,49 @@ +package cli + +import ( + "strings" + "sync" + "testing" +) + +func Test_logWriter_Write(t *testing.T) { + size := 64 * 1024 + lw := &logWriter{size: size} + lw.buf.Grow(lw.size) + data := strings.Repeat("A", size) + lw.Write([]byte(data)) + if lw.buf.String() != data { + t.Fatalf("unexpected buf content: %v", lw.buf.String()) + } + newData := "B" + halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker + lw.Write([]byte(newData)) + if lw.buf.String() != halfData+newData { + t.Fatalf("unexpected new buf content: %v", lw.buf.String()) + } + + bigData := strings.Repeat("B", 256*1024) + expected := halfData + strings.Repeat("B", 16*1024) + lw.Write([]byte(bigData)) + if lw.buf.String() != expected { + t.Fatalf("unexpected big buf content: %v", lw.buf.String()) + } +} + +func Test_logWriter_ConcurrentWrite(t *testing.T) { + size := 64 * 1024 + lw := &logWriter{size: size} + n := 10 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + lw.Write([]byte(strings.Repeat("A", i))) + }() + } + wg.Wait() + if lw.buf.Len() > lw.size { + t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String()) + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index bafcde1..73a601d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -101,9 +101,23 @@ func initConsoleLogging() { } // initLogging initializes global logging setup. -func initLogging() { +func initLogging() []io.Writer { zerolog.TimeFieldFormat = time.RFC3339 + ".000" - initLoggingWithBackup(true) + return initLoggingWithBackup(true) +} + +// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded +// to be used for all interactive commands. +// +// Current log file config will also be ignored. +func initInteractiveLogging() { + old := cfg.Service.LogPath + cfg.Service.LogPath = "" + zerolog.TimeFieldFormat = time.RFC3339 + ".000" + initLoggingWithBackup(false) + cfg.Service.LogPath = old + l := zerolog.New(io.Discard) + ctrld.ProxyLogger.Store(&l) } // initLoggingWithBackup initializes log setup base on current config. @@ -112,8 +126,8 @@ func initLogging() { // This is only used in runCmd for special handling in case of logging config // change in cd mode. Without special reason, the caller should use initLogging // wrapper instead of calling this function directly. -func initLoggingWithBackup(doBackup bool) { - writers := []io.Writer{io.Discard} +func initLoggingWithBackup(doBackup bool) []io.Writer { + var writers []io.Writer if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" { // Create parent directory if necessary. if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil { @@ -151,21 +165,22 @@ func initLoggingWithBackup(doBackup bool) { switch { case silent: zerolog.SetGlobalLevel(zerolog.NoLevel) - return + return writers case verbose == 1: logLevel = "info" case verbose > 1: logLevel = "debug" } if logLevel == "" { - return + return writers } level, err := zerolog.ParseLevel(logLevel) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not set log level") - return + return writers } zerolog.SetGlobalLevel(level) + return writers } func initCache() { diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index ece1862..6233161 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -9,17 +9,18 @@ import ( "strings" ) -func patchNetIfaceName(iface *net.Interface) error { +func patchNetIfaceName(iface *net.Interface) (bool, error) { b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output() if err != nil { - return err + return false, err } + patched := false if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { + patched = true iface.Name = name - mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface") } - return nil + return patched, nil } func networkServiceName(ifaceName string, r io.Reader) string { diff --git a/cmd/cli/net_linux.go b/cmd/cli/net_linux.go new file mode 100644 index 0000000..ea17d3d --- /dev/null +++ b/cmd/cli/net_linux.go @@ -0,0 +1,52 @@ +package cli + +import ( + "net" + "net/netip" + "os" + "strings" + + "tailscale.com/net/netmon" +) + +func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } + +// validInterface reports whether the *net.Interface is a valid one. +// Only non-virtual interfaces are considered valid. +func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { + _, ok := validIfacesMap[iface.Name] + return ok +} + +// validInterfacesMap returns a set containing non virtual interfaces. +func validInterfacesMap() map[string]struct{} { + m := make(map[string]struct{}) + vis := virtualInterfaces() + netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { + if _, existed := vis[i.Name]; existed { + return + } + m[i.Name] = struct{}{} + }) + // Fallback to default route interface if found nothing. + if len(m) == 0 { + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + return m + } + m[defaultRoute.InterfaceName] = struct{}{} + } + return m +} + +// virtualInterfaces returns a map of virtual interfaces on current machine. +func virtualInterfaces() map[string]struct{} { + s := make(map[string]struct{}) + entries, _ := os.ReadDir("/sys/devices/virtual/net") + for _, entry := range entries { + if entry.IsDir() { + s[strings.TrimSpace(entry.Name())] = struct{}{} + } + } + return s +} diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index 5a66e82..f347278 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -1,11 +1,22 @@ -//go:build !darwin && !windows +//go:build !darwin && !windows && !linux package cli -import "net" +import ( + "net" -func patchNetIfaceName(iface *net.Interface) error { return nil } + "tailscale.com/net/netmon" +) + +func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } -func validInterfacesMap() map[string]struct{} { return nil } +// validInterfacesMap returns a set containing only default route interfaces. +func validInterfacesMap() map[string]struct{} { + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + return nil + } + return map[string]struct{}{defaultRoute.InterfaceName: {}} +} diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index dc13b08..bed06b5 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -1,14 +1,20 @@ package cli import ( - "bufio" - "bytes" + "io" + "log" "net" - "strings" + "os" + + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" + "github.com/microsoft/wmi/pkg/hardware/network/netadapter" ) -func patchNetIfaceName(iface *net.Interface) error { - return nil +func patchNetIfaceName(iface *net.Interface) (bool, error) { + return true, nil } // validInterface reports whether the *net.Interface is a valid one. @@ -20,15 +26,68 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo // validInterfacesMap returns a set of all physical interfaces. func validInterfacesMap() map[string]struct{} { - out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name") - if err != nil { - return nil - } m := make(map[string]struct{}) - scanner := bufio.NewScanner(bytes.NewReader(out)) - for scanner.Scan() { - ifaceName := strings.TrimSpace(scanner.Text()) + for _, ifaceName := range validInterfaces() { m[ifaceName] = struct{}{} } return m } + +// validInterfaces returns a list of all physical interfaces. +func validInterfaces() []string { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("MSFT_NetAdapter") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") + return nil + } + var adapters []string + for _, i := range instances { + adapter, err := netadapter.NewNetworkAdapter(i) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") + continue + } + + name, err := adapter.GetPropertyName() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get interface name") + continue + } + + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) + // + // "Indicates if a connector is present on the network adapter. This value is set to TRUE + // if this is a physical adapter or FALSE if this is not a physical adapter." + physical, err := adapter.GetPropertyConnectorPresent() + if err != nil { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property") + continue + } + if !physical { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter") + continue + } + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() + if err != nil { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property") + continue + } + if !hardware { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface") + continue + } + + adapters = append(adapters, name) + } + return adapters +} diff --git a/cmd/cli/net_windows_test.go b/cmd/cli/net_windows_test.go new file mode 100644 index 0000000..a15f119 --- /dev/null +++ b/cmd/cli/net_windows_test.go @@ -0,0 +1,42 @@ +package cli + +import ( + "bufio" + "bytes" + "slices" + "strings" + "testing" + "time" +) + +func Test_validInterfaces(t *testing.T) { + verbose = 3 + initConsoleLogging() + start := time.Now() + ifaces := validInterfaces() + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + ifacesPowershell := validInterfacesPowershell() + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + slices.Sort(ifaces) + slices.Sort(ifacesPowershell) + if !slices.Equal(ifaces, ifacesPowershell) { + t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces) + } +} + +func validInterfacesPowershell() []string { + out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name") + if err != nil { + return nil + } + var res []string + scanner := bufio.NewScanner(bytes.NewReader(out)) + for scanner.Scan() { + ifaceName := strings.TrimSpace(scanner.Text()) + res = append(res, ifaceName) + } + return res +} diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index f319056..841be76 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { - if ns := savedStaticNameservers(iface); len(ns) > 0 { - if err := setDNS(iface, ns); err == nil { - return nil - } - } cmd := "networksetup" args := []string{"-setdnsservers", iface.Name, "empty"} if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { @@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + if ns := savedStaticNameservers(iface); len(ns) > 0 { + err = setDNS(iface, ns) + } + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index bddffca..72da485 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index ade5881..3f815e8 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) { }) } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(iface *net.Interface) []string { for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} { if ns := fn(iface.Name); len(ns) > 0 { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index b9412b6..990cc57 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -1,23 +1,27 @@ package cli import ( + "bytes" "errors" "fmt" "net" + "net/netip" "os" + "os/exec" "slices" - "strconv" "strings" "sync" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) const ( - v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` - v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` ) var ( @@ -30,14 +34,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e return setDNS(iface, nameservers) } -func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string { - nss := make([]string, 0, len(nameservers)) - for _, ns := range nameservers { - nss = append(nss, strconv.Quote(ns)) - } - return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ",")) -} - // setDNS sets the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { @@ -46,7 +42,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { setDNSOnce.Do(func() { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) oldForwardersContent, _ := os.ReadFile(file) hasLocalIPv6Listener := needLocalIPv6Listener() @@ -65,9 +61,36 @@ func setDNS(iface *net.Interface, nameservers []string) error { } } }) - out, err := powershell(setDnsPowershellCmd(iface, nameservers)) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("setDNS: %w", err) + } + var ( + serversV4 []netip.Addr + serversV6 []netip.Addr + ) + for _, ns := range nameservers { + if addr, err := netip.ParseAddr(ns); err == nil { + if addr.Is4() { + serversV4 = append(serversV4, addr) + } else { + serversV6 = append(serversV6, addr) + } + } + } + + if len(serversV4) == 0 && len(serversV6) == 0 { + return errors.New("invalid DNS nameservers") + } + if len(serversV4) > 0 { + if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil { + return fmt.Errorf("could not set DNS ipv4: %w", err) + } + } + if len(serversV6) > 0 { + if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil { + return fmt.Errorf("could not set DNS ipv6: %w", err) + } } return nil } @@ -81,7 +104,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { // See corresponding comment in setDNS. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) content, err := os.ReadFile(file) if err != nil { @@ -96,14 +119,23 @@ func resetDNS(iface *net.Interface) error { } }) - // Restoring DHCP settings. - cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index) - out, err := powershell(cmd) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("resetDNS: %w", err) } + // Restoring DHCP settings. + if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv4: %w", err) + } + if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv6: %w", err) + } + return nil +} - // If there's static DNS saved, restoring it. +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { if nss := savedStaticNameservers(iface); len(nss) > 0 { v4ns := make([]string, 0, 2) v6ns := make([]string, 0, 2) @@ -120,12 +152,14 @@ func resetDNS(iface *net.Interface) error { continue } mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) - if err := setDNS(iface, ns); err != nil { + err = setDNS(iface, ns) + + if err != nil { return err } } } - return nil + return err } func currentDNS(iface *net.Interface) []string { @@ -150,25 +184,31 @@ func currentDNS(iface *net.Interface) []string { func currentStaticDNS(iface *net.Interface) ([]string, error) { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return nil, err + return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err) } guid, err := luid.GUID() if err != nil { - return nil, err + return nil, fmt.Errorf("luid.GUID: %w", err) } var ns []string for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { - interfaceKeyPath := path + guid.String() found := false + interfaceKeyPath := path + guid.String() + k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err) + } for _, key := range []string{"NameServer", "ProfileNameServer"} { if found { continue } - cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) - out, err := powershell(cmd) - if err == nil && len(out) > 0 { + value, _, err := k.GetStringValue(key) + if err != nil && !errors.Is(err, registry.ErrNotExist) { + return nil, fmt.Errorf("%s: %w", key, err) + } + if len(value) > 0 { found = true - for _, e := range strings.Split(string(out), ",") { + for _, e := range strings.Split(value, ",") { ns = append(ns, strings.TrimRight(e, "\x00")) } } @@ -216,3 +256,9 @@ func removeDnsServerForwarders(nameservers []string) error { } return nil } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err +} diff --git a/cmd/cli/os_windows_test.go b/cmd/cli/os_windows_test.go new file mode 100644 index 0000000..40be5ed --- /dev/null +++ b/cmd/cli/os_windows_test.go @@ -0,0 +1,68 @@ +package cli + +import ( + "fmt" + "net" + "slices" + "strings" + "testing" + "time" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func Test_currentStaticDNS(t *testing.T) { + iface, err := net.InterfaceByName(defaultIfaceName()) + if err != nil { + t.Fatal(err) + } + start := time.Now() + staticDns, err := currentStaticDNS(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + staticDnsPowershell, err := currentStaticDnsPowershell(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + slices.Sort(staticDns) + slices.Sort(staticDnsPowershell) + if !slices.Equal(staticDns, staticDnsPowershell) { + t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns) + } +} + +func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) { + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return nil, err + } + guid, err := luid.GUID() + if err != nil { + return nil, err + } + var ns []string + for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} { + interfaceKeyPath := path + guid.String() + found := false + for _, key := range []string{"NameServer", "ProfileNameServer"} { + if found { + continue + } + cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) + out, err := powershell(cmd) + if err == nil && len(out) > 0 { + found = true + for _, e := range strings.Split(string(out), ",") { + ns = append(ns, strings.TrimRight(e, "\x00")) + } + } + } + } + return ns, nil +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 07c7677..8a86bcf 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -45,6 +45,18 @@ const ( upstreamOS = upstreamPrefix + "os" upstreamPrivate = upstreamPrefix + "private" dnsWatchdogDefaultInterval = 20 * time.Second + ctrldServiceName = "ctrld" +) + +// RecoveryReason provides context for why we are waiting for recovery. +// recovery involves removing the listener IP from the interface and +// waiting for the upstreams to work before returning +type RecoveryReason int + +const ( + RecoveryReasonNetworkChange RecoveryReason = iota + RecoveryReasonRegularFailure + RecoveryReasonOSFailure ) // ControlSocketName returns name for control unix socket. @@ -61,8 +73,9 @@ var logf = func(format string, args ...any) { } var svcConfig = &service.Config{ - Name: "ctrld", + Name: ctrldServiceName, DisplayName: "Control-D Helper Service", + Description: "A highly configurable, multi-protocol DNS forwarding proxy", Option: service.KeyValue{}, } @@ -84,21 +97,29 @@ type prog struct { dnsWg sync.WaitGroup dnsWatcherClosedOnce sync.Once dnsWatcherStopCh chan struct{} + rc *controld.ResolverConfig - cfg *ctrld.Config - localUpstreams []string - ptrNameservers []string - appCallback *AppCallback - cache dnscache.Cacher - cacheFlushDomainsMap map[string]struct{} - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrLoopGuard *loopGuard - lanLoopGuard *loopGuard - metricsQueryStats atomic.Bool - queryFromSelfMap sync.Map + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + cacheFlushDomainsMap map[string]struct{} + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard + metricsQueryStats atomic.Bool + queryFromSelfMap sync.Map + initInternalLogWriterOnce sync.Once + internalLogWriter *logWriter + internalWarnLogWriter *logWriter + internalLogSent time.Time + runningIface string + requiredMultiNICsConfig bool + adDomain string selfUninstallMu sync.Mutex refusedQueryCount int @@ -108,9 +129,9 @@ type prog struct { loopMu sync.Mutex loop map[string]bool - leakingQueryMu sync.Mutex - leakingQueryWasRun bool - leakingQuery atomic.Bool + recoveryCancelMu sync.Mutex + recoveryCancel context.CancelFunc + recoveryRunning atomic.Bool started chan struct{} onStartedDone chan struct{} @@ -162,11 +183,13 @@ func (p *prog) runWait() { if newCfg == nil { newCfg = &ctrld.Config{} + confFile := v.ConfigFileUsed() v := viper.NewWithOptions(viper.KeyDelimiter("::")) ctrld.InitConfig(v, "ctrld") if configPath != "" { - v.SetConfigFile(configPath) + confFile = configPath } + v.SetConfigFile(confFile) if err := v.ReadInConfig(); err != nil { logger.Err(err).Msg("could not read new config") waitOldRunDone() @@ -178,10 +201,14 @@ func (p *prog) runWait() { continue } if cdUID != "" { - if err := processCDFlags(newCfg); err != nil { + if rc, err := processCDFlags(newCfg); err != nil { logger.Err(err).Msg("could not fetch ControlD config") waitOldRunDone() continue + } else { + p.mu.Lock() + p.rc = rc + p.mu.Unlock() } } } @@ -233,6 +260,11 @@ func (p *prog) runWait() { } func (p *prog) preRun() { + if iface == "auto" { + iface = defaultIfaceName() + p.requiredMultiNICsConfig = requiredMultiNICsConfig() + } + p.runningIface = iface if runtime.GOOS == "darwin" { p.onStopped = append(p.onStopped, func() { if !service.Interactive() { @@ -245,11 +277,12 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { p.resetDNS() - ns := ctrld.InitializeOsResolver() + ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} close(p.csSetDnsDone) + p.logInterfacesState() } } @@ -288,7 +321,24 @@ func (p *prog) apiConfigReload() { cdDeactivationPin.Store(defaultDeactivationPin) } - if resolverConfig.Ctrld.CustomConfig == "" { + p.mu.Lock() + rc := p.rc + p.rc = resolverConfig + p.mu.Unlock() + noCustomConfig := resolverConfig.Ctrld.CustomConfig == "" + noExcludeListChanged := true + if rc != nil { + slices.Sort(rc.Exclude) + slices.Sort(resolverConfig.Exclude) + noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude) + } + if noCustomConfig && noExcludeListChanged { + return + } + + if noCustomConfig && !noExcludeListChanged { + logger.Debug().Msg("exclude list changes detected, reloading...") + p.apiReloadCh <- nil return } @@ -401,6 +451,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } } + if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() { + mainLog.Load().Debug().Msgf("active directory domain: %s", domain) + p.adDomain = domain + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -429,12 +483,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) - } + p.setupClientInfoDiscover(defaultRouteIP()) } // context for managing spawn goroutines. @@ -446,8 +495,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { wg.Add(1) go func() { defer wg.Done() - p.ciTable.Init() - p.ciTable.RefreshLoop(ctx) + p.runClientInfoDiscover(ctx) }() go p.watchLinkState(ctx) } @@ -463,9 +511,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { + if err := p.serveDNS(ctx, listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } + mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { @@ -511,16 +560,33 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { // Stop writing log to unix socket. consoleWriter.Out = os.Stdout - initLoggingWithBackup(false) + logWriters := initLoggingWithBackup(false) if p.logConn != nil { _ = p.logConn.Close() } go p.apiConfigReload() p.postRun() + p.initInternalLogging(logWriters) } wg.Wait() } +// setupClientInfoDiscover performs necessary works for running client info discover. +func (p *prog) setupClientInfoDiscover(selfIP string) { + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) + } +} + +// runClientInfoDiscover runs the client info discover. +func (p *prog) runClientInfoDiscover(ctx context.Context) { + p.ciTable.Init() + p.ciTable.RefreshLoop(ctx) +} + // metricsEnabled reports whether prometheus exporter is enabled/disabled. func (p *prog) metricsEnabled() bool { return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != "" @@ -529,7 +595,9 @@ func (p *prog) metricsEnabled() bool { func (p *prog) Stop(s service.Service) error { p.stopDnsWatchers() mainLog.Load().Debug().Msg("dns watchers stopped") - mainLog.Load().Info().Msg("Service stopped") + defer func() { + mainLog.Load().Info().Msg("Service stopped") + }() close(p.stopCh) if err := p.deAllocateIP(); err != nil { mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") @@ -579,27 +647,42 @@ func (p *prog) setDNS() { if cfg.Listener == nil { return } - if iface == "" { + if p.runningIface == "" { return } - runningIface := iface + // allIfaces tracks whether we should set DNS for all physical interfaces. - allIfaces := false - if runningIface == "auto" { - runningIface = defaultIfaceName() - // If runningIface is "auto", it means user does not specify "--iface" flag. - // In this case, ctrld has to set DNS for all physical interfaces, so - // thing will still work when user switch from one to the other. - allIfaces = requiredMultiNICsConfig() - } + allIfaces := p.requiredMultiNICsConfig lc := cfg.FirstListener() if lc == nil { return } - logger := mainLog.Load().With().Str("iface", runningIface).Logger() - netIface, err := netInterface(runningIface) - if err != nil { - logger.Error().Err(err).Msg("could not get interface") + logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + + const maxDNSRetryAttempts = 3 + const retryDelay = 1 * time.Second + var netIface *net.Interface + var err error + for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ { + netIface, err = netInterface(p.runningIface) + if err == nil { + break + } + if attempt < maxDNSRetryAttempts { + // Try to find a different working interface + newIface := findWorkingInterface(p.runningIface) + if newIface != p.runningIface { + p.runningIface = newIface + logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger.Info().Msg("switched to new interface") + continue + } + + logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...") + time.Sleep(retryDelay) + continue + } + logger.Error().Err(err).Msg("could not get interface after all attempts") return } if err := setupNetworkManager(); err != nil { @@ -686,12 +769,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces if !requiredMultiNICsConfig() { return } + logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + logger.Debug().Msg("start DNS settings watchdog") - mainLog.Load().Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) ticker := time.NewTicker(p.dnsWatchdogDuration()) - logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + for { select { case <-p.dnsWatcherStopCh: @@ -700,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQuery.Load() { + if p.recoveryRunning.Load() { return } if dnsChanged(iface, ns) { @@ -726,22 +810,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces } func (p *prog) resetDNS() { - if iface == "" { + if p.runningIface == "" { + mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") return } - runningIface := iface - allIfaces := false - if runningIface == "auto" { - runningIface = defaultIfaceName() - // See corresponding comments in (*prog).setDNS function. - allIfaces = requiredMultiNICsConfig() - } - logger := mainLog.Load().With().Str("iface", runningIface).Logger() - netIface, err := netInterface(runningIface) + // See corresponding comments in (*prog).setDNS function. + allIfaces := p.requiredMultiNICsConfig + logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return } + if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return @@ -757,16 +838,157 @@ func (p *prog) resetDNS() { } } -// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams. -func (p *prog) leakOnUpstreamFailure() bool { - if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { - return *ptr +func (p *prog) logInterfacesState() { + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + addrs, err := i.Addrs() + if err != nil { + mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") + } + nss, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") + } + if len(nss) == 0 { + nss = currentDNS(i) + } + mainLog.Load().Debug(). + Any("addrs", addrs). + Strs("nameservers", nss). + Int("index", i.Index). + Msgf("interface state: %s", i.Name) + return nil + }) +} + +// findWorkingInterface looks for a network interface with a valid IP configuration +func findWorkingInterface(currentIface string) string { + // Helper to check if IP is valid (not link-local) + isValidIP := func(ip net.IP) bool { + return ip != nil && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !ip.IsLoopback() && + !ip.IsUnspecified() } - // Default is false on routers, since this leaking is only useful for devices that move between networks. - if router.Name() != "" { + + // Helper to check if interface has valid IP configuration + hasValidIPConfig := func(iface *net.Interface) bool { + if iface == nil || iface.Flags&net.FlagUp == 0 { + return false + } + + addrs, err := iface.Addrs() + if err != nil { + mainLog.Load().Debug(). + Str("interface", iface.Name). + Err(err). + Msg("failed to get interface addresses") + return false + } + + for _, addr := range addrs { + // Check for IP network + if ipNet, ok := addr.(*net.IPNet); ok { + if isValidIP(ipNet.IP) { + return true + } + } + } return false } - return true + + // Get default route interface + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + mainLog.Load().Debug(). + Err(err). + Msg("failed to get default route") + } else { + mainLog.Load().Debug(). + Str("default_route_iface", defaultRoute.InterfaceName). + Msg("found default route") + } + + // Get all interfaces + ifaces, err := net.Interfaces() + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + return currentIface // Return current interface as fallback + } + + var firstWorkingIface string + var currentIfaceValid bool + + // Single pass through interfaces + for _, iface := range ifaces { + // Must be physical (has MAC address) + if len(iface.HardwareAddr) == 0 { + continue + } + // Skip interfaces that are: + // - Loopback + // - Not up + // - Point-to-point (like VPN tunnels) + if iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagPointToPoint != 0 { + continue + } + + if !hasValidIPConfig(&iface) { + continue + } + + // Found working physical interface + if err == nil && defaultRoute.InterfaceName == iface.Name { + // Found interface with default route - use it immediately + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", iface.Name). + Msg("switching to interface with default route") + return iface.Name + } + + // Keep track of first working interface as fallback + if firstWorkingIface == "" { + firstWorkingIface = iface.Name + } + + // Check if this is our current interface + if iface.Name == currentIface { + currentIfaceValid = true + } + } + + // Return interfaces in order of preference: + // 1. Current interface if it's still valid + if currentIfaceValid { + mainLog.Load().Debug(). + Str("interface", currentIface). + Msg("keeping current interface") + return currentIface + } + + // 2. First working interface found + if firstWorkingIface != "" { + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", firstWorkingIface). + Msg("switching to first working physical interface") + return firstWorkingIface + } + + // 3. Fall back to current interface if nothing else works + mainLog.Load().Warn(). + Str("current_iface", currentIface). + Msg("no working physical interface found, keeping current") + return currentIface +} + +// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure. +func (p *prog) recoverOnUpstreamFailure() bool { + // Default is false on routers, since this recovery flow is only useful for devices that move between networks. + return router.Name() == "" } func randomLocalIP() string { @@ -947,7 +1169,7 @@ func canBeLocalUpstream(addr string) bool { func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) { validIfacesMap := validInterfacesMap() netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { - // Skip loopback/virtual interface. + // Skip loopback/virtual/down interface. if i.IsLoopback() || len(i.HardwareAddr) == 0 { return } @@ -956,9 +1178,12 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. return } netIface := i.Interface - if err := patchNetIfaceName(netIface); err != nil { + if patched, err := patchNetIfaceName(netIface); err != nil { mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name") return + } else if !patched { + // The interface is not functional, skipping. + return } // Skip excluded interface. if netIface.Name == excludeIfaceName { @@ -1025,7 +1250,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string { func savedStaticNameservers(iface *net.Interface) []string { file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { - return strings.Split(string(data), ",") + saveValues := strings.Split(string(data), ",") + returnValues := []string{} + // check each one, if its in loopback range, remove it + for _, v := range saveValues { + if net.ParseIP(v).IsLoopback() { + continue + } + returnValues = append(returnValues, v) + } + return returnValues } return nil } @@ -1044,7 +1278,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool { // selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then. func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { - var uer *controld.UtilityErrorResponse + var uer *controld.ErrorResponse if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { p.stopDnsWatchers() diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6df7be6..0f3f731 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -3,11 +3,38 @@ package cli import ( "net" "net/netip" + "os" "path/filepath" + "strings" + "time" "github.com/fsnotify/fsnotify" ) +// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found. +// Returns nil if no nameservers are found. +func (p *prog) parseResolvConfNameservers(path string) ([]string, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // Parse the file for "nameserver" lines + var currentNS []string + lines := strings.Split(string(content), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "nameserver") { + parts := strings.Fields(trimmed) + if len(parts) >= 2 { + currentNS = append(currentNS, parts[1]) + } + } + } + + return currentNS, nil +} + // watchResolvConf watches any changes to /etc/resolv.conf file, // and reverting to the original config set by ctrld. func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { @@ -40,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQuery.Load() { + if p.recoveryRunning.Load() { return } if !ok { @@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting") - if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") - continue + mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + + // Convert expected nameservers to strings for comparison + expectedNS := make([]string, len(ns)) + for i, addr := range ns { + expectedNS[i] = addr.String() } - if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + + var foundNS []string + var err error + + maxRetries := 1 + for retry := 0; retry < maxRetries; retry++ { + foundNS, err = p.parseResolvConfNameservers(resolvConfPath) + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + break + } + + // If we found nameservers, break out of retry loop + if len(foundNS) > 0 { + break + } + + // Only retry if we found no nameservers + if retry < maxRetries-1 { + mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + select { + case <-p.stopCh: + return + case <-p.dnsWatcherStopCh: + return + case <-time.After(2 * time.Second): + continue + } + } else { + mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + } } - if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") - return + + // If we found nameservers, check if they match what we expect + if len(foundNS) > 0 { + // Check if the nameservers match exactly what we expect + matches := len(foundNS) == len(expectedNS) + if matches { + for i := range foundNS { + if foundNS[i] != expectedNS[i] { + matches = false + break + } + } + } + + mainLog.Load().Debug(). + Strs("found", foundNS). + Strs("expected", expectedNS). + Bool("matches", matches). + Msg("checking nameservers") + + // Only revert if the nameservers don't match + if !matches { + if err := watcher.Remove(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + continue + } + + if err := setDnsFn(iface, ns); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + } + + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + return + } + } } } case err, ok := <-watcher.Errors: diff --git a/cmd/cli/service.go b/cmd/cli/service.go index e4edfaf..82f144c 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) { type task struct { f func() error abortOnError bool + Name string } func doTasks(tasks []task) bool { - var prevErr error for _, task := range tasks { + mainLog.Load().Debug().Msgf("Running task %s", task.Name) if err := task.f(); err != nil { if task.abortOnError { - mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error()) + mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err) return false } - prevErr = err + mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) } } return true diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index f4d73e5..056903c 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -13,3 +13,8 @@ func hasElevatedPrivilege() (bool, error) { func openLogFile(path string, flags int) (*os.File, error) { return os.OpenFile(path, flags, os.FileMode(0o600)) } + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { return false } + +func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index d4e2449..c4df5a5 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,9 +2,14 @@ package cli import ( "os" + "runtime" + "strings" "syscall" + "time" + "unsafe" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" ) func hasElevatedPrivilege() (bool, error) { @@ -28,6 +33,67 @@ func hasElevatedPrivilege() (bool, error) { return token.IsMember(sid) } +// ConfigureWindowsServiceFailureActions checks if the given service +// has the correct failure actions configured, and updates them if not. +func ConfigureWindowsServiceFailureActions(serviceName string) error { + if runtime.GOOS != "windows" { + return nil // no-op on non-Windows + } + + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return err + } + defer s.Close() + + // 1. Retrieve the current config + cfg, err := s.Config() + if err != nil { + return err + } + + // 2. Update the Description + cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy" + + // 3. Apply the updated config + if err := s.UpdateConfig(cfg); err != nil { + return err + } + + // Then proceed with existing actions, e.g. setting failure actions + actions := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + } + + // Set the recovery actions (3 restarts, reset period = 120). + err = s.SetRecoveryActions(actions, 120) + if err != nil { + return err + } + + // Ensure that failure actions are NOT triggered on user-initiated stops. + var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG + failureActionsFlag.FailureActionsOnNonCrashFailures = 0 + + if err := windows.ChangeServiceConfig2( + s.Handle, + windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, + (*byte)(unsafe.Pointer(&failureActionsFlag)), + ); err != nil { + return err + } + + return nil +} + func openLogFile(path string, mode int) (*os.File, error) { if len(path) == 0 { return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} @@ -79,3 +145,23 @@ func openLogFile(path string, mode int) (*os.File, error) { return os.NewFile(uintptr(handle), path), nil } + +const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{})) + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { + h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if e != nil { + return false + } + p := windows.ProcessEntry32{Size: processEntrySize} + for { + e := windows.Process32Next(h, &p) + if e != nil { + return false + } + if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" { + return true + } + } +} diff --git a/cmd/cli/service_windows_test.go b/cmd/cli/service_windows_test.go new file mode 100644 index 0000000..67c2725 --- /dev/null +++ b/cmd/cli/service_windows_test.go @@ -0,0 +1,25 @@ +package cli + +import ( + "testing" + "time" +) + +func Test_hasLocalDnsServerRunning(t *testing.T) { + start := time.Now() + hasDns := hasLocalDnsServerRunning() + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + hasDnsPowershell := hasLocalDnsServerRunningPowershell() + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if hasDns != hasDnsPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns) + } +} + +func hasLocalDnsServerRunningPowershell() bool { + _, err := powershell("Get-Process -Name DNS") + return err == nil +} diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index b17cb32..6e19e38 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -1,18 +1,15 @@ package cli import ( - "context" "sync" "time" - "github.com/miekg/dns" - "github.com/Control-D-Inc/ctrld" ) const ( // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. - maxFailureRequest = 100 + maxFailureRequest = 50 // checkUpstreamBackoffSleep is the time interval between each upstream checks. checkUpstreamBackoffSleep = 2 * time.Second ) @@ -21,18 +18,24 @@ const ( type upstreamMonitor struct { cfg *ctrld.Config - mu sync.Mutex + mu sync.RWMutex checking map[string]bool down map[string]bool failureReq map[string]uint64 + recovered map[string]bool + + // failureTimerActive tracks if a timer is already running for a given upstream. + failureTimerActive map[string]bool } func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { um := &upstreamMonitor{ - cfg: cfg, - checking: make(map[string]bool), - down: make(map[string]bool), - failureReq: make(map[string]uint64), + cfg: cfg, + checking: make(map[string]bool), + down: make(map[string]bool), + failureReq: make(map[string]uint64), + recovered: make(map[string]bool), + failureTimerActive: make(map[string]bool), } for n := range cfg.Upstream { upstream := upstreamPrefix + n @@ -42,14 +45,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { return um } -// increaseFailureCount increase failed queries count for an upstream by 1. +// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information. +// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down +// within 10 seconds if failures persist, without spawning duplicate goroutines. func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.mu.Lock() defer um.mu.Unlock() + if um.recovered[upstream] { + mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) + return + } + um.failureReq[upstream] += 1 failedCount := um.failureReq[upstream] - um.down[upstream] = failedCount >= maxFailureRequest + + // Log the updated failure count. + mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + + // If this is the first failure and no timer is running, start a 10-second timer. + if failedCount == 1 && !um.failureTimerActive[upstream] { + um.failureTimerActive[upstream] = true + go func(upstream string) { + time.Sleep(10 * time.Second) + um.mu.Lock() + defer um.mu.Unlock() + // If no success occurred during the 10-second window (i.e. counter remains > 0) + // and the upstream is not in a recovered state, mark it as down. + if um.failureReq[upstream] > 0 && !um.recovered[upstream] { + um.down[upstream] = true + mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) + } + // Reset the timer flag so that a new timer can be spawned if needed. + um.failureTimerActive[upstream] = false + }(upstream) + } + + // If the failure count quickly reaches the threshold, mark the upstream as down immediately. + if failedCount >= maxFailureRequest { + um.down[upstream] = true + mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) + } } // isDown reports whether the given upstream is being marked as down. @@ -63,56 +99,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool { // reset marks an upstream as up and set failed queries counter to zero. func (um *upstreamMonitor) reset(upstream string) { um.mu.Lock() - defer um.mu.Unlock() - um.failureReq[upstream] = 0 um.down[upstream] = false -} - -// checkUpstream checks the given upstream status, periodically sending query to upstream -// until successfully. An upstream status/counter will be reset once it becomes reachable. -func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { - p.um.mu.Lock() - isChecking := p.um.checking[upstream] - if isChecking { - p.um.mu.Unlock() - return - } - p.um.checking[upstream] = true - p.um.mu.Unlock() - defer func() { - p.um.mu.Lock() - p.um.checking[upstream] = false - p.um.mu.Unlock() + um.recovered[upstream] = true + um.mu.Unlock() + go func() { + // debounce the recovery to avoid incrementing failure counts already in flight + time.Sleep(1 * time.Second) + um.mu.Lock() + um.recovered[upstream] = false + um.mu.Unlock() }() - - resolver, err := ctrld.NewResolver(uc) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not check upstream") - return - } - msg := new(dns.Msg) - msg.SetQuestion(".", dns.TypeNS) - - check := func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - uc.ReBootstrap() - _, err := resolver.Resolve(ctx, msg) - return err - } - for { - if err := check(); err == nil { - mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) - p.um.reset(upstream) - if p.leakingQuery.CompareAndSwap(true, false) { - p.leakingQueryMu.Lock() - p.leakingQueryWasRun = false - p.leakingQueryMu.Unlock() - mainLog.Load().Warn().Msg("stop leaking query") - } - return - } - time.Sleep(checkUpstreamBackoffSleep) - } +} + +// countHealthy returns the number of upstreams in the provided map that are considered healthy. +func (um *upstreamMonitor) countHealthy(upstreams []string) int { + var count int + um.mu.RLock() + for _, upstream := range upstreams { + if !um.down[upstream] { + count++ + } + } + um.mu.RUnlock() + return count } diff --git a/config.go b/config.go index 3f9b2f8..e1454f9 100644 --- a/config.go +++ b/config.go @@ -205,7 +205,7 @@ type ServiceConfig struct { CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"` DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` @@ -384,7 +384,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool { return *uc.Discoverable } switch uc.Type { - case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal: if ip, err := netip.ParseAddr(uc.Domain); err == nil { return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) } @@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") + ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil }) @@ -886,3 +886,12 @@ func upstreamUID() string { return hex.EncodeToString(b) } } + +// String returns a string representation of the UpstreamConfig for logging. +func (uc *UpstreamConfig) String() string { + if uc == nil { + return "" + } + return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}", + uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack) +} diff --git a/config_internal_test.go b/config_internal_test.go index 6823686..44b7e2f 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -2,12 +2,16 @@ package ctrld import ( "net/url" + "os" "testing" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { + l := zerolog.New(os.Stdout) + ProxyLogger.Store(&l) uc := &UpstreamConfig{ Name: "test", Type: ResolverTypeDOH, diff --git a/config_quic.go b/config_quic.go index a6dd8b7..a46780a 100644 --- a/config_quic.go +++ b/config_quic.go @@ -34,7 +34,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { } func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { - rt := &http3.RoundTripper{} + rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { _, port, _ := net.SplitHostPort(addr) @@ -64,7 +64,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } - runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) { + runtime.SetFinalizer(rt, func(rt *http3.Transport) { rt.CloseIdleConnections() }) return rt diff --git a/config_test.go b/config_test.go index a20b33c..cd392d5 100644 --- a/config_test.go +++ b/config_test.go @@ -111,6 +111,7 @@ func TestConfigValidation(t *testing.T) { {"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false}, {"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false}, {"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true}, + {"kea dhcp4 format", configWithDhcp4KeaFormat(t), false}, } for _, tc := range tests { @@ -307,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config { return cfg } +func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4" + return cfg +} + func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) cfg.Upstream["0"].Endpoint = "/1.1.1.1" diff --git a/go.mod b/go.mod index 84b58c4..635261f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/ameshkov/dnsstamps v1.0.3 github.com/coreos/go-systemd/v22 v22.5.0 github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf + github.com/docker/go-units v0.5.0 github.com/frankban/quicktest v1.14.6 github.com/fsnotify/fsnotify v1.7.0 github.com/go-playground/validator/v10 v10.11.1 @@ -20,6 +21,7 @@ require ( github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/kardianos/service v1.2.1 github.com/mdlayher/ndp v1.0.1 + github.com/microsoft/wmi v0.24.5 github.com/miekg/dns v1.1.58 github.com/minio/selfupdate v0.6.0 github.com/olekukonko/tablewriter v0.0.5 @@ -27,16 +29,16 @@ require ( github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.5.0 github.com/prometheus/prom2json v1.3.3 - github.com/quic-go/quic-go v0.42.0 + github.com/quic-go/quic-go v0.48.2 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.27.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 + golang.org/x/net v0.33.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.29.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) @@ -49,12 +51,14 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect + github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect @@ -72,10 +76,11 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect - github.com/quic-go/qpack v0.4.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/spf13/afero v1.9.5 // indirect @@ -87,10 +92,10 @@ require ( go.uber.org/mock v0.4.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.19.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.23.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect @@ -99,4 +104,4 @@ require ( replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8 -replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be +replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c diff --git a/go.sum b/go.sum index ebb9042..2ac97af 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik= -github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= +github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo= @@ -74,6 +74,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -91,6 +93,8 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= @@ -162,6 +166,8 @@ github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlG github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= @@ -207,11 +213,10 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -227,6 +232,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= +github.com/microsoft/wmi v0.24.5 h1:NT+WqhjKbEcg3ldmDsRMarWgHGkpeW+gMopSCfON0kM= +github.com/microsoft/wmi v0.24.5/go.mod h1:1zbdSF0A+5OwTUII5p3hN7/K6KF2m3o27pSG6Y51VU8= github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= @@ -245,6 +252,7 @@ github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -261,10 +269,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo= github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc= -github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= -github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= -github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -274,7 +282,7 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM= github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= @@ -338,8 +346,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -350,8 +358,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -409,8 +417,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -430,8 +438,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -472,16 +480,16 @@ golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -492,8 +500,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 780334b..35d5dbb 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -77,6 +77,7 @@ type Table struct { hostnameResolvers []HostnameResolver refreshers []refresher initOnce sync.Once + stopOnce sync.Once refreshInterval int dhcp *dhcp @@ -90,7 +91,9 @@ type Table struct { vni *virtualNetworkIface svcCfg ctrld.ServiceConfig quitCh chan struct{} + stopCh chan struct{} selfIP string + selfIPLock sync.RWMutex cdUID string ptrNameservers []string } @@ -103,6 +106,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), + stopCh: make(chan struct{}), selfIP: selfIP, cdUID: cdUID, ptrNameservers: ns, @@ -120,24 +124,59 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { // RefreshLoop runs all the refresher to update new client info data. func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval)) - defer timer.Stop() + defer func() { + timer.Stop() + close(t.quitCh) + }() for { select { case <-timer.C: - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() + case <-t.stopCh: + return case <-ctx.Done(): - close(t.quitCh) return } } } +// Init initializes all client info discovers. func (t *Table) Init() { t.initOnce.Do(t.init) } +// Refresh forces all discovers to retrieve new data. +func (t *Table) Refresh() { + for _, r := range t.refreshers { + _ = r.refresh() + } +} + +// Stop stops all the discovers. +// It blocks until all the discovers done. +func (t *Table) Stop() { + t.stopOnce.Do(func() { + close(t.stopCh) + }) + <-t.quitCh +} + +// SelfIP returns the selfIP value of the Table in a thread-safe manner. +func (t *Table) SelfIP() string { + t.selfIPLock.RLock() + defer t.selfIPLock.RUnlock() + return t.selfIP +} + +// SetSelfIP sets the selfIP value of the Table in a thread-safe manner. +func (t *Table) SetSelfIP(ip string) { + t.selfIPLock.Lock() + defer t.selfIPLock.Unlock() + t.selfIP = ip + t.dhcp.selfIP = t.selfIP + t.dhcp.addSelf() +} + func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { @@ -381,9 +420,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry { // ListClients returns list of clients discovered by ctrld. func (t *Table) ListClients() []*Client { - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() ipMap := make(map[string]*Client) il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni} for _, ir := range il { diff --git a/internal/controld/config.go b/internal/controld/config.go index 1bc2512..fbbd9d4 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -25,8 +25,12 @@ import ( const ( apiDomainCom = "api.controld.com" apiDomainDev = "api.controld.dev" - resolverDataURLCom = "https://api.controld.com/utility" - resolverDataURLDev = "https://api.controld.dev/utility" + apiURLCom = "https://api.controld.com" + apiURLDev = "https://api.controld.dev" + resolverDataURLCom = apiURLCom + "/utility" + resolverDataURLDev = apiURLDev + "/utility" + logURLCom = apiURLCom + "/logs" + logURLDev = apiURLDev + "/logs" InvalidConfigCode = 40402 ) @@ -49,14 +53,14 @@ type utilityResponse struct { } `json:"body"` } -type UtilityErrorResponse struct { +type ErrorResponse struct { ErrorField struct { Message string `json:"message"` Code int `json:"code"` } `json:"error"` } -func (u UtilityErrorResponse) Error() string { +func (u ErrorResponse) Error() string { return u.ErrorField.Message } @@ -71,6 +75,12 @@ type UtilityOrgRequest struct { Hostname string `json:"hostname"` } +// LogsRequest contains request data for sending runtime logs to API. +type LogsRequest struct { + UID string `json:"uid"` + Data io.ReadCloser `json:"-"` +} + // FetchResolverConfig fetch Control D config for given uid. func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) @@ -123,6 +133,81 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") + transport := apiTransport(cdDev) + client := http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) + } + defer resp.Body.Close() + d := json.NewDecoder(resp.Body) + if resp.StatusCode != http.StatusOK { + errResp := &ErrorResponse{} + if err := d.Decode(errResp); err != nil { + return nil, err + } + return nil, errResp + } + + ur := &utilityResponse{} + if err := d.Decode(ur); err != nil { + return nil, err + } + return &ur.Body.Resolver, nil +} + +// SendLogs sends runtime log to ControlD API. +func SendLogs(lr *LogsRequest, cdDev bool) error { + defer lr.Data.Close() + apiUrl := logURLCom + if cdDev { + apiUrl = logURLDev + } + req, err := http.NewRequest("POST", apiUrl, lr.Data) + if err != nil { + return fmt.Errorf("http.NewRequest: %w", err) + } + q := req.URL.Query() + q.Set("uid", lr.UID) + req.URL.RawQuery = q.Encode() + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + transport := apiTransport(cdDev) + client := http.Client{ + Timeout: 300 * time.Second, + Transport: transport, + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("SendLogs client.Do: %w", err) + } + defer resp.Body.Close() + d := json.NewDecoder(resp.Body) + if resp.StatusCode != http.StatusOK { + errResp := &ErrorResponse{} + if err := d.Decode(errResp); err != nil { + return err + } + return errResp + } + _, _ = io.Copy(io.Discard, resp.Body) + return nil +} + +// ParseRawUID parse the input raw UID, returning real UID and ClientID. +// The raw UID can have 2 forms: +// +// - +// - / +func ParseRawUID(rawUID string) (string, string) { + uid, clientID, _ := strings.Cut(rawUID, "/") + return uid, clientID +} + +// apiTransport returns an HTTP transport for connecting to ControlD API endpoint. +func apiTransport(cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom @@ -143,41 +228,8 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade d := &ctrldnet.ParallelDialer{} return d.DialContext(ctx, network, addrs) } - if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } - client := http.Client{ - Timeout: 10 * time.Second, - Transport: transport, - } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("client.Do: %w", err) - } - defer resp.Body.Close() - d := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - errResp := &UtilityErrorResponse{} - if err := d.Decode(errResp); err != nil { - return nil, err - } - return nil, errResp - } - - ur := &utilityResponse{} - if err := d.Decode(ur); err != nil { - return nil, err - } - return &ur.Body.Resolver, nil -} - -// ParseRawUID parse the input raw UID, returning real UID and ClientID. -// The raw UID can have 2 forms: -// -// - -// - / -func ParseRawUID(rawUID string) (string, string) { - uid, clientID, _ := strings.Cut(rawUID, "/") - return uid, clientID + return transport } diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 2beebd0..b835060 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -1,19 +1,16 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package ctrld import ( "net" - "os/exec" - "runtime" - "strings" "syscall" "golang.org/x/net/route" ) func dnsFns() []dnsFn { - return []dnsFn{dnsFromRIB, dnsFromIPConfig} + return []dnsFn{dnsFromRIB} } func dnsFromRIB() []string { @@ -49,18 +46,6 @@ func dnsFromRIB() []string { return dns } -func dnsFromIPConfig() []string { - if runtime.GOOS != "darwin" { - return nil - } - cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server") - out, _ := cmd.Output() - if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil { - return []string{ip.String()} - } - return nil -} - func toNetIP(addr route.Addr) net.IP { switch t := addr.(type) { case *route.Inet4Addr: diff --git a/nameservers_darwin.go b/nameservers_darwin.go new file mode 100644 index 0000000..d536d78 --- /dev/null +++ b/nameservers_darwin.go @@ -0,0 +1,217 @@ +//go:build darwin + +package ctrld + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "os/exec" + "regexp" + "slices" + "strings" + "time" + + "tailscale.com/net/netmon" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +) + +func dnsFns() []dnsFn { + return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} +} + +// dnsFromResolvConf reads nameservers from /etc/resolv.conf +func dnsFromResolvConf() []string { + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var dns []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + nss := resolvconffile.NameServers("") + var localDNS []string + seen := make(map[string]bool) + + for _, ns := range nss { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + continue + } + } + if !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + + // If we successfully read the file and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return dns +} + +func getDNSFromScutil() []string { + logger := *ProxyLogger.Load() + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var nameservers []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + cmd := exec.Command("scutil", "--dns") + output, err := cmd.Output() + if err != nil { + Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "nameserver[") { + parts := strings.Split(line, ":") + if len(parts) == 2 { + ns := strings.TrimSpace(parts[1]) + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + isLocal = true + break + } + } + if !isLocal && !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the output and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return nameservers +} + +func getDHCPNameservers(iface string) ([]string, error) { + // Run the ipconfig command for the given interface. + cmd := exec.Command("ipconfig", "getpacket", iface) + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("error running ipconfig: %v", err) + } + + // Look for a line like: + // domain_name_servers = 192.168.1.1 8.8.8.8; + re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`) + matches := re.FindStringSubmatch(string(output)) + if len(matches) < 2 { + return nil, fmt.Errorf("no DHCP nameservers found") + } + + // Split the nameservers by whitespace. + nameservers := strings.Fields(matches[1]) + return nameservers, nil +} + +func getAllDHCPNameservers() []string { + interfaces, err := net.Interfaces() + if err != nil { + return nil + } + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var allNameservers []string + seen := make(map[string]bool) + + for _, iface := range interfaces { + // Skip interfaces that are: + // - down + // - loopback + // - not physical (virtual) + // - point-to-point (like VPN interfaces) + // - without MAC address (non-physical) + if iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagPointToPoint != 0 || + (iface.Flags&net.FlagBroadcast == 0 && + iface.Flags&net.FlagMulticast == 0) || + len(iface.HardwareAddr) == 0 || + strings.HasPrefix(iface.Name, "utun") || + strings.HasPrefix(iface.Name, "llw") || + strings.HasPrefix(iface.Name, "awdl") { + continue + } + + // Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48) + if len(iface.HardwareAddr) != 6 { + continue + } + + nameservers, err := getDHCPNameservers(iface.Name) + if err != nil { + continue + } + + // Add unique nameservers to the result, skipping local IPs + for _, ns := range nameservers { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback and local IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + if ip.String() == v.String() { + isLocal = true + break + } + } + if !isLocal && !seen[ns] { + seen[ns] = true + allNameservers = append(allNameservers, ns) + } + } + } + } + + return allNameservers +} diff --git a/nameservers_windows.go b/nameservers_windows.go index 150f252..54fb8b6 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -1,44 +1,473 @@ package ctrld import ( + "context" + "fmt" + "io" + "log" + "net" + "os" + "strings" "syscall" + "time" + "unsafe" + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" + "github.com/microsoft/wmi/pkg/hardware/network/netadapter" + "github.com/rs/zerolog" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) +const ( + maxDNSAdapterRetries = 5 + retryDelayDNSAdapter = 1 * time.Second + defaultDNSAdapterTimeout = 10 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + NetSetupUnknown uint32 = 0 + NetSetupWorkgroup uint32 = 1 + NetSetupDomain uint32 = 2 + NetSetupCloudDomain uint32 = 3 + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 +) + +type DomainControllerInfo struct { + DomainControllerName *uint16 + DomainControllerAddress *uint16 + DomainControllerAddressType uint32 + DomainGuid windows.GUID + DomainName *uint16 + DnsForestName *uint16 + Flags uint32 + DcSiteName *uint16 + ClientSiteName *uint16 +} + func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } func dnsFromAdapter() []string { - aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) - if err != nil { - return nil + ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) + defer cancel() + + var ns []string + var err error + + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } + + for i := 0; i < maxDNSAdapterRetries; i++ { + if ctx.Err() != nil { + Log(context.Background(), logger.Debug(), + "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + return nil + } + + ns, err = getDNSServers(ctx) + if err == nil && len(ns) >= minDNSServers { + if i > 0 { + Log(context.Background(), logger.Debug(), + "Successfully got DNS servers after %d attempts, found %d servers", + i+1, len(ns)) + } + return ns + } + + // if osResolver is not initialized, this is likely a command line run + // and ctrld is already on the interface, abort retries + if or == nil { + return ns + } + + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get DNS servers, attempt %d: %v", i+1, err) + } else { + Log(context.Background(), logger.Debug(), + "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(retryDelayDNSAdapter): + } + } + + Log(context.Background(), logger.Debug(), + "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns +} + +func getDNSServers(ctx context.Context) ([]string, error) { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + // Check context before making the call + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // Get DNS servers from adapters (existing method) + flags := winipcfg.GAAFlagIncludeGateways | + winipcfg.GAAFlagIncludePrefix + + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags) + if err != nil { + return nil, fmt.Errorf("getting adapters: %w", err) + } + + Log(context.Background(), logger.Debug(), + "Found network adapters, count=%d", len(aas)) + + // Try to get domain controller info if domain-joined + var dcServers []string + isDomain := checkDomainJoined() + if isDomain { + domainName, err := getLocalADDomain() + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get local AD domain: %v", err) + } else { + // Load netapi32.dll + netapi32 := windows.NewLazySystemDLL("netapi32.dll") + dsDcName := netapi32.NewProc("DsGetDcNameW") + + var info *DomainControllerInfo + flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME) + + domainUTF16, err := windows.UTF16PtrFromString(domainName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to convert domain name to UTF16: %v", err) + } else { + Log(context.Background(), logger.Debug(), + "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + + // Call DsGetDcNameW with domain name + ret, _, err := dsDcName.Call( + 0, // ComputerName - can be NULL + uintptr(unsafe.Pointer(domainUTF16)), // DomainName + 0, // DomainGuid - not needed + 0, // SiteName - not needed + uintptr(flags), // Flags + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + + if ret != 0 { + switch ret { + case 1355: // ERROR_NO_SUCH_DOMAIN + Log(context.Background(), logger.Debug(), + "Domain not found: %s (%d)", domainName, ret) + case 1311: // ERROR_NO_LOGON_SERVERS + Log(context.Background(), logger.Debug(), + "No logon servers available for domain: %s (%d)", domainName, ret) + case 1004: // ERROR_DC_NOT_FOUND + Log(context.Background(), logger.Debug(), + "Domain controller not found for domain: %s (%d)", domainName, ret) + case 1722: // RPC_S_SERVER_UNAVAILABLE + Log(context.Background(), logger.Debug(), + "RPC server unavailable for domain: %s (%d)", domainName, ret) + default: + Log(context.Background(), logger.Debug(), + "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + } + } else if info != nil { + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) + + if info.DomainControllerAddress != nil { + dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) + dcAddr = strings.TrimPrefix(dcAddr, "\\\\") + Log(context.Background(), logger.Debug(), + "Found domain controller address: %s", dcAddr) + + if ip := net.ParseIP(dcAddr); ip != nil { + dcServers = append(dcServers, ip.String()) + Log(context.Background(), logger.Debug(), + "Added domain controller DNS servers: %v", dcServers) + } + } else { + Log(context.Background(), logger.Debug(), + "No domain controller address found") + } + } + } + } + } + + // Continue with existing adapter DNS collection ns := make([]string, 0, len(aas)*2) seen := make(map[string]bool) addressMap := make(map[string]struct{}) + + // Collect all local IPs for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + Log(context.Background(), logger.Debug(), + "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + continue + } + + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + + Log(context.Background(), logger.Debug(), + "Processing adapter %s", aa.FriendlyName()) + for a := aa.FirstUnicastAddress; a != nil; a = a.Next { - addressMap[a.Address.IP().String()] = struct{}{} + ip := a.Address.IP().String() + addressMap[ip] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } + + validInterfacesMap := validInterfaces() + + // Collect DNS servers for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + continue + } + + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + + // if not in the validInterfacesMap, skip + if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { + Log(context.Background(), logger.Debug(), + "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + continue + } + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() - if ip == nil || ip.IsLoopback() || seen[ip.String()] { + if ip == nil { + Log(context.Background(), logger.Debug(), + "Skipping nil IP from adapter %s", aa.FriendlyName()) continue } - if _, ok := addressMap[ip.String()]; ok { + + ipStr := ip.String() + l := logger.Debug(). + Str("ip", ipStr). + Str("adapter", aa.FriendlyName()) + + if ip.IsLoopback() { + l.Msg("Skipping loopback IP") continue } - seen[ip.String()] = true - ns = append(ns, ip.String()) + if seen[ipStr] { + l.Msg("Skipping duplicate IP") + continue + } + if _, ok := addressMap[ipStr]; ok { + l.Msg("Skipping local interface IP") + continue + } + + seen[ipStr] = true + ns = append(ns, ipStr) + l.Msg("Added DNS server") } } - return ns + + // Add DC servers if they're not already in the list + for _, dcServer := range dcServers { + if !seen[dcServer] { + seen[dcServer] = true + ns = append(ns, dcServer) + Log(context.Background(), logger.Debug(), + "Added additional domain controller DNS server: %s", dcServer) + } + } + + if len(ns) == 0 { + return nil, fmt.Errorf("no valid DNS servers found") + } + + Log(context.Background(), logger.Debug(), + "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", + len(ns), ns, len(dcServers)) + return ns, nil } func nameserversFromResolvconf() []string { return nil } + +// checkDomainJoined checks if the machine is joined to an Active Directory domain +// Returns whether it's domain joined and the domain name if available +func checkDomainJoined() bool { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + var domain *uint16 + var status uint32 + + err := windows.NetGetJoinInformation(nil, &domain, &status) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get domain join status: %v", err) + return false + } + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + + domainName := windows.UTF16PtrToString(domain) + Log(context.Background(), logger.Debug(), + "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", + domainName, status) + + // Consider domain or cloud domain as domain-joined + isDomain := status == NetSetupDomain || status == NetSetupCloudDomain + Log(context.Background(), logger.Debug(), + "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", + status, + status == NetSetupDomain, + status == NetSetupCloudDomain, + isDomain) + + return isDomain +} + +// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*) +// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call. +func getLocalADDomain() (string, error) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + // 1) Check environment variable + envDomain := os.Getenv("USERDNSDOMAIN") + if envDomain != "" { + return strings.TrimSpace(envDomain), nil + } + + // 2) Query WMI via the microsoft/wmi library + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("Win32_ComputerSystem") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) + if instances != nil { + defer instances.Close() + } + if err != nil { + return "", fmt.Errorf("WMI query failed: %v", err) + } + + // If no results, return an error + if len(instances) == 0 { + return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") + } + + // We only care about the first row + domainVal, err := instances[0].GetProperty("Domain") + if err != nil { + return "", fmt.Errorf("machine does not appear to have a domain set: %v", err) + } + + domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal)) + if domainName == "" { + return "", fmt.Errorf("machine does not appear to have a domain set") + } + return domainName, nil +} + +// validInterfaces returns a list of all physical interfaces. +// this is a duplicate of what is in net_windows.go, we should +// clean this up so there is only one version +func validInterfaces() map[string]struct{} { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("MSFT_NetAdapter") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get wmi network adapter: %v", err) + return nil + } + var adapters []string + for _, i := range instances { + adapter, err := netadapter.NewNetworkAdapter(i) + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get network adapter: %v", err) + continue + } + + name, err := adapter.GetPropertyName() + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get interface name: %v", err) + continue + } + + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) + // + // "Indicates if a connector is present on the network adapter. This value is set to TRUE + // if this is a physical adapter or FALSE if this is not a physical adapter." + physical, err := adapter.GetPropertyConnectorPresent() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter connector present property: %v", err) + continue + } + if !physical { + Log(context.Background(), logger.Debug(), + "skipping non-physical adapter: %s", name) + continue + } + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter hardware interface property: %v", err) + continue + } + if !hardware { + Log(context.Background(), logger.Debug(), + "skipping non-hardware interface: %s", name) + continue + } + + adapters = append(adapters, name) + } + + m := make(map[string]struct{}) + for _, ifaceName := range adapters { + m[ifaceName] = struct{}{} + } + return m +} diff --git a/resolver.go b/resolver.go index f54edfb..19ebc1f 100644 --- a/resolver.go +++ b/resolver.go @@ -4,15 +4,17 @@ import ( "context" "errors" "fmt" + "io" "net" "net/netip" + "runtime" "slices" - "strings" "sync" "sync/atomic" "time" "github.com/miekg/dns" + "github.com/rs/zerolog" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -30,8 +32,10 @@ const ( ResolverTypeOS = "os" // ResolverTypeLegacy specifies legacy resolver. ResolverTypeLegacy = "legacy" - // ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only. + // ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only. ResolverTypePrivate = "private" + // ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only. + ResolverTypeLocal = "local" // ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps. // See: https://dnscrypt.info/stamps-specifications/ ResolverTypeSDNS = "sdns" @@ -44,8 +48,30 @@ const ( var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") -// or is the Resolver used for ResolverTypeOS. -var or = newResolverWithNameserver(defaultNameservers()) +var localResolver = newLocalResolver() + +var ( + resolverMutex sync.Mutex + or *osResolver + defaultLocalIPv4 atomic.Value // holds net.IP (IPv4) + defaultLocalIPv6 atomic.Value // holds net.IP (IPv6) +) + +func newLocalResolver() Resolver { + var nss []string + for _, addr := range Rfc1918Addresses() { + nss = append(nss, net.JoinHostPort(addr, "53")) + } + return NewResolverWithNameserver(nss) +} + +// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network. +type LanQueryCtxKey struct{} + +// LanQueryCtx returns a context.Context with LanQueryCtxKey set. +func LanQueryCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, LanQueryCtxKey{}, true) +} // defaultNameservers is like nameservers with each element formed "ip:53". func defaultNameservers() []string { @@ -63,17 +89,39 @@ func availableNameservers() []string { // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - machineIPsMap[v.String()] = struct{}{} + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } - for _, ns := range nameservers() { + Log(context.Background(), logger.Debug(), + "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + machineIPsMap[ipStr] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP to OS resolverexclusion map: %s", ipStr) + } + + systemNameservers := nameservers() + Log(context.Background(), logger.Debug(), + "Got system nameservers: %v", systemNameservers) + + for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { + Log(context.Background(), logger.Debug(), + "Skipping local nameserver: %s", ns) continue } - if testNameserver(ns) { - nss = append(nss, ns) - } + nss = append(nss, ns) + Log(context.Background(), logger.Debug(), + "Added non-local nameserver: %s", ns) } + + Log(context.Background(), logger.Debug(), + "Final available nameservers: %v", nss) return nss } @@ -82,77 +130,47 @@ func availableNameservers() []string { // // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. -func InitializeOsResolver() []string { - return initializeOsResolver(availableNameservers()) +func InitializeOsResolver(guardAgainstNoNameservers bool) []string { + nameservers := availableNameservers() + // if no nameservers, return empty slice so we dont remove all nameservers + if len(nameservers) == 0 && guardAgainstNoNameservers { + return []string{} + } + ns := initializeOsResolver(nameservers) + resolverMutex.Lock() + defer resolverMutex.Unlock() + or = newResolverWithNameserver(ns) + return ns } + +// initializeOsResolver performs logic for choosing OS resolver nameserver. +// The logic: +// +// - First available LAN servers are saved and store. +// - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { - var ( - nss []string - publicNss []string - ) - var ( - lastLanServer netip.Addr - curLanServer netip.Addr - curLanServerAvailable bool - ) - if p := or.currentLanServer.Load(); p != nil { - curLanServer = *p - or.currentLanServer.Store(nil) - } - if p := or.lastLanServer.Load(); p != nil { - lastLanServer = *p - or.lastLanServer.Store(nil) - } + + var lanNss, publicNss []string + + // First categorize servers for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { continue } server := net.JoinHostPort(ns, "53") - // Always use new public nameserver. - if !isLanAddr(addr) { - publicNss = append(publicNss, server) - nss = append(nss, server) - continue - } - // For LAN server, storing only current and last LAN server if any. - if addr.Compare(curLanServer) == 0 { - curLanServerAvailable = true + if isLanAddr(addr) { + lanNss = append(lanNss, server) } else { - if addr.Compare(lastLanServer) == 0 { - or.lastLanServer.Store(&addr) - } else { - if or.currentLanServer.CompareAndSwap(nil, &addr) { - nss = append(nss, server) - } - } + publicNss = append(publicNss, server) } } - // Store current LAN server as last one only if it's still available. - if curLanServerAvailable && curLanServer.IsValid() { - or.lastLanServer.Store(&curLanServer) - nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) - } - if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) - nss = append(nss, controldPublicDnsWithPort) - } - or.publicServer.Store(&publicNss) - return nss -} -// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. -func testNameserver(addr string) bool { - msg := new(dns.Msg) - msg.SetQuestion("controld.com.", dns.TypeNS) - client := new(dns.Client) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53")) - if err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr) + if len(publicNss) == 0 { + publicNss = []string{controldPublicDnsWithPort} } - return err == nil + + return slices.Concat(lanNss, publicNss) } // Resolver is the interface that wraps the basic DNS operations. @@ -175,19 +193,23 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeDOQ: return &doqResolver{uc: uc}, nil case ResolverTypeOS: + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil case ResolverTypePrivate: return NewPrivateResolver(), nil + case ResolverTypeLocal: + return localResolver, nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } type osResolver struct { - currentLanServer atomic.Pointer[netip.Addr] - lastLanServer atomic.Pointer[netip.Addr] - publicServer atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -197,26 +219,75 @@ type osResolverResult struct { lan bool } +type publicResponse struct { + answer *dns.Msg + server string +} + +// SetDefaultLocalIPv4 updates the stored local IPv4. +func SetDefaultLocalIPv4(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) + defaultLocalIPv4.Store(ip) +} + +// SetDefaultLocalIPv6 updates the stored local IPv6. +func SetDefaultLocalIPv6(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) + defaultLocalIPv6.Store(ip) +} + +// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none. +func GetDefaultLocalIPv4() net.IP { + if v := defaultLocalIPv4.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none. +func GetDefaultLocalIPv6() net.IP { + if v := defaultLocalIPv6.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// customDNSExchange wraps the DNS exchange to use our debug dialer. +// It uses dns.ExchangeWithConn so that our custom dialer is used directly. +func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, time.Duration, error) { + baseDialer := &net.Dialer{ + Timeout: 3 * time.Second, + Resolver: &net.Resolver{PreferGo: true}, + } + if desiredLocalIP != nil { + baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0} + } + dnsClient := &dns.Client{Net: "udp"} + dnsClient.Dialer = baseDialer + return dnsClient.ExchangeContext(ctx, msg, server) +} + // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - publicServers := *o.publicServer.Load() - nss := make([]string, 0, 2) - if p := o.currentLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) - } - if p := o.lastLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) + publicServers := *o.publicServers.Load() + var nss []string + if p := o.lanServers.Load(); p != nil { + nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) + // If this is a LAN query, skip public DNS. + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if ok && lan { + numServers -= len(publicServers) + } if numServers == 0 { return nil, errors.New("no nameservers available") } ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp"} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) @@ -229,57 +300,86 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error for _, server := range servers { go func(server string) { defer wg.Done() - answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) + var answer *dns.Msg + var err error + var localOSResolverIP net.IP + if runtime.GOOS == "darwin" { + host, _, err := net.SplitHostPort(server) + if err == nil { + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { + // IPv6 nameserver; use default IPv6 address (if set) + localOSResolverIP = GetDefaultLocalIPv6() + } else { + localOSResolverIP = GetDefaultLocalIPv4() + } + } + } + answer, _, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP) ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan} }(server) } } do(nss, true) - do(publicServers, false) + if !lan { + do(publicServers, false) + } logAnswer := func(server string) { - if before, _, found := strings.Cut(server, ":"); found { - server = before + host, _, err := net.SplitHostPort(server) + if err != nil { + // If splitting fails, fallback to the original server string + host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server) + Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) } var ( nonSuccessAnswer *dns.Msg nonSuccessServer string controldSuccessAnswer *dns.Msg - publicServerAnswer *dns.Msg - publicServer string + publicResponses []publicResponse ) errs := make([]error, 0, numServers) for res := range ch { switch { case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { - case res.server == controldPublicDnsWithPort: - controldSuccessAnswer = res.answer // only use ControlD answer as last one. - case !res.lan && publicServerAnswer == nil: - publicServerAnswer = res.answer // use public DNS answer after LAN server.. - publicServer = res.server - default: + case res.lan: + // Always prefer LAN responses immediately + Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil + case res.server == controldPublicDnsWithPort: + controldSuccessAnswer = res.answer + case !res.lan: + publicResponses = append(publicResponses, publicResponse{ + answer: res.answer, + server: res.server, + }) } case res.answer != nil: nonSuccessAnswer = res.answer nonSuccessServer = res.server + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + res.server, res.answer.Rcode) } errs = append(errs, res.err) } - if publicServerAnswer != nil { - logAnswer(publicServer) - return publicServerAnswer, nil + + if len(publicResponses) > 0 { + resp := publicResponses[0] + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server) + logAnswer(resp.server) + return resp.answer, nil } if controldSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } @@ -328,7 +428,11 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - nss := defaultNameservers() + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } + nss := *or.lanServers.Load() + nss = append(nss, *or.publicServers.Load()...) if withBootstrapDNS { nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) } @@ -467,17 +571,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { r := &osResolver{} - nss := slices.Sorted(slices.Values(nameservers)) - for i, ns := range nss { + var publicNss []string + var lanNss []string + for _, ns := range slices.Sorted(slices.Values(nameservers)) { ip, _, _ := net.SplitHostPort(ns) addr, _ := netip.ParseAddr(ip) if isLanAddr(addr) { - r.currentLanServer.Store(&addr) - nss = slices.Delete(nss, i, i+1) - break + lanNss = append(lanNss, ns) + } else { + publicNss = append(publicNss, ns) } } - r.publicServer.Store(&nss) + r.lanServers.Store(&lanNss) + r.publicServers.Store(&publicNss) return r } diff --git a/resolver_test.go b/resolver_test.go index 7b1a49d..e96e875 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -3,13 +3,10 @@ package ctrld import ( "context" "net" - "slices" "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/miekg/dns" ) @@ -20,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() resolver := &osResolver{} - resolver.publicServer.Store(&[]string{"127.0.0.127:5353"}) + resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -34,26 +31,51 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveLanHostname(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reqId := "req-id" + ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId) + ctx = LanQueryCtx(ctx) + + go func(ctx context.Context) { + defer cancel() + id, ok := ctx.Value(ReqIdCtxKey{}).(string) + if !ok || id != reqId { + t.Error("missing request id") + return + } + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if !ok || !lan { + t.Error("not a LAN query") + return + } + resolver := &osResolver{} + resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + m := new(dns.Msg) + m.SetQuestion("controld.com.", dns.TypeA) + m.RecursionDesired = true + _, err := resolver.Resolve(ctx, m) + if err == nil { + t.Error("os resolver succeeded unexpectedly") + return + } + }(ctx) + + select { + case <-time.After(10 * time.Second): + t.Error("os resolver hangs") + case <-ctx.Done(): + } +} + func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { ns := make([]string, 0, 2) servers := make([]*dns.Server, 0, 2) - successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(msg, dns.RcodeSuccess) - w.WriteMsg(m) - }) - nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc { - return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(msg, rcode) - w.WriteMsg(m) - }) - } - handlers := []dns.Handler{ nonSuccessHandlerWithRcode(dns.RcodeRefused), nonSuccessHandlerWithRcode(dns.RcodeNameError), - successHandler, + successHandler(), } for i := range handlers { pc, err := net.ListenPacket("udp", ":0") @@ -74,7 +96,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } }() resolver := &osResolver{} - resolver.publicServer.Store(&ns) + resolver.publicServers.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) @@ -93,7 +115,7 @@ func Test_osResolver_InitializationRace(t *testing.T) { for range n { go func() { defer wg.Done() - InitializeOsResolver() + InitializeOsResolver(false) }() } wg.Wait() @@ -153,41 +175,18 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H return server, addr, nil } -func Test_initializeOsResolver(t *testing.T) { - lanServer1 := "192.168.1.1" - lanServer2 := "10.0.10.69" - wanServer := "1.1.1.1" - publicServers := []string{net.JoinHostPort(wanServer, "53")} - - // First initialization. - initializeOsResolver([]string{lanServer1, wanServer}) - p := or.currentLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // No new LAN server, current LAN server -> last LAN server. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.Nil(t, p) - p = or.lastLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // New LAN server detected. - initializeOsResolver([]string{lanServer2, lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer2, p.String()) - p = or.lastLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // No LAN server available. - initializeOsResolver([]string{wanServer}) - assert.Nil(t, or.currentLanServer.Load()) - assert.Nil(t, or.lastLanServer.Load()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) +func successHandler() dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + } +} + +func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, rcode) + w.WriteMsg(m) + } }