diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 048212a..3733f71 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -13,6 +13,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -206,6 +207,7 @@ func initStartCmd() *cobra.Command { NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { + args = filterEmptyStrings(args) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") @@ -219,6 +221,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] + osArgs = filterEmptyStrings(osArgs) if os.Args[1] == "service" { osArgs = os.Args[3:] } @@ -566,6 +569,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Args: func(cmd *cobra.Command, args []string) error { + args = filterEmptyStrings(args) if len(args) > 0 { return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") @@ -1381,3 +1385,11 @@ func initServicesCmd(commands ...*cobra.Command) *cobra.Command { return serviceCmd } + +// filterEmptyStrings removes empty strings from a slice of strings. +// It returns a new slice containing only non-empty strings. +func filterEmptyStrings(slice []string) []string { + return slices.DeleteFunc(slice, func(s string) bool { + return s == "" + }) +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index dd8de9f..90f403d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -35,6 +35,7 @@ import ( "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" "github.com/Control-D-Inc/ctrld/internal/router" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -328,7 +329,7 @@ func (p *prog) apiConfigReload() { // Performing self-upgrade check for production version. if isStable { - selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) + _ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) } if resolverConfig.DeactivationPin != nil { @@ -607,6 +608,12 @@ func (p *prog) setupClientInfoDiscover(selfIP string) { format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } + if leaseFiles := dnsmasq.AdditionalLeaseFiles(); len(leaseFiles) > 0 { + mainLog.Load().Debug().Msgf("watching additional lease files: %v", leaseFiles) + for _, leaseFile := range leaseFiles { + p.ciTable.AddLeaseFile(leaseFile, ctrld.Dnsmasq) + } + } } // runClientInfoDiscover runs the client info discover. @@ -1467,14 +1474,15 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { } } -// selfUpgradeCheck checks if the version target vt is greater -// than the current one cv, perform self-upgrade then. +// shouldUpgrade checks if the version target vt is greater than the current one cv. +// Major version upgrades are not allowed to prevent breaking changes. // // The callers must ensure curVer and logger are non-nil. -func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { +// Returns true if upgrade is allowed, false otherwise. +func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { if vt == "" { logger.Debug().Msg("no version target set, skipped checking self-upgrade") - return + return false } vts := vt if !strings.HasPrefix(vts, "v") { @@ -1483,28 +1491,58 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { targetVer, err := semver.NewVersion(vts) if err != nil { logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt) - return + return false } + + // Prevent major version upgrades to avoid breaking changes + if targetVer.Major() != cv.Major() { + logger.Warn(). + Str("target", vt). + Str("current", cv.String()). + Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major()) + return false + } + if !targetVer.GreaterThan(cv) { logger.Debug(). Str("target", vt). Str("current", cv.String()). Msgf("target version is not greater than current one, skipped self-upgrade") - return + return false } + return true +} + +// performUpgrade executes the self-upgrade command. +// Returns true if upgrade was initiated successfully, false otherwise. +func performUpgrade(vt string) bool { exe, err := os.Executable() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") - return + return false } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") - return + return false } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts) + mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt) + return true +} + +// selfUpgradeCheck checks if the version target vt is greater +// than the current one cv, perform self-upgrade then. +// Major version upgrades are not allowed to prevent breaking changes. +// +// The callers must ensure curVer and logger are non-nil. +// Returns true if upgrade is allowed and should proceed, false otherwise. +func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool { + if shouldUpgrade(vt, cv, logger) { + return performUpgrade(vt) + } + return false } // leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go index 5f2f8e1..c4ef5c3 100644 --- a/cmd/cli/prog_test.go +++ b/cmd/cli/prog_test.go @@ -1,11 +1,15 @@ package cli import ( + "runtime" "testing" "time" - "github.com/Control-D-Inc/ctrld" + "github.com/Masterminds/semver/v3" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld" ) func Test_prog_dnsWatchdogEnabled(t *testing.T) { @@ -55,3 +59,215 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) { }) } } + +func Test_shouldUpgrade(t *testing.T) { + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "empty version target", + versionTarget: "", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is empty", + }, + { + name: "invalid version target", + versionTarget: "invalid-version", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when version target is invalid", + }, + { + name: "same version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version equals current version", + }, + { + name: "older version", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v1.1.0"), + shouldUpgrade: false, + description: "should skip upgrade when target version is older than current version", + }, + { + name: "patch upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow patch version upgrade within same major version", + }, + { + name: "minor upgrade allowed", + versionTarget: "v1.1.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow minor version upgrade within same major version", + }, + { + name: "major upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block major version upgrade", + }, + { + name: "major downgrade blocked", + versionTarget: "v1.0.0", + currentVersion: makeVersion("v2.0.0"), + shouldUpgrade: false, + description: "should block major version downgrade", + }, + { + name: "version without v prefix", + versionTarget: "1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should handle version target without v prefix", + }, + { + name: "complex version upgrade allowed", + versionTarget: "v1.5.3", + currentVersion: makeVersion("v1.4.2"), + shouldUpgrade: true, + description: "should allow complex version upgrade within same major version", + }, + { + name: "complex major upgrade blocked", + versionTarget: "v3.1.0", + currentVersion: makeVersion("v2.5.3"), + shouldUpgrade: false, + description: "should block complex major version upgrade", + }, + { + name: "pre-release version upgrade allowed", + versionTarget: "v1.0.1-beta.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow pre-release version upgrade within same major version", + }, + { + name: "pre-release major upgrade blocked", + versionTarget: "v2.0.0-alpha.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block pre-release major version upgrade", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_selfUpgradeCheck(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipped due to Windows file locking issue on Github Action runners") + } + + // Helper function to create a version + makeVersion := func(v string) *semver.Version { + ver, err := semver.NewVersion(v) + if err != nil { + t.Fatalf("failed to create version %s: %v", v, err) + } + return ver + } + + tests := []struct { + name string + versionTarget string + currentVersion *semver.Version + shouldUpgrade bool + description string + }{ + { + name: "upgrade allowed", + versionTarget: "v1.0.1", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: true, + description: "should allow upgrade and attempt to perform it", + }, + { + name: "upgrade blocked", + versionTarget: "v2.0.0", + currentVersion: makeVersion("v1.0.0"), + shouldUpgrade: false, + description: "should block upgrade and not attempt to perform it", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + + // Call the function and capture the result + result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger) + + // Assert the expected result + assert.Equal(t, tc.shouldUpgrade, result, tc.description) + }) + } +} + +func Test_performUpgrade(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipped due to Windows file locking issue on Github Action runners") + } + + tests := []struct { + name string + versionTarget string + expectedResult bool + description string + }{ + { + name: "valid version target", + versionTarget: "v1.0.1", + expectedResult: true, + description: "should attempt to perform upgrade with valid version target", + }, + { + name: "empty version target", + versionTarget: "", + expectedResult: true, + description: "should attempt to perform upgrade even with empty version target", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Call the function and capture the result + result := performUpgrade(tc.versionTarget) + assert.Equal(t, tc.expectedResult, result, tc.description) + }) + } +} diff --git a/internal/router/dnsmasq/conf.go b/internal/router/dnsmasq/conf.go index b168042..bb81d60 100644 --- a/internal/router/dnsmasq/conf.go +++ b/internal/router/dnsmasq/conf.go @@ -6,6 +6,7 @@ import ( "errors" "io" "os" + "path/filepath" "strings" ) @@ -28,3 +29,62 @@ func interfaceNameFromReader(r io.Reader) (string, error) { } return "", errors.New("not found") } + +// AdditionalConfigFiles returns a list of Dnsmasq configuration files found in the "/tmp/etc" directory. +func AdditionalConfigFiles() []string { + if paths, err := filepath.Glob("/tmp/etc/dnsmasq-*.conf"); err == nil { + return paths + } + return nil +} + +// AdditionalLeaseFiles returns a list of lease file paths corresponding to the Dnsmasq configuration files. +func AdditionalLeaseFiles() []string { + cfgFiles := AdditionalConfigFiles() + if len(cfgFiles) == 0 { + return nil + } + leaseFiles := make([]string, 0, len(cfgFiles)) + for _, cfgFile := range cfgFiles { + if leaseFile := leaseFileFromConfigFileName(cfgFile); leaseFile != "" { + leaseFiles = append(leaseFiles, leaseFile) + + } else { + leaseFiles = append(leaseFiles, defaultLeaseFileFromConfigPath(cfgFile)) + } + } + return leaseFiles +} + +// leaseFileFromConfigFileName retrieves the DHCP lease file path by reading and parsing the provided configuration file. +func leaseFileFromConfigFileName(cfgFile string) string { + if f, err := os.Open(cfgFile); err == nil { + return leaseFileFromReader(f) + } + return "" +} + +// leaseFileFromReader parses the given io.Reader for the "dhcp-leasefile" configuration and returns its value as a string. +func leaseFileFromReader(r io.Reader) string { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") { + continue + } + before, after, found := strings.Cut(line, "=") + if !found { + continue + } + if before == "dhcp-leasefile" { + return after + } + } + return "" +} + +// defaultLeaseFileFromConfigPath generates the default lease file path based on the provided configuration file path. +func defaultLeaseFileFromConfigPath(path string) string { + name := filepath.Base(path) + return filepath.Join("/var/lib/misc", strings.TrimSuffix(name, ".conf")+".leases") +} diff --git a/internal/router/dnsmasq/conf_test.go b/internal/router/dnsmasq/conf_test.go index 99a0710..9ca672b 100644 --- a/internal/router/dnsmasq/conf_test.go +++ b/internal/router/dnsmasq/conf_test.go @@ -1,6 +1,7 @@ package dnsmasq import ( + "io" "strings" "testing" ) @@ -44,3 +45,49 @@ interface=eth0 }) } } + +func Test_leaseFileFromReader(t *testing.T) { + tests := []struct { + name string + in io.Reader + expected string + }{ + { + "default", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +dhcp-leasefile=/var/lib/misc/dnsmasq-1.leases +script-arp +`), + "/var/lib/misc/dnsmasq-1.leases", + }, + { + "non-default", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +dhcp-leasefile=/tmp/var/lib/misc/dnsmasq-1.leases +script-arp +`), + "/tmp/var/lib/misc/dnsmasq-1.leases", + }, + { + "missing", + strings.NewReader(` +dhcp-script=/sbin/dhcpc_lease +script-arp +`), + "", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := leaseFileFromReader(tc.in); got != tc.expected { + t.Errorf("leaseFileFromReader() = %v, want %v", got, tc.expected) + } + }) + } + +} diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index 819bd59..058b0b5 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -4,6 +4,7 @@ import ( "errors" "html/template" "net" + "os" "path/filepath" "strings" @@ -26,9 +27,13 @@ max-cache-ttl=0 {{- end}} ` -const MerlinConfPath = "/tmp/etc/dnsmasq.conf" -const MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf" -const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" +const ( + MerlinConfPath = "/tmp/etc/dnsmasq.conf" + MerlinJffsConfDir = "/jffs/configs" + MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf" + MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" +) + const MerlinPostConfMarker = `# GENERATED BY ctrld - EOF` const MerlinPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY @@ -159,3 +164,27 @@ func FirewallaSelfInterfaces() []*net.Interface { } return ifaces } + +const ( + ubios43ConfPath = "/run/dnsmasq.dhcp.conf.d" + ubios42ConfPath = "/run/dnsmasq.conf.d" + ubios43PidFile = "/run/dnsmasq-main.pid" + ubios42PidFile = "/run/dnsmasq.pid" + UbiosConfName = "zzzctrld.conf" +) + +// UbiosConfPath returns the appropriate configuration path based on the system's directory structure. +func UbiosConfPath() string { + if st, _ := os.Stat(ubios43ConfPath); st != nil && st.IsDir() { + return ubios43ConfPath + } + return ubios42ConfPath +} + +// UbiosPidFile returns the appropriate dnsmasq pid file based on the system's directory structure. +func UbiosPidFile() string { + if st, _ := os.Stat(ubios43PidFile); st != nil && !st.IsDir() { + return ubios43PidFile + } + return ubios42PidFile +} diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index 2e229ac..7364ac1 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "strings" "github.com/kardianos/service" @@ -181,7 +182,7 @@ func ContentFilteringEnabled() bool { // DnsShieldEnabled reports whether DNS Shield is enabled. // See: https://community.ui.com/releases/UniFi-OS-Dream-Machines-3-2-7/251dfc1e-f4dd-4264-a080-3be9d8b9e02b func DnsShieldEnabled() bool { - buf, err := os.ReadFile("/var/run/dnsmasq.conf.d/dns.conf") + buf, err := os.ReadFile(filepath.Join(dnsmasq.UbiosConfPath(), "dns.conf")) if err != nil { return false } diff --git a/internal/router/merlin/merlin.go b/internal/router/merlin/merlin.go index cacc508..c1c6821 100644 --- a/internal/router/merlin/merlin.go +++ b/internal/router/merlin/merlin.go @@ -6,6 +6,7 @@ import ( "io" "os" "os/exec" + "path/filepath" "strings" "time" "unicode" @@ -20,10 +21,18 @@ import ( const Name = "merlin" +// nvramKvMap is a map of NVRAM key-value pairs used to configure and manage Merlin-specific settings. var nvramKvMap = map[string]string{ "dnspriv_enable": "0", // Ensure Merlin native DoT disabled. } +// dnsmasqConfig represents configuration paths for dnsmasq operations in Merlin firmware. +type dnsmasqConfig struct { + confPath string + jffsConfPath string +} + +// Merlin represents a configuration handler for setting up and managing ctrld on Merlin routers. type Merlin struct { cfg *ctrld.Config } @@ -33,18 +42,22 @@ func New(cfg *ctrld.Config) *Merlin { return &Merlin{cfg: cfg} } +// ConfigureService configures the service based on the provided configuration. It returns an error if the configuration fails. func (m *Merlin) ConfigureService(config *service.Config) error { return nil } +// Install sets up the necessary configurations and services required for the Merlin instance to function properly. func (m *Merlin) Install(_ *service.Config) error { return nil } +// Uninstall removes the ctrld-related configurations and services from the Merlin router and reverts to the original state. func (m *Merlin) Uninstall(_ *service.Config) error { return nil } +// PreRun prepares the Merlin instance for operation by waiting for essential services and directories to become available. func (m *Merlin) PreRun() error { // Wait NTP ready. _ = m.Cleanup() @@ -66,6 +79,7 @@ func (m *Merlin) PreRun() error { return nil } +// Setup initializes and configures the Merlin instance for use, including setting up dnsmasq and necessary nvram settings. func (m *Merlin) Setup() error { if m.cfg.FirstListener().IsDirectDnsListener() { return nil @@ -79,35 +93,10 @@ func (m *Merlin) Setup() error { return err } - // Copy current dnsmasq config to /jffs/configs/dnsmasq.conf, - // Then we will run postconf script on this file. - // - // Normally, adding postconf script is enough. However, we see - // reports on some Merlin devices that postconf scripts does not - // work, but manipulating the config directly via /jffs/configs does. - src, err := os.Open(dnsmasq.MerlinConfPath) - if err != nil { - return fmt.Errorf("failed to open dnsmasq config: %w", err) - } - defer src.Close() - - dst, err := os.Create(dnsmasq.MerlinJffsConfPath) - if err != nil { - return fmt.Errorf("failed to create %s: %w", dnsmasq.MerlinJffsConfPath, err) - } - defer dst.Close() - - if _, err := io.Copy(dst, src); err != nil { - return fmt.Errorf("failed to copy current dnsmasq config: %w", err) - } - if err := dst.Close(); err != nil { - return fmt.Errorf("failed to save %s: %w", dnsmasq.MerlinJffsConfPath, err) - } - - // Run postconf script on /jffs/configs/dnsmasq.conf directly. - cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, dnsmasq.MerlinJffsConfPath) - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to run post conf: %s: %w", string(out), err) + for _, cfg := range getDnsmasqConfigs() { + if err := m.setupDnsmasq(cfg); err != nil { + return fmt.Errorf("failed to setup dnsmasq: config: %s, error: %w", cfg.confPath, err) + } } // Restart dnsmasq service. @@ -122,6 +111,7 @@ func (m *Merlin) Setup() error { return nil } +// Cleanup restores the original dnsmasq and nvram configurations and restarts dnsmasq if necessary. func (m *Merlin) Cleanup() error { if m.cfg.FirstListener().IsDirectDnsListener() { return nil @@ -143,9 +133,11 @@ func (m *Merlin) Cleanup() error { if err := os.WriteFile(dnsmasq.MerlinPostConfPath, merlinParsePostConf(buf), 0750); err != nil { return err } - // Remove /jffs/configs/dnsmasq.conf file. - if err := os.Remove(dnsmasq.MerlinJffsConfPath); err != nil && !os.IsNotExist(err) { - return err + + for _, cfg := range getDnsmasqConfigs() { + if err := m.cleanupDnsmasqJffs(cfg); err != nil { + return fmt.Errorf("failed to cleanup jffs dnsmasq: config: %s, error: %w", cfg.confPath, err) + } } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { @@ -154,6 +146,54 @@ func (m *Merlin) Cleanup() error { return nil } +// setupDnsmasq sets up dnsmasq configuration by writing postconf, copying configuration, and running a postconf script. +func (m *Merlin) setupDnsmasq(cfg *dnsmasqConfig) error { + src, err := os.Open(cfg.confPath) + if os.IsNotExist(err) { + return nil // nothing to do if conf file does not exist. + } + if err != nil { + return fmt.Errorf("failed to open dnsmasq config: %w", err) + } + defer src.Close() + + // Copy current dnsmasq config to cfg.jffsConfPath, + // Then we will run postconf script on this file. + // + // Normally, adding postconf script is enough. However, we see + // reports on some Merlin devices that postconf scripts does not + // work, but manipulating the config directly via /jffs/configs does. + dst, err := os.Create(cfg.jffsConfPath) + if err != nil { + return fmt.Errorf("failed to create %s: %w", cfg.jffsConfPath, err) + } + defer dst.Close() + + if _, err := io.Copy(dst, src); err != nil { + return fmt.Errorf("failed to copy current dnsmasq config: %w", err) + } + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to save %s: %w", cfg.jffsConfPath, err) + } + + // Run postconf script on cfg.jffsConfPath directly. + cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, cfg.jffsConfPath) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to run post conf: %s: %w", string(out), err) + } + return nil +} + +// cleanupDnsmasqJffs removes the JFFS configuration file specified in the given dnsmasqConfig, if it exists. +func (m *Merlin) cleanupDnsmasqJffs(cfg *dnsmasqConfig) error { + // Remove cfg.jffsConfPath file. + if err := os.Remove(cfg.jffsConfPath); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +// writeDnsmasqPostconf writes the requireddnsmasqConfigs post-configuration for dnsmasq to enable custom DNS settings with ctrld. func (m *Merlin) writeDnsmasqPostconf() error { buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) // Already setup. @@ -179,6 +219,8 @@ func (m *Merlin) writeDnsmasqPostconf() error { return os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750) } +// restartDNSMasq restarts the dnsmasq service by executing the appropriate system command using "service". +// Returns an error if the command fails or if there is an issue processing the command output. func restartDNSMasq() error { if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil { return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err) @@ -186,6 +228,22 @@ func restartDNSMasq() error { return nil } +// getDnsmasqConfigs retrieves a list of dnsmasqConfig containing configuration and JFFS paths for dnsmasq operations. +func getDnsmasqConfigs() []*dnsmasqConfig { + cfgs := []*dnsmasqConfig{ + {dnsmasq.MerlinConfPath, dnsmasq.MerlinJffsConfPath}, + } + for _, path := range dnsmasq.AdditionalConfigFiles() { + jffsConfPath := filepath.Join(dnsmasq.MerlinJffsConfDir, filepath.Base(path)) + cfgs = append(cfgs, &dnsmasqConfig{path, jffsConfPath}) + } + + return cfgs +} + +// merlinParsePostConf parses the dnsmasq post configuration by removing content after the MerlinPostConfMarker, if present. +// If no marker is found, the original buffer is returned unmodified. +// Returns nil if the input buffer is empty. func merlinParsePostConf(buf []byte) []byte { if len(buf) == 0 { return nil @@ -197,6 +255,7 @@ func merlinParsePostConf(buf []byte) []byte { return buf } +// waitDirExists waits until the specified directory exists, polling its existence every second. func waitDirExists(dir string) { for { if _, err := os.Stat(dir); !os.IsNotExist(err) { diff --git a/internal/router/service_ubios.go b/internal/router/service_ubios.go index 8077c07..9ad971d 100644 --- a/internal/router/service_ubios.go +++ b/internal/router/service_ubios.go @@ -13,14 +13,13 @@ import ( "time" "github.com/kardianos/service" + + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) // This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go, // with modification for supporting ubios v1 init system. -// Keep in sync with ubios.ubiosDNSMasqConfigPath -const ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" - type ubiosSvc struct { i service.Interface platform string @@ -86,7 +85,7 @@ func (s *ubiosSvc) Install() error { }{ s.Config, path, - ubiosDNSMasqConfigPath, + filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), } if err := s.template().Execute(f, to); err != nil { diff --git a/internal/router/ubios/ubios.go b/internal/router/ubios/ubios.go index a1f0b6c..cba6842 100644 --- a/internal/router/ubios/ubios.go +++ b/internal/router/ubios/ubios.go @@ -3,6 +3,7 @@ package ubios import ( "bytes" "os" + "path/filepath" "strconv" "github.com/kardianos/service" @@ -12,19 +13,19 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router/edgeos" ) -const ( - Name = "ubios" - ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" - ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf" -) +const Name = "ubios" type Ubios struct { - cfg *ctrld.Config + cfg *ctrld.Config + dnsmasqConfPath string } // New returns a router.Router for configuring/setup/run ctrld on Ubios routers. func New(cfg *ctrld.Config) *Ubios { - return &Ubios{cfg: cfg} + return &Ubios{ + cfg: cfg, + dnsmasqConfPath: filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), + } } func (u *Ubios) ConfigureService(config *service.Config) error { @@ -59,7 +60,7 @@ func (u *Ubios) Setup() error { if err != nil { return err } - if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil { + if err := os.WriteFile(u.dnsmasqConfPath, []byte(data), 0600); err != nil { return err } // Restart dnsmasq service. @@ -74,7 +75,7 @@ func (u *Ubios) Cleanup() error { return nil } // Remove the custom dnsmasq config - if err := os.Remove(ubiosDNSMasqConfigPath); err != nil { + if err := os.Remove(u.dnsmasqConfPath); err != nil { return err } // Restart dnsmasq service. @@ -85,7 +86,7 @@ func (u *Ubios) Cleanup() error { } func restartDNSMasq() error { - buf, err := os.ReadFile("/run/dnsmasq.pid") + buf, err := os.ReadFile(dnsmasq.UbiosPidFile()) if err != nil { return err }