Compare commits

..

35 Commits

Author SHA1 Message Date
Cuong Manh Le
e6586fd360 Merge pull request #169 from Control-D-Inc/release-branch-v1.3.8
Release branch v1.3.8
2024-09-14 22:07:22 +07:00
Cuong Manh Le
33a6db2599 Configure timeout for HTTP2 transport
Otherwise, a stale TCP connection may still alive for too long, causing
unexpected failed to connect upstream error when network changed.
2024-09-14 21:59:33 +07:00
Cuong Manh Le
70b0c4f7b9 cmd/cli: honoring "iface" value in resetDnsTask
Otherwise, ctrld service command will always do reset DNS while it
should not.
2024-08-26 22:06:55 +07:00
Cuong Manh Le
5af3ec4f7b cmd/cli: ensure DNS goroutines terminated before self-uninstall
Otherwise, these goroutines could mess up with what resetDNS function
do, reverting DHCP DNS settings to ctrld listeners.
2024-08-16 13:50:11 +07:00
Cuong Manh Le
79476add12 Testing nameserver when initializing OS resolver
There are several issues with OS resolver right now:

 - The list of nameservers are obtained un-conditionally from all
   running interfaces.

 - ControlD public DNS query is always be used if response ok.

This could lead to slow query time, and also incorrect result if a
domain is resolved differently between internal DNS and ControlD public
DNS.

To fix these problems:

 - While initializing OS resolver, sending a test query to the
   nameserver to ensure it will response. Unreachable nameserver will
   not be used.

 - Only use ControlD public DNS success response as last one, preferring
   ok response from internal DNS servers.

While at it, also using standard package slices, since ctrld now
requires go1.21 as the minimum version.
2024-08-12 14:16:02 +07:00
Cuong Manh Le
1634a06330 all: change refresh_time -> refetch_time
The custom config is refetched from API, not refresh.
2024-08-12 14:15:49 +07:00
Cuong Manh Le
a007394f60 cmd/cli: ensure goroutines that check DNS terminated
So changes to DNS after ctrld stopped won't be reverted by the goroutine
itself. The problem happens rarely on darwin, because networksetup
command won't propagate config to /etc/resolv.conf if there is no
changes between multiple running.
2024-08-08 01:25:49 +07:00
Cuong Manh Le
62a0ba8731 cmd/cli: fix staticcheck linting 2024-08-08 01:25:22 +07:00
Cuong Manh Le
e8d3ed1acd cmd/cli: use currentStaticDNS when checking DNS changed
The dns watchdog is spawned *after* DNS was set by ctrld, thus it should
use the currentStaticDNS for getting the static DNS, instead of relying
on currentDNS, which could be system wide instead of per interfaces.
2024-08-07 15:54:22 +07:00
Cuong Manh Le
8b98faa441 cmd/cli: do not mask err argument of selfUninstall
The err should be preserved, so if we passed the error around, other
functions could still check for utility error code correctly.
2024-08-07 15:54:22 +07:00
Cuong Manh Le
30320ec9c7 cmd/cli: fix issue with editing /etc/resolv.conf directly on Darwin
On Darwin, modifying /etc/resolv.conf directly does not change interface
network settings. Thus the networksetup command uses to set DNS does not
do anything.

To fix this, after setting DNS using networksetup, re-check the content
of /etc/resolv.conf file to see if the nameservers are what we expected.
Otherwise, re-generate the file with proper nameservers.
2024-08-07 15:54:20 +07:00
Cuong Manh Le
5f4a399850 cmd/cli: extend list of valid interfaces for MacOS
By parsing "networksetup -listallhardwareports" output to get list of
available hardware ports.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
82e0d4b0c4 all: add api driven config reload at runtime 2024-08-07 15:51:11 +07:00
Cuong Manh Le
95a9df826d cmd/cli: extend list of valid interfaces for MacOS 2024-08-07 15:51:11 +07:00
Cuong Manh Le
3b71d26cf3 cmd/cli: change "ctrld start" behavior
Without reading the documentation, users may think that "ctrld start"
will just start ctrld service. However, this is not the case, and may
lead to unexpected result from user's point of view.

This commit changes "ctrld start" to just start already installed ctrld
service, so users won't lost what they did installed before. If there
are any arguments specified, performing the current behavior.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
c233ad9b1b cmd/cli: write new config file on reload 2024-08-07 15:51:11 +07:00
Cuong Manh Le
12d6484b1c Remove quic free file
The quic free build was gone long time ago.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
bc7b1cc6d8 cmd/cli: fix wrong config file reading during self-check
At the time self-check process running, we have already known the exact
config file being used by ctrld service. Thus, we should just re-read
this config file directly instead of guessing the config file.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
ec684348ed cmd/cli: add config to control DNS watchdog 2024-08-07 15:51:11 +07:00
Cuong Manh Le
18a19a3aa2 cmd/cli: cleanup more ctrld generated files
While at it, implement function to open log file on Windows for sharing
delete. So the log file could be backup correctly.

This may fix #303
2024-08-07 15:51:11 +07:00
Cuong Manh Le
905f2d08c5 cmd/cli: fix reset DNS when doing self-uninstall
While at it, also using "ctrld uninstall" on unix platform, ensuring
everything is cleanup properly.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
04947b4d87 cmd/cli: make --cleanup removing more files
While at it, also implementing self-delete function for Windows.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
72bf80533e cmd/cli: always run dns watchdog on Darwin/Windows 2024-08-07 15:51:11 +07:00
Cuong Manh Le
9ddedf926e cmd/cli: fix watching symlink /etc/resolv.conf
Currently, ctrld watches changes to /etc/resolv.conf file, then
reverting to the expected settings. However, if /etc/resolv.conf is a
symlink, changes made to the target file maynot be seen if it's not
under /etc directory.

To fix this, just evaluate the /etc/resolv.conf file before watching it.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
139dd62ff3 cmd/cli: Capitalizing launchd status error message 2024-08-07 15:51:11 +07:00
Cuong Manh Le
50ef00526e cmd/cli: add "--cleanup" flag to remove ctrld's files 2024-08-07 15:51:11 +07:00
Cuong Manh Le
80cf79b9cb all: implement self-uninstall ctrld based on REFUSED queries 2024-08-07 15:51:11 +07:00
Cuong Manh Le
e6ad39b070 cmd/cli: add DNS watchdog on Darwin/Windows
Once per minute, ctrld will check if DNS settings was changed or not. If
yes, re-applying the proper settings for system interfaces.

For now, this is only applied when deactivation_pin was set.
2024-08-07 15:51:11 +07:00
Cuong Manh Le
56f9c72569 Add ControlD public DNS to OS resolver
Since the OS resolver only returns response with NOERROR first, it's
safe to use ControlD public DNS in parallel with system DNS. Local
domains would resolve only though local resolvers, because public ones
will return NXDOMAIN response.
2024-08-07 15:51:09 +07:00
Cuong Manh Le
dc48c908b8 cmd/cli: log validate remote config during "ctrld restart"
The same manner with what ctrld is doing for "ctrld start" command.
2024-08-07 15:28:00 +07:00
Cuong Manh Le
9b0f0e792a cmd/cli: workaround incorrect status data when not root 2024-08-07 15:27:46 +07:00
Cuong Manh Le
b3eebb19b6 internal/router: change default config directory on EdgeOS
So ctrld's own files will survive firmware upgrades.
2024-08-07 15:27:18 +07:00
Cuong Manh Le
c24589a5be internal/clientinfo: avoid heap alloc with mdns read loop
Once resource record (RR)  was used to extract necessary information, it
should be freed in memory. However, the current way that ctrld declare
the RRs causing the slices to be heap allocated, and stay in memory
longer than necessary. On system with low capacity, or firmware that GC
does not run agressively, it may causes the system memory exhausted.

To fix it, prevent RRs to be heap allocated, so they could be freed
immediately after each iterations.
2024-08-07 15:27:07 +07:00
Cuong Manh Le
1e1c5a4dc8 internal/clientinfo: tighten condition to stop probing mdns
If we see permission denied error when probing dns, that mean the
current ctrld process won't be able to do that anyway. So the probing
loop must be terminated to prevent waste of resources, or false positive
from system firewall because of too many failed attempts.
2024-08-07 15:27:02 +07:00
Cuong Manh Le
339023421a docker: bump go version for Dockerfile.debug 2024-08-07 15:26:25 +07:00
38 changed files with 1240 additions and 242 deletions

View File

@@ -105,9 +105,11 @@ Available Commands:
start Quick start service and configure DNS on interface
stop Quick stop service and remove DNS from interface
restart Restart the ctrld service
reload Reload the ctrld service
status Show status of the ctrld service
uninstall Stop and uninstall the ctrld service
clients Manage clients
upgrade Upgrading ctrld to latest version
Flags:
-h, --help help for ctrld

View File

@@ -48,6 +48,11 @@ import (
// selfCheckInternalTestDomain is used for testing ctrld self response to clients.
const selfCheckInternalTestDomain = "ctrld" + loopTestDomain
const (
windowsForwardersFilename = ".forwarders.txt"
oldBinSuffix = "_previous"
oldLogSuffix = ".1"
)
var (
version = "dev"
@@ -110,7 +115,7 @@ func initCLI() {
&verbose,
"verbose",
"v",
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
`verbose log output, "-v" basic logging, "-vv" debug logging`,
)
rootCmd.PersistentFlags().BoolVarP(
&silent,
@@ -158,7 +163,10 @@ func initCLI() {
},
Use: "start",
Short: "Install and start the ctrld service",
Args: cobra.NoArgs,
Long: `Install and start the ctrld service
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
checkStrFlagEmpty(cmd, cdUidFlagName)
checkStrFlagEmpty(cmd, cdOrgFlagName)
@@ -182,8 +190,9 @@ func initCLI() {
return
}
status, _ := s.Status()
status, err := s.Status()
isCtrldRunning := status == service.StatusRunning
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
// If pin code was set, do not allow running start command.
if isCtrldRunning {
@@ -192,39 +201,56 @@ func initCLI() {
}
}
if cdUID != "" {
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
if !startOnly {
startOnly = len(osArgs) == 0
}
// If user run "ctrld start" and ctrld is already installed, starting existing service.
if startOnly && isCtrldInstalled {
tryReadingConfigWithNotice(false, true)
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
}
// validateCdRemoteConfig clobbers v, saving it here to restore later.
oldV := v
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
if errors.As(err, &viper.ConfigParseError{}) {
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
tmpDir := os.TempDir()
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
errorLogged := false
// Write remote config to a temporary file to get details error.
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
row, col := de.Position()
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
errorLogged = true
}
_ = os.Remove(tmpConfFile)
}
// If we could not log details error, emit what we have already got.
if !errorLogged {
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
}
}
} else {
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
initLogging()
tasks := []task{
resetDnsTask(p, s),
{s.Stop, false},
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
return saveCurrentStaticDNS(i)
})
return nil
}, false},
{s.Start, true},
{noticeWritingControlDConfig, false},
}
mainLog.Load().Notice().Msg("Starting existing ctrld service")
if doTasks(tasks) {
mainLog.Load().Notice().Msg("Service started")
sockDir, err := socketDir()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory")
os.Exit(1)
}
mainLog.Load().Warn().Msg("disregarding invalid custom config")
if cc := newSocketControlClient(s, sockDir); cc != nil {
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
if iface == "auto" {
iface = defaultIfaceName()
}
logger := mainLog.Load().With().Str("iface", iface).Logger()
logger.Debug().Msg("setting DNS successfully")
}
}
} else {
mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service")
os.Exit(1)
}
v = oldV
return
}
if cdUID != "" {
doValidateCdRemoteConfig(cdUID)
} else if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
mainLog.Load().Debug().Msg("using uid from provision token")
@@ -399,6 +425,8 @@ func initCLI() {
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
_ = startCmd.Flags().MarkHidden("start_only")
routerCmd := &cobra.Command{
Use: "setup",
@@ -485,7 +513,10 @@ func initCLI() {
Run: func(cmd *cobra.Command, args []string) {
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, runInCdMode())}
cdUID = curCdUID()
cdMode := cdUID != ""
p := &prog{router: router.New(&cfg, cdMode)}
s, err := newService(p, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
@@ -497,6 +528,10 @@ func initCLI() {
}
initLogging()
if cdMode {
doValidateCdRemoteConfig(cdUID)
}
iface = runningIface(s)
tasks := []task{
{s.Stop, false},
@@ -623,11 +658,72 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
os.Exit(deactivationPinInvalidExitCode)
}
uninstall(p, s)
if cleanup {
var files []string
// Config file.
files = append(files, v.ConfigFileUsed())
// Log file.
logFile := normalizeLogFilePath(cfg.Service.LogPath)
files = append(files, logFile)
// Backup log file.
oldLogFile := logFile + oldLogSuffix
if _, err := os.Stat(oldLogFile); err == nil {
files = append(files, oldLogFile)
}
// Socket files.
if dir, _ := socketDir(); dir != "" {
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
}
// Static DNS settings files.
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
files = append(files, file)
}
return nil
})
// Windows forwarders file.
if windowsHasLocalDnsServerRunning() {
files = append(files, absHomeDir(windowsForwardersFilename))
}
// Binary itself.
bin, _ := os.Executable()
if bin != "" && supportedSelfDelete {
files = append(files, bin)
}
// Backup file after upgrading.
oldBin := bin + oldBinSuffix
if _, err := os.Stat(oldBin); err == nil {
files = append(files, oldBin)
}
for _, file := range files {
if file == "" {
continue
}
if err := os.Remove(file); err != nil {
if os.IsNotExist(err) {
continue
}
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
} else {
mainLog.Load().Debug().Msgf("file removed: %s", file)
}
}
if err := selfDeleteExe(); err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
} else {
if !supportedSelfDelete {
mainLog.Load().Debug().Msgf("file removed: %s", bin)
}
}
}
},
}
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`)
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`)
_ = uninstallCmd.Flags().MarkHidden("pin")
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
listIfacesCmd := &cobra.Command{
Use: "list",
@@ -697,7 +793,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
},
Use: "start",
Short: "Quick start service and configure DNS on interface",
Long: `Quick start service and configure DNS on interface
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Run: func(cmd *cobra.Command, args []string) {
if len(os.Args) == 2 {
startOnly = true
}
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
@@ -776,7 +878,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
},
}
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
uninstallCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
rootCmd.AddCommand(uninstallCmdAlias)
listClientsCmd := &cobra.Command{
@@ -894,7 +996,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
svcInstalled = false
}
oldBin := bin + "_previous"
oldBin := bin + oldBinSuffix
baseUrl := upgradeChannel[upgradeChannelDefault]
if len(args) > 0 {
channel := args[0]
@@ -1033,12 +1135,14 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
waitCh := make(chan struct{})
p := &prog{
waitCh: waitCh,
stopCh: stopCh,
reloadCh: make(chan struct{}),
reloadDoneCh: make(chan struct{}),
cfg: &cfg,
appCallback: appCallback,
waitCh: waitCh,
stopCh: stopCh,
reloadCh: make(chan struct{}),
reloadDoneCh: make(chan struct{}),
dnsWatcherStopCh: make(chan struct{}),
apiReloadCh: make(chan *ctrld.Config),
cfg: &cfg,
appCallback: appCallback,
}
if homedir == "" {
if dir, err := userHomeDir(); err == nil {
@@ -1128,36 +1232,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
return
}
uninstallIfInvalidCdUID := func() {
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := newService(&prog{}, svcConfig)
if err != nil {
cdLogger.Warn().Err(err).Msg("failed to create new service")
return
}
if netIface, _ := netInterface(iface); netIface != nil {
if err := restoreNetworkManager(); err != nil {
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
} else {
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
}
}
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
cdLogger.Info().Msg("uninstalled service")
}
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
return
}
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
// Performs self-uninstallation if the ControlD device does not exist.
var uer *controld.UtilityErrorResponse
if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
_ = uninstallInvalidCdUID(p, cdLogger, false)
}
uninstallIfInvalidCdUID()
cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config")
}
}
@@ -1168,7 +1249,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
if updated {
if err := writeConfigFile(); err != nil {
if err := writeConfigFile(&cfg); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to write config file")
} else {
mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile)
@@ -1257,12 +1338,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
close(waitCh)
<-stopCh
// Wait goroutines which watches/manipulates DNS settings terminated,
// ensuring that changes to DNS since here won't be reverted.
p.dnsWg.Wait()
for _, f := range p.onStopped {
f()
}
}
func writeConfigFile() error {
func writeConfigFile(cfg *ctrld.Config) error {
if cfu := v.ConfigFileUsed(); cfu != "" {
defaultConfigFile = cfu
} else if configPath != "" {
@@ -1315,7 +1400,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
}
nop := zerolog.Nop()
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
if err := writeConfigFile(); err != nil {
if err := writeConfigFile(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
} else {
fp, err := filepath.Abs(defaultConfigFile)
@@ -1639,9 +1724,10 @@ func selfCheckStatus(s service.Service, homedir, sockDir string) (bool, service.
}
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.SetConfigNameWithPath(v, "ctrld", homedir)
if configPath != "" {
v.SetConfigFile(configPath)
} else {
v.SetConfigFile(defaultConfigFile)
}
if err := v.ReadInConfig(); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed())
@@ -2375,7 +2461,7 @@ func doGenerateNextDNSConfig(uid string) error {
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
generateNextDNSConfig(uid)
updateListenerConfig(&cfg)
return writeConfigFile()
return writeConfigFile(&cfg)
}
func noticeWritingControlDConfig() error {
@@ -2423,7 +2509,7 @@ func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
return nil // the server is running older version of ctrld
}
}
mainLog.Load().Error().Msg(errInvalidDeactivationPin.Error())
mainLog.Load().Error().Err(err).Msg(errInvalidDeactivationPin.Error())
return errInvalidDeactivationPin
}
@@ -2482,6 +2568,11 @@ func absHomeDir(filename string) string {
// runInCdMode reports whether ctrld service is running in cd mode.
func runInCdMode() bool {
return curCdUID() != ""
}
// curCdUID returns the current ControlD UID used by running ctrld process.
func curCdUID() string {
if s, _ := newService(&prog{}, svcConfig); s != nil {
if dir, _ := socketDir(); dir != "" {
cc := newSocketControlClient(s, dir)
@@ -2489,12 +2580,13 @@ func runInCdMode() bool {
resp, _ := cc.post(cdPath, nil)
if resp != nil {
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
buf, _ := io.ReadAll(resp.Body)
return string(buf)
}
}
}
}
return false
return ""
}
// goArm returns the GOARM value for the binary.
@@ -2557,6 +2649,9 @@ func resetDnsTask(p *prog, s service.Service) task {
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
isCtrldRunning := status == service.StatusRunning
return task{func() error {
if iface == "" {
return nil
}
// Always reset DNS first, ensuring DNS setting is in a good state.
// resetDNS must use the "iface" value of current running ctrld
// process to reset what setDNS has done properly.
@@ -2572,3 +2667,60 @@ func resetDnsTask(p *prog, s service.Service) task {
return nil
}, false}
}
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
func doValidateCdRemoteConfig(cdUID string) {
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
}
// validateCdRemoteConfig clobbers v, saving it here to restore later.
oldV := v
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
if errors.As(err, &viper.ConfigParseError{}) {
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
tmpDir := os.TempDir()
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
errorLogged := false
// Write remote config to a temporary file to get details error.
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
row, col := de.Position()
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
errorLogged = true
}
_ = os.Remove(tmpConfFile)
}
// If we could not log details error, emit what we have already got.
if !errorLogged {
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
}
}
} else {
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
}
mainLog.Load().Warn().Msg("disregarding invalid custom config")
}
v = oldV
}
// uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist.
func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
s, err := newService(p, svcConfig)
if err != nil {
logger.Warn().Err(err).Msg("failed to create new service")
return false
}
p.resetDNS()
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
logger.Info().Msg("uninstalled service")
if doStop {
_ = s.Stop()
}
return true
}
return false
}

View File

@@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) {
_, err := os.Stat(configPath)
assert.True(t, os.IsNotExist(err))
assert.NoError(t, writeConfigFile())
assert.NoError(t, writeConfigFile(&cfg))
_, err = os.Stat(configPath)
require.NoError(t, err)

View File

@@ -73,7 +73,7 @@ func (p *prog) registerControlServerHandler() {
sort.Slice(clients, func(i, j int) bool {
return clients[i].IP.Less(clients[j].IP)
})
if p.cfg.Service.MetricsQueryStats {
if p.metricsQueryStats.Load() {
for _, client := range clients {
client.IncludeQueryCount = true
dm := &dto.Metric{}
@@ -178,6 +178,7 @@ func (p *prog) registerControlServerHandler() {
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if cdUID != "" {
w.WriteHeader(http.StatusOK)
w.Write([]byte(cdUID))
return
}
w.WriteHeader(http.StatusBadRequest)

View File

@@ -21,6 +21,7 @@ import (
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
@@ -32,6 +33,9 @@ const (
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
EDNS0_OPTION_MAC = 0xFDE9
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
selfUninstallMaxQueries = 32
)
var osUpstreamConfig = &ctrld.UpstreamConfig{
@@ -89,6 +93,7 @@ func (p *prog) serveDNS(listenerNum string) error {
_ = w.WriteMsg(answer)
return
}
listenerConfig := p.cfg.Listener[listenerNum]
reqId := requestID()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
@@ -143,6 +148,7 @@ func (p *prog) serveDNS(listenerNum string) error {
failoverRcodes: failoverRcode,
ufr: ur,
})
go p.doSelfUninstall(pr.answer)
answer = pr.answer
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
@@ -836,6 +842,51 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
}
}
// doSelfUninstall performs self-uninstall if these condition met:
//
// - There is only 1 ControlD upstream in-use.
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
// - The cdUID is deleted.
func (p *prog) doSelfUninstall(answer *dns.Msg) {
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
return
}
p.selfUninstallMu.Lock()
defer p.selfUninstallMu.Unlock()
if p.checkingSelfUninstall {
return
}
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
if p.refusedQueryCount > selfUninstallMaxQueries {
p.checkingSelfUninstall = true
_, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
selfUninstallCheck(err, p, logger)
if err != nil {
logger.Warn().Err(err).Msg("could not fetch resolver config")
}
// Cool-of period to prevent abusing the API.
go p.selfUninstallCoolOfPeriod()
return
}
p.refusedQueryCount++
}
// selfUninstallCoolOfPeriod waits for 30 minutes before
// calling API again for checking ControlD device status.
func (p *prog) selfUninstallCoolOfPeriod() {
t := time.NewTimer(time.Minute * 30)
defer t.Stop()
<-t.C
p.selfUninstallMu.Lock()
p.checkingSelfUninstall = false
p.refusedQueryCount = 0
p.selfUninstallMu.Unlock()
}
// queryFromSelf reports whether the input IP is from device running ctrld.
func queryFromSelf(ip string) bool {
netIP := netip.MustParseAddr(ip)

View File

@@ -36,6 +36,8 @@ var (
cdUpstreamProto string
deactivationPin int64
skipSelfChecks bool
cleanup bool
startOnly bool
mainLog atomic.Pointer[zerolog.Logger]
consoleWriter zerolog.ConsoleWriter
@@ -63,8 +65,11 @@ func Main() {
}
func normalizeLogFilePath(logFilePath string) string {
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
return logFilePath
// In cleanup mode, we always want the full log file path.
if !cleanup {
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
return logFilePath
}
}
if homedir != "" {
return filepath.Join(homedir, logFilePath)
@@ -121,14 +126,14 @@ func initLoggingWithBackup(doBackup bool) {
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
if doBackup {
// Backup old log file with .1 suffix.
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
} else {
// Backup was created, set flags for truncating old log file.
flags = os.O_CREATE | os.O_RDWR
}
}
logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600))
logFile, err := openLogFile(logFilePath, flags)
if err != nil {
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
os.Exit(1)

View File

@@ -107,7 +107,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
reg := prometheus.NewRegistry()
// Register queries count stats if enabled.
if cfg.Service.MetricsQueryStats {
if p.metricsQueryStats.Load() {
reg.MustRegister(statsQueriesCount)
reg.MustRegister(statsClientQueriesCount)
}

View File

@@ -43,20 +43,32 @@ func networkServiceName(ifaceName string, r io.Reader) string {
return ""
}
// validInterface reports whether the *net.Interface is a valid one, which includes:
//
// - en0: physical wireless
// - en1: Thunderbolt 1
// - en2: Thunderbolt 2
// - en3: Thunderbolt 3
// - en4: Thunderbolt 4
//
// For full list, see: https://unix.stackexchange.com/questions/603506/what-are-these-ifconfig-interfaces-on-macos
func validInterface(iface *net.Interface) bool {
switch iface.Name {
case "en0", "en1", "en2", "en3", "en4":
return true
default:
return false
}
// validInterface reports whether the *net.Interface is a valid one.
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
_, ok := validIfacesMap[iface.Name]
return ok
}
func validInterfacesMap() map[string]struct{} {
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
if err != nil {
return nil
}
return parseListAllHardwarePorts(bytes.NewReader(b))
}
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
// and returns map presents all hardware ports.
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
m := make(map[string]struct{})
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
after, ok := strings.CutPrefix(line, "Device: ")
if !ok {
continue
}
m[after] = struct{}{}
}
return m
}

View File

@@ -1,6 +1,7 @@
package cli
import (
"maps"
"strings"
"testing"
@@ -57,3 +58,47 @@ func Test_networkServiceName(t *testing.T) {
})
}
}
const listallhardwareportsOutput = `
Hardware Port: Ethernet Adapter (en6)
Device: en6
Ethernet Address: 3a:3e:fc:1e:ab:41
Hardware Port: Ethernet Adapter (en7)
Device: en7
Ethernet Address: 3a:3e:fc:1e:ab:42
Hardware Port: Thunderbolt Bridge
Device: bridge0
Ethernet Address: 36:21:bb:3a:7a:40
Hardware Port: Wi-Fi
Device: en0
Ethernet Address: a0:78:17:68:56:3f
Hardware Port: Thunderbolt 1
Device: en1
Ethernet Address: 36:21:bb:3a:7a:40
Hardware Port: Thunderbolt 2
Device: en2
Ethernet Address: 36:21:bb:3a:7a:44
VLAN Configurations
===================
`
func Test_parseListAllHardwarePorts(t *testing.T) {
expected := map[string]struct{}{
"en0": {},
"en1": {},
"en2": {},
"en6": {},
"en7": {},
"bridge0": {},
}
m := parseListAllHardwarePorts(strings.NewReader(listallhardwareportsOutput))
if !maps.Equal(m, expected) {
t.Errorf("unexpected output, want: %v, got: %v", expected, m)
}
}

View File

@@ -6,4 +6,6 @@ import "net"
func patchNetIfaceName(iface *net.Interface) error { return nil }
func validInterface(iface *net.Interface) bool { return true }
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
func validInterfacesMap() map[string]struct{} { return nil }

View File

@@ -10,7 +10,7 @@ func patchNetIfaceName(iface *net.Interface) error {
// validInterface reports whether the *net.Interface is a valid one.
// On Windows, only physical interfaces are considered valid.
func validInterface(iface *net.Interface) bool {
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
if iface == nil {
return false
}
@@ -19,3 +19,5 @@ func validInterface(iface *net.Interface) bool {
}
return false
}
func validInterfacesMap() map[string]struct{} { return nil }

View File

@@ -24,6 +24,8 @@ import (
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
// allocate loopback ip
// sudo ip a add 127.0.0.2/24 dev lo
func allocateIP(ip string) error {

View File

@@ -16,7 +16,6 @@ import (
)
const (
forwardersFilename = ".forwarders.txt"
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
)
@@ -40,7 +39,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if windowsHasLocalDnsServerRunning() {
file := absHomeDir(forwardersFilename)
file := absHomeDir(windowsForwardersFilename)
oldForwardersContent, _ := os.ReadFile(file)
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
@@ -72,7 +71,7 @@ func resetDNS(iface *net.Interface) error {
resetDNSOnce.Do(func() {
// See corresponding comment in setDNS.
if windowsHasLocalDnsServerRunning() {
file := absHomeDir(forwardersFilename)
file := absHomeDir(windowsForwardersFilename)
content, err := os.ReadFile(file)
if err != nil {
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")

View File

@@ -12,19 +12,24 @@ import (
"net/url"
"os"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/kardianos/service"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"tailscale.com/net/interfaces"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/internal/router"
)
@@ -38,6 +43,7 @@ const (
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
dnsWatchdogDefaultInterval = 20 * time.Second
)
// ControlSocketName returns name for control unix socket.
@@ -62,15 +68,19 @@ var svcConfig = &service.Config{
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
apiReloadCh chan *ctrld.Config
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
dnsWatchDogOnce sync.Once
dnsWg sync.WaitGroup
dnsWatcherStopCh chan struct{}
cfg *ctrld.Config
localUpstreams []string
@@ -84,6 +94,12 @@ type prog struct {
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
metricsQueryStats atomic.Bool
selfUninstallMu sync.Mutex
refusedQueryCount int
canSelfUninstall atomic.Bool
checkingSelfUninstall bool
loopMu sync.Mutex
loop map[string]bool
@@ -117,11 +133,15 @@ func (p *prog) runWait() {
p.run(reload, reloadCh)
reload = true
}()
var newCfg *ctrld.Config
select {
case sig := <-reloadSigCh:
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
case <-p.reloadCh:
logger.Notice().Msg("reloading...")
case apiCfg := <-p.apiReloadCh:
newCfg = apiCfg
case <-p.stopCh:
close(reloadCh)
return
@@ -131,28 +151,31 @@ func (p *prog) runWait() {
close(reloadCh)
<-done
}
newCfg := &ctrld.Config{}
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
if newCfg == nil {
newCfg = &ctrld.Config{}
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
}
}
}
waitOldRunDone()
@@ -178,6 +201,10 @@ func (p *prog) runWait() {
continue
}
if err := writeConfigFile(newCfg); err != nil {
logger.Err(err).Msg("could not write new config")
}
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
@@ -188,6 +215,7 @@ func (p *prog) runWait() {
p.mu.Unlock()
logger.Notice().Msg("reloading config successfully")
select {
case p.reloadDoneCh <- struct{}{}:
default:
@@ -214,12 +242,67 @@ func (p *prog) postRun() {
}
}
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
func (p *prog) apiConfigReload() {
if cdUID == "" {
return
}
secs := 3600
if p.cfg.Service.RefetchTime != nil && *p.cfg.Service.RefetchTime > 0 {
secs = *p.cfg.Service.RefetchTime
}
ticker := time.NewTicker(time.Duration(secs) * time.Second)
defer ticker.Stop()
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
logger.Debug().Msg("starting custom config reload timer")
lastUpdated := time.Now().Unix()
for {
select {
case <-ticker.C:
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
selfUninstallCheck(err, p, logger)
if err != nil {
logger.Warn().Err(err).Msg("could not fetch resolver config")
continue
}
if resolverConfig.Ctrld.CustomConfig == "" {
continue
}
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated {
lastUpdated = time.Now().Unix()
cfg := &ctrld.Config{}
if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil {
logger.Warn().Err(err).Msg("skipping invalid custom config")
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
logger.Error().Err(err).Msg("could not mark custom last update failed")
}
break
}
setListenerDefaultValue(cfg)
logger.Debug().Msg("custom config changes detected, reloading...")
p.apiReloadCh <- cfg
} else {
logger.Debug().Msg("custom config does not change")
}
case <-p.stopCh:
return
}
}
}
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
isControlDUpstream := false
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
uc.Init()
isControlDUpstream = isControlDUpstream || uc.IsControlD()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
@@ -236,6 +319,10 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) {
ptrNameservers = append(ptrNameservers, uc.Endpoint)
}
}
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
if len(cfg.Upstream) == 1 && isControlDUpstream {
p.canSelfUninstall.Store(true)
}
p.localUpstreams = localUpstreams
p.ptrNameservers = ptrNameservers
}
@@ -271,6 +358,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
@@ -397,6 +485,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if p.logConn != nil {
_ = p.logConn.Close()
}
go p.apiConfigReload()
p.postRun()
}
wg.Wait()
@@ -510,13 +599,86 @@ func (p *prog) setDNS() {
for i := range nameservers {
servers[i] = netip.MustParseAddr(nameservers[i])
}
go watchResolvConf(netIface, servers, setResolvConf)
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.dnsWatchdog(netIface, nameservers, allIfaces)
}()
}
}
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
func (p *prog) dnsWatchdogEnabled() bool {
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
return *ptr
}
return true
}
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
func (p *prog) dnsWatchdogDuration() time.Duration {
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
if (*ptr).Seconds() > 0 {
return *ptr
}
}
return dnsWatchdogDefaultInterval
}
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
// This is only works when deactivation pin set.
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) {
if !requiredMultiNICsConfig() {
return
}
p.dnsWatchDogOnce.Do(func() {
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
for {
select {
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if dnsChanged(iface, ns) {
logger.Debug().Msg("DNS settings were changed, re-applying settings")
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if allIfaces {
withEachPhysicalInterfaces(iface.Name, "re-applying DNS", func(i *net.Interface) error {
if dnsChanged(i, ns) {
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
}
}
return nil
})
}
}
}
})
}
func (p *prog) resetDNS() {
@@ -727,13 +889,14 @@ func canBeLocalUpstream(addr string) bool {
// the interface that matches excludeIfaceName. The context is used to clarify the
// log message when error happens.
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap()
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
// Skip loopback/virtual interface.
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
return
}
// Skip invalid interface.
if !validInterface(i.Interface) {
if !validInterface(i.Interface, validIfacesMap) {
return
}
netIface := i.Interface
@@ -747,7 +910,9 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
}
// TODO: investigate whether we should report this error?
if err := f(netIface); err == nil {
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
if context != "" {
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
}
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
}
@@ -806,3 +971,24 @@ func savedStaticNameservers(iface *net.Interface) []string {
}
return nil
}
// dnsChanged reports whether DNS settings for given interface was changed.
// The caller must sort the nameservers before calling this function.
func dnsChanged(iface *net.Interface, nameservers []string) bool {
curNameservers, _ := currentStaticDNS(iface)
slices.Sort(curNameservers)
return !slices.Equal(curNameservers, nameservers)
}
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
var uer *controld.UtilityErrorResponse
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
// Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall.
close(p.dnsWatcherStopCh)
p.dnsWg.Wait()
// Perform self-uninstall now.
selfUninstall(p, logger)
}
}

57
cmd/cli/prog_test.go Normal file
View File

@@ -0,0 +1,57 @@
package cli
import (
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
"github.com/stretchr/testify/assert"
)
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
p := &prog{cfg: &ctrld.Config{}}
// Default value is true.
assert.True(t, p.dnsWatchdogEnabled())
tests := []struct {
name string
enabled bool
}{
{"enabled", true},
{"disabled", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
})
}
}
func Test_prog_dnsWatchdogInterval(t *testing.T) {
p := &prog{cfg: &ctrld.Config{}}
// Default value is 20s.
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
tests := []struct {
name string
duration time.Duration
expected time.Duration
}{
{"valid", time.Minute, time.Minute},
{"zero", 0, dnsWatchdogDefaultInterval},
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
})
}
}

View File

@@ -51,7 +51,7 @@ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
if p.cfg.Service.MetricsQueryStats {
if p.metricsQueryStats.Load() {
c.WithLabelValues(lvs...).Inc()
}
}

View File

@@ -8,15 +8,15 @@ import (
"github.com/fsnotify/fsnotify"
)
const (
resolvConfPath = "/etc/resolv.conf"
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
)
// watchResolvConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld.
func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
resolvConfPath := "/etc/resolv.conf"
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
resolvConfPath = rp
}
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
watcher, err := fsnotify.NewWatcher()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
@@ -28,12 +28,17 @@ func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
watchDir := filepath.Dir(resolvConfPath)
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
return
}
for {
select {
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if !ok {
return

View File

@@ -3,15 +3,44 @@ package cli
import (
"net"
"net/netip"
"os"
"slices"
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
)
const resolvConfPath = "/etc/resolv.conf"
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
servers := make([]string, len(ns))
for i := range ns {
servers[i] = ns[i].String()
}
return setDNS(iface, servers)
if err := setDNS(iface, servers); err != nil {
return err
}
slices.Sort(servers)
curNs := currentDNS(iface)
slices.Sort(curNs)
if !slices.Equal(curNs, servers) {
c, err := resolvconffile.ParseFile(resolvConfPath)
if err != nil {
return err
}
c.Nameservers = ns
f, err := os.Create(resolvConfPath)
if err != nil {
return err
}
defer f.Close()
if err := c.Write(f); err != nil {
return err
}
return f.Close()
}
return nil
}
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.

View File

@@ -0,0 +1,7 @@
//go:build !windows
package cli
var supportedSelfDelete = true
func selfDeleteExe() error { return nil }

View File

@@ -0,0 +1,134 @@
// Copied from https://github.com/secur30nly/go-self-delete
// with modification to suitable for ctrld usage.
/*
License: MIT Licence
References:
- https://github.com/LloydLabs/delete-self-poc
- https://twitter.com/jonasLyk/status/1350401461985955840
*/
package cli
import (
"unsafe"
"golang.org/x/sys/windows"
)
var supportedSelfDelete = false
type FILE_RENAME_INFO struct {
Union struct {
ReplaceIfExists bool
Flags uint32
}
RootDirectory windows.Handle
FileNameLength uint32
FileName [1]uint16
}
type FILE_DISPOSITION_INFO struct {
DeleteFile bool
}
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
handle, err := windows.CreateFile(
pwPath,
windows.DELETE,
0,
nil,
windows.OPEN_EXISTING,
windows.FILE_ATTRIBUTE_NORMAL,
0,
)
if err != nil {
return 0, err
}
return handle, nil
}
func dsRenameHandle(hHandle windows.Handle) error {
var fRename FILE_RENAME_INFO
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
if err != nil {
return err
}
lpwStream := &DS_STREAM_RENAME[0]
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
uintptr(unsafe.Pointer(&fRename.FileName[0])),
uintptr(unsafe.Pointer(lpwStream)),
unsafe.Sizeof(lpwStream),
)
err = windows.SetFileInformationByHandle(
hHandle,
windows.FileRenameInfo,
(*byte)(unsafe.Pointer(&fRename)),
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
)
if err != nil {
return err
}
return nil
}
func dsDepositeHandle(hHandle windows.Handle) error {
var fDelete FILE_DISPOSITION_INFO
fDelete.DeleteFile = true
err := windows.SetFileInformationByHandle(
hHandle,
windows.FileDispositionInfo,
(*byte)(unsafe.Pointer(&fDelete)),
uint32(unsafe.Sizeof(fDelete)),
)
if err != nil {
return err
}
return nil
}
func selfDeleteExe() error {
var wcPath [windows.MAX_PATH + 1]uint16
var hCurrent windows.Handle
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
if err != nil {
return err
}
hCurrent, err = dsOpenHandle(&wcPath[0])
if err != nil || hCurrent == windows.InvalidHandle {
return err
}
if err := dsRenameHandle(hCurrent); err != nil {
_ = windows.CloseHandle(hCurrent)
return err
}
_ = windows.CloseHandle(hCurrent)
hCurrent, err = dsOpenHandle(&wcPath[0])
if err != nil || hCurrent == windows.InvalidHandle {
return err
}
if err := dsDepositeHandle(hCurrent); err != nil {
_ = windows.CloseHandle(hCurrent)
return err
}
return windows.CloseHandle(hCurrent)
}

View File

@@ -0,0 +1,16 @@
//go:build !unix
package cli
import (
"os"
"github.com/rs/zerolog"
)
func selfUninstall(p *prog, logger zerolog.Logger) {
if uninstallInvalidCdUID(p, logger, false) {
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
os.Exit(0)
}
}

45
cmd/cli/self_kill_unix.go Normal file
View File

@@ -0,0 +1,45 @@
//go:build unix
package cli
import (
"fmt"
"os"
"os/exec"
"runtime"
"syscall"
"github.com/rs/zerolog"
)
func selfUninstall(p *prog, logger zerolog.Logger) {
if runtime.GOOS == "linux" {
selfUninstallLinux(p, logger)
}
bin, err := os.Executable()
if err != nil {
logger.Fatal().Err(err).Msg("could not determine executable")
}
args := []string{"uninstall"}
if !deactivationPinNotSet() {
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin))
}
cmd := exec.Command(bin, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
if err := cmd.Start(); err != nil {
logger.Fatal().Err(err).Msg("could not start self uninstall command")
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
_ = cmd.Wait()
os.Exit(0)
}
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
if uninstallInvalidCdUID(p, logger, true) {
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
os.Exit(0)
}
}

View File

@@ -28,6 +28,9 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
return &sysV{s}, nil
case s.Platform() == "linux-systemd":
return &systemd{s}, nil
case s.Platform() == "darwin-launchd":
return newLaunchd(s), nil
}
return s, nil
}
@@ -113,7 +116,7 @@ func (s *procd) Status() (service.Status, error) {
return service.StatusRunning, nil
}
// procd wraps a service.Service, and provide status command to
// systemd wraps a service.Service, and provide status command to
// report the status correctly.
type systemd struct {
service.Service
@@ -127,6 +130,29 @@ func (s *systemd) Status() (service.Status, error) {
return s.Service.Status()
}
func newLaunchd(s service.Service) *launchd {
return &launchd{
Service: s,
statusErrMsg: "Permission denied",
}
}
// launchd wraps a service.Service, and provide status command to
// report the status correctly when not running as root on Darwin.
//
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
type launchd struct {
service.Service
statusErrMsg string
}
func (l *launchd) Status() (service.Status, error) {
if os.Geteuid() != 0 {
return service.StatusUnknown, errors.New(l.statusErrMsg)
}
return l.Service.Status()
}
type task struct {
f func() error
abortOnError bool

View File

@@ -9,3 +9,7 @@ import (
func hasElevatedPrivilege() (bool, error) {
return os.Geteuid() == 0, nil
}
func openLogFile(path string, flags int) (*os.File, error) {
return os.OpenFile(path, flags, os.FileMode(0o600))
}

View File

@@ -1,6 +1,11 @@
package cli
import "golang.org/x/sys/windows"
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
func hasElevatedPrivilege() (bool, error) {
var sid *windows.SID
@@ -22,3 +27,55 @@ func hasElevatedPrivilege() (bool, error) {
token := windows.Token(0)
return token.IsMember(sid)
}
func openLogFile(path string, mode int) (*os.File, error) {
if len(path) == 0 {
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
}
pathP, err := syscall.UTF16PtrFromString(path)
if err != nil {
return nil, err
}
var access uint32
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
case os.O_RDONLY:
access = windows.GENERIC_READ
case os.O_WRONLY:
access = windows.GENERIC_WRITE
case os.O_RDWR:
access = windows.GENERIC_READ | windows.GENERIC_WRITE
}
if mode&os.O_CREATE != 0 {
access |= windows.GENERIC_WRITE
}
if mode&os.O_APPEND != 0 {
access &^= windows.GENERIC_WRITE
access |= windows.FILE_APPEND_DATA
}
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var sa *syscall.SecurityAttributes
var createMode uint32
switch {
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
createMode = windows.CREATE_NEW
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
createMode = windows.CREATE_ALWAYS
case mode&os.O_CREATE == os.O_CREATE:
createMode = windows.OPEN_ALWAYS
case mode&os.O_TRUNC == os.O_TRUNC:
createMode = windows.TRUNCATE_EXISTING
default:
createMode = windows.OPEN_EXISTING
}
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
if err != nil {
return nil, &os.PathError{Path: path, Op: "open", Err: err}
}
return os.NewFile(uintptr(handle), path), nil
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/go-playground/validator/v10"
"github.com/miekg/dns"
"github.com/spf13/viper"
"golang.org/x/net/http2"
"golang.org/x/sync/singleflight"
"tailscale.com/logtail/backoff"
"tailscale.com/net/tsaddr"
@@ -188,27 +189,30 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
// ServiceConfig specifies the general ctrld config.
type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
}
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
@@ -316,7 +320,7 @@ func (uc *UpstreamConfig) Init() {
}
}
if uc.IPStack == "" {
if uc.isControlD() {
if uc.IsControlD() {
uc.IPStack = IpStackSplit
} else {
uc.IPStack = IpStackBoth
@@ -354,7 +358,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
}
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
if uc.isControlD() || uc.isNextDNS() {
if uc.IsControlD() || uc.isNextDNS() {
return true
}
}
@@ -401,7 +405,7 @@ func (uc *UpstreamConfig) UID() string {
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
isControlD := uc.isControlD()
isControlD := uc.IsControlD()
for {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
@@ -486,6 +490,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
ClientSessionCache: tls.NewLRUClientSessionCache(0),
}
// Prevent bad tcp connection hanging the requests for too long.
// See: https://github.com/golang/go/issues/36026
if t2, err := http2.ConfigureTransports(transport); err == nil {
t2.ReadIdleTimeout = 10 * time.Second
t2.PingTimeout = 5 * time.Second
}
dialerTimeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
dialerTimeoutMs = uc.Timeout
@@ -572,7 +583,8 @@ func (uc *UpstreamConfig) ping() error {
return nil
}
func (uc *UpstreamConfig) isControlD() bool {
// IsControlD reports whether this is a ControlD upstream.
func (uc *UpstreamConfig) IsControlD() bool {
domain := uc.Domain
if domain == "" {
if u, err := url.Parse(uc.Endpoint); err == nil {

View File

@@ -5,6 +5,7 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/go-playground/validator/v10"
"github.com/spf13/viper"
@@ -22,6 +23,8 @@ func TestLoadConfig(t *testing.T) {
assert.Equal(t, "info", cfg.Service.LogLevel)
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled)
assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval)
assert.Len(t, cfg.Network, 2)
assert.Contains(t, cfg.Network, "0")

View File

@@ -8,7 +8,7 @@
# - Non-cgo ctrld binary.
#
# CI_COMMIT_TAG is used to set the version of ctrld binary.
FROM golang:1.20-bullseye as base
FROM golang:bullseye as base
WORKDIR /app

View File

@@ -252,6 +252,35 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
- Required: no
- Default: ""
### dns_watchdog_enabled
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
The DNS watchdog process only runs on Windows and MacOS.
- Type: boolean
- Required: no
- Default: true
### dns_watchdog_interval
Time duration between each DNS watchdog iteration.
A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix,
such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
If the time duration is non-positive, default value will be used.
- Type: time duration string
- Required: no
- Default: 20s
### refetch_time
Time in seconds between each iteration that reloads custom config if changed.
The value must be a positive number, any invalid value will be ignored and default value will be used.
- Type: number
- Required: no
- Default: 3600
## Upstream
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.

2
doh.go
View File

@@ -147,7 +147,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
switch {
case uc.isControlD():
case uc.IsControlD():
dohHeader = newControlDHeaders(ci)
case uc.isNextDNS():
dohHeader = newNextDNSHeaders(ci)

View File

@@ -1,18 +0,0 @@
//go:build qf
package ctrld
import (
"context"
"errors"
"github.com/miekg/dns"
)
type doqResolver struct {
uc *UpstreamConfig
}
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
return nil, errors.New("DoQ is not supported")
}

2
dot.go
View File

@@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype

View File

@@ -122,8 +122,8 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan
bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30)
for {
err := m.probe(conns, remoteAddr)
if isErrNetUnreachableOrInvalid(err) {
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: network unreachable or invalid", remoteAddr)
if shouldStopProbing(err) {
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err)
break
}
if err != nil {
@@ -165,7 +165,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
}
var ip, name string
rrs := make([]dns.RR, 0, len(msg.Answer)+len(msg.Extra))
var rrs []dns.RR
rrs = append(rrs, msg.Answer...)
rrs = append(rrs, msg.Extra...)
for _, rr := range rrs {
@@ -273,10 +273,14 @@ func multicastInterfaces() ([]net.Interface, error) {
return interfaces, nil
}
func isErrNetUnreachableOrInvalid(err error) bool {
// shouldStopProbing reports whether ctrld should stop probing mdns.
func shouldStopProbing(err error) bool {
var se *os.SyscallError
if errors.As(err, &se) {
return se.Err == syscall.ENETUNREACH || se.Err == syscall.EINVAL
switch se.Err {
case syscall.ENETUNREACH, syscall.EINVAL, syscall.EPERM:
return true
}
}
return false
}

View File

@@ -26,14 +26,15 @@ const (
apiDomainDev = "api.controld.dev"
resolverDataURLCom = "https://api.controld.com/utility"
resolverDataURLDev = "https://api.controld.dev/utility"
InvalidConfigCode = 40401
InvalidConfigCode = 40402
)
// ResolverConfig represents Control D resolver data.
type ResolverConfig struct {
DOH string `json:"doh"`
Ctrld struct {
CustomConfig string `json:"custom_config"`
CustomConfig string `json:"custom_config"`
CustomLastUpdate int64 `json:"custom_last_update"`
} `json:"ctrld"`
Exclude []string `json:"exclude"`
UID string `json:"uid"`
@@ -76,17 +77,28 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e
req.ClientID = clientID
}
body, _ := json.Marshal(req)
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
}
// FetchResolverUID fetch resolver uid from provision token.
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
hostname, _ := os.Hostname()
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
}
func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) {
// UpdateCustomLastFailed calls API to mark custom config is bad.
func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) {
uid, clientID := ParseRawUID(rawUID)
req := utilityRequest{UID: uid}
if clientID != "" {
req.ClientID = clientID
}
body, _ := json.Marshal(req)
return postUtilityAPI(version, cdDev, true, bytes.NewReader(body))
}
func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) {
apiUrl := resolverDataURLCom
if cdDev {
apiUrl = resolverDataURLDev
@@ -98,6 +110,9 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
q := req.URL.Query()
q.Set("platform", "ctrld")
q.Set("version", version)
if lastUpdatedFailed {
q.Set("custom_last_failed", "1")
}
req.URL.RawQuery = q.Encode()
req.Header.Add("Content-Type", "application/json")
transport := http.DefaultTransport.(*http.Transport).Clone()

View File

@@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"path/filepath"
"strings"
"sync/atomic"
"github.com/kardianos/service"
@@ -164,6 +165,16 @@ func HomeDir() (string, error) {
return "", err
}
return filepath.Dir(exe), nil
case edgeos.Name:
exe, err := os.Executable()
if err != nil {
return "", err
}
// Using binary directory as home dir if it is located in /config.
// Otherwise, fallback to old behavior for compatibility.
if strings.HasPrefix(exe, "/config/") {
return filepath.Dir(exe), nil
}
}
return "", nil
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
@@ -30,18 +31,19 @@ const (
ResolverTypePrivate = "private"
)
const bootstrapDNS = "76.76.2.22"
const (
controldBootstrapDns = "76.76.2.22"
controldPublicDns = "76.76.2.0"
)
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS.
var or = &osResolver{nameservers: defaultNameservers()}
// defaultNameservers returns nameservers used by the OS.
// If no nameservers can be found, ctrld bootstrap nameserver will be used.
// defaultNameservers returns OS nameservers plus ControlD public DNS.
func defaultNameservers() []string {
ns := nameservers()
if len(ns) == 0 {
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
}
return ns
}
@@ -51,10 +53,27 @@ func defaultNameservers() []string {
// It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function.
func InitializeOsResolver() []string {
or.nameservers = defaultNameservers()
or.nameservers = or.nameservers[:0]
for _, ns := range defaultNameservers() {
if testNameserver(ns) {
or.nameservers = append(or.nameservers, ns)
}
}
or.nameservers = append(or.nameservers, controldPublicDnsWithPort)
return or.nameservers
}
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
func testNameserver(addr string) bool {
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
client := new(dns.Client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, _, err := client.ExchangeContext(ctx, msg, addr)
return err == nil
}
// Resolver is the interface that wraps the basic DNS operations.
//
// Resolve resolves the DNS query, return the result and the corresponding error.
@@ -89,8 +108,9 @@ type osResolver struct {
}
type osResolverResult struct {
answer *dns.Msg
err error
answer *dns.Msg
err error
isControlDPublicDNS bool
}
// Resolve resolves DNS queries using pre-configured nameservers.
@@ -116,19 +136,34 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
go func(server string) {
defer wg.Done()
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
ch <- &osResolverResult{answer: answer, err: err}
ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort}
}(server)
}
var (
nonSuccessAnswer *dns.Msg
controldSuccessAnswer *dns.Msg
)
errs := make([]error, 0, numServers)
for res := range ch {
if res.err == nil {
cancel()
return res.answer, res.err
switch {
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
if res.isControlDPublicDNS {
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
} else {
cancel()
return res.answer, nil
}
case res.answer != nil:
nonSuccessAnswer = res.answer
}
errs = append(errs, res.err)
}
for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} {
if answer != nil {
return answer, nil
}
}
return nil, errors.Join(errs...)
}
@@ -138,7 +173,7 @@ type legacyResolver struct {
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// See comment in (*dotResolver).resolve method.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
@@ -176,7 +211,7 @@ func LookupIP(domain string) []string {
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolver := &osResolver{nameservers: nameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
}
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
timeoutMs := 2000
@@ -252,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
// - Input servers.
func NewBootstrapResolver(servers ...string) Resolver {
resolver := &osResolver{nameservers: nameservers()}
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...)
for _, ns := range servers {
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
}
@@ -279,11 +314,11 @@ func NewPrivateResolver() Resolver {
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
//
// causing the query always succeed.
if sliceContains(resolveConfNss, host) {
if slices.Contains(resolveConfNss, host) {
continue
}
// Ignoring local RFC 1918 addresses.
if sliceContains(localRfc1918Addrs, host) {
if slices.Contains(localRfc1918Addrs, host) {
continue
}
ip := net.ParseIP(host)
@@ -335,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer {
},
}
}
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
// sliceContains reports whether v is present in s.
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
return sliceIndex(s, v) >= 0
}
// sliceIndex returns the index of the first occurrence of v in s,
// or -1 if not present.
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
for i := range s {
if v == s[i] {
return i
}
}
return -1
}

View File

@@ -2,6 +2,8 @@ package ctrld
import (
"context"
"net"
"sync"
"testing"
"time"
@@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) {
}
}
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
ns := make([]string, 0, 2)
servers := make([]*dns.Server, 0, 2)
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, dns.RcodeSuccess)
w.WriteMsg(m)
})
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, rcode)
w.WriteMsg(m)
})
}
handlers := []dns.Handler{
nonSuccessHandlerWithRcode(dns.RcodeRefused),
nonSuccessHandlerWithRcode(dns.RcodeNameError),
successHandler,
}
for i := range handlers {
pc, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ns = append(ns, addr)
servers = append(servers, s)
}
defer func() {
for _, server := range servers {
server.Shutdown()
}
}()
resolver := &osResolver{nameservers: ns}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
if err != nil {
t.Fatal(err)
}
if answer.Rcode != dns.RcodeSuccess {
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
}
}
func Test_upstreamTypeFromEndpoint(t *testing.T) {
tests := []struct {
name string
@@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) {
})
}
}
func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) {
t.Helper()
server := &dns.Server{
PacketConn: pc,
ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
Handler: handler,
}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
for _, opt := range opts {
opt(server)
}
addr, closer := pc.LocalAddr().String(), pc
go func() {
if err := server.ActivateAndServe(); err != nil {
t.Error(err)
}
closer.Close()
}()
waitLock.Lock()
return server, addr, nil
}

View File

@@ -27,6 +27,8 @@ var sampleConfigContent = `
[service]
log_level = "info"
log_path = "/path/to/log.log"
dns_watchdog_enabled = false
dns_watchdog_interval = "20s"
[network.0]
name = "Home Wifi"