diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 8809f25..1c7cbd5 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -152,7 +152,7 @@ func initCLI() { dir, err := userHomeDir() if err != nil { - log.Fatalf("failed to get config dir: %v", dir) + log.Fatalf("failed to get config dir: %v", err) } for _, config := range configs { ctrld.SetConfigNameWithPath(v, config.name, dir) diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 1b7f25a..7388983 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -167,8 +167,10 @@ func (p *prog) deAllocateIP() error { } func (p *prog) setDNS() { - // On router, ctrld run as a DNS provider, it does not have to change system DNS. - if router.Name() != "" { + switch router.Name() { + case router.DDWrt, router.OpenWrt, router.Ubios: + // On router, ctrld run as a DNS forwarder, it does not have to change system DNS. + // Except for Merlin, which has WAN DNS setup on boot for NTP. return } if cfg.Listener == nil || cfg.Listener["0"] == nil { @@ -199,8 +201,9 @@ func (p *prog) setDNS() { } func (p *prog) resetDNS() { - // See comment in p.setDNS method. - if router.Name() != "" { + switch router.Name() { + case router.DDWrt, router.OpenWrt, router.Ubios: + // See comment in p.setDNS method. return } if iface == "" { diff --git a/internal/router/dnsmasq.go b/internal/router/dnsmasq.go index 35d9022..9c6fd2b 100644 --- a/internal/router/dnsmasq.go +++ b/internal/router/dnsmasq.go @@ -4,3 +4,31 @@ const dnsMasqConfigContent = `# GENERATED BY ctrld - DO NOT MODIFY no-resolv server=127.0.0.1#5353 ` + +const merlinDNSMasqPostConfPath = "/jffs/scripts/dnsmasq.postconf" +const merlinDNSMasqPostConfMarker = `# GENERATED BY ctrld - EOF` + +const merlinDNSMasqPostConf = `# GENERATED BY ctrld - DO NOT MODIFY + +#!/bin/sh + +config_file="$1" +. /usr/sbin/helper.sh + +pid=$(cat /tmp/ctrld.pid 2>/dev/null) +if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then + pc_delete "servers-file" "$config_file" # no WAN DNS settings + pc_append "no-resolv" "$config_file" # do not read /etc/resolv.conf + pc_append "server=127.0.0.1#5354" "$config_file" # use ctrld as upstream + + + # For John fork + pc_delete "resolv-file" "$config_file" # no WAN DNS settings + + # Change /etc/resolv.conf, which may be changed by WAN DNS setup + pc_delete "nameserver" /etc/resolv.conf + pc_append "nameserver 127.0.0.1" /etc/resolv.conf + + exit 0 +fi +` diff --git a/internal/router/merlin.go b/internal/router/merlin.go new file mode 100644 index 0000000..0048d17 --- /dev/null +++ b/internal/router/merlin.go @@ -0,0 +1,76 @@ +package router + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "strings" + "unicode" +) + +func setupMerlin() error { + buf, err := os.ReadFile(merlinDNSMasqPostConfPath) + // Already setup. + if bytes.Contains(buf, []byte(merlinDNSMasqPostConfMarker)) { + return nil + } + if err != nil && !os.IsNotExist(err) { + return err + } + + data := strings.Join([]string{ + merlinDNSMasqPostConf, + "\n", + merlinDNSMasqPostConfMarker, + "\n", + string(buf), + }, "\n") + // Write dnsmasq post conf file. + if err := os.WriteFile(merlinDNSMasqPostConfPath, []byte(data), 0750); err != nil { + return err + } + // Restart dnsmasq service. + if err := merlinRestartDNSMasq(); err != nil { + return err + } + return nil +} + +func cleanupMerlin() error { + buf, err := os.ReadFile(merlinDNSMasqPostConf) + if err != nil && !os.IsNotExist(err) { + return err + } + // Restore dnsmasq post conf file. + if err := os.WriteFile(merlinDNSMasqPostConfPath, merlinParsePostConf(buf), 0750); err != nil { + return err + } + // Restart dnsmasq service. + if err := merlinRestartDNSMasq(); err != nil { + return err + } + return nil +} + +func postInstallMerlin() error { + return nil +} + +func merlinRestartDNSMasq() error { + if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil { + return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err) + } + return nil +} + +func merlinParsePostConf(buf []byte) []byte { + if len(buf) == 0 { + return nil + } + parts := bytes.Split(buf, []byte(merlinDNSMasqPostConfMarker)) + if len(parts) != 1 { + return bytes.TrimLeftFunc(parts[1], unicode.IsSpace) + } + return buf +} diff --git a/internal/router/merlin_test.go b/internal/router/merlin_test.go new file mode 100644 index 0000000..2a3c241 --- /dev/null +++ b/internal/router/merlin_test.go @@ -0,0 +1,38 @@ +package router + +import ( + "bytes" + "strings" + "testing" +) + +func Test_merlinParsePostConf(t *testing.T) { + origContent := "# foo" + data := strings.Join([]string{ + merlinDNSMasqPostConf, + "\n", + merlinDNSMasqPostConfMarker, + "\n", + }, "\n") + + tests := []struct { + name string + data string + expected string + }{ + {"empty", "", ""}, + {"no ctrld", origContent, origContent}, + {"ctrld with data", data + origContent, origContent}, + {"ctrld without data", data, ""}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + //t.Parallel() + if got := merlinParsePostConf([]byte(tc.data)); !bytes.Equal(got, []byte(tc.expected)) { + t.Errorf("unexpected result, want: %q, got: %q", tc.expected, string(got)) + } + }) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 9987314..bca465c 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -40,9 +40,11 @@ func Configure(c *ctrld.Config) error { switch name { case DDWrt: return setupDDWrt() + case Merlin: + return setupMerlin() case OpenWrt: return setupOpenWrt() - case Merlin, Ubios: + case Ubios: default: return ErrNotSupported } @@ -72,9 +74,12 @@ func PostInstall() error { switch name { case DDWrt: return postInstallDDWrt() + case Merlin: + return postInstallMerlin() case OpenWrt: return postInstallOpenWrt() - case Merlin, Ubios: + + case Ubios: } return nil } @@ -83,11 +88,13 @@ func PostInstall() error { func Cleanup() error { name := Name() switch name { - case OpenWrt: - return cleanupOpenWrt() case DDWrt: return cleanupDDWrt() - case Merlin, Ubios: + case Merlin: + return cleanupMerlin() + case OpenWrt: + return cleanupOpenWrt() + case Ubios: } return nil } @@ -98,7 +105,9 @@ func ListenAddress() string { switch name { case DDWrt, OpenWrt: return "127.0.0.1:5353" - case Merlin, Ubios: + case Merlin: + return "127.0.0.1:5354" + case Ubios: } return "" } diff --git a/internal/router/service.go b/internal/router/service.go index 8845564..7c8bb40 100644 --- a/internal/router/service.go +++ b/internal/router/service.go @@ -7,16 +7,27 @@ import ( ) func init() { - system := &linuxSystemService{ - name: "ddwrt", - detect: func() bool { return Name() == DDWrt }, - interactive: func() bool { - is, _ := isInteractive() - return is + systems := []service.System{ + &linuxSystemService{ + name: "ddwrt", + detect: func() bool { return Name() == DDWrt }, + interactive: func() bool { + is, _ := isInteractive() + return is + }, + new: newddwrtService, + }, + &linuxSystemService{ + name: "merlin", + detect: func() bool { return Name() == Merlin }, + interactive: func() bool { + is, _ := isInteractive() + return is + }, + new: newMerlinService, }, - new: newddwrtService, } - systems := append([]service.System{system}, service.AvailableSystems()...) + systems = append(systems, service.AvailableSystems()...) service.ChooseSystem(systems...) } diff --git a/internal/router/service_ddwrt.go b/internal/router/service_ddwrt.go index 097b3e6..953035e 100644 --- a/internal/router/service_ddwrt.go +++ b/internal/router/service_ddwrt.go @@ -157,7 +157,11 @@ func (s *ddwrtSvc) Run() (err error) { return err } - var sigChan = make(chan os.Signal, 3) + if interactice, _ := isInteractive(); !interactice { + signal.Ignore(syscall.SIGHUP) + signal.Ignore(sigCHLD) + } + var sigChan = make(chan os.Signal, 2) signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) <-sigChan diff --git a/internal/router/service_merlin.go b/internal/router/service_merlin.go new file mode 100644 index 0000000..3c5e2de --- /dev/null +++ b/internal/router/service_merlin.go @@ -0,0 +1,308 @@ +package router + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "syscall" + "text/template" + + "github.com/kardianos/service" +) + +const merlinJFFSScriptPath = "/jffs/scripts/services-start" + +type merlinSvc struct { + i service.Interface + platform string + *service.Config +} + +func newMerlinService(i service.Interface, platform string, c *service.Config) (service.Service, error) { + s := &merlinSvc{ + i: i, + platform: platform, + Config: c, + } + return s, nil +} + +func (s *merlinSvc) String() string { + if len(s.DisplayName) > 0 { + return s.DisplayName + } + return s.Name +} + +func (s *merlinSvc) Platform() string { + return s.platform +} + +func (s *merlinSvc) configPath() string { + path, err := os.Executable() + if err != nil { + return "" + } + return path + ".startup" +} + +func (s *merlinSvc) template() *template.Template { + return template.Must(template.New("").Parse(merlinSvcScript)) +} + +func (s *merlinSvc) Install() error { + exePath, err := os.Executable() + if err != nil { + return err + } + + if !strings.HasPrefix(exePath, "/jffs/") { + return errors.New("could not install service outside /jffs") + } + if _, err := nvram("set", "jffs2_scripts=1"); err != nil { + return err + } + if _, err := nvram("commit"); err != nil { + return err + } + + confPath := s.configPath() + if _, err := os.Stat(confPath); err == nil { + return fmt.Errorf("already installed: %s", confPath) + } + + var to = &struct { + *service.Config + Path string + }{ + s.Config, + exePath, + } + + f, err := os.Create(confPath) + if err != nil { + return fmt.Errorf("os.Create: %w", err) + } + defer f.Close() + + if err := s.template().Execute(f, to); err != nil { + return fmt.Errorf("s.template.Execute: %w", err) + } + + if err = os.Chmod(confPath, 0755); err != nil { + return fmt.Errorf("os.Chmod: startup script: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(merlinJFFSScriptPath), 0755); err != nil { + return fmt.Errorf("os.MkdirAll: %w", err) + } + if _, err := os.Stat(merlinJFFSScriptPath); os.IsNotExist(err) { + if err := os.WriteFile(merlinJFFSScriptPath, []byte("#!/bin/sh\n"), 0755); err != nil { + return err + } + } + if err := os.Chmod(merlinJFFSScriptPath, 0755); err != nil { + return fmt.Errorf("os.Chmod: jffs script: %w", err) + } + + tmpScript, err := os.CreateTemp("", "ctrld_install") + if err != nil { + return fmt.Errorf("os.CreateTemp: %w", err) + } + defer os.Remove(tmpScript.Name()) + defer tmpScript.Close() + + if _, err := tmpScript.WriteString(merlinAddStartupScript); err != nil { + return fmt.Errorf("tmpScript.WriteString: %w", err) + } + if err := tmpScript.Close(); err != nil { + return fmt.Errorf("tmpScript.Close: %w", err) + } + if err := exec.Command("sh", tmpScript.Name(), s.configPath()+" start", merlinJFFSScriptPath).Run(); err != nil { + return fmt.Errorf("exec.Command: add startup script: %w", err) + } + + return nil +} + +func (s *merlinSvc) Uninstall() error { + if err := os.Remove(s.configPath()); err != nil { + return fmt.Errorf("os.Remove: %w", err) + } + tmpScript, err := os.CreateTemp("", "ctrld_uninstall") + if err != nil { + return fmt.Errorf("os.CreateTemp: %w", err) + } + defer os.Remove(tmpScript.Name()) + defer tmpScript.Close() + + if _, err := tmpScript.WriteString(merlinRemoveStartupScript); err != nil { + return fmt.Errorf("tmpScript.WriteString: %w", err) + } + if err := tmpScript.Close(); err != nil { + return fmt.Errorf("tmpScript.Close: %w", err) + } + if err := exec.Command("sh", tmpScript.Name(), s.configPath()+" start", merlinJFFSScriptPath).Run(); err != nil { + return fmt.Errorf("exec.Command: %w", err) + } + return nil +} + +func (s *merlinSvc) Logger(errs chan<- error) (service.Logger, error) { + if service.Interactive() { + return service.ConsoleLogger, nil + } + return s.SystemLogger(errs) +} + +func (s *merlinSvc) SystemLogger(errs chan<- error) (service.Logger, error) { + return newSysLogger(s.Name, errs) +} + +func (s *merlinSvc) Run() (err error) { + err = s.i.Start(s) + if err != nil { + return err + } + + if interactice, _ := isInteractive(); !interactice { + signal.Ignore(syscall.SIGHUP) + signal.Ignore(sigCHLD) + } + + var sigChan = make(chan os.Signal, 3) + signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) + <-sigChan + + return s.i.Stop(s) +} + +func (s *merlinSvc) Status() (service.Status, error) { + if _, err := os.Stat(s.configPath()); os.IsNotExist(err) { + return service.StatusUnknown, service.ErrNotInstalled + } + out, err := exec.Command(s.configPath(), "status").CombinedOutput() + if err != nil { + return service.StatusUnknown, err + } + switch string(bytes.TrimSpace(out)) { + case "running": + return service.StatusRunning, nil + default: + return service.StatusStopped, nil + } +} + +func (s *merlinSvc) Start() error { + return exec.Command(s.configPath(), "start").Run() +} + +func (s *merlinSvc) Stop() error { + return exec.Command(s.configPath(), "stop").Run() +} + +func (s *merlinSvc) Restart() error { + err := s.Stop() + if err != nil { + return err + } + return s.Start() +} + +const merlinSvcScript = `#!/bin/sh + +name="{{.Name}}" +cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}" +pid_file="/tmp/$name.pid" + +get_pid() { + cat "$pid_file" +} + +is_running() { + [ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) " +} + +case "$1" in + start) + if is_running; then + logger -c "Already started" + else + logger -c "Starting $name" + if [ -f /rom/ca-bundle.crt ]; then + # For John’s fork + export SSL_CERT_FILE=/rom/ca-bundle.crt + fi + $cmd & + echo $! > "$pid_file" + chmod 600 "$pid_file" + if ! is_running; then + logger -c "Failed to start $name" + exit 1 + fi + fi + ;; + stop) + if is_running; then + logger -c "Stopping $name..." + kill "$(get_pid)" + for _ in 1 2 3 4 5; do + if ! is_running; then + logger -c "stopped" + if [ -f "$pid_file" ]; then + rm "$pid_file" + fi + exit 0 + fi + printf "." + sleep 2 + done + logger -c "failed to stop $name" + exit 1 + fi + exit 1 + ;; + restart) + $0 stop + $0 start + ;; + status) + if is_running; then + echo "running" + else + echo "stopped" + exit 1 + fi + ;; + *) + echo "Usage: $0 {start|stop|restart|status}" + exit 1 + ;; +esac +exit 0 +` + +const merlinAddStartupScript = `#!/bin/sh + +line=$1 +file=$2 + +. /usr/sbin/helper.sh + +pc_append "$line" "$file" +` + +const merlinRemoveStartupScript = `#!/bin/sh + +line=$1 +file=$2 + +. /usr/sbin/helper.sh + +pc_delete "$line" "$file" +` diff --git a/internal/router/signal.go b/internal/router/signal.go new file mode 100644 index 0000000..f6f11ed --- /dev/null +++ b/internal/router/signal.go @@ -0,0 +1,7 @@ +//go:build !windows + +package router + +import "syscall" + +const sigCHLD = syscall.SIGCHLD diff --git a/internal/router/signal_windows.go b/internal/router/signal_windows.go new file mode 100644 index 0000000..6526575 --- /dev/null +++ b/internal/router/signal_windows.go @@ -0,0 +1,5 @@ +package router + +import "syscall" + +const sigCHLD = syscall.SIGHUP diff --git a/internal/router/syslog.go b/internal/router/syslog.go new file mode 100644 index 0000000..008bbeb --- /dev/null +++ b/internal/router/syslog.go @@ -0,0 +1,49 @@ +//go:build linux || darwin || freebsd + +package router + +import ( + "fmt" + "log/syslog" + + "github.com/kardianos/service" +) + +func newSysLogger(name string, errs chan<- error) (service.Logger, error) { + w, err := syslog.New(syslog.LOG_INFO, name) + if err != nil { + return nil, err + } + return sysLogger{w, errs}, nil +} + +type sysLogger struct { + *syslog.Writer + errs chan<- error +} + +func (s sysLogger) send(err error) error { + if err != nil && s.errs != nil { + s.errs <- err + } + return err +} + +func (s sysLogger) Error(v ...interface{}) error { + return s.send(s.Writer.Err(fmt.Sprint(v...))) +} +func (s sysLogger) Warning(v ...interface{}) error { + return s.send(s.Writer.Warning(fmt.Sprint(v...))) +} +func (s sysLogger) Info(v ...interface{}) error { + return s.send(s.Writer.Info(fmt.Sprint(v...))) +} +func (s sysLogger) Errorf(format string, a ...interface{}) error { + return s.send(s.Writer.Err(fmt.Sprintf(format, a...))) +} +func (s sysLogger) Warningf(format string, a ...interface{}) error { + return s.send(s.Writer.Warning(fmt.Sprintf(format, a...))) +} +func (s sysLogger) Infof(format string, a ...interface{}) error { + return s.send(s.Writer.Info(fmt.Sprintf(format, a...))) +} diff --git a/internal/router/syslog_windows.go b/internal/router/syslog_windows.go new file mode 100644 index 0000000..ecac969 --- /dev/null +++ b/internal/router/syslog_windows.go @@ -0,0 +1,7 @@ +package router + +import "github.com/kardianos/service" + +func newSysLogger(name string, errs chan<- error) (service.Logger, error) { + return service.ConsoleLogger, nil +}