diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index ab0fc2f..f9c69cc 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -176,9 +176,10 @@ func initCLI() { rootCmd.AddCommand(runCmd) startCmd := &cobra.Command{ - Use: "start", - Short: "Start the ctrld service", - Args: cobra.NoArgs, + PreRun: checkHasElevatedPrivilege, + Use: "start", + Short: "Start the ctrld service", + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { sc := &service.Config{} *sc = *svcConfig @@ -239,9 +240,10 @@ func initCLI() { startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) stopCmd := &cobra.Command{ - Use: "stop", - Short: "Stop the ctrld service", - Args: cobra.NoArgs, + PreRun: checkHasElevatedPrivilege, + Use: "stop", + Short: "Stop the ctrld service", + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { s, err := service.New(&prog{}, svcConfig) if err != nil { @@ -256,9 +258,10 @@ func initCLI() { stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) restartCmd := &cobra.Command{ - Use: "restart", - Short: "Restart the ctrld service", - Args: cobra.NoArgs, + PreRun: checkHasElevatedPrivilege, + Use: "restart", + Short: "Restart the ctrld service", + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { s, err := service.New(&prog{}, svcConfig) if err != nil { @@ -298,9 +301,10 @@ func initCLI() { } uninstallCmd := &cobra.Command{ - Use: "uninstall", - Short: "Uninstall the ctrld service", - Args: cobra.NoArgs, + PreRun: checkHasElevatedPrivilege, + Use: "uninstall", + Short: "Uninstall the ctrld service", + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { s, err := service.New(&prog{}, svcConfig) if err != nil { @@ -379,8 +383,9 @@ func initCLI() { serviceCmd.AddCommand(interfacesCmd) rootCmd.AddCommand(serviceCmd) startCmdAlias := &cobra.Command{ - Use: "start", - Short: "Quick start service and configure DNS on interface", + PreRun: checkHasElevatedPrivilege, + Use: "start", + Short: "Quick start service and configure DNS on interface", Run: func(cmd *cobra.Command, args []string) { if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) @@ -392,8 +397,9 @@ func initCLI() { startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) stopCmdAlias := &cobra.Command{ - Use: "stop", - Short: "Quick stop service and remove DNS from interface", + PreRun: checkHasElevatedPrivilege, + Use: "stop", + Short: "Quick stop service and remove DNS from interface", Run: func(cmd *cobra.Command, args []string) { if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) diff --git a/cmd/ctrld/service.go b/cmd/ctrld/service.go index 7307a89..14834c6 100644 --- a/cmd/ctrld/service.go +++ b/cmd/ctrld/service.go @@ -3,6 +3,8 @@ package main import ( "fmt" "os" + + "github.com/spf13/cobra" ) func stderrMsg(msg string) { @@ -29,3 +31,15 @@ func doTasks(tasks []task) bool { } return true } + +func checkHasElevatedPrivilege(cmd *cobra.Command, args []string) { + ok, err := hasElevatedPrivilege() + if err != nil { + fmt.Printf("could not detect user privilege: %v", err) + return + } + if !ok { + fmt.Println("Please relaunch process with admin/root privilege.") + os.Exit(1) + } +} diff --git a/cmd/ctrld/service_others.go b/cmd/ctrld/service_others.go new file mode 100644 index 0000000..82a6ea3 --- /dev/null +++ b/cmd/ctrld/service_others.go @@ -0,0 +1,11 @@ +//go:build !windows + +package main + +import ( + "os" +) + +func hasElevatedPrivilege() (bool, error) { + return os.Geteuid() == 0, nil +} diff --git a/cmd/ctrld/service_windows.go b/cmd/ctrld/service_windows.go new file mode 100644 index 0000000..0ce8d3a --- /dev/null +++ b/cmd/ctrld/service_windows.go @@ -0,0 +1,24 @@ +package main + +import "golang.org/x/sys/windows" + +func hasElevatedPrivilege() (bool, error) { + var sid *windows.SID + if err := windows.AllocateAndInitializeSid( + &windows.SECURITY_NT_AUTHORITY, + 2, + windows.SECURITY_BUILTIN_DOMAIN_RID, + windows.DOMAIN_ALIAS_RID_ADMINS, + 0, + 0, + 0, + 0, + 0, + 0, + &sid, + ); err != nil { + return false, err + } + token := windows.Token(0) + return token.IsMember(sid) +}