mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
all: implement router setup for ddwrt
This commit is contained in:
committed by
Cuong Manh Le
parent
c94be0df35
commit
8a2cdbfaa3
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
"tailscale.com/net/interfaces"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
@@ -46,6 +48,7 @@ var (
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
defaultConfigWritten = false
|
||||
defaultConfigFile = "ctrld.toml"
|
||||
rootCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
|
||||
@@ -146,8 +149,13 @@ func initCLI() {
|
||||
{"config", false},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get config dir: %v", dir)
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigName(v, config.name)
|
||||
ctrld.SetConfigNameWithPath(v, config.name, dir)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
break
|
||||
@@ -200,15 +208,21 @@ func initCLI() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if router.Name() != "" {
|
||||
mainLog.Debug().Msg("Router setup")
|
||||
err := router.Configure(&cfg)
|
||||
if errors.Is(err, router.ErrNotSupported) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure router")
|
||||
if setupRouter {
|
||||
switch platform := router.Name(); {
|
||||
case platform == router.DDWrt:
|
||||
rootCertPool = certs.CACertPool()
|
||||
fallthrough
|
||||
case platform != "":
|
||||
mainLog.Debug().Msg("Router setup")
|
||||
err := router.Configure(&cfg)
|
||||
if errors.Is(err, router.ErrNotSupported) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure router")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,6 +244,8 @@ func initCLI() {
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = runCmd.Flags().MarkHidden("router")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -247,7 +263,9 @@ func initCLI() {
|
||||
}
|
||||
setDependencies(sc)
|
||||
sc.Arguments = append([]string{"run"}, osArgs...)
|
||||
router.ConfigureService(sc)
|
||||
if err := router.ConfigureService(sc); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
@@ -255,7 +273,7 @@ func initCLI() {
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if dir, err := os.UserHomeDir(); err == nil {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
setWorkingDirectory(sc, dir)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
|
||||
@@ -332,6 +350,8 @@ func initCLI() {
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = startCmd.Flags().MarkHidden("router")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
@@ -430,6 +450,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
iface = "auto"
|
||||
}
|
||||
prog.resetDNS()
|
||||
mainLog.Debug().Msg("Router cleanup")
|
||||
if err := router.Cleanup(); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
@@ -821,3 +842,23 @@ func selfCheckStatus(status service.Status) service.Status {
|
||||
func unsupportedPlatformHelp(cmd *cobra.Command) {
|
||||
cmd.PrintErrln("Unsupported or incorrectly chosen router platform. Please open an issue and provide all relevant information: https://github.com/Control-D-Inc/ctrld/issues/new")
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.Merlin:
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exe), nil
|
||||
}
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
dir := "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ func initRouterCLI() {
|
||||
|
||||
cmdArgs := []string{"start"}
|
||||
cmdArgs = append(cmdArgs, osArgs(platform)...)
|
||||
cmdArgs = append(cmdArgs, "--router")
|
||||
command := exec.Command(exe, cmdArgs...)
|
||||
command.Stdout = os.Stdout
|
||||
command.Stderr = os.Stderr
|
||||
|
||||
@@ -29,6 +29,7 @@ var (
|
||||
cdUID string
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
setupRouter bool
|
||||
|
||||
mainLog = zerolog.New(io.Discard)
|
||||
)
|
||||
@@ -50,7 +51,7 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
}
|
||||
dir, _ := os.UserHomeDir()
|
||||
dir, _ := userHomeDir()
|
||||
if dir == "" {
|
||||
return logFilePath
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -77,6 +78,7 @@ func (p *prog) run() {
|
||||
} else {
|
||||
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
uc.SetupTransport()
|
||||
}
|
||||
|
||||
@@ -165,6 +167,10 @@ func (p *prog) deAllocateIP() error {
|
||||
}
|
||||
|
||||
func (p *prog) setDNS() {
|
||||
// On router, ctrld run as a DNS provider, it does not have to change system DNS.
|
||||
if router.Name() != "" {
|
||||
return
|
||||
}
|
||||
if cfg.Listener == nil || cfg.Listener["0"] == nil {
|
||||
return
|
||||
}
|
||||
@@ -193,6 +199,10 @@ func (p *prog) setDNS() {
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
// See comment in p.setDNS method.
|
||||
if router.Name() != "" {
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -24,12 +25,14 @@ type task struct {
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
stderrMsg(err.Error())
|
||||
stderrMsg(errors.Join(prevErr, err).Error())
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
20
config.go
20
config.go
@@ -2,6 +2,8 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -20,23 +22,26 @@ import (
|
||||
)
|
||||
|
||||
// SetConfigName set the config name that ctrld will look for.
|
||||
// DEPRECATED: use SetConfigNameWithPath instead.
|
||||
func SetConfigName(v *viper.Viper, name string) {
|
||||
v.SetConfigName(name)
|
||||
|
||||
configPath := "$HOME"
|
||||
// viper has its own way to get user home directory: https://github.com/spf13/viper/blob/v1.14.0/util.go#L134
|
||||
// To be consistent, we prefer os.UserHomeDir instead.
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configPath = homeDir
|
||||
}
|
||||
SetConfigNameWithPath(v, name, configPath)
|
||||
}
|
||||
|
||||
// SetConfigNameWithPath set the config path and name that ctrld will look for.
|
||||
func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
v.SetConfigName(name)
|
||||
v.AddConfigPath(configPath)
|
||||
v.AddConfigPath(".")
|
||||
}
|
||||
|
||||
// InitConfig initializes default config values for given *viper.Viper instance.
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
SetConfigName(v, name)
|
||||
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "127.0.0.1",
|
||||
@@ -104,6 +109,7 @@ type UpstreamConfig struct {
|
||||
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
|
||||
transport *http.Transport `mapstructure:"-" toml:"-"`
|
||||
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
|
||||
certPool *x509.CertPool `mapstructure:"-" toml:"-"`
|
||||
|
||||
g singleflight.Group
|
||||
bootstrapIPs []string
|
||||
@@ -152,6 +158,11 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetCertPool sets the system cert pool used for TLS connections.
|
||||
func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
@@ -297,6 +308,7 @@ func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
|
||||
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
|
||||
uc.transport.IdleConnTimeout = 5 * time.Second
|
||||
uc.transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
|
||||
dialerTimeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||
|
||||
@@ -18,6 +18,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
host := addr
|
||||
ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS)
|
||||
|
||||
7
dot.go
7
dot.go
@@ -20,12 +20,13 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dnsClient := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Dialer: dialer,
|
||||
Net: "tcp-tls",
|
||||
Dialer: dialer,
|
||||
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
|
||||
}
|
||||
endpoint := r.uc.Endpoint
|
||||
if r.uc.BootstrapIP != "" {
|
||||
dnsClient.TLSConfig = &tls.Config{ServerName: r.uc.Domain}
|
||||
dnsClient.TLSConfig.ServerName = r.uc.Domain
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
|
||||
}
|
||||
|
||||
3372
internal/certs/cacert.pem
Normal file
3372
internal/certs/cacert.pem
Normal file
File diff suppressed because it is too large
Load Diff
22
internal/certs/root_ca.go
Normal file
22
internal/certs/root_ca.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
_ "embed"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed cacert.pem
|
||||
caRoots []byte
|
||||
caCertPoolOnce sync.Once
|
||||
caCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
func CACertPool() *x509.CertPool {
|
||||
caCertPoolOnce.Do(func() {
|
||||
caCertPool = x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caRoots)
|
||||
})
|
||||
return caCertPool
|
||||
}
|
||||
27
internal/certs/root_ca_test.go
Normal file
27
internal/certs/root_ca_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCACertPool(t *testing.T) {
|
||||
c := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: CACertPool(),
|
||||
},
|
||||
},
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
resp, err := c.Get("https://freedns.controld.com/p1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !resp.TLS.HandshakeComplete {
|
||||
t.Error("TLS handshake is not complete")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package controld
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -13,7 +14,9 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -114,6 +117,10 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) {
|
||||
}
|
||||
return ctrldnet.Dialer.DialContext(ctx, proto, addr)
|
||||
}
|
||||
|
||||
if router.Name() == router.DDWrt {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: transport,
|
||||
|
||||
115
internal/router/ddwrt.go
Normal file
115
internal/router/ddwrt.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
nvramCtrldKeyPrefix = "ctrld_"
|
||||
nvramCtrldSetupKey = "ctrld_setup"
|
||||
nvramRCStartupKey = "rc_startup"
|
||||
)
|
||||
|
||||
var ddwrtJffs2NotEnabledErr = errors.New(`could not install service without jffs, follow this guide to enable:
|
||||
|
||||
https://wiki.dd-wrt.com/wiki/index.php/Journalling_Flash_File_System
|
||||
`)
|
||||
|
||||
var nvramKeys = map[string]string{
|
||||
"dns_dnsmasq": "1", // Make dnsmasq running but disable DNS ability, ctrld will replace it.
|
||||
"dnsmasq_options": dnsMasqConfigContent, // Configuration of dnsmasq set by ctrld.
|
||||
"dns_crypt": "0", // Disable DNSCrypt.
|
||||
}
|
||||
|
||||
func setupDDWrt() error {
|
||||
// Already setup.
|
||||
if val, _ := nvram("get", nvramCtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Backup current value, store ctrld's configs.
|
||||
for key, value := range nvramKeys {
|
||||
old, err := nvram("get", key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", old, err)
|
||||
}
|
||||
if out, err := nvram("set", nvramCtrldKeyPrefix+key+"="+old); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
if out, err := nvram("set", key+"="+value); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
if out, err := nvram("set", nvramCtrldSetupKey+"=1"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Commit.
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := ddwrtRestartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupDDWrt() error {
|
||||
// Restore old configs.
|
||||
for key := range nvramKeys {
|
||||
ctrldKey := nvramCtrldKeyPrefix + key
|
||||
old, err := nvram("get", ctrldKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", old, err)
|
||||
}
|
||||
_, _ = nvram("unset", ctrldKey)
|
||||
if out, err := nvram("set", key+"="+old); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
if out, err := nvram("unset", "ctrld_setup"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Commit.
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := ddwrtRestartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallDDWrt() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func nvram(args ...string) (string, error) {
|
||||
cmd := exec.Command("nvram", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("%s:%w", stderr.String(), err)
|
||||
}
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
func ddwrtRestartDNSMasq() error {
|
||||
if out, err := exec.Command("restart_dns").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart_dns: %s, %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ddwrtJff2Enabled() bool {
|
||||
out, _ := nvram("get", "enable_jffs2")
|
||||
return out == "1"
|
||||
}
|
||||
6
internal/router/dnsmasq.go
Normal file
6
internal/router/dnsmasq.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package router
|
||||
|
||||
const dnsMasqConfigContent = `# GENERATED BY ctrld - DO NOT MODIFY
|
||||
no-resolv
|
||||
server=127.0.0.1#5353
|
||||
`
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
const openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
const openwrtDNSMasqConfigContent = `# GENERATED BY ctrld - DO NOT MODIFY
|
||||
port=0
|
||||
`
|
||||
|
||||
func setupOpenWrt() error {
|
||||
// Delete dnsmasq port if set.
|
||||
@@ -22,7 +19,7 @@ func setupOpenWrt() error {
|
||||
return err
|
||||
}
|
||||
// Disable dnsmasq as DNS server.
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(openwrtDNSMasqConfigContent), 0600); err != nil {
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
@@ -30,7 +27,7 @@ func setupOpenWrt() error {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
if err := openwrtRestartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -42,7 +39,7 @@ func cleanupOpenWrt() error {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
if err := openwrtRestartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -66,7 +63,7 @@ func uci(args ...string) (string, error) {
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
func restartDNSMasq() error {
|
||||
func openwrtRestartDNSMasq() error {
|
||||
if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
|
||||
@@ -38,9 +38,11 @@ func SupportedPlatforms() []string {
|
||||
func Configure(c *ctrld.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case DDWrt:
|
||||
return setupDDWrt()
|
||||
case OpenWrt:
|
||||
return setupOpenWrt()
|
||||
case DDWrt, Merlin, Ubios:
|
||||
case Merlin, Ubios:
|
||||
default:
|
||||
return ErrNotSupported
|
||||
}
|
||||
@@ -50,22 +52,29 @@ func Configure(c *ctrld.Config) error {
|
||||
}
|
||||
|
||||
// ConfigureService performs necessary setup for running ctrld as a service on router.
|
||||
func ConfigureService(sc *service.Config) {
|
||||
func ConfigureService(sc *service.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case DDWrt:
|
||||
if !ddwrtJff2Enabled() {
|
||||
return ddwrtJffs2NotEnabledErr
|
||||
}
|
||||
case OpenWrt:
|
||||
sc.Option["SysvScript"] = openWrtScript
|
||||
case DDWrt, Merlin, Ubios:
|
||||
case Merlin, Ubios:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostInstall performs task after installing ctrld on router.
|
||||
func PostInstall() error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case DDWrt:
|
||||
return postInstallDDWrt()
|
||||
case OpenWrt:
|
||||
return postInstallOpenWrt()
|
||||
case DDWrt, Merlin, Ubios:
|
||||
case Merlin, Ubios:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -76,7 +85,9 @@ func Cleanup() error {
|
||||
switch name {
|
||||
case OpenWrt:
|
||||
return cleanupOpenWrt()
|
||||
case DDWrt, Merlin, Ubios:
|
||||
case DDWrt:
|
||||
return cleanupDDWrt()
|
||||
case Merlin, Ubios:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -85,9 +96,9 @@ func Cleanup() error {
|
||||
func ListenAddress() string {
|
||||
name := Name()
|
||||
switch name {
|
||||
case OpenWrt:
|
||||
return ":53"
|
||||
case DDWrt, Merlin, Ubios:
|
||||
case DDWrt, OpenWrt:
|
||||
return "127.0.0.1:5353"
|
||||
case Merlin, Ubios:
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
49
internal/router/service.go
Normal file
49
internal/router/service.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
system := &linuxSystemService{
|
||||
name: "ddwrt",
|
||||
detect: func() bool { return Name() == DDWrt },
|
||||
interactive: func() bool {
|
||||
is, _ := isInteractive()
|
||||
return is
|
||||
},
|
||||
new: newddwrtService,
|
||||
}
|
||||
systems := append([]service.System{system}, service.AvailableSystems()...)
|
||||
service.ChooseSystem(systems...)
|
||||
}
|
||||
|
||||
type linuxSystemService struct {
|
||||
name string
|
||||
detect func() bool
|
||||
interactive func() bool
|
||||
new func(i service.Interface, platform string, c *service.Config) (service.Service, error)
|
||||
}
|
||||
|
||||
func (sc linuxSystemService) String() string {
|
||||
return sc.name
|
||||
}
|
||||
func (sc linuxSystemService) Detect() bool {
|
||||
return sc.detect()
|
||||
}
|
||||
func (sc linuxSystemService) Interactive() bool {
|
||||
return sc.interactive()
|
||||
}
|
||||
func (sc linuxSystemService) New(i service.Interface, c *service.Config) (service.Service, error) {
|
||||
return sc.new(i, sc.String(), c)
|
||||
}
|
||||
|
||||
func isInteractive() (bool, error) {
|
||||
ppid := os.Getppid()
|
||||
if ppid == 1 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
289
internal/router/service_ddwrt.go
Normal file
289
internal/router/service_ddwrt.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
type ddwrtSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
*service.Config
|
||||
rcStartup string
|
||||
}
|
||||
|
||||
func newddwrtService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
|
||||
s := &ddwrtSvc{
|
||||
i: i,
|
||||
platform: platform,
|
||||
Config: c,
|
||||
}
|
||||
if err := os.MkdirAll("/jffs/etc/config", 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) String() string {
|
||||
if len(s.DisplayName) > 0 {
|
||||
return s.DisplayName
|
||||
}
|
||||
return s.Name
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Platform() string {
|
||||
return s.platform
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) configPath() string {
|
||||
return fmt.Sprintf("/jffs/etc/config/%s.startup", s.Config.Name)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) template() *template.Template {
|
||||
return template.Must(template.New("").Parse(ddwrtSvcScript))
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Install() error {
|
||||
confPath := s.configPath()
|
||||
if _, err := os.Stat(confPath); err == nil {
|
||||
return fmt.Errorf("already installed: %s", confPath)
|
||||
}
|
||||
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(path, "/jffs/") {
|
||||
return errors.New("could not install service outside /jffs")
|
||||
}
|
||||
|
||||
var to = &struct {
|
||||
*service.Config
|
||||
Path string
|
||||
}{
|
||||
s.Config,
|
||||
path,
|
||||
}
|
||||
|
||||
f, err := os.Create(confPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = os.Chmod(confPath, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if err := template.Must(template.New("").Parse(ddwrtStartupCmd)).Execute(&sb, to); err != nil {
|
||||
return err
|
||||
}
|
||||
s.rcStartup = sb.String()
|
||||
curVal, err := nvram("get", nvramRCStartupKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nvram("set", nvramCtrldKeyPrefix+nvramRCStartupKey+"="+curVal); err != nil {
|
||||
return err
|
||||
}
|
||||
val := strings.Join([]string{curVal, s.rcStartup + " &", fmt.Sprintf(`echo $! > "/tmp/%s.pid"`, s.Config.Name)}, "\n")
|
||||
|
||||
if _, err := nvram("set", nvramRCStartupKey+"="+val); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Uninstall() error {
|
||||
if err := os.Remove(s.configPath()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctrldStartupKey := nvramCtrldKeyPrefix + nvramRCStartupKey
|
||||
rcStartup, err := nvram("get", ctrldStartupKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = nvram("unset", ctrldStartupKey)
|
||||
if _, err := nvram("set", nvramRCStartupKey+"="+rcStartup); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Logger(errs chan<- error) (service.Logger, error) {
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
return s.SystemLogger(errs)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
|
||||
// TODO(cuonglm): detect syslog enable and return proper logger?
|
||||
// this at least works with default configuration.
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
|
||||
}
|
||||
return &noopLogger{}, nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Run() (err error) {
|
||||
err = s.i.Start(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sigChan = make(chan os.Signal, 3)
|
||||
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
return s.i.Stop(s)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Status() (service.Status, error) {
|
||||
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, err
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Start() error {
|
||||
return exec.Command(s.configPath(), "start").Run()
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Stop() error {
|
||||
return exec.Command(s.configPath(), "stop").Run()
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Restart() error {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
type noopLogger struct {
|
||||
}
|
||||
|
||||
func (c noopLogger) Error(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Warning(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Info(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Errorf(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Warningf(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Infof(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const ddwrtStartupCmd = `{{.Path}}{{range .Arguments}} {{.}}{{end}}`
|
||||
const ddwrtSvcScript = `#!/bin/sh
|
||||
|
||||
name="{{.Name}}"
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
pid_file="/tmp/$name.pid"
|
||||
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) "
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
echo "Already started"
|
||||
else
|
||||
echo "Starting $name"
|
||||
$cmd &
|
||||
echo $! > "$pid_file"
|
||||
chmod 600 "$pid_file"
|
||||
if ! is_running; then
|
||||
echo "Failed to start $name"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
echo -n "Stopping $name..."
|
||||
kill "$(get_pid)"
|
||||
for _ in 1 2 3 4 5; do
|
||||
if ! is_running; then
|
||||
echo "stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
echo "failed to stop $name"
|
||||
exit 1
|
||||
fi
|
||||
exit 1
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "running"
|
||||
else
|
||||
echo "stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
Reference in New Issue
Block a user