cmd/ctrld: add commands to control ctrld as a system service

Supported actions:

 - start: install and start ctrld as a system service
 - stop: stop the ctrld service
 - restart: restart ctrld service
 - status: show status of ctrld service
 - uninstall: remove ctrld from system service
This commit is contained in:
Cuong Manh Le
2022-12-23 21:35:26 +07:00
committed by Cuong Manh Le
parent 9e7578fb29
commit ec72af1916
8 changed files with 315 additions and 39 deletions
+191 -20
View File
@@ -3,7 +3,6 @@ package main
import (
"bytes"
"encoding/base64"
"fmt"
"log"
"net"
"os"
@@ -25,6 +24,17 @@ var (
defaultConfigWritten = false
)
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains", "log", "cache_size"}
func isNoConfigStart(cmd *cobra.Command) bool {
for _, flagName := range basicModeFlags {
if cmd.Flags().Lookup(flagName).Changed {
return true
}
}
return false
}
func initCLI() {
// Enable opening via explorer.exe on Windows.
// See: https://github.com/spf13/cobra/issues/844.
@@ -42,7 +52,6 @@ func initCLI() {
`verbose log output, "-v" means query logging enabled, "-vv" means debug level logging enabled`,
)
basicModeFlags := []string{"listen", "primary_upstream", "secondary_upstream", "domains", "log", "cache_size"}
runCmd := &cobra.Command{
Use: "run",
Short: "Run the DNS proxy server",
@@ -51,19 +60,25 @@ func initCLI() {
if daemon && runtime.GOOS == "windows" {
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
}
if configPath != "" {
v.SetConfigFile(configPath)
}
noConfigStart := func() bool {
for _, flagName := range basicModeFlags {
if cmd.Flags().Lookup(flagName).Changed {
return true
}
}
return false
}()
readConfigFile(!noConfigStart && configBase64 == "")
noConfigStart := isNoConfigStart(cmd)
configs := []struct {
name string
written bool
}{
// For compatibility, we check for config.toml first, but only read it if exists.
{"config", false},
{"ctrld", !noConfigStart && configBase64 == ""},
}
for _, config := range configs {
ctrld.SetConfigName(v, config.name)
v.SetConfigFile(configPath)
if readConfigFile(config.written) {
break
}
}
readBase64Config()
processNoConfigFlags(noConfigStart)
if err := v.Unmarshal(&cfg); err != nil {
@@ -126,8 +141,158 @@ func initCLI() {
rootCmd.AddCommand(runCmd)
startCmd := &cobra.Command{
Use: "start",
Short: "Start the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
cfg := &service.Config{}
*cfg = *svcConfig
cfg.Arguments = append([]string{"run"}, os.Args[3:]...)
if dir, err := os.UserHomeDir(); err == nil {
// WorkingDirectory is not supported on Windows.
cfg.WorkingDirectory = dir
// No config path, generating config in HOME directory.
if configPath == "" && !isNoConfigStart(cmd) && configBase64 == "" {
readConfigFile(true)
}
}
s, err := service.New(&prog{}, cfg)
if err != nil {
stderrMsg(err.Error())
return
}
tasks := []task{
{s.Stop, false},
{s.Uninstall, false},
{s.Install, false},
{s.Start, true},
}
if doTasks(tasks) {
stdoutMsg("Service started")
return
}
},
}
// Keep these flags in sync with runCmd above, except for "-d".
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "base64 encoded config")
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "listener address and port, in format: address:port")
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "primary upstream endpoint")
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "secondary upstream endpoint")
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
startCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
stopCmd := &cobra.Command{
Use: "stop",
Short: "Stop the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
}
if doTasks([]task{{s.Stop, true}}) {
stdoutMsg("Service stopped")
}
},
}
restartCmd := &cobra.Command{
Use: "restart",
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
}
if doTasks([]task{{s.Restart, true}}) {
stdoutMsg("Service restarted")
}
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
}
status, err := s.Status()
if err != nil {
stderrMsg(err.Error())
return
}
switch status {
case service.StatusUnknown:
stdoutMsg("Unknown status")
case service.StatusRunning:
stdoutMsg("Service is running")
case service.StatusStopped:
stdoutMsg("Service is stopped")
}
},
}
uninstallCmd := &cobra.Command{
Use: "uninstall",
Short: "Uninstall the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
}
tasks := []task{
{s.Stop, false},
{s.Uninstall, true},
}
if doTasks(tasks) {
stdoutMsg("Service uninstalled")
return
}
},
}
serviceCmd := &cobra.Command{
Use: "service",
Short: "Manage ctrld service",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
statusCmd.Use,
stopCmd.Use,
restartCmd.Use,
statusCmd.Use,
uninstallCmd.Use,
},
}
serviceCmd.AddCommand(startCmd)
serviceCmd.AddCommand(stopCmd)
serviceCmd.AddCommand(restartCmd)
serviceCmd.AddCommand(statusCmd)
serviceCmd.AddCommand(uninstallCmd)
rootCmd.AddCommand(serviceCmd)
startCmdAlias := &cobra.Command{
Use: "start",
Short: "Alias for service start",
Run: func(cmd *cobra.Command, args []string) {
startCmd.Run(cmd, args)
},
}
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
rootCmd.AddCommand(startCmdAlias)
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
stderrMsg(err.Error())
os.Exit(1)
}
}
@@ -138,22 +303,28 @@ func writeConfigFile() {
if err != nil {
log.Fatalf("unable to marshal config to toml: %v", err)
}
if err := os.WriteFile("config.toml", bs, 0600); err != nil {
if err := os.WriteFile("ctrld.toml", bs, 0600); err != nil {
log.Printf("failed to write config file: %v\n", err)
}
}
func readConfigFile(configWritten bool) {
func readConfigFile(configWritten bool) bool {
err := v.ReadInConfig()
if err == nil || !configWritten {
return
if err == nil {
return true
}
if !configWritten {
return false
}
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
writeConfigFile()
defaultConfigWritten = true
return
return false
}
log.Fatalf("failed to decode config file: %v", err)
return false
}
func readBase64Config() {
+22 -4
View File
@@ -4,8 +4,10 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/kardianos/service"
"github.com/rs/zerolog"
"github.com/Control-D-Inc/ctrld"
@@ -32,17 +34,33 @@ var (
)
func main() {
ctrld.InitConfig(v, "config")
ctrld.InitConfig(v, "ctrld")
initCLI()
}
func normalizeLogFilePath(logFilePath string) string {
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
return logFilePath
}
dir, _ := os.UserHomeDir()
if dir == "" {
return logFilePath
}
return filepath.Join(dir, logFilePath)
}
func initLogging() {
writers := []io.Writer{io.Discard}
isLog := cfg.Service.LogLevel != ""
if logPath := cfg.Service.LogPath; logPath != "" {
logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
// Create parent directory if necessary.
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
fmt.Fprintf(os.Stderr, "failed to create log path: %v", err)
os.Exit(1)
}
logFile, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to creating log file: %v", err)
fmt.Fprintf(os.Stderr, "failed to create log file: %v", err)
os.Exit(1)
}
isLog = true
+31
View File
@@ -0,0 +1,31 @@
package main
import (
"fmt"
"os"
)
func stderrMsg(msg string) {
_, _ = fmt.Fprintln(os.Stderr, msg)
}
func stdoutMsg(msg string) {
_, _ = fmt.Fprintln(os.Stdout, msg)
}
type task struct {
f func() error
abortOnError bool
}
func doTasks(tasks []task) bool {
for _, task := range tasks {
if err := task.f(); err != nil {
if task.abortOnError {
stderrMsg(err.Error())
return false
}
}
}
return true
}