From 6663925c4d576109df349d98b3907714f9afa0ce Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 30 Jun 2025 15:22:25 +0700 Subject: [PATCH 1/6] internal/router: support Merlin Guest Network Pro VLAN By looking for any additional dnsmasq configuration files under /tmp/etc, and handling them like default one. --- cmd/cli/prog.go | 7 ++ internal/router/dnsmasq/conf.go | 60 +++++++++++++ internal/router/dnsmasq/conf_test.go | 47 ++++++++++ internal/router/dnsmasq/dnsmasq.go | 10 ++- internal/router/merlin/merlin.go | 123 ++++++++++++++++++++------- 5 files changed, 212 insertions(+), 35 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index dd8de9f..48a1e07 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -35,6 +35,7 @@ import ( "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" "github.com/Control-D-Inc/ctrld/internal/router" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -607,6 +608,12 @@ func (p *prog) setupClientInfoDiscover(selfIP string) { format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } + if leaseFiles := dnsmasq.AdditionalLeaseFiles(); len(leaseFiles) > 0 { + mainLog.Load().Debug().Msgf("watching additional lease files: %v", leaseFiles) + for _, leaseFile := range leaseFiles { + p.ciTable.AddLeaseFile(leaseFile, ctrld.Dnsmasq) + } + } } // runClientInfoDiscover runs the client info discover. diff --git a/internal/router/dnsmasq/conf.go b/internal/router/dnsmasq/conf.go index b168042..bb81d60 100644 --- a/internal/router/dnsmasq/conf.go +++ b/internal/router/dnsmasq/conf.go @@ -6,6 +6,7 @@ import ( "errors" "io" "os" + "path/filepath" "strings" ) @@ -28,3 +29,62 @@ func interfaceNameFromReader(r io.Reader) (string, error) { } return "", errors.New("not found") } + +// AdditionalConfigFiles returns a list of Dnsmasq configuration files found in the "/tmp/etc" directory. +func AdditionalConfigFiles() []string { + if paths, err := filepath.Glob("/tmp/etc/dnsmasq-*.conf"); err == nil { + return paths + } + return nil +} + +// AdditionalLeaseFiles returns a list of lease file paths corresponding to the Dnsmasq configuration files. +func AdditionalLeaseFiles() []string { + cfgFiles := AdditionalConfigFiles() + if len(cfgFiles) == 0 { + return nil + } + leaseFiles := make([]string, 0, len(cfgFiles)) + for _, cfgFile := range cfgFiles { + if leaseFile := leaseFileFromConfigFileName(cfgFile); leaseFile != "" { + leaseFiles = append(leaseFiles, leaseFile) + + } else { + leaseFiles = append(leaseFiles, defaultLeaseFileFromConfigPath(cfgFile)) + } + } + return leaseFiles +} + +// leaseFileFromConfigFileName retrieves the DHCP lease file path by reading and parsing the provided configuration file. +func leaseFileFromConfigFileName(cfgFile string) string { + if f, err := os.Open(cfgFile); err == nil { + return leaseFileFromReader(f) + } + return "" +} + +// leaseFileFromReader parses the given io.Reader for the "dhcp-leasefile" configuration and returns its value as a string. +func leaseFileFromReader(r io.Reader) string { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") { + continue + } + before, after, found := strings.Cut(line, "=") + if !found { + continue + } + if before == "dhcp-leasefile" { + return after + } + } + return "" +} + +// defaultLeaseFileFromConfigPath generates the default lease file path based on the provided configuration file path. +func defaultLeaseFileFromConfigPath(path string) string { + name := filepath.Base(path) + return filepath.Join("/var/lib/misc", strings.TrimSuffix(name, ".conf")+".leases") +} diff --git a/internal/router/dnsmasq/conf_test.go b/internal/router/dnsmasq/conf_test.go index 99a0710..9ca672b 100644 --- a/internal/router/dnsmasq/conf_test.go +++ b/internal/router/dnsmasq/conf_test.go @@ -1,6 +1,7 @@ package dnsmasq import ( + "io" "strings" "testing" ) @@ -44,3 +45,49 @@ interface=eth0 }) } } + +func Test_leaseFileFromReader(t *testing.T) { + tests := []struct { + name string + in io.Reader + expected string + }{ + { + "default", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +dhcp-leasefile=/var/lib/misc/dnsmasq-1.leases +script-arp +`), + "/var/lib/misc/dnsmasq-1.leases", + }, + { + "non-default", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +dhcp-leasefile=/tmp/var/lib/misc/dnsmasq-1.leases +script-arp +`), + "/tmp/var/lib/misc/dnsmasq-1.leases", + }, + { + "missing", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +script-arp +`), + "", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := leaseFileFromReader(tc.in); got != tc.expected { + t.Errorf("leaseFileFromReader() = %v, want %v", got, tc.expected) + } + }) + } + +} diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index 819bd59..a690ee4 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -26,9 +26,13 @@ max-cache-ttl=0 {{- end}} ` -const MerlinConfPath = "/tmp/etc/dnsmasq.conf" -const MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf" -const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" +const ( + MerlinConfPath = "/tmp/etc/dnsmasq.conf" + MerlinJffsConfDir = "/jffs/configs" + MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf" + MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" +) + const MerlinPostConfMarker = `# GENERATED BY ctrld - EOF` const MerlinPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY diff --git a/internal/router/merlin/merlin.go b/internal/router/merlin/merlin.go index cacc508..c1c6821 100644 --- a/internal/router/merlin/merlin.go +++ b/internal/router/merlin/merlin.go @@ -6,6 +6,7 @@ import ( "io" "os" "os/exec" + "path/filepath" "strings" "time" "unicode" @@ -20,10 +21,18 @@ import ( const Name = "merlin" +// nvramKvMap is a map of NVRAM key-value pairs used to configure and manage Merlin-specific settings. var nvramKvMap = map[string]string{ "dnspriv_enable": "0", // Ensure Merlin native DoT disabled. } +// dnsmasqConfig represents configuration paths for dnsmasq operations in Merlin firmware. +type dnsmasqConfig struct { + confPath string + jffsConfPath string +} + +// Merlin represents a configuration handler for setting up and managing ctrld on Merlin routers. type Merlin struct { cfg *ctrld.Config } @@ -33,18 +42,22 @@ func New(cfg *ctrld.Config) *Merlin { return &Merlin{cfg: cfg} } +// ConfigureService configures the service based on the provided configuration. It returns an error if the configuration fails. func (m *Merlin) ConfigureService(config *service.Config) error { return nil } +// Install sets up the necessary configurations and services required for the Merlin instance to function properly. func (m *Merlin) Install(_ *service.Config) error { return nil } +// Uninstall removes the ctrld-related configurations and services from the Merlin router and reverts to the original state. func (m *Merlin) Uninstall(_ *service.Config) error { return nil } +// PreRun prepares the Merlin instance for operation by waiting for essential services and directories to become available. func (m *Merlin) PreRun() error { // Wait NTP ready. _ = m.Cleanup() @@ -66,6 +79,7 @@ func (m *Merlin) PreRun() error { return nil } +// Setup initializes and configures the Merlin instance for use, including setting up dnsmasq and necessary nvram settings. func (m *Merlin) Setup() error { if m.cfg.FirstListener().IsDirectDnsListener() { return nil @@ -79,35 +93,10 @@ func (m *Merlin) Setup() error { return err } - // Copy current dnsmasq config to /jffs/configs/dnsmasq.conf, - // Then we will run postconf script on this file. - // - // Normally, adding postconf script is enough. However, we see - // reports on some Merlin devices that postconf scripts does not - // work, but manipulating the config directly via /jffs/configs does. - src, err := os.Open(dnsmasq.MerlinConfPath) - if err != nil { - return fmt.Errorf("failed to open dnsmasq config: %w", err) - } - defer src.Close() - - dst, err := os.Create(dnsmasq.MerlinJffsConfPath) - if err != nil { - return fmt.Errorf("failed to create %s: %w", dnsmasq.MerlinJffsConfPath, err) - } - defer dst.Close() - - if _, err := io.Copy(dst, src); err != nil { - return fmt.Errorf("failed to copy current dnsmasq config: %w", err) - } - if err := dst.Close(); err != nil { - return fmt.Errorf("failed to save %s: %w", dnsmasq.MerlinJffsConfPath, err) - } - - // Run postconf script on /jffs/configs/dnsmasq.conf directly. - cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, dnsmasq.MerlinJffsConfPath) - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to run post conf: %s: %w", string(out), err) + for _, cfg := range getDnsmasqConfigs() { + if err := m.setupDnsmasq(cfg); err != nil { + return fmt.Errorf("failed to setup dnsmasq: config: %s, error: %w", cfg.confPath, err) + } } // Restart dnsmasq service. @@ -122,6 +111,7 @@ func (m *Merlin) Setup() error { return nil } +// Cleanup restores the original dnsmasq and nvram configurations and restarts dnsmasq if necessary. func (m *Merlin) Cleanup() error { if m.cfg.FirstListener().IsDirectDnsListener() { return nil @@ -143,9 +133,11 @@ func (m *Merlin) Cleanup() error { if err := os.WriteFile(dnsmasq.MerlinPostConfPath, merlinParsePostConf(buf), 0750); err != nil { return err } - // Remove /jffs/configs/dnsmasq.conf file. - if err := os.Remove(dnsmasq.MerlinJffsConfPath); err != nil && !os.IsNotExist(err) { - return err + + for _, cfg := range getDnsmasqConfigs() { + if err := m.cleanupDnsmasqJffs(cfg); err != nil { + return fmt.Errorf("failed to cleanup jffs dnsmasq: config: %s, error: %w", cfg.confPath, err) + } } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { @@ -154,6 +146,54 @@ func (m *Merlin) Cleanup() error { return nil } +// setupDnsmasq sets up dnsmasq configuration by writing postconf, copying configuration, and running a postconf script. +func (m *Merlin) setupDnsmasq(cfg *dnsmasqConfig) error { + src, err := os.Open(cfg.confPath) + if os.IsNotExist(err) { + return nil // nothing to do if conf file does not exist. + } + if err != nil { + return fmt.Errorf("failed to open dnsmasq config: %w", err) + } + defer src.Close() + + // Copy current dnsmasq config to cfg.jffsConfPath, + // Then we will run postconf script on this file. + // + // Normally, adding postconf script is enough. However, we see + // reports on some Merlin devices that postconf scripts does not + // work, but manipulating the config directly via /jffs/configs does. + dst, err := os.Create(cfg.jffsConfPath) + if err != nil { + return fmt.Errorf("failed to create %s: %w", cfg.jffsConfPath, err) + } + defer dst.Close() + + if _, err := io.Copy(dst, src); err != nil { + return fmt.Errorf("failed to copy current dnsmasq config: %w", err) + } + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to save %s: %w", cfg.jffsConfPath, err) + } + + // Run postconf script on cfg.jffsConfPath directly. + cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, cfg.jffsConfPath) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to run post conf: %s: %w", string(out), err) + } + return nil +} + +// cleanupDnsmasqJffs removes the JFFS configuration file specified in the given dnsmasqConfig, if it exists. +func (m *Merlin) cleanupDnsmasqJffs(cfg *dnsmasqConfig) error { + // Remove cfg.jffsConfPath file. + if err := os.Remove(cfg.jffsConfPath); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +// writeDnsmasqPostconf writes the requireddnsmasqConfigs post-configuration for dnsmasq to enable custom DNS settings with ctrld. func (m *Merlin) writeDnsmasqPostconf() error { buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) // Already setup. @@ -179,6 +219,8 @@ func (m *Merlin) writeDnsmasqPostconf() error { return os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750) } +// restartDNSMasq restarts the dnsmasq service by executing the appropriate system command using "service". +// Returns an error if the command fails or if there is an issue processing the command output. func restartDNSMasq() error { if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil { return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err) @@ -186,6 +228,22 @@ func restartDNSMasq() error { return nil } +// getDnsmasqConfigs retrieves a list of dnsmasqConfig containing configuration and JFFS paths for dnsmasq operations. +func getDnsmasqConfigs() []*dnsmasqConfig { + cfgs := []*dnsmasqConfig{ + {dnsmasq.MerlinConfPath, dnsmasq.MerlinJffsConfPath}, + } + for _, path := range dnsmasq.AdditionalConfigFiles() { + jffsConfPath := filepath.Join(dnsmasq.MerlinJffsConfDir, filepath.Base(path)) + cfgs = append(cfgs, &dnsmasqConfig{path, jffsConfPath}) + } + + return cfgs +} + +// merlinParsePostConf parses the dnsmasq post configuration by removing content after the MerlinPostConfMarker, if present. +// If no marker is found, the original buffer is returned unmodified. +// Returns nil if the input buffer is empty. func merlinParsePostConf(buf []byte) []byte { if len(buf) == 0 { return nil @@ -197,6 +255,7 @@ func merlinParsePostConf(buf []byte) []byte { return buf } +// waitDirExists waits until the specified directory exists, polling its existence every second. func waitDirExists(dir string) { for { if _, err := os.Stat(dir); !os.IsNotExist(err) { From de24fa293ec6b5dc93b85095b676494eae1fd7c4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 8 Jul 2025 20:40:24 +0700 Subject: [PATCH 2/6] internal/router: support Ubios 4.3+ This change improves compatibility with newer UniFi OS versions while maintaining backward compatibility with UniFi OS 4.2 and earlier. The refactoring also reduces code duplication and improves maintainability by centralizing dnsmasq configuration path logic. --- internal/router/dnsmasq/dnsmasq.go | 25 +++++++++++++++++++++++++ internal/router/edgeos/edgeos.go | 3 ++- internal/router/service_ubios.go | 7 +++---- internal/router/ubios/ubios.go | 21 +++++++++++---------- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index a690ee4..058b0b5 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -4,6 +4,7 @@ import ( "errors" "html/template" "net" + "os" "path/filepath" "strings" @@ -163,3 +164,27 @@ func FirewallaSelfInterfaces() []*net.Interface { } return ifaces } + +const ( + ubios43ConfPath = "/run/dnsmasq.dhcp.conf.d" + ubios42ConfPath = "/run/dnsmasq.conf.d" + ubios43PidFile = "/run/dnsmasq-main.pid" + ubios42PidFile = "/run/dnsmasq.pid" + UbiosConfName = "zzzctrld.conf" +) + +// UbiosConfPath returns the appropriate configuration path based on the system's directory structure. +func UbiosConfPath() string { + if st, _ := os.Stat(ubios43ConfPath); st != nil && st.IsDir() { + return ubios43ConfPath + } + return ubios42ConfPath +} + +// UbiosPidFile returns the appropriate dnsmasq pid file based on the system's directory structure. +func UbiosPidFile() string { + if st, _ := os.Stat(ubios43PidFile); st != nil && !st.IsDir() { + return ubios43PidFile + } + return ubios42PidFile +} diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index 2e229ac..7364ac1 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "strings" "github.com/kardianos/service" @@ -181,7 +182,7 @@ func ContentFilteringEnabled() bool { // DnsShieldEnabled reports whether DNS Shield is enabled. // See: https://community.ui.com/releases/UniFi-OS-Dream-Machines-3-2-7/251dfc1e-f4dd-4264-a080-3be9d8b9e02b func DnsShieldEnabled() bool { - buf, err := os.ReadFile("/var/run/dnsmasq.conf.d/dns.conf") + buf, err := os.ReadFile(filepath.Join(dnsmasq.UbiosConfPath(), "dns.conf")) if err != nil { return false } diff --git a/internal/router/service_ubios.go b/internal/router/service_ubios.go index 8077c07..9ad971d 100644 --- a/internal/router/service_ubios.go +++ b/internal/router/service_ubios.go @@ -13,14 +13,13 @@ import ( "time" "github.com/kardianos/service" + + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) // This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go, // with modification for supporting ubios v1 init system. -// Keep in sync with ubios.ubiosDNSMasqConfigPath -const ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" - type ubiosSvc struct { i service.Interface platform string @@ -86,7 +85,7 @@ func (s *ubiosSvc) Install() error { }{ s.Config, path, - ubiosDNSMasqConfigPath, + filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), } if err := s.template().Execute(f, to); err != nil { diff --git a/internal/router/ubios/ubios.go b/internal/router/ubios/ubios.go index a1f0b6c..cba6842 100644 --- a/internal/router/ubios/ubios.go +++ b/internal/router/ubios/ubios.go @@ -3,6 +3,7 @@ package ubios import ( "bytes" "os" + "path/filepath" "strconv" "github.com/kardianos/service" @@ -12,19 +13,19 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router/edgeos" ) -const ( - Name = "ubios" - ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" - ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf" -) +const Name = "ubios" type Ubios struct { - cfg *ctrld.Config + cfg *ctrld.Config + dnsmasqConfPath string } // New returns a router.Router for configuring/setup/run ctrld on Ubios routers. func New(cfg *ctrld.Config) *Ubios { - return &Ubios{cfg: cfg} + return &Ubios{ + cfg: cfg, + dnsmasqConfPath: filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), + } } func (u *Ubios) ConfigureService(config *service.Config) error { @@ -59,7 +60,7 @@ func (u *Ubios) Setup() error { if err != nil { return err } - if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil { + if err := os.WriteFile(u.dnsmasqConfPath, []byte(data), 0600); err != nil { return err } // Restart dnsmasq service. @@ -74,7 +75,7 @@ func (u *Ubios) Cleanup() error { return nil } // Remove the custom dnsmasq config - if err := os.Remove(ubiosDNSMasqConfigPath); err != nil { + if err := os.Remove(u.dnsmasqConfPath); err != nil { return err } // Restart dnsmasq service. @@ -85,7 +86,7 @@ func (u *Ubios) Cleanup() error { } func restartDNSMasq() error { - buf, err := os.ReadFile("/run/dnsmasq.pid") + buf, err := os.ReadFile(dnsmasq.UbiosPidFile()) if err != nil { return err } From ce29b5d217624192c975ddc09a073827b180deb6 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 14 Jul 2025 15:28:01 +0700 Subject: [PATCH 3/6] refactor: split selfUpgradeCheck into version check and upgrade execution - Move version checking logic to shouldUpgrade for testability - Move upgrade command execution to performUpgrade - selfUpgradeCheck now composes these two for clarity - Update and expand tests: focus on logic, not side effects - Improves maintainability, testability, and separation of concerns --- cmd/cli/prog.go | 51 ++++++++--- cmd/cli/prog_test.go | 209 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 249 insertions(+), 11 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 48a1e07..90f403d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -329,7 +329,7 @@ func (p *prog) apiConfigReload() { // Performing self-upgrade check for production version. if isStable { - selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) + _ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) } if resolverConfig.DeactivationPin != nil { @@ -1474,14 +1474,15 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { } } -// selfUpgradeCheck checks if the version target vt is greater -// than the current one cv, perform self-upgrade then. +// shouldUpgrade checks if the version target vt is greater than the current one cv. +// Major version upgrades are not allowed to prevent breaking changes. // // The callers must ensure curVer and logger are non-nil. -func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { +// Returns true if upgrade is allowed, false otherwise. +func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { if vt == "" { logger.Debug().Msg("no version target set, skipped checking self-upgrade") - return + return false } vts := vt if !strings.HasPrefix(vts, "v") { @@ -1490,28 +1491,58 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { targetVer, err := semver.NewVersion(vts) if err != nil { logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt) - return + return false } + + // Prevent major version upgrades to avoid breaking changes + if targetVer.Major() != cv.Major() { + logger.Warn(). + Str("target", vt). + Str("current", cv.String()). + Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major()) + return false + } + if !targetVer.GreaterThan(cv) { logger.Debug(). Str("target", vt). Str("current", cv.String()). Msgf("target version is not greater than current one, skipped self-upgrade") - return + return false } + return true +} + +// performUpgrade executes the self-upgrade command. +// Returns true if upgrade was initiated successfully, false otherwise. +func performUpgrade(vt string) bool { exe, err := os.Executable() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") - return + return false } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") - return + return false } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts) + mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt) + return true +} + +// selfUpgradeCheck checks if the version target vt is greater +// than the current one cv, perform self-upgrade then. +// Major version upgrades are not allowed to prevent breaking changes. +// +// The callers must ensure curVer and logger are non-nil. +// Returns true if upgrade is allowed and should proceed, false otherwise. +func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool { + if shouldUpgrade(vt, cv, logger) { + return performUpgrade(vt) + } + return false } // leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go index 5f2f8e1..1fee462 100644 --- a/cmd/cli/prog_test.go +++ b/cmd/cli/prog_test.go @@ -4,8 +4,11 @@ import ( "testing" "time" - "github.com/Control-D-Inc/ctrld" + "github.com/Masterminds/semver/v3" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld" ) func Test_prog_dnsWatchdogEnabled(t *testing.T) { @@ -55,3 +58,207 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) { }) } } + +func Test_shouldUpgrade(t *testing.T) { + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "empty version target", + versionTarget: "", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is empty", + }, + { + name: "invalid version target", + versionTarget: "invalid-version", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is invalid", + }, + { + name: "same version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version equals current version", + }, + { + name: "older version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.1.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version is older than current version", + }, + { + name: "patch upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow patch version upgrade within same major version", + }, + { + name: "minor upgrade allowed", + versionTarget: "v1.1.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow minor version upgrade within same major version", + }, + { + name: "major upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block major version upgrade", + }, + { + name: "major downgrade blocked", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v2.0.0"), + shouldUpgrade: false, + description: "should block major version downgrade", + }, + { + name: "version without v prefix", + versionTarget: "1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should handle version target without v prefix", + }, + { + name: "complex version upgrade allowed", + versionTarget: "v1.5.3", + currentVersion: makeVersion("v1.4.2"), + shouldUpgrade: true, + description: "should allow complex version upgrade within same major version", + }, + { + name: "complex major upgrade blocked", + versionTarget: "v3.1.0", + currentVersion: makeVersion("v2.5.3"), + shouldUpgrade: false, + description: "should block complex major version upgrade", + }, + { + name: "pre-release version upgrade allowed", + versionTarget: "v1.0.1-beta.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow pre-release version upgrade within same major version", + }, + { + name: "pre-release major upgrade blocked", + versionTarget: "v2.0.0-alpha.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block pre-release major version upgrade", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_selfUpgradeCheck(t *testing.T) { + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow upgrade and attempt to perform it", + }, + { + name: "upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block upgrade and not attempt to perform it", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_performUpgrade(t *testing.T) { + tests := []struct { + name string + versionTarget string + expectedResult bool + description string + }{ + { + name: "valid version target", + versionTarget: "v1.0.1", + expectedResult: true, + description: "should attempt to perform upgrade with valid version target", + }, + { + name: "empty version target", + versionTarget: "", + expectedResult: true, + description: "should attempt to perform upgrade even with empty version target", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Call the function and capture the result + result := performUpgrade(tc.versionTarget) + assert.Equal(t, tc.expectedResult, result, tc.description) + }) + } +} From 0948161529505bd52c82f6a913bbc546bb21bd7f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 15 Jul 2025 20:59:57 +0700 Subject: [PATCH 4/6] Avoiding Windows runners file locking issue --- cmd/cli/prog_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go index 1fee462..c4ef5c3 100644 --- a/cmd/cli/prog_test.go +++ b/cmd/cli/prog_test.go @@ -1,6 +1,7 @@ package cli import ( + "runtime" "testing" "time" @@ -185,6 +186,10 @@ func Test_shouldUpgrade(t *testing.T) { } func Test_selfUpgradeCheck(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipped due to Windows file locking issue on Github Action runners") + } + // Helper function to create a version makeVersion := func(v string) *semver.Version { ver, err := semver.NewVersion(v) @@ -233,6 +238,10 @@ func Test_selfUpgradeCheck(t *testing.T) { } func Test_performUpgrade(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipped due to Windows file locking issue on Github Action runners") + } + tests := []struct { name string versionTarget string From e6160912497036d58c9e52debb4f694d3c39f78c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 15 Jul 2025 21:47:50 +0700 Subject: [PATCH 5/6] cmd/cli: ignore empty positional argument for start command The validation was added during v1.4.0 release, but causing one-liner install failed unexpectedly. --- cmd/cli/commands.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 048212a..2b43320 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -13,6 +13,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -206,6 +207,9 @@ func initStartCmd() *cobra.Command { NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { + args = slices.DeleteFunc(args, func(arg string) bool { + return arg == "" + }) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") @@ -566,6 +570,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { + args = slices.DeleteFunc(args, func(arg string) bool { + return arg == "" + }) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") From 36a7423634bccd0a894e6a36b0e8ff01891c6773 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 15 Jul 2025 22:49:52 +0700 Subject: [PATCH 6/6] refactor: extract empty string filtering to reusable function - Add filterEmptyStrings utility function for consistent string filtering - Replace inline slices.DeleteFunc calls with filterEmptyStrings - Apply filtering to osArgs in addition to command args - Improves code readability and reduces duplication - Uses slices.DeleteFunc internally for efficient filtering --- cmd/cli/commands.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 2b43320..3733f71 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -207,9 +207,7 @@ func initStartCmd() *cobra.Command { NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { - args = slices.DeleteFunc(args, func(arg string) bool { - return arg == "" - }) + args = filterEmptyStrings(args) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") @@ -223,6 +221,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] + osArgs = filterEmptyStrings(osArgs) if os.Args[1] == "service" { osArgs = os.Args[3:] } @@ -570,9 +569,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { - args = slices.DeleteFunc(args, func(arg string) bool { - return arg == "" - }) + args = filterEmptyStrings(args) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") @@ -1388,3 +1385,11 @@ func initServicesCmd(commands ...*cobra.Command) *cobra.Command { return serviceCmd } + +// filterEmptyStrings removes empty strings from a slice of strings. +// It returns a new slice containing only non-empty strings. +func filterEmptyStrings(slice []string) []string { + return slices.DeleteFunc(slice, func(s string) bool { + return s == "" + }) +}