mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a |
@@ -105,9 +105,11 @@ Available Commands:
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
|
||||
310
cmd/cli/cli.go
310
cmd/cli/cli.go
@@ -48,6 +48,11 @@ import (
|
||||
|
||||
// selfCheckInternalTestDomain is used for testing ctrld self response to clients.
|
||||
const selfCheckInternalTestDomain = "ctrld" + loopTestDomain
|
||||
const (
|
||||
windowsForwardersFilename = ".forwarders.txt"
|
||||
oldBinSuffix = "_previous"
|
||||
oldLogSuffix = ".1"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
@@ -110,7 +115,7 @@ func initCLI() {
|
||||
&verbose,
|
||||
"verbose",
|
||||
"v",
|
||||
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
|
||||
`verbose log output, "-v" basic logging, "-vv" debug logging`,
|
||||
)
|
||||
rootCmd.PersistentFlags().BoolVarP(
|
||||
&silent,
|
||||
@@ -158,7 +163,10 @@ func initCLI() {
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Long: `Install and start the ctrld service
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
@@ -182,8 +190,9 @@ func initCLI() {
|
||||
return
|
||||
}
|
||||
|
||||
status, _ := s.Status()
|
||||
status, err := s.Status()
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
@@ -192,39 +201,56 @@ func initCLI() {
|
||||
}
|
||||
}
|
||||
|
||||
if cdUID != "" {
|
||||
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
|
||||
if !startOnly {
|
||||
startOnly = len(osArgs) == 0
|
||||
}
|
||||
// If user run "ctrld start" and ctrld is already installed, starting existing service.
|
||||
if startOnly && isCtrldInstalled {
|
||||
tryReadingConfigWithNotice(false, true)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
// validateCdRemoteConfig clobbers v, saving it here to restore later.
|
||||
oldV := v
|
||||
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
|
||||
if errors.As(err, &viper.ConfigParseError{}) {
|
||||
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
|
||||
tmpDir := os.TempDir()
|
||||
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
|
||||
errorLogged := false
|
||||
// Write remote config to a temporary file to get details error.
|
||||
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
|
||||
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
|
||||
row, col := de.Position()
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
|
||||
errorLogged = true
|
||||
}
|
||||
_ = os.Remove(tmpConfFile)
|
||||
}
|
||||
// If we could not log details error, emit what we have already got.
|
||||
if !errorLogged {
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
|
||||
|
||||
initLogging()
|
||||
tasks := []task{
|
||||
resetDnsTask(p, s),
|
||||
{s.Stop, false},
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
|
||||
return saveCurrentStaticDNS(i)
|
||||
})
|
||||
return nil
|
||||
}, false},
|
||||
{s.Start, true},
|
||||
{noticeWritingControlDConfig, false},
|
||||
}
|
||||
mainLog.Load().Notice().Msg("Starting existing ctrld service")
|
||||
if doTasks(tasks) {
|
||||
mainLog.Load().Notice().Msg("Service started")
|
||||
sockDir, err := socketDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Load().Warn().Msg("disregarding invalid custom config")
|
||||
if cc := newSocketControlClient(s, sockDir); cc != nil {
|
||||
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
||||
logger.Debug().Msg("setting DNS successfully")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service")
|
||||
os.Exit(1)
|
||||
}
|
||||
v = oldV
|
||||
return
|
||||
}
|
||||
|
||||
if cdUID != "" {
|
||||
doValidateCdRemoteConfig(cdUID)
|
||||
} else if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
mainLog.Load().Debug().Msg("using uid from provision token")
|
||||
@@ -399,6 +425,8 @@ func initCLI() {
|
||||
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
|
||||
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
|
||||
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
|
||||
_ = startCmd.Flags().MarkHidden("start_only")
|
||||
|
||||
routerCmd := &cobra.Command{
|
||||
Use: "setup",
|
||||
@@ -485,7 +513,10 @@ func initCLI() {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
p := &prog{router: router.New(&cfg, runInCdMode())}
|
||||
cdUID = curCdUID()
|
||||
cdMode := cdUID != ""
|
||||
|
||||
p := &prog{router: router.New(&cfg, cdMode)}
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
@@ -497,6 +528,10 @@ func initCLI() {
|
||||
}
|
||||
initLogging()
|
||||
|
||||
if cdMode {
|
||||
doValidateCdRemoteConfig(cdUID)
|
||||
}
|
||||
|
||||
iface = runningIface(s)
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
@@ -623,11 +658,72 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
uninstall(p, s)
|
||||
if cleanup {
|
||||
var files []string
|
||||
// Config file.
|
||||
files = append(files, v.ConfigFileUsed())
|
||||
// Log file.
|
||||
logFile := normalizeLogFilePath(cfg.Service.LogPath)
|
||||
files = append(files, logFile)
|
||||
// Backup log file.
|
||||
oldLogFile := logFile + oldLogSuffix
|
||||
if _, err := os.Stat(oldLogFile); err == nil {
|
||||
files = append(files, oldLogFile)
|
||||
}
|
||||
// Socket files.
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
|
||||
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
|
||||
}
|
||||
// Static DNS settings files.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
file := savedStaticDnsSettingsFilePath(i)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
files = append(files, file)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// Windows forwarders file.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
files = append(files, absHomeDir(windowsForwardersFilename))
|
||||
}
|
||||
// Binary itself.
|
||||
bin, _ := os.Executable()
|
||||
if bin != "" && supportedSelfDelete {
|
||||
files = append(files, bin)
|
||||
}
|
||||
// Backup file after upgrading.
|
||||
oldBin := bin + oldBinSuffix
|
||||
if _, err := os.Stat(oldBin); err == nil {
|
||||
files = append(files, oldBin)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(file); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("file removed: %s", file)
|
||||
}
|
||||
}
|
||||
if err := selfDeleteExe(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
|
||||
} else {
|
||||
if !supportedSelfDelete {
|
||||
mainLog.Load().Debug().Msgf("file removed: %s", bin)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`)
|
||||
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`)
|
||||
_ = uninstallCmd.Flags().MarkHidden("pin")
|
||||
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
|
||||
|
||||
listIfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
@@ -697,7 +793,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Long: `Quick start service and configure DNS on interface
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(os.Args) == 2 {
|
||||
startOnly = true
|
||||
}
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
@@ -776,7 +878,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
},
|
||||
}
|
||||
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
|
||||
rootCmd.AddCommand(uninstallCmdAlias)
|
||||
|
||||
listClientsCmd := &cobra.Command{
|
||||
@@ -894,7 +996,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
svcInstalled = false
|
||||
}
|
||||
oldBin := bin + "_previous"
|
||||
oldBin := bin + oldBinSuffix
|
||||
baseUrl := upgradeChannel[upgradeChannelDefault]
|
||||
if len(args) > 0 {
|
||||
channel := args[0]
|
||||
@@ -1033,12 +1135,14 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
waitCh := make(chan struct{})
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
reloadCh: make(chan struct{}),
|
||||
reloadDoneCh: make(chan struct{}),
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
reloadCh: make(chan struct{}),
|
||||
reloadDoneCh: make(chan struct{}),
|
||||
dnsWatcherStopCh: make(chan struct{}),
|
||||
apiReloadCh: make(chan *ctrld.Config),
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
}
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
@@ -1128,36 +1232,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
return
|
||||
}
|
||||
|
||||
uninstallIfInvalidCdUID := func() {
|
||||
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
cdLogger.Info().Msg("uninstalled service")
|
||||
}
|
||||
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
return
|
||||
}
|
||||
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
// Performs self-uninstallation if the ControlD device does not exist.
|
||||
var uer *controld.UtilityErrorResponse
|
||||
if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
_ = uninstallInvalidCdUID(p, cdLogger, false)
|
||||
}
|
||||
uninstallIfInvalidCdUID()
|
||||
cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1168,7 +1249,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := writeConfigFile(); err != nil {
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
@@ -1257,12 +1338,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
|
||||
// Wait goroutines which watches/manipulates DNS settings terminated,
|
||||
// ensuring that changes to DNS since here won't be reverted.
|
||||
p.dnsWg.Wait()
|
||||
for _, f := range p.onStopped {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
func writeConfigFile() error {
|
||||
func writeConfigFile(cfg *ctrld.Config) error {
|
||||
if cfu := v.ConfigFileUsed(); cfu != "" {
|
||||
defaultConfigFile = cfu
|
||||
} else if configPath != "" {
|
||||
@@ -1315,7 +1400,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
}
|
||||
nop := zerolog.Nop()
|
||||
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
|
||||
if err := writeConfigFile(); err != nil {
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
fp, err := filepath.Abs(defaultConfigFile)
|
||||
@@ -1639,9 +1724,10 @@ func selfCheckStatus(s service.Service, homedir, sockDir string) (bool, service.
|
||||
}
|
||||
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.SetConfigNameWithPath(v, "ctrld", homedir)
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
} else {
|
||||
v.SetConfigFile(defaultConfigFile)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed())
|
||||
@@ -2375,7 +2461,7 @@ func doGenerateNextDNSConfig(uid string) error {
|
||||
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
|
||||
generateNextDNSConfig(uid)
|
||||
updateListenerConfig(&cfg)
|
||||
return writeConfigFile()
|
||||
return writeConfigFile(&cfg)
|
||||
}
|
||||
|
||||
func noticeWritingControlDConfig() error {
|
||||
@@ -2423,7 +2509,7 @@ func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
|
||||
return nil // the server is running older version of ctrld
|
||||
}
|
||||
}
|
||||
mainLog.Load().Error().Msg(errInvalidDeactivationPin.Error())
|
||||
mainLog.Load().Error().Err(err).Msg(errInvalidDeactivationPin.Error())
|
||||
return errInvalidDeactivationPin
|
||||
}
|
||||
|
||||
@@ -2482,6 +2568,11 @@ func absHomeDir(filename string) string {
|
||||
|
||||
// runInCdMode reports whether ctrld service is running in cd mode.
|
||||
func runInCdMode() bool {
|
||||
return curCdUID() != ""
|
||||
}
|
||||
|
||||
// curCdUID returns the current ControlD UID used by running ctrld process.
|
||||
func curCdUID() string {
|
||||
if s, _ := newService(&prog{}, svcConfig); s != nil {
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
cc := newSocketControlClient(s, dir)
|
||||
@@ -2489,12 +2580,13 @@ func runInCdMode() bool {
|
||||
resp, _ := cc.post(cdPath, nil)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
buf, _ := io.ReadAll(resp.Body)
|
||||
return string(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
// goArm returns the GOARM value for the binary.
|
||||
@@ -2557,6 +2649,9 @@ func resetDnsTask(p *prog, s service.Service) task {
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
return task{func() error {
|
||||
if iface == "" {
|
||||
return nil
|
||||
}
|
||||
// Always reset DNS first, ensuring DNS setting is in a good state.
|
||||
// resetDNS must use the "iface" value of current running ctrld
|
||||
// process to reset what setDNS has done properly.
|
||||
@@ -2572,3 +2667,60 @@ func resetDnsTask(p *prog, s service.Service) task {
|
||||
return nil
|
||||
}, false}
|
||||
}
|
||||
|
||||
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
|
||||
func doValidateCdRemoteConfig(cdUID string) {
|
||||
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
|
||||
}
|
||||
// validateCdRemoteConfig clobbers v, saving it here to restore later.
|
||||
oldV := v
|
||||
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
|
||||
if errors.As(err, &viper.ConfigParseError{}) {
|
||||
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
|
||||
tmpDir := os.TempDir()
|
||||
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
|
||||
errorLogged := false
|
||||
// Write remote config to a temporary file to get details error.
|
||||
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
|
||||
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
|
||||
row, col := de.Position()
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
|
||||
errorLogged = true
|
||||
}
|
||||
_ = os.Remove(tmpConfFile)
|
||||
}
|
||||
// If we could not log details error, emit what we have already got.
|
||||
if !errorLogged {
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
|
||||
}
|
||||
mainLog.Load().Warn().Msg("disregarding invalid custom config")
|
||||
}
|
||||
v = oldV
|
||||
}
|
||||
|
||||
// uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist.
|
||||
func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return false
|
||||
}
|
||||
|
||||
p.resetDNS()
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
if doStop {
|
||||
_ = s.Stop()
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) {
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile())
|
||||
assert.NoError(t, writeConfigFile(&cfg))
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -73,7 +73,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
if p.cfg.Service.MetricsQueryStats {
|
||||
if p.metricsQueryStats.Load() {
|
||||
for _, client := range clients {
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
@@ -178,6 +178,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if cdUID != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(cdUID))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
@@ -32,6 +33,9 @@ const (
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
EDNS0_OPTION_MAC = 0xFDE9
|
||||
|
||||
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
|
||||
selfUninstallMaxQueries = 32
|
||||
)
|
||||
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
@@ -89,6 +93,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
@@ -143,6 +148,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
failoverRcodes: failoverRcode,
|
||||
ufr: ur,
|
||||
})
|
||||
go p.doSelfUninstall(pr.answer)
|
||||
answer = pr.answer
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
@@ -836,6 +842,51 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
// doSelfUninstall performs self-uninstall if these condition met:
|
||||
//
|
||||
// - There is only 1 ControlD upstream in-use.
|
||||
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
||||
// - The cdUID is deleted.
|
||||
func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
||||
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
||||
return
|
||||
}
|
||||
|
||||
p.selfUninstallMu.Lock()
|
||||
defer p.selfUninstallMu.Unlock()
|
||||
if p.checkingSelfUninstall {
|
||||
return
|
||||
}
|
||||
|
||||
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
|
||||
if p.refusedQueryCount > selfUninstallMaxQueries {
|
||||
p.checkingSelfUninstall = true
|
||||
_, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
|
||||
selfUninstallCheck(err, p, logger)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
}
|
||||
// Cool-of period to prevent abusing the API.
|
||||
go p.selfUninstallCoolOfPeriod()
|
||||
return
|
||||
}
|
||||
p.refusedQueryCount++
|
||||
}
|
||||
|
||||
// selfUninstallCoolOfPeriod waits for 30 minutes before
|
||||
// calling API again for checking ControlD device status.
|
||||
func (p *prog) selfUninstallCoolOfPeriod() {
|
||||
t := time.NewTimer(time.Minute * 30)
|
||||
defer t.Stop()
|
||||
<-t.C
|
||||
p.selfUninstallMu.Lock()
|
||||
p.checkingSelfUninstall = false
|
||||
p.refusedQueryCount = 0
|
||||
p.selfUninstallMu.Unlock()
|
||||
}
|
||||
|
||||
// queryFromSelf reports whether the input IP is from device running ctrld.
|
||||
func queryFromSelf(ip string) bool {
|
||||
netIP := netip.MustParseAddr(ip)
|
||||
|
||||
@@ -36,6 +36,8 @@ var (
|
||||
cdUpstreamProto string
|
||||
deactivationPin int64
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
@@ -63,8 +65,11 @@ func Main() {
|
||||
}
|
||||
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
// In cleanup mode, we always want the full log file path.
|
||||
if !cleanup {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
}
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
@@ -121,14 +126,14 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600))
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -107,7 +107,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
// Register queries count stats if enabled.
|
||||
if cfg.Service.MetricsQueryStats {
|
||||
if p.metricsQueryStats.Load() {
|
||||
reg.MustRegister(statsQueriesCount)
|
||||
reg.MustRegister(statsClientQueriesCount)
|
||||
}
|
||||
|
||||
@@ -43,20 +43,32 @@ func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one, which includes:
|
||||
//
|
||||
// - en0: physical wireless
|
||||
// - en1: Thunderbolt 1
|
||||
// - en2: Thunderbolt 2
|
||||
// - en3: Thunderbolt 3
|
||||
// - en4: Thunderbolt 4
|
||||
//
|
||||
// For full list, see: https://unix.stackexchange.com/questions/603506/what-are-these-ifconfig-interfaces-on-macos
|
||||
func validInterface(iface *net.Interface) bool {
|
||||
switch iface.Name {
|
||||
case "en0", "en1", "en2", "en3", "en4":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||
// and returns map presents all hardware ports.
|
||||
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
after, ok := strings.CutPrefix(line, "Device: ")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m[after] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -57,3 +58,47 @@ func Test_networkServiceName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const listallhardwareportsOutput = `
|
||||
Hardware Port: Ethernet Adapter (en6)
|
||||
Device: en6
|
||||
Ethernet Address: 3a:3e:fc:1e:ab:41
|
||||
|
||||
Hardware Port: Ethernet Adapter (en7)
|
||||
Device: en7
|
||||
Ethernet Address: 3a:3e:fc:1e:ab:42
|
||||
|
||||
Hardware Port: Thunderbolt Bridge
|
||||
Device: bridge0
|
||||
Ethernet Address: 36:21:bb:3a:7a:40
|
||||
|
||||
Hardware Port: Wi-Fi
|
||||
Device: en0
|
||||
Ethernet Address: a0:78:17:68:56:3f
|
||||
|
||||
Hardware Port: Thunderbolt 1
|
||||
Device: en1
|
||||
Ethernet Address: 36:21:bb:3a:7a:40
|
||||
|
||||
Hardware Port: Thunderbolt 2
|
||||
Device: en2
|
||||
Ethernet Address: 36:21:bb:3a:7a:44
|
||||
|
||||
VLAN Configurations
|
||||
===================
|
||||
`
|
||||
|
||||
func Test_parseListAllHardwarePorts(t *testing.T) {
|
||||
expected := map[string]struct{}{
|
||||
"en0": {},
|
||||
"en1": {},
|
||||
"en2": {},
|
||||
"en6": {},
|
||||
"en7": {},
|
||||
"bridge0": {},
|
||||
}
|
||||
m := parseListAllHardwarePorts(strings.NewReader(listallhardwareportsOutput))
|
||||
if !maps.Equal(m, expected) {
|
||||
t.Errorf("unexpected output, want: %v, got: %v", expected, m)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,4 +6,6 @@ import "net"
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
|
||||
func validInterface(iface *net.Interface) bool { return true }
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
|
||||
@@ -10,7 +10,7 @@ func patchNetIfaceName(iface *net.Interface) error {
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface) bool {
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
if iface == nil {
|
||||
return false
|
||||
}
|
||||
@@ -19,3 +19,5 @@ func validInterface(iface *net.Interface) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
func allocateIP(ip string) error {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
forwardersFilename = ".forwarders.txt"
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
@@ -40,7 +39,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(forwardersFilename)
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
@@ -72,7 +71,7 @@ func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(forwardersFilename)
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||
|
||||
248
cmd/cli/prog.go
248
cmd/cli/prog.go
@@ -12,19 +12,24 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
@@ -38,6 +43,7 @@ const (
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
dnsWatchdogDefaultInterval = 20 * time.Second
|
||||
)
|
||||
|
||||
// ControlSocketName returns name for control unix socket.
|
||||
@@ -62,15 +68,19 @@ var svcConfig = &service.Config{
|
||||
var useSystemdResolved = false
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
csSetDnsDone chan struct{}
|
||||
csSetDnsOk bool
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
apiReloadCh chan *ctrld.Config
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
csSetDnsDone chan struct{}
|
||||
csSetDnsOk bool
|
||||
dnsWatchDogOnce sync.Once
|
||||
dnsWg sync.WaitGroup
|
||||
dnsWatcherStopCh chan struct{}
|
||||
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
@@ -84,6 +94,12 @@ type prog struct {
|
||||
router router.Router
|
||||
ptrLoopGuard *loopGuard
|
||||
lanLoopGuard *loopGuard
|
||||
metricsQueryStats atomic.Bool
|
||||
|
||||
selfUninstallMu sync.Mutex
|
||||
refusedQueryCount int
|
||||
canSelfUninstall atomic.Bool
|
||||
checkingSelfUninstall bool
|
||||
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
@@ -117,11 +133,15 @@ func (p *prog) runWait() {
|
||||
p.run(reload, reloadCh)
|
||||
reload = true
|
||||
}()
|
||||
|
||||
var newCfg *ctrld.Config
|
||||
select {
|
||||
case sig := <-reloadSigCh:
|
||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
case <-p.reloadCh:
|
||||
logger.Notice().Msg("reloading...")
|
||||
case apiCfg := <-p.apiReloadCh:
|
||||
newCfg = apiCfg
|
||||
case <-p.stopCh:
|
||||
close(reloadCh)
|
||||
return
|
||||
@@ -131,28 +151,31 @@ func (p *prog) runWait() {
|
||||
close(reloadCh)
|
||||
<-done
|
||||
}
|
||||
newCfg := &ctrld.Config{}
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
|
||||
if newCfg == nil {
|
||||
newCfg = &ctrld.Config{}
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
waitOldRunDone()
|
||||
@@ -178,6 +201,10 @@ func (p *prog) runWait() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := writeConfigFile(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not write new config")
|
||||
}
|
||||
|
||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||
// upstream config because its initialization function have not been called yet.
|
||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||
@@ -188,6 +215,7 @@ func (p *prog) runWait() {
|
||||
p.mu.Unlock()
|
||||
|
||||
logger.Notice().Msg("reloading config successfully")
|
||||
|
||||
select {
|
||||
case p.reloadDoneCh <- struct{}{}:
|
||||
default:
|
||||
@@ -214,12 +242,67 @@ func (p *prog) postRun() {
|
||||
}
|
||||
}
|
||||
|
||||
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
|
||||
func (p *prog) apiConfigReload() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
secs := 3600
|
||||
if p.cfg.Service.RefetchTime != nil && *p.cfg.Service.RefetchTime > 0 {
|
||||
secs = *p.cfg.Service.RefetchTime
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(secs) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
|
||||
logger.Debug().Msg("starting custom config reload timer")
|
||||
lastUpdated := time.Now().Unix()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
selfUninstallCheck(err, p, logger)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
continue
|
||||
}
|
||||
|
||||
if resolverConfig.Ctrld.CustomConfig == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated {
|
||||
lastUpdated = time.Now().Unix()
|
||||
cfg := &ctrld.Config{}
|
||||
if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil {
|
||||
logger.Warn().Err(err).Msg("skipping invalid custom config")
|
||||
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
|
||||
logger.Error().Err(err).Msg("could not mark custom last update failed")
|
||||
}
|
||||
break
|
||||
}
|
||||
setListenerDefaultValue(cfg)
|
||||
logger.Debug().Msg("custom config changes detected, reloading...")
|
||||
p.apiReloadCh <- cfg
|
||||
} else {
|
||||
logger.Debug().Msg("custom config does not change")
|
||||
}
|
||||
case <-p.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
isControlDUpstream := false
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
uc.Init()
|
||||
isControlDUpstream = isControlDUpstream || uc.IsControlD()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
@@ -236,6 +319,10 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
|
||||
if len(cfg.Upstream) == 1 && isControlDUpstream {
|
||||
p.canSelfUninstall.Store(true)
|
||||
}
|
||||
p.localUpstreams = localUpstreams
|
||||
p.ptrNameservers = ptrNameservers
|
||||
}
|
||||
@@ -271,6 +358,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
p.cacheFlushDomainsMap = nil
|
||||
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
@@ -397,6 +485,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
go p.apiConfigReload()
|
||||
p.postRun()
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -510,13 +599,86 @@ func (p *prog) setDNS() {
|
||||
for i := range nameservers {
|
||||
servers[i] = netip.MustParseAddr(nameservers[i])
|
||||
}
|
||||
go watchResolvConf(netIface, servers, setResolvConf)
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||
}()
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||
})
|
||||
}
|
||||
if p.dnsWatchdogEnabled() {
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.dnsWatchdog(netIface, nameservers, allIfaces)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
|
||||
func (p *prog) dnsWatchdogEnabled() bool {
|
||||
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
|
||||
return *ptr
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
|
||||
func (p *prog) dnsWatchdogDuration() time.Duration {
|
||||
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
|
||||
if (*ptr).Seconds() > 0 {
|
||||
return *ptr
|
||||
}
|
||||
}
|
||||
return dnsWatchdogDefaultInterval
|
||||
}
|
||||
|
||||
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
|
||||
// This is only works when deactivation pin set.
|
||||
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) {
|
||||
if !requiredMultiNICsConfig() {
|
||||
return
|
||||
}
|
||||
|
||||
p.dnsWatchDogOnce.Do(func() {
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if dnsChanged(iface, ns) {
|
||||
logger.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
}
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(iface.Name, "re-applying DNS", func(i *net.Interface) error {
|
||||
if dnsChanged(i, ns) {
|
||||
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
@@ -727,13 +889,14 @@ func canBeLocalUpstream(addr string) bool {
|
||||
// the interface that matches excludeIfaceName. The context is used to clarify the
|
||||
// log message when error happens.
|
||||
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
|
||||
validIfacesMap := validInterfacesMap()
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
// Skip loopback/virtual interface.
|
||||
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
||||
return
|
||||
}
|
||||
// Skip invalid interface.
|
||||
if !validInterface(i.Interface) {
|
||||
if !validInterface(i.Interface, validIfacesMap) {
|
||||
return
|
||||
}
|
||||
netIface := i.Interface
|
||||
@@ -747,7 +910,9 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
|
||||
}
|
||||
// TODO: investigate whether we should report this error?
|
||||
if err := f(netIface); err == nil {
|
||||
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
||||
if context != "" {
|
||||
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
||||
}
|
||||
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
||||
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
|
||||
}
|
||||
@@ -806,3 +971,24 @@ func savedStaticNameservers(iface *net.Interface) []string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dnsChanged reports whether DNS settings for given interface was changed.
|
||||
// The caller must sort the nameservers before calling this function.
|
||||
func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
curNameservers, _ := currentStaticDNS(iface)
|
||||
slices.Sort(curNameservers)
|
||||
return !slices.Equal(curNameservers, nameservers)
|
||||
}
|
||||
|
||||
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
|
||||
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
|
||||
var uer *controld.UtilityErrorResponse
|
||||
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
// Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall.
|
||||
close(p.dnsWatcherStopCh)
|
||||
p.dnsWg.Wait()
|
||||
|
||||
// Perform self-uninstall now.
|
||||
selfUninstall(p, logger)
|
||||
}
|
||||
}
|
||||
|
||||
57
cmd/cli/prog_test.go
Normal file
57
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is true.
|
||||
assert.True(t, p.dnsWatchdogEnabled())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
}{
|
||||
{"enabled", true},
|
||||
{"disabled", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is 20s.
|
||||
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"valid", time.Minute, time.Minute},
|
||||
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.cfg.Service.MetricsQueryStats {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
)
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
@@ -28,12 +28,17 @@ func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
|
||||
@@ -3,15 +3,44 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
}
|
||||
return setDNS(iface, servers)
|
||||
if err := setDNS(iface, servers); err != nil {
|
||||
return err
|
||||
}
|
||||
slices.Sort(servers)
|
||||
curNs := currentDNS(iface)
|
||||
slices.Sort(curNs)
|
||||
if !slices.Equal(curNs, servers) {
|
||||
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Nameservers = ns
|
||||
f, err := os.Create(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := c.Write(f); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
|
||||
7
cmd/cli/self_delete_others.go
Normal file
7
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
func selfDeleteExe() error { return nil }
|
||||
134
cmd/cli/self_delete_windows.go
Normal file
134
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copied from https://github.com/secur30nly/go-self-delete
|
||||
// with modification to suitable for ctrld usage.
|
||||
|
||||
/*
|
||||
License: MIT Licence
|
||||
|
||||
References:
|
||||
- https://github.com/LloydLabs/delete-self-poc
|
||||
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||
*/
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var supportedSelfDelete = false
|
||||
|
||||
type FILE_RENAME_INFO struct {
|
||||
Union struct {
|
||||
ReplaceIfExists bool
|
||||
Flags uint32
|
||||
}
|
||||
RootDirectory windows.Handle
|
||||
FileNameLength uint32
|
||||
FileName [1]uint16
|
||||
}
|
||||
|
||||
type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
windows.DELETE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_NORMAL,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lpwStream := &DS_STREAM_RENAME[0]
|
||||
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||
|
||||
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||
uintptr(unsafe.Pointer(lpwStream)),
|
||||
unsafe.Sizeof(lpwStream),
|
||||
)
|
||||
|
||||
err = windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileRenameInfo,
|
||||
(*byte)(unsafe.Pointer(&fRename)),
|
||||
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
|
||||
err := windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileDispositionInfo,
|
||||
(*byte)(unsafe.Pointer(&fDelete)),
|
||||
uint32(unsafe.Sizeof(fDelete)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsRenameHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.CloseHandle(hCurrent)
|
||||
}
|
||||
16
cmd/cli/self_kill_others.go
Normal file
16
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
45
cmd/cli/self_kill_unix.go
Normal file
45
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if !deactivationPinNotSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,9 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
return &systemd{s}, nil
|
||||
case s.Platform() == "darwin-launchd":
|
||||
return newLaunchd(s), nil
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -113,7 +116,7 @@ func (s *procd) Status() (service.Status, error) {
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide status command to
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
service.Service
|
||||
@@ -127,6 +130,29 @@ func (s *systemd) Status() (service.Status, error) {
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
statusErrMsg: "Permission denied",
|
||||
}
|
||||
}
|
||||
|
||||
// launchd wraps a service.Service, and provide status command to
|
||||
// report the status correctly when not running as root on Darwin.
|
||||
//
|
||||
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||
type launchd struct {
|
||||
service.Service
|
||||
statusErrMsg string
|
||||
}
|
||||
|
||||
func (l *launchd) Status() (service.Status, error) {
|
||||
if os.Geteuid() != 0 {
|
||||
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||
}
|
||||
return l.Service.Status()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
|
||||
@@ -9,3 +9,7 @@ import (
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package cli
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
@@ -22,3 +27,55 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
token := windows.Token(0)
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
}
|
||||
|
||||
pathP, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var access uint32
|
||||
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||
case os.O_RDONLY:
|
||||
access = windows.GENERIC_READ
|
||||
case os.O_WRONLY:
|
||||
access = windows.GENERIC_WRITE
|
||||
case os.O_RDWR:
|
||||
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_CREATE != 0 {
|
||||
access |= windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_APPEND != 0 {
|
||||
access &^= windows.GENERIC_WRITE
|
||||
access |= windows.FILE_APPEND_DATA
|
||||
}
|
||||
|
||||
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||
|
||||
var sa *syscall.SecurityAttributes
|
||||
|
||||
var createMode uint32
|
||||
switch {
|
||||
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||
createMode = windows.CREATE_NEW
|
||||
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||
createMode = windows.CREATE_ALWAYS
|
||||
case mode&os.O_CREATE == os.O_CREATE:
|
||||
createMode = windows.OPEN_ALWAYS
|
||||
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||
createMode = windows.TRUNCATE_EXISTING
|
||||
default:
|
||||
createMode = windows.OPEN_EXISTING
|
||||
}
|
||||
|
||||
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
62
config.go
62
config.go
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@@ -188,27 +189,30 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
|
||||
|
||||
// ServiceConfig specifies the general ctrld config.
|
||||
type ServiceConfig struct {
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
|
||||
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
|
||||
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
||||
@@ -316,7 +320,7 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
if uc.IPStack == "" {
|
||||
if uc.isControlD() {
|
||||
if uc.IsControlD() {
|
||||
uc.IPStack = IpStackSplit
|
||||
} else {
|
||||
uc.IPStack = IpStackBoth
|
||||
@@ -354,7 +358,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() || uc.isNextDNS() {
|
||||
if uc.IsControlD() || uc.isNextDNS() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -401,7 +405,7 @@ func (uc *UpstreamConfig) UID() string {
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.isControlD()
|
||||
isControlD := uc.IsControlD()
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
@@ -486,6 +490,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||
}
|
||||
|
||||
// Prevent bad tcp connection hanging the requests for too long.
|
||||
// See: https://github.com/golang/go/issues/36026
|
||||
if t2, err := http2.ConfigureTransports(transport); err == nil {
|
||||
t2.ReadIdleTimeout = 10 * time.Second
|
||||
t2.PingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
dialerTimeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
@@ -572,7 +583,8 @@ func (uc *UpstreamConfig) ping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isControlD() bool {
|
||||
// IsControlD reports whether this is a ControlD upstream.
|
||||
func (uc *UpstreamConfig) IsControlD() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/spf13/viper"
|
||||
@@ -22,6 +23,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "info", cfg.Service.LogLevel)
|
||||
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
|
||||
assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled)
|
||||
assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval)
|
||||
|
||||
assert.Len(t, cfg.Network, 2)
|
||||
assert.Contains(t, cfg.Network, "0")
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:1.20-bullseye as base
|
||||
FROM golang:bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -252,6 +252,35 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
|
||||
|
||||
The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### dns_watchdog_interval
|
||||
Time duration between each DNS watchdog iteration.
|
||||
|
||||
A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix,
|
||||
such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
|
||||
If the time duration is non-positive, default value will be used.
|
||||
|
||||
- Type: time duration string
|
||||
- Required: no
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config if changed.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
- Required: no
|
||||
- Default: 3600
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
|
||||
2
doh.go
2
doh.go
@@ -147,7 +147,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||
switch {
|
||||
case uc.isControlD():
|
||||
case uc.IsControlD():
|
||||
dohHeader = newControlDHeaders(ci)
|
||||
case uc.isNextDNS():
|
||||
dohHeader = newNextDNSHeaders(ci)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build qf
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type doqResolver struct {
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
|
||||
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
return nil, errors.New("DoQ is not supported")
|
||||
}
|
||||
2
dot.go
2
dot.go
@@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
|
||||
@@ -122,8 +122,8 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan
|
||||
bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30)
|
||||
for {
|
||||
err := m.probe(conns, remoteAddr)
|
||||
if isErrNetUnreachableOrInvalid(err) {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: network unreachable or invalid", remoteAddr)
|
||||
if shouldStopProbing(err) {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
@@ -165,7 +165,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
||||
}
|
||||
|
||||
var ip, name string
|
||||
rrs := make([]dns.RR, 0, len(msg.Answer)+len(msg.Extra))
|
||||
var rrs []dns.RR
|
||||
rrs = append(rrs, msg.Answer...)
|
||||
rrs = append(rrs, msg.Extra...)
|
||||
for _, rr := range rrs {
|
||||
@@ -273,10 +273,14 @@ func multicastInterfaces() ([]net.Interface, error) {
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func isErrNetUnreachableOrInvalid(err error) bool {
|
||||
// shouldStopProbing reports whether ctrld should stop probing mdns.
|
||||
func shouldStopProbing(err error) bool {
|
||||
var se *os.SyscallError
|
||||
if errors.As(err, &se) {
|
||||
return se.Err == syscall.ENETUNREACH || se.Err == syscall.EINVAL
|
||||
switch se.Err {
|
||||
case syscall.ENETUNREACH, syscall.EINVAL, syscall.EPERM:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -26,14 +26,15 @@ const (
|
||||
apiDomainDev = "api.controld.dev"
|
||||
resolverDataURLCom = "https://api.controld.com/utility"
|
||||
resolverDataURLDev = "https://api.controld.dev/utility"
|
||||
InvalidConfigCode = 40401
|
||||
InvalidConfigCode = 40402
|
||||
)
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
type ResolverConfig struct {
|
||||
DOH string `json:"doh"`
|
||||
Ctrld struct {
|
||||
CustomConfig string `json:"custom_config"`
|
||||
CustomConfig string `json:"custom_config"`
|
||||
CustomLastUpdate int64 `json:"custom_last_update"`
|
||||
} `json:"ctrld"`
|
||||
Exclude []string `json:"exclude"`
|
||||
UID string `json:"uid"`
|
||||
@@ -76,17 +77,28 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e
|
||||
req.ClientID = clientID
|
||||
}
|
||||
body, _ := json.Marshal(req)
|
||||
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
|
||||
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
// FetchResolverUID fetch resolver uid from provision token.
|
||||
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
hostname, _ := os.Hostname()
|
||||
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
|
||||
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
|
||||
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) {
|
||||
// UpdateCustomLastFailed calls API to mark custom config is bad.
|
||||
func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) {
|
||||
uid, clientID := ParseRawUID(rawUID)
|
||||
req := utilityRequest{UID: uid}
|
||||
if clientID != "" {
|
||||
req.ClientID = clientID
|
||||
}
|
||||
body, _ := json.Marshal(req)
|
||||
return postUtilityAPI(version, cdDev, true, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) {
|
||||
apiUrl := resolverDataURLCom
|
||||
if cdDev {
|
||||
apiUrl = resolverDataURLDev
|
||||
@@ -98,6 +110,9 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
|
||||
q := req.URL.Query()
|
||||
q.Set("platform", "ctrld")
|
||||
q.Set("version", version)
|
||||
if lastUpdatedFailed {
|
||||
q.Set("custom_last_failed", "1")
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -164,6 +165,16 @@ func HomeDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exe), nil
|
||||
case edgeos.Name:
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Using binary directory as home dir if it is located in /config.
|
||||
// Otherwise, fallback to old behavior for compatibility.
|
||||
if strings.HasPrefix(exe, "/config/") {
|
||||
return filepath.Dir(exe), nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
90
resolver.go
90
resolver.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -30,18 +31,19 @@ const (
|
||||
ResolverTypePrivate = "private"
|
||||
)
|
||||
|
||||
const bootstrapDNS = "76.76.2.22"
|
||||
const (
|
||||
controldBootstrapDns = "76.76.2.22"
|
||||
controldPublicDns = "76.76.2.0"
|
||||
)
|
||||
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = &osResolver{nameservers: defaultNameservers()}
|
||||
|
||||
// defaultNameservers returns nameservers used by the OS.
|
||||
// If no nameservers can be found, ctrld bootstrap nameserver will be used.
|
||||
// defaultNameservers returns OS nameservers plus ControlD public DNS.
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
if len(ns) == 0 {
|
||||
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -51,10 +53,27 @@ func defaultNameservers() []string {
|
||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||
// calling this function.
|
||||
func InitializeOsResolver() []string {
|
||||
or.nameservers = defaultNameservers()
|
||||
or.nameservers = or.nameservers[:0]
|
||||
for _, ns := range defaultNameservers() {
|
||||
if testNameserver(ns) {
|
||||
or.nameservers = append(or.nameservers, ns)
|
||||
}
|
||||
}
|
||||
or.nameservers = append(or.nameservers, controldPublicDnsWithPort)
|
||||
return or.nameservers
|
||||
}
|
||||
|
||||
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
func testNameserver(addr string) bool {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
client := new(dns.Client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, _, err := client.ExchangeContext(ctx, msg, addr)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
//
|
||||
// Resolve resolves the DNS query, return the result and the corresponding error.
|
||||
@@ -89,8 +108,9 @@ type osResolver struct {
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
answer *dns.Msg
|
||||
err error
|
||||
answer *dns.Msg
|
||||
err error
|
||||
isControlDPublicDNS bool
|
||||
}
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
@@ -116,19 +136,34 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
ch <- &osResolverResult{answer: answer, err: err}
|
||||
ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort}
|
||||
}(server)
|
||||
}
|
||||
|
||||
var (
|
||||
nonSuccessAnswer *dns.Msg
|
||||
controldSuccessAnswer *dns.Msg
|
||||
)
|
||||
errs := make([]error, 0, numServers)
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
return res.answer, res.err
|
||||
switch {
|
||||
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||
if res.isControlDPublicDNS {
|
||||
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
||||
} else {
|
||||
cancel()
|
||||
return res.answer, nil
|
||||
}
|
||||
case res.answer != nil:
|
||||
nonSuccessAnswer = res.answer
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} {
|
||||
if answer != nil {
|
||||
return answer, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -138,7 +173,7 @@ type legacyResolver struct {
|
||||
|
||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// See comment in (*dotResolver).resolve method.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -176,7 +211,7 @@ func LookupIP(domain string) []string {
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
@@ -252,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...)
|
||||
for _, ns := range servers {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
||||
}
|
||||
@@ -279,11 +314,11 @@ func NewPrivateResolver() Resolver {
|
||||
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
|
||||
//
|
||||
// causing the query always succeed.
|
||||
if sliceContains(resolveConfNss, host) {
|
||||
if slices.Contains(resolveConfNss, host) {
|
||||
continue
|
||||
}
|
||||
// Ignoring local RFC 1918 addresses.
|
||||
if sliceContains(localRfc1918Addrs, host) {
|
||||
if slices.Contains(localRfc1918Addrs, host) {
|
||||
continue
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
@@ -335,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
||||
// sliceContains reports whether v is present in s.
|
||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
||||
return sliceIndex(s, v) >= 0
|
||||
}
|
||||
|
||||
// sliceIndex returns the index of the first occurrence of v in s,
|
||||
// or -1 if not present.
|
||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
for i := range s {
|
||||
if v == s[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
ns := make([]string, 0, 2)
|
||||
servers := make([]*dns.Server, 0, 2)
|
||||
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
|
||||
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, rcode)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
}
|
||||
|
||||
handlers := []dns.Handler{
|
||||
nonSuccessHandlerWithRcode(dns.RcodeRefused),
|
||||
nonSuccessHandlerWithRcode(dns.RcodeNameError),
|
||||
successHandler,
|
||||
}
|
||||
for i := range handlers {
|
||||
pc, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ns = append(ns, addr)
|
||||
servers = append(servers, s)
|
||||
}
|
||||
defer func() {
|
||||
for _, server := range servers {
|
||||
server.Shutdown()
|
||||
}
|
||||
}()
|
||||
resolver := &osResolver{nameservers: ns}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
|
||||
}
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
server := &dns.Server{
|
||||
PacketConn: pc,
|
||||
ReadTimeout: time.Hour,
|
||||
WriteTimeout: time.Hour,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
server.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
addr, closer := pc.LocalAddr().String(), pc
|
||||
go func() {
|
||||
if err := server.ActivateAndServe(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
closer.Close()
|
||||
}()
|
||||
|
||||
waitLock.Lock()
|
||||
return server, addr, nil
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ var sampleConfigContent = `
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = "/path/to/log.log"
|
||||
dns_watchdog_enabled = false
|
||||
dns_watchdog_interval = "20s"
|
||||
|
||||
[network.0]
|
||||
name = "Home Wifi"
|
||||
|
||||
Reference in New Issue
Block a user