diff --git a/README.md b/README.md index ef98c38..66e70c3 100644 --- a/README.md +++ b/README.md @@ -105,9 +105,11 @@ Available Commands: start Quick start service and configure DNS on interface stop Quick stop service and remove DNS from interface restart Restart the ctrld service + reload Reload the ctrld service status Show status of the ctrld service uninstall Stop and uninstall the ctrld service clients Manage clients + upgrade Upgrading ctrld to latest version Flags: -h, --help help for ctrld diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index bdae37a..987a470 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -48,6 +48,11 @@ import ( // selfCheckInternalTestDomain is used for testing ctrld self response to clients. const selfCheckInternalTestDomain = "ctrld" + loopTestDomain +const ( + windowsForwardersFilename = ".forwarders.txt" + oldBinSuffix = "_previous" + oldLogSuffix = ".1" +) var ( version = "dev" @@ -110,7 +115,7 @@ func initCLI() { &verbose, "verbose", "v", - `verbose log output, "-v" basic logging, "-vv" debug level logging`, + `verbose log output, "-v" basic logging, "-vv" debug logging`, ) rootCmd.PersistentFlags().BoolVarP( &silent, @@ -158,7 +163,10 @@ func initCLI() { }, Use: "start", Short: "Install and start the ctrld service", - Args: cobra.NoArgs, + Long: `Install and start the ctrld service + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) @@ -182,8 +190,9 @@ func initCLI() { return } - status, _ := s.Status() + status, err := s.Status() isCtrldRunning := status == service.StatusRunning + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) // If pin code was set, do not allow running start command. if isCtrldRunning { @@ -192,39 +201,56 @@ func initCLI() { } } - if cdUID != "" { - rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) - if err != nil { - mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + if !startOnly { + startOnly = len(osArgs) == 0 + } + // If user run "ctrld start" and ctrld is already installed, starting existing service. + if startOnly && isCtrldInstalled { + tryReadingConfigWithNotice(false, true) + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } - // validateCdRemoteConfig clobbers v, saving it here to restore later. - oldV := v - if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil { - if errors.As(err, &viper.ConfigParseError{}) { - if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { - tmpDir := os.TempDir() - tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") - errorLogged := false - // Write remote config to a temporary file to get details error. - if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { - if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { - row, col := de.Position() - mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) - errorLogged = true - } - _ = os.Remove(tmpConfFile) - } - // If we could not log details error, emit what we have already got. - if !errorLogged { - mainLog.Load().Error().Msgf("failed to parse custom config: %v", err) - } - } - } else { - mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) + + initLogging() + tasks := []task{ + resetDnsTask(p, s), + {s.Stop, false}, + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { + return saveCurrentStaticDNS(i) + }) + return nil + }, false}, + {s.Start, true}, + {noticeWritingControlDConfig, false}, + } + mainLog.Load().Notice().Msg("Starting existing ctrld service") + if doTasks(tasks) { + mainLog.Load().Notice().Msg("Service started") + sockDir, err := socketDir() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") + os.Exit(1) } - mainLog.Load().Warn().Msg("disregarding invalid custom config") + if cc := newSocketControlClient(s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + logger := mainLog.Load().With().Str("iface", iface).Logger() + logger.Debug().Msg("setting DNS successfully") + } + } + } else { + mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") + os.Exit(1) } - v = oldV + return + } + + if cdUID != "" { + doValidateCdRemoteConfig(cdUID) } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid mainLog.Load().Debug().Msg("using uid from provision token") @@ -399,6 +425,8 @@ func initCLI() { startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) + startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") + _ = startCmd.Flags().MarkHidden("start_only") routerCmd := &cobra.Command{ Use: "setup", @@ -485,7 +513,10 @@ func initCLI() { Run: func(cmd *cobra.Command, args []string) { readConfig(false) v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} + cdUID = curCdUID() + cdMode := cdUID != "" + + p := &prog{router: router.New(&cfg, cdMode)} s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) @@ -497,6 +528,10 @@ func initCLI() { } initLogging() + if cdMode { + doValidateCdRemoteConfig(cdUID) + } + iface = runningIface(s) tasks := []task{ {s.Stop, false}, @@ -623,11 +658,72 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, os.Exit(deactivationPinInvalidExitCode) } uninstall(p, s) + if cleanup { + var files []string + // Config file. + files = append(files, v.ConfigFileUsed()) + // Log file. + logFile := normalizeLogFilePath(cfg.Service.LogPath) + files = append(files, logFile) + // Backup log file. + oldLogFile := logFile + oldLogSuffix + if _, err := os.Stat(oldLogFile); err == nil { + files = append(files, oldLogFile) + } + // Socket files. + if dir, _ := socketDir(); dir != "" { + files = append(files, filepath.Join(dir, ctrldControlUnixSock)) + files = append(files, filepath.Join(dir, ctrldLogUnixSock)) + } + // Static DNS settings files. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + files = append(files, file) + } + return nil + }) + // Windows forwarders file. + if windowsHasLocalDnsServerRunning() { + files = append(files, absHomeDir(windowsForwardersFilename)) + } + // Binary itself. + bin, _ := os.Executable() + if bin != "" && supportedSelfDelete { + files = append(files, bin) + } + // Backup file after upgrading. + oldBin := bin + oldBinSuffix + if _, err := os.Stat(oldBin); err == nil { + files = append(files, oldBin) + } + for _, file := range files { + if file == "" { + continue + } + if err := os.Remove(file); err != nil { + if os.IsNotExist(err) { + continue + } + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + mainLog.Load().Debug().Msgf("file removed: %s", file) + } + } + if err := selfDeleteExe(); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + if !supportedSelfDelete { + mainLog.Load().Debug().Msgf("file removed: %s", bin) + } + } + } }, } uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) _ = uninstallCmd.Flags().MarkHidden("pin") + uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) listIfacesCmd := &cobra.Command{ Use: "list", @@ -697,7 +793,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, }, Use: "start", Short: "Quick start service and configure DNS on interface", + Long: `Quick start service and configure DNS on interface + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, Run: func(cmd *cobra.Command, args []string) { + if len(os.Args) == 2 { + startOnly = true + } if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) } @@ -776,7 +878,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, }, } uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - uninstallCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) + uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) rootCmd.AddCommand(uninstallCmdAlias) listClientsCmd := &cobra.Command{ @@ -894,7 +996,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { svcInstalled = false } - oldBin := bin + "_previous" + oldBin := bin + oldBinSuffix baseUrl := upgradeChannel[upgradeChannelDefault] if len(args) > 0 { channel := args[0] @@ -1033,12 +1135,14 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } waitCh := make(chan struct{}) p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - reloadCh: make(chan struct{}), - reloadDoneCh: make(chan struct{}), - cfg: &cfg, - appCallback: appCallback, + waitCh: waitCh, + stopCh: stopCh, + reloadCh: make(chan struct{}), + reloadDoneCh: make(chan struct{}), + dnsWatcherStopCh: make(chan struct{}), + apiReloadCh: make(chan *ctrld.Config), + cfg: &cfg, + appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { @@ -1128,36 +1232,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { return } - uninstallIfInvalidCdUID := func() { - cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() - if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { - s, err := newService(&prog{}, svcConfig) - if err != nil { - cdLogger.Warn().Err(err).Msg("failed to create new service") - return - } - if netIface, _ := netInterface(iface); netIface != nil { - if err := restoreNetworkManager(); err != nil { - cdLogger.Error().Err(err).Msg("could not restore NetworkManager") - return - } - cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS") - } else { - cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") - } - } - - tasks := []task{{s.Uninstall, true}} - if doTasks(tasks) { - cdLogger.Info().Msg("uninstalled service") - } - cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config") - return - } + cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + // Performs self-uninstallation if the ControlD device does not exist. + var uer *controld.UtilityErrorResponse + if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { + _ = uninstallInvalidCdUID(p, cdLogger, false) } - uninstallIfInvalidCdUID() + cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") } } @@ -1168,7 +1249,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if updated { - if err := writeConfigFile(); err != nil { + if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to write config file") } else { mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) @@ -1257,12 +1338,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { close(waitCh) <-stopCh + + // Wait goroutines which watches/manipulates DNS settings terminated, + // ensuring that changes to DNS since here won't be reverted. + p.dnsWg.Wait() for _, f := range p.onStopped { f() } } -func writeConfigFile() error { +func writeConfigFile(cfg *ctrld.Config) error { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } else if configPath != "" { @@ -1315,7 +1400,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { } nop := zerolog.Nop() _, _ = tryUpdateListenerConfig(&cfg, &nop, true) - if err := writeConfigFile(); err != nil { + if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { fp, err := filepath.Abs(defaultConfigFile) @@ -1639,9 +1724,10 @@ func selfCheckStatus(s service.Service, homedir, sockDir string) (bool, service. } v = viper.NewWithOptions(viper.KeyDelimiter("::")) - ctrld.SetConfigNameWithPath(v, "ctrld", homedir) if configPath != "" { v.SetConfigFile(configPath) + } else { + v.SetConfigFile(defaultConfigFile) } if err := v.ReadInConfig(); err != nil { mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed()) @@ -2375,7 +2461,7 @@ func doGenerateNextDNSConfig(uid string) error { mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) generateNextDNSConfig(uid) updateListenerConfig(&cfg) - return writeConfigFile() + return writeConfigFile(&cfg) } func noticeWritingControlDConfig() error { @@ -2423,7 +2509,7 @@ func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { return nil // the server is running older version of ctrld } } - mainLog.Load().Error().Msg(errInvalidDeactivationPin.Error()) + mainLog.Load().Error().Err(err).Msg(errInvalidDeactivationPin.Error()) return errInvalidDeactivationPin } @@ -2482,6 +2568,11 @@ func absHomeDir(filename string) string { // runInCdMode reports whether ctrld service is running in cd mode. func runInCdMode() bool { + return curCdUID() != "" +} + +// curCdUID returns the current ControlD UID used by running ctrld process. +func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(s, dir) @@ -2489,12 +2580,13 @@ func runInCdMode() bool { resp, _ := cc.post(cdPath, nil) if resp != nil { defer resp.Body.Close() - return resp.StatusCode == http.StatusOK + buf, _ := io.ReadAll(resp.Body) + return string(buf) } } } } - return false + return "" } // goArm returns the GOARM value for the binary. @@ -2557,6 +2649,9 @@ func resetDnsTask(p *prog, s service.Service) task { isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) isCtrldRunning := status == service.StatusRunning return task{func() error { + if iface == "" { + return nil + } // Always reset DNS first, ensuring DNS setting is in a good state. // resetDNS must use the "iface" value of current running ctrld // process to reset what setDNS has done properly. @@ -2572,3 +2667,60 @@ func resetDnsTask(p *prog, s service.Service) task { return nil }, false} } + +// doValidateCdRemoteConfig fetches and validates custom config for cdUID. +func doValidateCdRemoteConfig(cdUID string) { + rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + if err != nil { + mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + } + // validateCdRemoteConfig clobbers v, saving it here to restore later. + oldV := v + if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil { + if errors.As(err, &viper.ConfigParseError{}) { + if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { + tmpDir := os.TempDir() + tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") + errorLogged := false + // Write remote config to a temporary file to get details error. + if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { + if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { + row, col := de.Position() + mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) + errorLogged = true + } + _ = os.Remove(tmpConfFile) + } + // If we could not log details error, emit what we have already got. + if !errorLogged { + mainLog.Load().Error().Msgf("failed to parse custom config: %v", err) + } + } + } else { + mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) + } + mainLog.Load().Warn().Msg("disregarding invalid custom config") + } + v = oldV +} + +// uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. +func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { + s, err := newService(p, svcConfig) + if err != nil { + logger.Warn().Err(err).Msg("failed to create new service") + return false + } + + p.resetDNS() + + tasks := []task{{s.Uninstall, true}} + if doTasks(tasks) { + logger.Info().Msg("uninstalled service") + if doStop { + _ = s.Stop() + } + return true + } + return false +} diff --git a/cmd/cli/cli_test.go b/cmd/cli/cli_test.go index fcede32..eae2673 100644 --- a/cmd/cli/cli_test.go +++ b/cmd/cli/cli_test.go @@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) { _, err := os.Stat(configPath) assert.True(t, os.IsNotExist(err)) - assert.NoError(t, writeConfigFile()) + assert.NoError(t, writeConfigFile(&cfg)) _, err = os.Stat(configPath) require.NoError(t, err) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 66a38a3..f69c301 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -73,7 +73,7 @@ func (p *prog) registerControlServerHandler() { sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) - if p.cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { for _, client := range clients { client.IncludeQueryCount = true dm := &dto.Metric{} @@ -178,6 +178,7 @@ func (p *prog) registerControlServerHandler() { p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { if cdUID != "" { w.WriteHeader(http.StatusOK) + w.Write([]byte(cdUID)) return } w.WriteHeader(http.StatusBadRequest) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9f95812..a7c62af 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -21,6 +21,7 @@ import ( "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) @@ -32,6 +33,9 @@ const ( // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. EDNS0_OPTION_MAC = 0xFDE9 + + // selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation. + selfUninstallMaxQueries = 32 ) var osUpstreamConfig = &ctrld.UpstreamConfig{ @@ -89,6 +93,7 @@ func (p *prog) serveDNS(listenerNum string) error { _ = w.WriteMsg(answer) return } + listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { @@ -143,6 +148,7 @@ func (p *prog) serveDNS(listenerNum string) error { failoverRcodes: failoverRcode, ufr: ur, }) + go p.doSelfUninstall(pr.answer) answer = pr.answer rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) @@ -836,6 +842,51 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { } } +// doSelfUninstall performs self-uninstall if these condition met: +// +// - There is only 1 ControlD upstream in-use. +// - Number of refused queries seen so far equals to selfUninstallMaxQueries. +// - The cdUID is deleted. +func (p *prog) doSelfUninstall(answer *dns.Msg) { + if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { + return + } + + p.selfUninstallMu.Lock() + defer p.selfUninstallMu.Unlock() + if p.checkingSelfUninstall { + return + } + + logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() + if p.refusedQueryCount > selfUninstallMaxQueries { + p.checkingSelfUninstall = true + _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + logger.Debug().Msg("maximum number of refused queries reached, checking device status") + selfUninstallCheck(err, p, logger) + + if err != nil { + logger.Warn().Err(err).Msg("could not fetch resolver config") + } + // Cool-of period to prevent abusing the API. + go p.selfUninstallCoolOfPeriod() + return + } + p.refusedQueryCount++ +} + +// selfUninstallCoolOfPeriod waits for 30 minutes before +// calling API again for checking ControlD device status. +func (p *prog) selfUninstallCoolOfPeriod() { + t := time.NewTimer(time.Minute * 30) + defer t.Stop() + <-t.C + p.selfUninstallMu.Lock() + p.checkingSelfUninstall = false + p.refusedQueryCount = 0 + p.selfUninstallMu.Unlock() +} + // queryFromSelf reports whether the input IP is from device running ctrld. func queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 279f5f2..146c58d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -36,6 +36,8 @@ var ( cdUpstreamProto string deactivationPin int64 skipSelfChecks bool + cleanup bool + startOnly bool mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter @@ -63,8 +65,11 @@ func Main() { } func normalizeLogFilePath(logFilePath string) string { - if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { - return logFilePath + // In cleanup mode, we always want the full log file path. + if !cleanup { + if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { + return logFilePath + } } if homedir != "" { return filepath.Join(homedir, logFilePath) @@ -121,14 +126,14 @@ func initLoggingWithBackup(doBackup bool) { flags := os.O_CREATE | os.O_RDWR | os.O_APPEND if doBackup { // Backup old log file with .1 suffix. - if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) { + if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) { mainLog.Load().Error().Msgf("could not backup old log file: %v", err) } else { // Backup was created, set flags for truncating old log file. flags = os.O_CREATE | os.O_RDWR } } - logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600)) + logFile, err := openLogFile(logFilePath, flags) if err != nil { mainLog.Load().Error().Msgf("failed to create log file: %v", err) os.Exit(1) diff --git a/cmd/cli/metrics.go b/cmd/cli/metrics.go index ee64975..565cdcc 100644 --- a/cmd/cli/metrics.go +++ b/cmd/cli/metrics.go @@ -107,7 +107,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { reg := prometheus.NewRegistry() // Register queries count stats if enabled. - if cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { reg.MustRegister(statsQueriesCount) reg.MustRegister(statsClientQueriesCount) } diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index 37f8d7b..b58a0bf 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -43,20 +43,32 @@ func networkServiceName(ifaceName string, r io.Reader) string { return "" } -// validInterface reports whether the *net.Interface is a valid one, which includes: -// -// - en0: physical wireless -// - en1: Thunderbolt 1 -// - en2: Thunderbolt 2 -// - en3: Thunderbolt 3 -// - en4: Thunderbolt 4 -// -// For full list, see: https://unix.stackexchange.com/questions/603506/what-are-these-ifconfig-interfaces-on-macos -func validInterface(iface *net.Interface) bool { - switch iface.Name { - case "en0", "en1", "en2", "en3", "en4": - return true - default: - return false - } +// validInterface reports whether the *net.Interface is a valid one. +func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { + _, ok := validIfacesMap[iface.Name] + return ok +} + +func validInterfacesMap() map[string]struct{} { + b, err := exec.Command("networksetup", "-listallhardwareports").Output() + if err != nil { + return nil + } + return parseListAllHardwarePorts(bytes.NewReader(b)) +} + +// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports" +// and returns map presents all hardware ports. +func parseListAllHardwarePorts(r io.Reader) map[string]struct{} { + m := make(map[string]struct{}) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + after, ok := strings.CutPrefix(line, "Device: ") + if !ok { + continue + } + m[after] = struct{}{} + } + return m } diff --git a/cmd/cli/net_darwin_test.go b/cmd/cli/net_darwin_test.go index 443a9d1..9ef1906 100644 --- a/cmd/cli/net_darwin_test.go +++ b/cmd/cli/net_darwin_test.go @@ -1,6 +1,7 @@ package cli import ( + "maps" "strings" "testing" @@ -57,3 +58,47 @@ func Test_networkServiceName(t *testing.T) { }) } } + +const listallhardwareportsOutput = ` +Hardware Port: Ethernet Adapter (en6) +Device: en6 +Ethernet Address: 3a:3e:fc:1e:ab:41 + +Hardware Port: Ethernet Adapter (en7) +Device: en7 +Ethernet Address: 3a:3e:fc:1e:ab:42 + +Hardware Port: Thunderbolt Bridge +Device: bridge0 +Ethernet Address: 36:21:bb:3a:7a:40 + +Hardware Port: Wi-Fi +Device: en0 +Ethernet Address: a0:78:17:68:56:3f + +Hardware Port: Thunderbolt 1 +Device: en1 +Ethernet Address: 36:21:bb:3a:7a:40 + +Hardware Port: Thunderbolt 2 +Device: en2 +Ethernet Address: 36:21:bb:3a:7a:44 + +VLAN Configurations +=================== +` + +func Test_parseListAllHardwarePorts(t *testing.T) { + expected := map[string]struct{}{ + "en0": {}, + "en1": {}, + "en2": {}, + "en6": {}, + "en7": {}, + "bridge0": {}, + } + m := parseListAllHardwarePorts(strings.NewReader(listallhardwareportsOutput)) + if !maps.Equal(m, expected) { + t.Errorf("unexpected output, want: %v, got: %v", expected, m) + } +} diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index ebe7ba0..5a66e82 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -6,4 +6,6 @@ import "net" func patchNetIfaceName(iface *net.Interface) error { return nil } -func validInterface(iface *net.Interface) bool { return true } +func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } + +func validInterfacesMap() map[string]struct{} { return nil } diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index c75ee32..8ec5a5f 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -10,7 +10,7 @@ func patchNetIfaceName(iface *net.Interface) error { // validInterface reports whether the *net.Interface is a valid one. // On Windows, only physical interfaces are considered valid. -func validInterface(iface *net.Interface) bool { +func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { if iface == nil { return false } @@ -19,3 +19,5 @@ func validInterface(iface *net.Interface) bool { } return false } + +func validInterfacesMap() map[string]struct{} { return nil } diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index a36311d..eff5edf 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -24,6 +24,8 @@ import ( "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) +const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" + // allocate loopback ip // sudo ip a add 127.0.0.2/24 dev lo func allocateIP(ip string) error { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 6441e05..e5ac1d2 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -16,7 +16,6 @@ import ( ) const ( - forwardersFilename = ".forwarders.txt" v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` ) @@ -40,7 +39,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. if windowsHasLocalDnsServerRunning() { - file := absHomeDir(forwardersFilename) + file := absHomeDir(windowsForwardersFilename) oldForwardersContent, _ := os.ReadFile(file) if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil { mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") @@ -72,7 +71,7 @@ func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { // See corresponding comment in setDNS. if windowsHasLocalDnsServerRunning() { - file := absHomeDir(forwardersFilename) + file := absHomeDir(windowsForwardersFilename) content, err := os.ReadFile(file) if err != nil { mainLog.Load().Error().Err(err).Msg("could not read forwarders settings") diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 8e35575..82daa24 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -12,19 +12,24 @@ import ( "net/url" "os" "runtime" + "slices" "sort" "strconv" "strings" "sync" + "sync/atomic" "syscall" + "time" "github.com/kardianos/service" + "github.com/rs/zerolog" "github.com/spf13/viper" "tailscale.com/net/interfaces" "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" + "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" "github.com/Control-D-Inc/ctrld/internal/router" ) @@ -38,6 +43,7 @@ const ( upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" upstreamPrivate = upstreamPrefix + "private" + dnsWatchdogDefaultInterval = 20 * time.Second ) // ControlSocketName returns name for control unix socket. @@ -62,15 +68,19 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - reloadCh chan struct{} // For Windows. - reloadDoneCh chan struct{} - logConn net.Conn - cs *controlServer - csSetDnsDone chan struct{} - csSetDnsOk bool + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + apiReloadCh chan *ctrld.Config + logConn net.Conn + cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool + dnsWatchDogOnce sync.Once + dnsWg sync.WaitGroup + dnsWatcherStopCh chan struct{} cfg *ctrld.Config localUpstreams []string @@ -84,6 +94,12 @@ type prog struct { router router.Router ptrLoopGuard *loopGuard lanLoopGuard *loopGuard + metricsQueryStats atomic.Bool + + selfUninstallMu sync.Mutex + refusedQueryCount int + canSelfUninstall atomic.Bool + checkingSelfUninstall bool loopMu sync.Mutex loop map[string]bool @@ -117,11 +133,15 @@ func (p *prog) runWait() { p.run(reload, reloadCh) reload = true }() + + var newCfg *ctrld.Config select { case sig := <-reloadSigCh: logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) case <-p.reloadCh: logger.Notice().Msg("reloading...") + case apiCfg := <-p.apiReloadCh: + newCfg = apiCfg case <-p.stopCh: close(reloadCh) return @@ -131,28 +151,31 @@ func (p *prog) runWait() { close(reloadCh) <-done } - newCfg := &ctrld.Config{} - v := viper.NewWithOptions(viper.KeyDelimiter("::")) - ctrld.InitConfig(v, "ctrld") - if configPath != "" { - v.SetConfigFile(configPath) - } - if err := v.ReadInConfig(); err != nil { - logger.Err(err).Msg("could not read new config") - waitOldRunDone() - continue - } - if err := v.Unmarshal(&newCfg); err != nil { - logger.Err(err).Msg("could not unmarshal new config") - waitOldRunDone() - continue - } - if cdUID != "" { - if err := processCDFlags(newCfg); err != nil { - logger.Err(err).Msg("could not fetch ControlD config") + + if newCfg == nil { + newCfg = &ctrld.Config{} + v := viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.InitConfig(v, "ctrld") + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + logger.Err(err).Msg("could not read new config") waitOldRunDone() continue } + if err := v.Unmarshal(&newCfg); err != nil { + logger.Err(err).Msg("could not unmarshal new config") + waitOldRunDone() + continue + } + if cdUID != "" { + if err := processCDFlags(newCfg); err != nil { + logger.Err(err).Msg("could not fetch ControlD config") + waitOldRunDone() + continue + } + } } waitOldRunDone() @@ -178,6 +201,10 @@ func (p *prog) runWait() { continue } + if err := writeConfigFile(newCfg); err != nil { + logger.Err(err).Msg("could not write new config") + } + // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. mainLog.Load().Debug().Msg("setup upstream with new config") @@ -188,6 +215,7 @@ func (p *prog) runWait() { p.mu.Unlock() logger.Notice().Msg("reloading config successfully") + select { case p.reloadDoneCh <- struct{}{}: default: @@ -214,12 +242,67 @@ func (p *prog) postRun() { } } +// apiConfigReload calls API to check for latest config update then reload ctrld if necessary. +func (p *prog) apiConfigReload() { + if cdUID == "" { + return + } + + secs := 3600 + if p.cfg.Service.RefetchTime != nil && *p.cfg.Service.RefetchTime > 0 { + secs = *p.cfg.Service.RefetchTime + } + + ticker := time.NewTicker(time.Duration(secs) * time.Second) + defer ticker.Stop() + + logger := mainLog.Load().With().Str("mode", "api-reload").Logger() + logger.Debug().Msg("starting custom config reload timer") + lastUpdated := time.Now().Unix() + for { + select { + case <-ticker.C: + resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + selfUninstallCheck(err, p, logger) + if err != nil { + logger.Warn().Err(err).Msg("could not fetch resolver config") + continue + } + + if resolverConfig.Ctrld.CustomConfig == "" { + continue + } + + if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated { + lastUpdated = time.Now().Unix() + cfg := &ctrld.Config{} + if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil { + logger.Warn().Err(err).Msg("skipping invalid custom config") + if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { + logger.Error().Err(err).Msg("could not mark custom last update failed") + } + break + } + setListenerDefaultValue(cfg) + logger.Debug().Msg("custom config changes detected, reloading...") + p.apiReloadCh <- cfg + } else { + logger.Debug().Msg("custom config does not change") + } + case <-p.stopCh: + return + } + } +} + func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) + isControlDUpstream := false for n := range cfg.Upstream { uc := cfg.Upstream[n] uc.Init() + isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { uc.SetupBootstrapIP() mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) @@ -236,6 +319,10 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { ptrNameservers = append(ptrNameservers, uc.Endpoint) } } + // Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config. + if len(cfg.Upstream) == 1 && isControlDUpstream { + p.canSelfUninstall.Store(true) + } p.localUpstreams = localUpstreams p.ptrNameservers = ptrNameservers } @@ -271,6 +358,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil + p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats) if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -397,6 +485,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.logConn != nil { _ = p.logConn.Close() } + go p.apiConfigReload() p.postRun() } wg.Wait() @@ -510,13 +599,86 @@ func (p *prog) setDNS() { for i := range nameservers { servers[i] = netip.MustParseAddr(nameservers[i]) } - go watchResolvConf(netIface, servers, setResolvConf) + p.dnsWg.Add(1) + go func() { + defer p.dnsWg.Done() + p.watchResolvConf(netIface, servers, setResolvConf) + }() } if allIfaces { withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error { return setDnsIgnoreUnusableInterface(i, nameservers) }) } + if p.dnsWatchdogEnabled() { + p.dnsWg.Add(1) + go func() { + defer p.dnsWg.Done() + p.dnsWatchdog(netIface, nameservers, allIfaces) + }() + } +} + +// dnsWatchdogEnabled reports whether DNS watchdog is enabled. +func (p *prog) dnsWatchdogEnabled() bool { + if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil { + return *ptr + } + return true +} + +// dnsWatchdogDuration returns the time duration between each DNS watchdog loop. +func (p *prog) dnsWatchdogDuration() time.Duration { + if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil { + if (*ptr).Seconds() > 0 { + return *ptr + } + } + return dnsWatchdogDefaultInterval +} + +// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings. +// This is only works when deactivation pin set. +func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) { + if !requiredMultiNICsConfig() { + return + } + + p.dnsWatchDogOnce.Do(func() { + mainLog.Load().Debug().Msg("start DNS settings watchdog") + ns := nameservers + slices.Sort(ns) + ticker := time.NewTicker(p.dnsWatchdogDuration()) + logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + for { + select { + case <-p.dnsWatcherStopCh: + return + case <-p.stopCh: + mainLog.Load().Debug().Msg("stop dns watchdog") + return + case <-ticker.C: + if dnsChanged(iface, ns) { + logger.Debug().Msg("DNS settings were changed, re-applying settings") + if err := setDNS(iface, ns); err != nil { + mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") + } + } + if allIfaces { + withEachPhysicalInterfaces(iface.Name, "re-applying DNS", func(i *net.Interface) error { + if dnsChanged(i, ns) { + if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { + mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") + } else { + mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) + } + } + return nil + }) + } + } + } + }) } func (p *prog) resetDNS() { @@ -727,13 +889,14 @@ func canBeLocalUpstream(addr string) bool { // the interface that matches excludeIfaceName. The context is used to clarify the // log message when error happens. func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) { + validIfacesMap := validInterfacesMap() interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { // Skip loopback/virtual interface. if i.IsLoopback() || len(i.HardwareAddr) == 0 { return } // Skip invalid interface. - if !validInterface(i.Interface) { + if !validInterface(i.Interface, validIfacesMap) { return } netIface := i.Interface @@ -747,7 +910,9 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. } // TODO: investigate whether we should report this error? if err := f(netIface); err == nil { - mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name) + if context != "" { + mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name) + } } else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) { mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name) } @@ -806,3 +971,24 @@ func savedStaticNameservers(iface *net.Interface) []string { } return nil } + +// dnsChanged reports whether DNS settings for given interface was changed. +// The caller must sort the nameservers before calling this function. +func dnsChanged(iface *net.Interface, nameservers []string) bool { + curNameservers, _ := currentStaticDNS(iface) + slices.Sort(curNameservers) + return !slices.Equal(curNameservers, nameservers) +} + +// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then. +func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { + var uer *controld.UtilityErrorResponse + if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { + // Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall. + close(p.dnsWatcherStopCh) + p.dnsWg.Wait() + + // Perform self-uninstall now. + selfUninstall(p, logger) + } +} diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go new file mode 100644 index 0000000..5f2f8e1 --- /dev/null +++ b/cmd/cli/prog_test.go @@ -0,0 +1,57 @@ +package cli + +import ( + "testing" + "time" + + "github.com/Control-D-Inc/ctrld" + "github.com/stretchr/testify/assert" +) + +func Test_prog_dnsWatchdogEnabled(t *testing.T) { + p := &prog{cfg: &ctrld.Config{}} + + // Default value is true. + assert.True(t, p.dnsWatchdogEnabled()) + + tests := []struct { + name string + enabled bool + }{ + {"enabled", true}, + {"disabled", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p.cfg.Service.DnsWatchdogEnabled = &tc.enabled + assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled()) + }) + } +} + +func Test_prog_dnsWatchdogInterval(t *testing.T) { + p := &prog{cfg: &ctrld.Config{}} + + // Default value is 20s. + assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration()) + + tests := []struct { + name string + duration time.Duration + expected time.Duration + }{ + {"valid", time.Minute, time.Minute}, + {"zero", 0, dnsWatchdogDefaultInterval}, + {"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p.cfg.Service.DnsWatchdogInvterval = &tc.duration + assert.Equal(t, tc.expected, p.dnsWatchdogDuration()) + }) + } +} diff --git a/cmd/cli/prometheus.go b/cmd/cli/prometheus.go index fc2fc5d..9082a58 100644 --- a/cmd/cli/prometheus.go +++ b/cmd/cli/prometheus.go @@ -51,7 +51,7 @@ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ // WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled. func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) { - if p.cfg.Service.MetricsQueryStats { + if p.metricsQueryStats.Load() { c.WithLabelValues(lvs...).Inc() } } diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index f09d864..5be34fc 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -8,15 +8,15 @@ import ( "github.com/fsnotify/fsnotify" ) -const ( - resolvConfPath = "/etc/resolv.conf" - resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" -) - // watchResolvConf watches any changes to /etc/resolv.conf file, // and reverting to the original config set by ctrld. -func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { - mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file") +func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { + resolvConfPath := "/etc/resolv.conf" + // Evaluating symbolics link to watch the target file that /etc/resolv.conf point to. + if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" { + resolvConfPath = rp + } + mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath) watcher, err := fsnotify.NewWatcher() if err != nil { mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") @@ -28,12 +28,17 @@ func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well watchDir := filepath.Dir(resolvConfPath) if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list") + mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) return } for { select { + case <-p.dnsWatcherStopCh: + return + case <-p.stopCh: + mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) + return case event, ok := <-watcher.Events: if !ok { return diff --git a/cmd/cli/resolvconf_darwin.go b/cmd/cli/resolvconf_darwin.go index 7e26f41..eb70eed 100644 --- a/cmd/cli/resolvconf_darwin.go +++ b/cmd/cli/resolvconf_darwin.go @@ -3,15 +3,44 @@ package cli import ( "net" "net/netip" + "os" + "slices" + + "github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile" ) +const resolvConfPath = "/etc/resolv.conf" + // setResolvConf sets the content of resolv.conf file using the given nameservers list. func setResolvConf(iface *net.Interface, ns []netip.Addr) error { servers := make([]string, len(ns)) for i := range ns { servers[i] = ns[i].String() } - return setDNS(iface, servers) + if err := setDNS(iface, servers); err != nil { + return err + } + slices.Sort(servers) + curNs := currentDNS(iface) + slices.Sort(curNs) + if !slices.Equal(curNs, servers) { + c, err := resolvconffile.ParseFile(resolvConfPath) + if err != nil { + return err + } + c.Nameservers = ns + f, err := os.Create(resolvConfPath) + if err != nil { + return err + } + defer f.Close() + + if err := c.Write(f); err != nil { + return err + } + return f.Close() + } + return nil } // shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator. diff --git a/cmd/cli/self_delete_others.go b/cmd/cli/self_delete_others.go new file mode 100644 index 0000000..02ae977 --- /dev/null +++ b/cmd/cli/self_delete_others.go @@ -0,0 +1,7 @@ +//go:build !windows + +package cli + +var supportedSelfDelete = true + +func selfDeleteExe() error { return nil } diff --git a/cmd/cli/self_delete_windows.go b/cmd/cli/self_delete_windows.go new file mode 100644 index 0000000..c2f2719 --- /dev/null +++ b/cmd/cli/self_delete_windows.go @@ -0,0 +1,134 @@ +// Copied from https://github.com/secur30nly/go-self-delete +// with modification to suitable for ctrld usage. + +/* + License: MIT Licence + + References: + - https://github.com/LloydLabs/delete-self-poc + - https://twitter.com/jonasLyk/status/1350401461985955840 +*/ + +package cli + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +var supportedSelfDelete = false + +type FILE_RENAME_INFO struct { + Union struct { + ReplaceIfExists bool + Flags uint32 + } + RootDirectory windows.Handle + FileNameLength uint32 + FileName [1]uint16 +} + +type FILE_DISPOSITION_INFO struct { + DeleteFile bool +} + +func dsOpenHandle(pwPath *uint16) (windows.Handle, error) { + handle, err := windows.CreateFile( + pwPath, + windows.DELETE, + 0, + nil, + windows.OPEN_EXISTING, + windows.FILE_ATTRIBUTE_NORMAL, + 0, + ) + + if err != nil { + return 0, err + } + + return handle, nil +} + +func dsRenameHandle(hHandle windows.Handle) error { + var fRename FILE_RENAME_INFO + DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef") + + if err != nil { + return err + } + + lpwStream := &DS_STREAM_RENAME[0] + fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream)) + + windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call( + uintptr(unsafe.Pointer(&fRename.FileName[0])), + uintptr(unsafe.Pointer(lpwStream)), + unsafe.Sizeof(lpwStream), + ) + + err = windows.SetFileInformationByHandle( + hHandle, + windows.FileRenameInfo, + (*byte)(unsafe.Pointer(&fRename)), + uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)), + ) + + if err != nil { + return err + } + + return nil +} + +func dsDepositeHandle(hHandle windows.Handle) error { + var fDelete FILE_DISPOSITION_INFO + fDelete.DeleteFile = true + + err := windows.SetFileInformationByHandle( + hHandle, + windows.FileDispositionInfo, + (*byte)(unsafe.Pointer(&fDelete)), + uint32(unsafe.Sizeof(fDelete)), + ) + + if err != nil { + return err + } + + return nil +} + +func selfDeleteExe() error { + var wcPath [windows.MAX_PATH + 1]uint16 + var hCurrent windows.Handle + + _, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH) + if err != nil { + return err + } + + hCurrent, err = dsOpenHandle(&wcPath[0]) + if err != nil || hCurrent == windows.InvalidHandle { + return err + } + + if err := dsRenameHandle(hCurrent); err != nil { + _ = windows.CloseHandle(hCurrent) + return err + } + _ = windows.CloseHandle(hCurrent) + + hCurrent, err = dsOpenHandle(&wcPath[0]) + if err != nil || hCurrent == windows.InvalidHandle { + return err + } + + if err := dsDepositeHandle(hCurrent); err != nil { + _ = windows.CloseHandle(hCurrent) + return err + } + + return windows.CloseHandle(hCurrent) +} diff --git a/cmd/cli/self_kill_others.go b/cmd/cli/self_kill_others.go new file mode 100644 index 0000000..e9fb1f8 --- /dev/null +++ b/cmd/cli/self_kill_others.go @@ -0,0 +1,16 @@ +//go:build !unix + +package cli + +import ( + "os" + + "github.com/rs/zerolog" +) + +func selfUninstall(p *prog, logger zerolog.Logger) { + if uninstallInvalidCdUID(p, logger, false) { + logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + os.Exit(0) + } +} diff --git a/cmd/cli/self_kill_unix.go b/cmd/cli/self_kill_unix.go new file mode 100644 index 0000000..9e494b4 --- /dev/null +++ b/cmd/cli/self_kill_unix.go @@ -0,0 +1,45 @@ +//go:build unix + +package cli + +import ( + "fmt" + "os" + "os/exec" + "runtime" + "syscall" + + "github.com/rs/zerolog" +) + +func selfUninstall(p *prog, logger zerolog.Logger) { + if runtime.GOOS == "linux" { + selfUninstallLinux(p, logger) + } + + bin, err := os.Executable() + if err != nil { + logger.Fatal().Err(err).Msg("could not determine executable") + } + args := []string{"uninstall"} + if !deactivationPinNotSet() { + args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin)) + } + cmd := exec.Command(bin, args...) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + if err := cmd.Start(); err != nil { + logger.Fatal().Err(err).Msg("could not start self uninstall command") + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + _ = cmd.Wait() + os.Exit(0) +} + +func selfUninstallLinux(p *prog, logger zerolog.Logger) { + if uninstallInvalidCdUID(p, logger, true) { + logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + os.Exit(0) + } +} diff --git a/cmd/cli/service.go b/cmd/cli/service.go index 1de206f..e4edfaf 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -28,6 +28,9 @@ func newService(i service.Interface, c *service.Config) (service.Service, error) return &sysV{s}, nil case s.Platform() == "linux-systemd": return &systemd{s}, nil + case s.Platform() == "darwin-launchd": + return newLaunchd(s), nil + } return s, nil } @@ -113,7 +116,7 @@ func (s *procd) Status() (service.Status, error) { return service.StatusRunning, nil } -// procd wraps a service.Service, and provide status command to +// systemd wraps a service.Service, and provide status command to // report the status correctly. type systemd struct { service.Service @@ -127,6 +130,29 @@ func (s *systemd) Status() (service.Status, error) { return s.Service.Status() } +func newLaunchd(s service.Service) *launchd { + return &launchd{ + Service: s, + statusErrMsg: "Permission denied", + } +} + +// launchd wraps a service.Service, and provide status command to +// report the status correctly when not running as root on Darwin. +// +// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed. +type launchd struct { + service.Service + statusErrMsg string +} + +func (l *launchd) Status() (service.Status, error) { + if os.Geteuid() != 0 { + return service.StatusUnknown, errors.New(l.statusErrMsg) + } + return l.Service.Status() +} + type task struct { f func() error abortOnError bool diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index e9522f4..f4d73e5 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -9,3 +9,7 @@ import ( func hasElevatedPrivilege() (bool, error) { return os.Geteuid() == 0, nil } + +func openLogFile(path string, flags int) (*os.File, error) { + return os.OpenFile(path, flags, os.FileMode(0o600)) +} diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index a1010a8..d4e2449 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -1,6 +1,11 @@ package cli -import "golang.org/x/sys/windows" +import ( + "os" + "syscall" + + "golang.org/x/sys/windows" +) func hasElevatedPrivilege() (bool, error) { var sid *windows.SID @@ -22,3 +27,55 @@ func hasElevatedPrivilege() (bool, error) { token := windows.Token(0) return token.IsMember(sid) } + +func openLogFile(path string, mode int) (*os.File, error) { + if len(path) == 0 { + return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} + } + + pathP, err := syscall.UTF16PtrFromString(path) + if err != nil { + return nil, err + } + var access uint32 + switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) { + case os.O_RDONLY: + access = windows.GENERIC_READ + case os.O_WRONLY: + access = windows.GENERIC_WRITE + case os.O_RDWR: + access = windows.GENERIC_READ | windows.GENERIC_WRITE + } + if mode&os.O_CREATE != 0 { + access |= windows.GENERIC_WRITE + } + if mode&os.O_APPEND != 0 { + access &^= windows.GENERIC_WRITE + access |= windows.FILE_APPEND_DATA + } + + shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE) + + var sa *syscall.SecurityAttributes + + var createMode uint32 + switch { + case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL): + createMode = windows.CREATE_NEW + case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC): + createMode = windows.CREATE_ALWAYS + case mode&os.O_CREATE == os.O_CREATE: + createMode = windows.OPEN_ALWAYS + case mode&os.O_TRUNC == os.O_TRUNC: + createMode = windows.TRUNCATE_EXISTING + default: + createMode = windows.OPEN_EXISTING + } + + handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return nil, &os.PathError{Path: path, Op: "open", Err: err} + } + + return os.NewFile(uintptr(handle), path), nil +} diff --git a/config.go b/config.go index 8c99a8e..e09fdad 100644 --- a/config.go +++ b/config.go @@ -25,6 +25,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/miekg/dns" "github.com/spf13/viper" + "golang.org/x/net/http2" "golang.org/x/sync/singleflight" "tailscale.com/logtail/backoff" "tailscale.com/net/tsaddr" @@ -188,27 +189,30 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` - ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` - MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` - MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` + ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` + MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` + MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` + DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"` + DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"` + RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. @@ -316,7 +320,7 @@ func (uc *UpstreamConfig) Init() { } } if uc.IPStack == "" { - if uc.isControlD() { + if uc.IsControlD() { uc.IPStack = IpStackSplit } else { uc.IPStack = IpStackBoth @@ -354,7 +358,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { } switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: - if uc.isControlD() || uc.isNextDNS() { + if uc.IsControlD() || uc.isNextDNS() { return true } } @@ -401,7 +405,7 @@ func (uc *UpstreamConfig) UID() string { // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) - isControlD := uc.isControlD() + isControlD := uc.IsControlD() for { uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, @@ -486,6 +490,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { ClientSessionCache: tls.NewLRUClientSessionCache(0), } + // Prevent bad tcp connection hanging the requests for too long. + // See: https://github.com/golang/go/issues/36026 + if t2, err := http2.ConfigureTransports(transport); err == nil { + t2.ReadIdleTimeout = 10 * time.Second + t2.PingTimeout = 5 * time.Second + } + dialerTimeoutMs := 2000 if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs { dialerTimeoutMs = uc.Timeout @@ -572,7 +583,8 @@ func (uc *UpstreamConfig) ping() error { return nil } -func (uc *UpstreamConfig) isControlD() bool { +// IsControlD reports whether this is a ControlD upstream. +func (uc *UpstreamConfig) IsControlD() bool { domain := uc.Domain if domain == "" { if u, err := url.Parse(uc.Endpoint); err == nil { diff --git a/config_test.go b/config_test.go index 83a1e13..03a1a3f 100644 --- a/config_test.go +++ b/config_test.go @@ -5,6 +5,7 @@ import ( "os" "strings" "testing" + "time" "github.com/go-playground/validator/v10" "github.com/spf13/viper" @@ -22,6 +23,8 @@ func TestLoadConfig(t *testing.T) { assert.Equal(t, "info", cfg.Service.LogLevel) assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath) + assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled) + assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval) assert.Len(t, cfg.Network, 2) assert.Contains(t, cfg.Network, "0") diff --git a/docker/Dockerfile.debug b/docker/Dockerfile.debug index e7ce172..2ba3602 100644 --- a/docker/Dockerfile.debug +++ b/docker/Dockerfile.debug @@ -8,7 +8,7 @@ # - Non-cgo ctrld binary. # # CI_COMMIT_TAG is used to set the version of ctrld binary. -FROM golang:1.20-bullseye as base +FROM golang:bullseye as base WORKDIR /app diff --git a/docs/config.md b/docs/config.md index 5615d30..8c216ec 100644 --- a/docs/config.md +++ b/docs/config.md @@ -252,6 +252,35 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus - Required: no - Default: "" +### dns_watchdog_enabled +Checking DNS changes to network interfaces and reverting to ctrld's own settings. + +The DNS watchdog process only runs on Windows and MacOS. + +- Type: boolean +- Required: no +- Default: true + +### dns_watchdog_interval +Time duration between each DNS watchdog iteration. + +A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix, +such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". + +If the time duration is non-positive, default value will be used. + +- Type: time duration string +- Required: no +- Default: 20s + +### refetch_time +Time in seconds between each iteration that reloads custom config if changed. + +The value must be a positive number, any invalid value will be ignored and default value will be used. +- Type: number +- Required: no +- Default: 3600 + ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to. diff --git a/doh.go b/doh.go index bddc583..d702995 100644 --- a/doh.go +++ b/doh.go @@ -147,7 +147,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" switch { - case uc.isControlD(): + case uc.IsControlD(): dohHeader = newControlDHeaders(ci) case uc.isNextDNS(): dohHeader = newNextDNSHeaders(ci) diff --git a/doq_quic_free.go b/doq_quic_free.go deleted file mode 100644 index 36fd63c..0000000 --- a/doq_quic_free.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build qf - -package ctrld - -import ( - "context" - "errors" - - "github.com/miekg/dns" -) - -type doqResolver struct { - uc *UpstreamConfig -} - -func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - return nil, errors.New("DoQ is not supported") -} diff --git a/dot.go b/dot.go index 1fef409..c0fe102 100644 --- a/dot.go +++ b/dot.go @@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro // dns.controld.dev first. By using a dialer with custom resolver, // we ensure that we can always resolve the bootstrap domain // regardless of the machine DNS status. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 3fb004e..3c8af6e 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -122,8 +122,8 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30) for { err := m.probe(conns, remoteAddr) - if isErrNetUnreachableOrInvalid(err) { - ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: network unreachable or invalid", remoteAddr) + if shouldStopProbing(err) { + ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err) break } if err != nil { @@ -165,7 +165,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) { } var ip, name string - rrs := make([]dns.RR, 0, len(msg.Answer)+len(msg.Extra)) + var rrs []dns.RR rrs = append(rrs, msg.Answer...) rrs = append(rrs, msg.Extra...) for _, rr := range rrs { @@ -273,10 +273,14 @@ func multicastInterfaces() ([]net.Interface, error) { return interfaces, nil } -func isErrNetUnreachableOrInvalid(err error) bool { +// shouldStopProbing reports whether ctrld should stop probing mdns. +func shouldStopProbing(err error) bool { var se *os.SyscallError if errors.As(err, &se) { - return se.Err == syscall.ENETUNREACH || se.Err == syscall.EINVAL + switch se.Err { + case syscall.ENETUNREACH, syscall.EINVAL, syscall.EPERM: + return true + } } return false } diff --git a/internal/controld/config.go b/internal/controld/config.go index c095c0c..01e114b 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -26,14 +26,15 @@ const ( apiDomainDev = "api.controld.dev" resolverDataURLCom = "https://api.controld.com/utility" resolverDataURLDev = "https://api.controld.dev/utility" - InvalidConfigCode = 40401 + InvalidConfigCode = 40402 ) // ResolverConfig represents Control D resolver data. type ResolverConfig struct { DOH string `json:"doh"` Ctrld struct { - CustomConfig string `json:"custom_config"` + CustomConfig string `json:"custom_config"` + CustomLastUpdate int64 `json:"custom_last_update"` } `json:"ctrld"` Exclude []string `json:"exclude"` UID string `json:"uid"` @@ -76,17 +77,28 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, bytes.NewReader(body)) + return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) } // FetchResolverUID fetch resolver uid from provision token. func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) { hostname, _ := os.Hostname() body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname}) - return postUtilityAPI(version, cdDev, bytes.NewReader(body)) + return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) } -func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) { +// UpdateCustomLastFailed calls API to mark custom config is bad. +func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { + uid, clientID := ParseRawUID(rawUID) + req := utilityRequest{UID: uid} + if clientID != "" { + req.ClientID = clientID + } + body, _ := json.Marshal(req) + return postUtilityAPI(version, cdDev, true, bytes.NewReader(body)) +} + +func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev @@ -98,6 +110,9 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig q := req.URL.Query() q.Set("platform", "ctrld") q.Set("version", version) + if lastUpdatedFailed { + q.Set("custom_last_failed", "1") + } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") transport := http.DefaultTransport.(*http.Transport).Clone() diff --git a/internal/router/router.go b/internal/router/router.go index bf65e6e..4b335a6 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync/atomic" "github.com/kardianos/service" @@ -164,6 +165,16 @@ func HomeDir() (string, error) { return "", err } return filepath.Dir(exe), nil + case edgeos.Name: + exe, err := os.Executable() + if err != nil { + return "", err + } + // Using binary directory as home dir if it is located in /config. + // Otherwise, fallback to old behavior for compatibility. + if strings.HasPrefix(exe, "/config/") { + return filepath.Dir(exe), nil + } } return "", nil } diff --git a/resolver.go b/resolver.go index 49ac652..d8b7f8d 100644 --- a/resolver.go +++ b/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/netip" + "slices" "sync" "time" @@ -30,18 +31,19 @@ const ( ResolverTypePrivate = "private" ) -const bootstrapDNS = "76.76.2.22" +const ( + controldBootstrapDns = "76.76.2.22" + controldPublicDns = "76.76.2.0" +) + +var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} -// defaultNameservers returns nameservers used by the OS. -// If no nameservers can be found, ctrld bootstrap nameserver will be used. +// defaultNameservers returns OS nameservers plus ControlD public DNS. func defaultNameservers() []string { ns := nameservers() - if len(ns) == 0 { - ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) - } return ns } @@ -51,10 +53,27 @@ func defaultNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - or.nameservers = defaultNameservers() + or.nameservers = or.nameservers[:0] + for _, ns := range defaultNameservers() { + if testNameserver(ns) { + or.nameservers = append(or.nameservers, ns) + } + } + or.nameservers = append(or.nameservers, controldPublicDnsWithPort) return or.nameservers } +// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. +func testNameserver(addr string) bool { + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + client := new(dns.Client) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, _, err := client.ExchangeContext(ctx, msg, addr) + return err == nil +} + // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error. @@ -89,8 +108,9 @@ type osResolver struct { } type osResolverResult struct { - answer *dns.Msg - err error + answer *dns.Msg + err error + isControlDPublicDNS bool } // Resolve resolves DNS queries using pre-configured nameservers. @@ -116,19 +136,34 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error go func(server string) { defer wg.Done() answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) - ch <- &osResolverResult{answer: answer, err: err} + ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort} }(server) } + var ( + nonSuccessAnswer *dns.Msg + controldSuccessAnswer *dns.Msg + ) errs := make([]error, 0, numServers) for res := range ch { - if res.err == nil { - cancel() - return res.answer, res.err + switch { + case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: + if res.isControlDPublicDNS { + controldSuccessAnswer = res.answer // only use ControlD answer as last one. + } else { + cancel() + return res.answer, nil + } + case res.answer != nil: + nonSuccessAnswer = res.answer } errs = append(errs, res.err) } - + for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} { + if answer != nil { + return answer, nil + } + } return nil, errors.Join(errs...) } @@ -138,7 +173,7 @@ type legacyResolver struct { func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // See comment in (*dotResolver).resolve method. - dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -176,7 +211,7 @@ func LookupIP(domain string) []string { func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { resolver := &osResolver{nameservers: nameservers()} if withBootstrapDNS { - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...) } ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers) timeoutMs := 2000 @@ -252,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { resolver := &osResolver{nameservers: nameservers()} - resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...) for _, ns := range servers { resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) } @@ -279,11 +314,11 @@ func NewPrivateResolver() Resolver { // - Direct listener that has ctrld as an upstream (e.g: dnsmasq). // // causing the query always succeed. - if sliceContains(resolveConfNss, host) { + if slices.Contains(resolveConfNss, host) { continue } // Ignoring local RFC 1918 addresses. - if sliceContains(localRfc1918Addrs, host) { + if slices.Contains(localRfc1918Addrs, host) { continue } ip := net.ParseIP(host) @@ -335,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer { }, } } - -// TODO(cuonglm): use slices.Contains once upgrading to go1.21 -// sliceContains reports whether v is present in s. -func sliceContains[S ~[]E, E comparable](s S, v E) bool { - return sliceIndex(s, v) >= 0 -} - -// sliceIndex returns the index of the first occurrence of v in s, -// or -1 if not present. -func sliceIndex[S ~[]E, E comparable](s S, v E) int { - for i := range s { - if v == s[i] { - return i - } - } - return -1 -} diff --git a/resolver_test.go b/resolver_test.go index 531570b..23c27ae 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "net" + "sync" "testing" "time" @@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { + ns := make([]string, 0, 2) + servers := make([]*dns.Server, 0, 2) + successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + }) + nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc { + return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, rcode) + w.WriteMsg(m) + }) + } + + handlers := []dns.Handler{ + nonSuccessHandlerWithRcode(dns.RcodeRefused), + nonSuccessHandlerWithRcode(dns.RcodeNameError), + successHandler, + } + for i := range handlers { + pc, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i]) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ns = append(ns, addr) + servers = append(servers, s) + } + defer func() { + for _, server := range servers { + server.Shutdown() + } + }() + resolver := &osResolver{nameservers: ns} + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + answer, err := resolver.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + if answer.Rcode != dns.RcodeSuccess { + t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode]) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) { }) } } + +func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) { + t.Helper() + + server := &dns.Server{ + PacketConn: pc, + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, + Handler: handler, + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + for _, opt := range opts { + opt(server) + } + + addr, closer := pc.LocalAddr().String(), pc + go func() { + if err := server.ActivateAndServe(); err != nil { + t.Error(err) + } + closer.Close() + }() + + waitLock.Lock() + return server, addr, nil +} diff --git a/testhelper/config.go b/testhelper/config.go index 6199424..a39ac62 100644 --- a/testhelper/config.go +++ b/testhelper/config.go @@ -27,6 +27,8 @@ var sampleConfigContent = ` [service] log_level = "info" log_path = "/path/to/log.log" +dns_watchdog_enabled = false +dns_watchdog_interval = "20s" [network.0] name = "Home Wifi"