From 8b605da861d0f72ace36158b10b9362516c444e1 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Aug 2025 14:37:21 +0700 Subject: [PATCH] refactor: convert rootCmd from global to local variable - Add appVersion variable to store curVersion() result during init - Change initCLI() to return *cobra.Command - Move rootCmd creation inside initCLI() as local variable - Replace all rootCmd.Version usage with appVersion variable - Update Main() function to capture returned rootCmd from initCLI() - Remove sync.Once guard from tests and use initCLI() directly - Remove sync import from test file as it's no longer needed This refactoring improves encapsulation by eliminating global state, reduces version computation overhead, and simplifies test setup by removing the need for sync.Once guards. All tests pass and the application builds successfully. --- cmd/cli/cli.go | 32 ++++++++++++++++++-------------- cmd/cli/commands_test.go | 29 +++++++++-------------------- cmd/cli/control_server.go | 2 +- cmd/cli/dns_proxy.go | 2 +- cmd/cli/main.go | 2 +- cmd/cli/prog.go | 4 ++-- 6 files changed, 32 insertions(+), 39 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 6bb7e9b..602c391 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -61,6 +61,8 @@ var ( defaultConfigFile = "ctrld.toml" rootCertPool *x509.CertPool errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks") + // Store version once during init to avoid repeated calls to curVersion() + appVersion = curVersion() ) var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} @@ -83,15 +85,6 @@ _/ ___\ __\_ __ \ | / __ | \/ dns forwarding proxy \/ ` -var rootCmd = &cobra.Command{ - Use: "ctrld", - Short: strings.TrimLeft(rootShortDesc, "\n"), - Version: curVersion(), - PersistentPreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, -} - func curVersion() string { if version != "dev" && !strings.HasPrefix(version, "v") { version = "v" + version @@ -105,12 +98,21 @@ func curVersion() string { return fmt.Sprintf("%s-%s", version, commit) } -func initCLI() { +func initCLI() *cobra.Command { // Enable opening via explorer.exe on Windows. // See: https://github.com/spf13/cobra/issues/844. cobra.MousetrapHelpText = "" cobra.EnableCommandSorting = false + rootCmd := &cobra.Command{ + Use: "ctrld", + Short: strings.TrimLeft(rootShortDesc, "\n"), + Version: appVersion, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + }, + } + rootCmd.PersistentFlags().CountVarP( &verbose, "verbose", @@ -132,6 +134,8 @@ func initCLI() { InitClientsCmd(rootCmd) InitUpgradeCmd(rootCmd) InitLogCmd(rootCmd) + + return rootCmd } // isMobile reports whether the current OS is a mobile platform. @@ -603,12 +607,12 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) bo.LogLongerThan = 30 * time.Second ctx := ctrld.LoggerCtx(context.Background(), logger) - resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) + resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, appVersion, cdDev) for { if errUrlNetworkError(err) { bo.BackOff(ctx, err) logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...") - resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) + resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, appVersion, cdDev) continue } break @@ -1391,7 +1395,7 @@ func cdUIDFromProvToken() string { req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) - resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, rootCmd.Version, cdDev) + resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, appVersion, cdDev) if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) } @@ -1715,7 +1719,7 @@ func runningIface(s service.Service) *ifaceResponse { // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) - rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) + rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) if err != nil { logger := mainLog.Load().Fatal() if !fatal { diff --git a/cmd/cli/commands_test.go b/cmd/cli/commands_test.go index 683aa79..98ac760 100644 --- a/cmd/cli/commands_test.go +++ b/cmd/cli/commands_test.go @@ -2,7 +2,6 @@ package cli import ( "bytes" - "sync" "testing" "github.com/spf13/cobra" @@ -10,20 +9,10 @@ import ( "github.com/stretchr/testify/require" ) -// setupTestCLI initializes the CLI for testing, ensuring it's only done once -var cliInitOnce sync.Once - -func setupTestCLI() { - cliInitOnce.Do(func() { - initCLI() - }) -} - // TestBasicCommandStructure tests the actual root command structure func TestBasicCommandStructure(t *testing.T) { - // Test the actual global rootCmd that's used in the application - // Initialize the CLI to set up the root command - setupTestCLI() + // Test the actual root command that's returned from initCLI() + rootCmd := initCLI() // Test that root command has basic properties assert.Equal(t, "ctrld", rootCmd.Use) @@ -93,7 +82,7 @@ func TestServiceCommandSubCommands(t *testing.T) { // TestCommandHelp tests basic help functionality func TestCommandHelp(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test help command execution var buf bytes.Buffer @@ -109,7 +98,7 @@ func TestCommandHelp(t *testing.T) { // TestCommandVersion tests version command func TestCommandVersion(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() var buf bytes.Buffer rootCmd.SetOut(&buf) @@ -125,7 +114,7 @@ func TestCommandVersion(t *testing.T) { // TestCommandErrorHandling tests error handling func TestCommandErrorHandling(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test invalid flag instead of invalid command rootCmd.SetArgs([]string{"--invalid-flag"}) @@ -136,7 +125,7 @@ func TestCommandErrorHandling(t *testing.T) { // TestCommandFlags tests flag functionality func TestCommandFlags(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test that root command has expected flags verboseFlag := rootCmd.PersistentFlags().Lookup("verbose") @@ -151,7 +140,7 @@ func TestCommandFlags(t *testing.T) { // TestCommandExecution tests basic command execution func TestCommandExecution(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test that root command can be executed (help command) var buf bytes.Buffer @@ -167,7 +156,7 @@ func TestCommandExecution(t *testing.T) { // TestCommandArgs tests argument handling func TestCommandArgs(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test that root command can handle arguments properly // Test with no args (should succeed) @@ -183,7 +172,7 @@ func TestCommandArgs(t *testing.T) { // TestCommandSubcommands tests subcommand functionality func TestCommandSubcommands(t *testing.T) { // Initialize the CLI to set up the root command - setupTestCLI() + rootCmd := initCLI() // Test that root command has subcommands commands := rootCmd.Commands() diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index de3a27a..848ecf6 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -218,7 +218,7 @@ func (p *prog) registerControlServerHandler() { loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // Re-fetch pin code from API. - if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil { + if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil { if rc.DeactivationPin != nil { cdDeactivationPin.Store(*rc.DeactivationPin) } else { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 12a4be4..298a80d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1097,7 +1097,7 @@ func (p *prog) doSelfUninstall(pr *proxyResponse) { if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) - _, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) + _, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index b3bda67..91fab80 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -60,7 +60,7 @@ func init() { func Main() { ctrld.InitConfig(v, "ctrld") - initCLI() + rootCmd := initCLI() if err := rootCmd.Execute(); err != nil { mainLog.Load().Error().Msg(err.Error()) os.Exit(1) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 55e7751..f7586ab 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -304,7 +304,7 @@ func (p *prog) apiConfigReload() { doReloadApiConfig := func(forced bool, logger *ctrld.Logger) { loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) - resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) + resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) selfUninstallCheck(err, p, logger) if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") @@ -362,7 +362,7 @@ func (p *prog) apiConfigReload() { } if cfgErr != nil { logger.Warn().Err(err).Msg("skipping invalid custom config") - if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, rootCmd.Version, cdDev, true); err != nil { + if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, appVersion, cdDev, true); err != nil { logger.Error().Err(err).Msg("could not mark custom last update failed") } return