From ce29b5d217624192c975ddc09a073827b180deb6 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 14 Jul 2025 15:28:01 +0700 Subject: [PATCH] refactor: split selfUpgradeCheck into version check and upgrade execution - Move version checking logic to shouldUpgrade for testability - Move upgrade command execution to performUpgrade - selfUpgradeCheck now composes these two for clarity - Update and expand tests: focus on logic, not side effects - Improves maintainability, testability, and separation of concerns --- cmd/cli/prog.go | 51 ++++++++--- cmd/cli/prog_test.go | 209 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 249 insertions(+), 11 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 48a1e07..90f403d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -329,7 +329,7 @@ func (p *prog) apiConfigReload() { // Performing self-upgrade check for production version. if isStable { - selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) + _ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) } if resolverConfig.DeactivationPin != nil { @@ -1474,14 +1474,15 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { } } -// selfUpgradeCheck checks if the version target vt is greater -// than the current one cv, perform self-upgrade then. +// shouldUpgrade checks if the version target vt is greater than the current one cv. +// Major version upgrades are not allowed to prevent breaking changes. // // The callers must ensure curVer and logger are non-nil. -func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { +// Returns true if upgrade is allowed, false otherwise. +func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { if vt == "" { logger.Debug().Msg("no version target set, skipped checking self-upgrade") - return + return false } vts := vt if !strings.HasPrefix(vts, "v") { @@ -1490,28 +1491,58 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { targetVer, err := semver.NewVersion(vts) if err != nil { logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt) - return + return false } + + // Prevent major version upgrades to avoid breaking changes + if targetVer.Major() != cv.Major() { + logger.Warn(). + Str("target", vt). + Str("current", cv.String()). + Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major()) + return false + } + if !targetVer.GreaterThan(cv) { logger.Debug(). Str("target", vt). Str("current", cv.String()). Msgf("target version is not greater than current one, skipped self-upgrade") - return + return false } + return true +} + +// performUpgrade executes the self-upgrade command. +// Returns true if upgrade was initiated successfully, false otherwise. +func performUpgrade(vt string) bool { exe, err := os.Executable() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") - return + return false } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") - return + return false } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts) + mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt) + return true +} + +// selfUpgradeCheck checks if the version target vt is greater +// than the current one cv, perform self-upgrade then. +// Major version upgrades are not allowed to prevent breaking changes. +// +// The callers must ensure curVer and logger are non-nil. +// Returns true if upgrade is allowed and should proceed, false otherwise. +func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool { + if shouldUpgrade(vt, cv, logger) { + return performUpgrade(vt) + } + return false } // leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go index 5f2f8e1..1fee462 100644 --- a/cmd/cli/prog_test.go +++ b/cmd/cli/prog_test.go @@ -4,8 +4,11 @@ import ( "testing" "time" - "github.com/Control-D-Inc/ctrld" + "github.com/Masterminds/semver/v3" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld" ) func Test_prog_dnsWatchdogEnabled(t *testing.T) { @@ -55,3 +58,207 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) { }) } } + +func Test_shouldUpgrade(t *testing.T) { + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "empty version target", + versionTarget: "", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is empty", + }, + { + name: "invalid version target", + versionTarget: "invalid-version", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is invalid", + }, + { + name: "same version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version equals current version", + }, + { + name: "older version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.1.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version is older than current version", + }, + { + name: "patch upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow patch version upgrade within same major version", + }, + { + name: "minor upgrade allowed", + versionTarget: "v1.1.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow minor version upgrade within same major version", + }, + { + name: "major upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block major version upgrade", + }, + { + name: "major downgrade blocked", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v2.0.0"), + shouldUpgrade: false, + description: "should block major version downgrade", + }, + { + name: "version without v prefix", + versionTarget: "1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should handle version target without v prefix", + }, + { + name: "complex version upgrade allowed", + versionTarget: "v1.5.3", + currentVersion: makeVersion("v1.4.2"), + shouldUpgrade: true, + description: "should allow complex version upgrade within same major version", + }, + { + name: "complex major upgrade blocked", + versionTarget: "v3.1.0", + currentVersion: makeVersion("v2.5.3"), + shouldUpgrade: false, + description: "should block complex major version upgrade", + }, + { + name: "pre-release version upgrade allowed", + versionTarget: "v1.0.1-beta.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow pre-release version upgrade within same major version", + }, + { + name: "pre-release major upgrade blocked", + versionTarget: "v2.0.0-alpha.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block pre-release major version upgrade", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_selfUpgradeCheck(t *testing.T) { + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow upgrade and attempt to perform it", + }, + { + name: "upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block upgrade and not attempt to perform it", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_performUpgrade(t *testing.T) { + tests := []struct { + name string + versionTarget string + expectedResult bool + description string + }{ + { + name: "valid version target", + versionTarget: "v1.0.1", + expectedResult: true, + description: "should attempt to perform upgrade with valid version target", + }, + { + name: "empty version target", + versionTarget: "", + expectedResult: true, + description: "should attempt to perform upgrade even with empty version target", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Call the function and capture the result + result := performUpgrade(tc.versionTarget) + assert.Equal(t, tc.expectedResult, result, tc.description) + }) + } +}