package cli import ( "bytes" "context" "crypto/x509" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/netip" "net/url" "os" "os/exec" "path/filepath" "reflect" "runtime" "runtime/debug" "sort" "strconv" "strings" "sync/atomic" "time" "github.com/Masterminds/semver/v3" "github.com/cuonglm/osinfo" "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/miekg/dns" "github.com/pelletier/go-toml/v2" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" "tailscale.com/logtail/backoff" "tailscale.com/net/netmon" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" "github.com/Control-D-Inc/ctrld/internal/router" ) // selfCheckInternalTestDomain is used for testing ctrld self response to clients. const selfCheckInternalTestDomain = "ctrld" + loopTestDomain const ( windowsForwardersFilename = ".forwarders.txt" oldBinSuffix = "_previous" oldLogSuffix = ".1" msgExit = "$$EXIT$$" ) var ( version = "dev" commit = "none" ) var ( v = viper.NewWithOptions(viper.KeyDelimiter("::")) defaultConfigFile = "ctrld.toml" rootCertPool *x509.CertPool errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks") ) var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} func isNoConfigStart(cmd *cobra.Command) bool { for _, flagName := range basicModeFlags { if cmd.Flags().Lookup(flagName).Changed { return true } } return false } const rootShortDesc = ` __ .__ .___ _____/ |________| | __| _/ _/ ___\ __\_ __ \ | / __ | \ \___| | | | \/ |__/ /_/ | \___ >__| |__| |____/\____ | \/ dns forwarding proxy \/ ` var rootCmd = &cobra.Command{ Use: "ctrld", Short: strings.TrimLeft(rootShortDesc, "\n"), Version: curVersion(), PersistentPreRun: func(cmd *cobra.Command, args []string) { initConsoleLogging() }, } func curVersion() string { if version != "dev" && !strings.HasPrefix(version, "v") { version = "v" + version } if version != "" && version != "dev" { return version } if len(commit) > 7 { commit = commit[:7] } return fmt.Sprintf("%s-%s", version, commit) } func initCLI() { // Enable opening via explorer.exe on Windows. // See: https://github.com/spf13/cobra/issues/844. cobra.MousetrapHelpText = "" cobra.EnableCommandSorting = false rootCmd.PersistentFlags().CountVarP( &verbose, "verbose", "v", `verbose log output, "-v" basic logging, "-vv" debug logging`, ) rootCmd.PersistentFlags().BoolVarP( &silent, "silent", "s", false, `do not write any log output`, ) rootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) rootCmd.CompletionOptions.HiddenDefaultCmd = true initRunCmd() startCmd := initStartCmd() stopCmd := initStopCmd() restartCmd := initRestartCmd() reloadCmd := initReloadCmd(restartCmd) statusCmd := initStatusCmd() uninstallCmd := initUninstallCmd() interfacesCmd := initInterfacesCmd() initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) initClientsCmd() initUpgradeCmd() initLogCmd() } // isMobile reports whether the current OS is a mobile platform. func isMobile() bool { return runtime.GOOS == "android" || runtime.GOOS == "ios" } // isAndroid reports whether the current OS is Android. func isAndroid() bool { return runtime.GOOS == "android" } // isStableVersion reports whether vs is a stable semantic version. func isStableVersion(vs string) bool { v, err := semver.NewVersion(vs) if err != nil { return false } return v.Prerelease() == "" } // RunCobraCommand runs ctrld cli. func RunCobraCommand(cmd *cobra.Command) { noConfigStart = isNoConfigStart(cmd) checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) run(nil, make(chan struct{})) } // RunMobile runs the ctrld cli on mobile platforms. func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struct{}) { if appConfig == nil { panic("appConfig is nil") } initConsoleLogging() noConfigStart = false homedir = appConfig.HomeDir verbose = appConfig.Verbose if appConfig.ProvisionID != "" { cdOrg = appConfig.ProvisionID } if appConfig.CustomHostname != "" { customHostname = appConfig.CustomHostname } if appConfig.CdUID != "" { cdUID = appConfig.CdUID } cdUpstreamProto = appConfig.UpstreamProto logPath = appConfig.LogPath run(appCallback, stopCh) } // CheckDeactivationPin checks if deactivation pin is valid func CheckDeactivationPin(pin int64, stopCh chan struct{}) int { deactivationPin = pin if err := checkDeactivationPin(nil, stopCh); isCheckDeactivationPinErr(err) { return deactivationPinInvalidExitCode } return 0 } // run runs ctrld cli with given app callback and stop channel. func run(appCallback *AppCallback, stopCh chan struct{}) { if stopCh == nil { mainLog.Load().Fatal().Msg("stopCh is nil") } waitCh := make(chan struct{}) p := &prog{ waitCh: waitCh, stopCh: stopCh, pinCodeValidCh: make(chan struct{}, 1), reloadCh: make(chan struct{}), reloadDoneCh: make(chan struct{}), dnsWatcherStopCh: make(chan struct{}), apiReloadCh: make(chan *ctrld.Config), apiForceReloadCh: make(chan struct{}), cfg: &cfg, appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { homedir = dir } } sockDir := homedir if d, err := socketDir(); err == nil { sockDir = d } sockPath := filepath.Join(sockDir, ctrldLogUnixSock) if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { lc := &logConn{conn: conn} consoleWriter.Out = io.MultiWriter(os.Stdout, lc) p.logConn = lc } else { if !errors.Is(err, os.ErrNotExist) { mainLog.Load().Warn().Err(err).Msg("unable to create log ipc connection") } } } else { mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) } notifyExitToLogServer := func() { if p.logConn != nil { _, _ = p.logConn.Write([]byte(msgExit)) } } if daemon && runtime.GOOS == "windows" { mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") } if !daemon { // We need to call s.Run() as soon as possible to response to the OS manager, so it // can see ctrld is running and don't mark ctrld as failed service. go func() { s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed create new service") } if err := s.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start service") } }() } writeDefaultConfig := !noConfigStart && configBase64 == "" tryReadingConfig(writeDefaultConfig) if err := readBase64Config(configBase64); err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") } processNoConfigFlags(noConfigStart) // After s.Run() was called, if ctrld is going to be terminated for any reason, // write msgExit to p.logConn so others (like "ctrld start") won't have to wait for timeout. p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { notifyExitToLogServer() mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } p.mu.Unlock() processLogAndCacheFlags() // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. p.initLogging(true) mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) mainLog.Load().Info().Msgf("os: %s", osVersion()) // Wait for network up. if !ctrldnet.Up() { notifyExitToLogServer() mainLog.Load().Fatal().Msg("network is not up yet") } p.router = router.New(&cfg, cdUID != "") cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName())) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not create control server") } p.cs = cs // Processing --cd flag require connecting to ControlD API, which needs valid // time for validating server certificate. Some routers need NTP synchronization // to set the current time, so this check must happen before processCDFlags. if err := p.router.PreRun(); err != nil { notifyExitToLogServer() mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") } oldLogPath := cfg.Service.LogPath if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid } if cdUID != "" { validateCdUpstreamProtocol() if rc, err := processCDFlags(&cfg); err != nil { if isMobile() { appCallback.Exit(err.Error()) return } cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() // Performs self-uninstallation if the ControlD device does not exist. var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { _ = uninstallInvalidCdUID(p, cdLogger, false) } notifyExitToLogServer() cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") } else { p.mu.Lock() p.rc = rc p.mu.Unlock() } } updated := updateListenerConfig(&cfg, notifyExitToLogServer) if cdUID != "" { processLogAndCacheFlags() } if updated { if err := writeConfigFile(&cfg); err != nil { notifyExitToLogServer() mainLog.Load().Fatal().Err(err).Msg("failed to write config file") } else { mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) } } if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { // After processCDFlags, log config may change, so reset mainLog and re-init logging. l := zerolog.New(io.Discard) mainLog.Store(&l) // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { mainLog.Load().Warn().Err(err).Msg("could not copy old log file") } } initLoggingWithBackup(false) } if err := validateConfig(&cfg); err != nil { notifyExitToLogServer() os.Exit(1) } initCache() if daemon { exe, err := os.Executable() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to find the binary") notifyExitToLogServer() os.Exit(1) } curDir, err := os.Getwd() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to get current working directory") notifyExitToLogServer() os.Exit(1) } // If running as daemon, re-run the command in background, with daemon off. cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) cmd.Dir = curDir if err := cmd.Start(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") notifyExitToLogServer() os.Exit(1) } mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") os.Exit(0) } p.onStarted = append(p.onStarted, func() { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := allocateIP(lc.IP); err != nil { mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) } } } // Configure Windows service failure actions _ = ConfigureWindowsServiceFailureActions(ctrldServiceName) }) p.onStopped = append(p.onStopped, func() { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := deAllocateIP(lc.IP); err != nil { mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) } } } }) if platform := router.Name(); platform != "" { if cp := router.CertPool(); cp != nil { rootCertPool = cp } if iface != "" { p.onStarted = append(p.onStarted, func() { mainLog.Load().Debug().Msg("router setup on start") if err := p.router.Setup(); err != nil { mainLog.Load().Error().Err(err).Msg("could not configure router") } }) p.onStopped = append(p.onStopped, func() { mainLog.Load().Debug().Msg("router cleanup on stop") if err := p.router.Cleanup(); err != nil { mainLog.Load().Error().Err(err).Msg("could not cleanup router") } }) } } p.onStopped = append(p.onStopped, func() { // restore static DNS settings or DHCP p.resetDNS(false, true) // Iterate over all physical interfaces and restore static DNS if a saved static config exists. withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { file := savedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) } else { mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) } } return nil }) }) close(waitCh) <-stopCh } func writeConfigFile(cfg *ctrld.Config) error { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } else if configPath != "" { defaultConfigFile = configPath } f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) if err != nil { return err } defer f.Close() if cdUID != "" { if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil { return err } } enc := toml.NewEncoder(f).SetIndentTables(true) if err := enc.Encode(&cfg); err != nil { return err } if err := f.Close(); err != nil { return err } return nil } // readConfigFile reads in config file. // // - It writes default config file if config file not found if writeDefaultConfig is true. // - It emits notice message to user if notice is true. func readConfigFile(writeDefaultConfig, notice bool) bool { // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { if notice { mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) } mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed()) defaultConfigFile = v.ConfigFileUsed() return true } if !writeDefaultConfig { return false } // If error is viper.ConfigFileNotFoundError, write default config. if errors.As(err, &viper.ConfigFileNotFoundError{}) { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } nop := zerolog.Nop() _, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true) addExtraSplitDnsRule(&cfg) if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { fp, err := filepath.Abs(defaultConfigFile) if err != nil { mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err) } if cdUID == "" && nextdns == "" { mainLog.Load().Notice().Msg("Generating controld default config: " + fp) } mainLog.Load().Info().Msg("writing default config file to: " + fp) } return false } // If error is viper.ConfigParseError, emit details line and column number. if errors.As(err, &viper.ConfigParseError{}) { if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil { row, col := de.Position() mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) } } // Otherwise, report fatal error and exit. mainLog.Load().Fatal().Msgf("failed to decode config file: %v", err) return false } // decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error. func decoderErrorFromTomlFile(cf string) *toml.DecodeError { if f, _ := os.Open(cf); f != nil { defer f.Close() var i any var de *toml.DecodeError if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) { return de } } return nil } // readBase64Config reads ctrld config from the base64 input string. func readBase64Config(configBase64 string) error { if configBase64 == "" { return nil } configStr, err := base64.StdEncoding.DecodeString(configBase64) if err != nil { return fmt.Errorf("invalid base64 config: %w", err) } // readBase64Config is called when: // // - "--base64_config" flag set. // - Reading custom config when "--cd" flag set. // // So we need to re-create viper instance to discard old one. v = viper.NewWithOptions(viper.KeyDelimiter("::")) v.SetConfigType("toml") return v.ReadConfig(bytes.NewReader(configStr)) } func processNoConfigFlags(noConfigStart bool) { if !noConfigStart { return } if listenAddress == "" || primaryUpstream == "" { mainLog.Load().Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`) } processListenFlag() endpointAndTyp := func(endpoint string) (string, string) { typ := ctrld.ResolverTypeFromEndpoint(endpoint) endpoint = strings.TrimPrefix(endpoint, "quic://") if after, found := strings.CutPrefix(endpoint, "h3://"); found { endpoint = "https://" + after } return endpoint, typ } pEndpoint, pType := endpointAndTyp(primaryUpstream) puc := &ctrld.UpstreamConfig{ Name: pEndpoint, Endpoint: pEndpoint, Type: pType, Timeout: 5000, } puc.Init() upstream := map[string]*ctrld.UpstreamConfig{"0": puc} if secondaryUpstream != "" { sEndpoint, sType := endpointAndTyp(secondaryUpstream) suc := &ctrld.UpstreamConfig{ Name: sEndpoint, Endpoint: sEndpoint, Type: sType, Timeout: 5000, } suc.Init() upstream["1"] = suc rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}}) } lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"] lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} } v.Set("upstream", upstream) } // defaultDeactivationPin is the default value for cdDeactivationPin. // If cdDeactivationPin equals to this default, it means the pin code is not set from Control D API. const defaultDeactivationPin = -1 // cdDeactivationPin is used in cd mode to decide whether stop and uninstall commands can be run. var cdDeactivationPin atomic.Int64 func init() { cdDeactivationPin.Store(defaultDeactivationPin) } // deactivationPinSet indicates if cdDeactivationPin is non-default.. func deactivationPinSet() bool { return cdDeactivationPin.Load() != defaultDeactivationPin } func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) bo.LogLongerThan = 30 * time.Second ctx := context.Background() resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) for { if errUrlNetworkError(err) { bo.BackOff(ctx, err) logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...") resolverConfig, err = controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) continue } break } if err != nil { if isMobile() { return nil, err } logger.Warn().Err(err).Msg("could not fetch resolver config") return nil, err } if resolverConfig.DeactivationPin != nil { logger.Debug().Msg("saving deactivation pin") cdDeactivationPin.Store(*resolverConfig.DeactivationPin) } logger.Info().Msg("generating ctrld config from Control-D configuration") *cfg = ctrld.Config{} // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") var cfgErr error if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil { setListenerDefaultValue(cfg) setNetworkDefaultValue(cfg) if cfgErr = validateConfig(cfg); cfgErr == nil { return resolverConfig, nil } } mainLog.Load().Warn().Err(err).Msg("disregarding invalid custom config") } bootstrapIP := func(endpoint string) string { u, err := url.Parse(endpoint) if err != nil { logger.Warn().Err(err).Msgf("no bootstrap IP for invalid endpoint: %s", endpoint) return "" } switch { case dns.IsSubDomain(ctrld.FreeDnsDomain, u.Host): return ctrld.FreeDNSBoostrapIP case dns.IsSubDomain(ctrld.PremiumDnsDomain, u.Host): return ctrld.PremiumDNSBoostrapIP } return "" } cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ BootstrapIP: bootstrapIP(resolverConfig.DOH), Endpoint: resolverConfig.DOH, Type: cdUpstreamProto, Timeout: 5000, } rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) for _, domain := range resolverConfig.Exclude { rules = append(rules, ctrld.Rule{domain: []string{}}) } cfg.Listener = make(map[string]*ctrld.ListenerConfig) lc := &ctrld.ListenerConfig{ Policy: &ctrld.ListenerPolicyConfig{ Name: "My Policy", Rules: rules, }, } cfg.Listener["0"] = lc // Set default value. setListenerDefaultValue(cfg) setNetworkDefaultValue(cfg) return resolverConfig, nil } // setListenerDefaultValue sets the default value for cfg.Listener if none existed. func setListenerDefaultValue(cfg *ctrld.Config) { if len(cfg.Listener) == 0 { cfg.Listener = map[string]*ctrld.ListenerConfig{ "0": {IP: "", Port: 0}, } } } // setListenerDefaultValue sets the default value for cfg.Listener if none existed. func setNetworkDefaultValue(cfg *ctrld.Config) { if len(cfg.Network) == 0 { cfg.Network = map[string]*ctrld.NetworkConfig{ "0": { Name: "Network 0", Cidrs: []string{"0.0.0.0/0"}, }, } } } // validateCdRemoteConfig validates the custom config from ControlD if defined. // This only validate the config syntax. To validate the config rules, calling // validateConfig with the cfg after calling this function. func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error { if rc.Ctrld.CustomConfig == "" { return nil } if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil { return err } return v.Unmarshal(&cfg) } func processListenFlag() { if listenAddress == "" { return } host, portStr, err := net.SplitHostPort(listenAddress) if err != nil { mainLog.Load().Fatal().Msgf("invalid listener address: %v", err) } port, err := strconv.Atoi(portStr) if err != nil { mainLog.Load().Fatal().Msgf("invalid port number: %v", err) } lc := &ctrld.ListenerConfig{ IP: host, Port: port, } v.Set("listener", map[string]*ctrld.ListenerConfig{ "0": lc, }) } func processLogAndCacheFlags() { if logPath != "" { cfg.Service.LogPath = logPath } if logPath != "" && cfg.Service.LogLevel == "" { cfg.Service.LogLevel = "debug" } if cacheSize != 0 { cfg.Service.CacheEnable = true cfg.Service.CacheSize = cacheSize } v.Set("service", cfg.Service) } func netInterface(ifaceName string) (*net.Interface, error) { if ifaceName == "auto" { ifaceName = defaultIfaceName() } var iface *net.Interface err := netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { if i.Name == ifaceName { iface = i.Interface } }) if iface == nil { return nil, errors.New("interface not found") } if _, err := patchNetIfaceName(iface); err != nil { return nil, err } return iface, err } func defaultIfaceName() string { if ifaceName := router.DefaultInterfaceName(); ifaceName != "" { return ifaceName } dri, err := netmon.DefaultRouteInterface() if err != nil { // On WSL 1, the route table does not have any default route. But the fact that // it only uses /etc/resolv.conf for setup DNS, so we can use "lo" here. if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") { return "lo" } // On linux, it could be either resolvconf or systemd which is managing DNS settings, // so the interface name does not matter if there's no default route interface. if runtime.GOOS == "linux" { return "lo" } mainLog.Load().Debug().Err(err).Msg("no default route interface found") return "" } return dri } // selfCheckStatus performs the end-to-end DNS test by sending query to ctrld listener. // It returns a boolean to indicate whether the check is succeeded, the actual status // of ctrld service, and an additional error if any. // // We perform two tests: // // - Internal testing, ensuring query could be sent from client -> ctrld. // - External testing, ensuring query could be sent from ctrld -> upstream. // // Self-check is considered success only if both tests are ok. func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bool, service.Status, error) { status, err := s.Status() if err != nil { mainLog.Load().Warn().Err(err).Msg("could not get service status") return false, service.StatusUnknown, err } // If ctrld is not running, do nothing, just return the status as-is. if status != service.StatusRunning { return false, status, nil } // Skip self checks if set. if skipSelfChecks { return true, status, nil } mainLog.Load().Debug().Msg("waiting for ctrld listener to be ready") cc := newSocketControlClient(ctx, s, sockDir) if cc == nil { return false, status, errors.New("could not connect to control server") } v = viper.NewWithOptions(viper.KeyDelimiter("::")) if configPath != "" { v.SetConfigFile(configPath) } else { v.SetConfigFile(defaultConfigFile) } if err := v.ReadInConfig(); err != nil { mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed()) return false, status, err } cfg = ctrld.Config{} if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Error().Err(err).Msg("failed to update new config") return false, status, err } selfCheckExternalDomain := cfg.FirstUpstream().VerifyDomain() if selfCheckExternalDomain == "" { // Nothing to do, return the status as-is. return true, status, nil } mainLog.Load().Debug().Msg("ctrld listener is ready") lc := cfg.FirstListener() addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr) if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil { return false, status, err } if err := selfCheckResolveDomain(context.TODO(), addr, "external", selfCheckExternalDomain); err != nil { return false, status, err } return true, status, nil } // selfCheckResolveDomain performs DNS test query against ctrld listener. func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error { bo := backoff.NewBackoff("self-check", logf, 10*time.Second) bo.LogLongerThan = 500 * time.Millisecond maxAttempts := 10 c := new(dns.Client) var ( lastAnswer *dns.Msg lastErr error ) oi := osinfo.New() for i := 0; i < maxAttempts; i++ { if domain == "" { return errors.New("empty test domain") } m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) m.RecursionDesired = true r, _, exErr := exchangeContextWithTimeout(c, 5*time.Second, m, addr) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain) return nil } // Return early if this is a connection refused. if errConnectionRefused(exErr) { return exErr } // Return early if this is MacOS 15.0 and error is timeout error. var e net.Error if oi.Name == "darwin" && oi.Version == "15.0" && errors.As(exErr, &e) && e.Timeout() { mainLog.Load().Warn().Msg("MacOS 15.0 Sequoia has a bug with the firewall which may prevent ctrld from starting. Disable the MacOS firewall and try again") return exErr } lastAnswer = r lastErr = exErr bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } mainLog.Load().Debug().Msgf("self-check against %q failed", domain) // Ping all upstreams to provide better error message to users. for name, uc := range cfg.Upstream { if err := uc.ErrorPing(); err != nil { mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } marker := strings.Repeat("=", 32) mainLog.Load().Debug().Msg(marker) mainLog.Load().Debug().Msgf("listener address : %s", addr) mainLog.Load().Debug().Msgf("last error : %v", lastErr) if lastAnswer != nil { mainLog.Load().Debug().Msgf("last answer from ctrld :") mainLog.Load().Debug().Msg(marker) for _, s := range strings.Split(lastAnswer.String(), "\n") { mainLog.Load().Debug().Msgf("%s", s) } } return errSelfCheckNoAnswer } func userHomeDir() (string, error) { dir, err := router.HomeDir() if err != nil { return "", err } if dir != "" { return dir, nil } // viper will expand for us. if runtime.GOOS == "windows" { // If we're on windows, use the install path for this. exePath, err := os.Executable() if err != nil { return "", err } return filepath.Dir(exePath), nil } // Mobile platform should provide a rw dir path for this. if isMobile() { return homedir, nil } dir = "/etc/controld" if err := os.MkdirAll(dir, 0750); err != nil { return os.UserHomeDir() // fallback to user home directory } if ok, _ := dirWritable(dir); !ok { return os.UserHomeDir() } return dir, nil } // socketDir returns directory that ctrld will create socket file for running controlServer. func socketDir() (string, error) { switch { case runtime.GOOS == "windows", isMobile(): return userHomeDir() } dir := "/var/run" if ok, _ := dirWritable(dir); !ok { return userHomeDir() } return dir, nil } // tryReadingConfig is like tryReadingConfigWithNotice, with notice set to false. func tryReadingConfig(writeDefaultConfig bool) { tryReadingConfigWithNotice(writeDefaultConfig, false) } // tryReadingConfigWithNotice tries reading in config files, either specified by user or from default // locations. If notice is true, emitting a notice message to user which config file was read. func tryReadingConfigWithNotice(writeDefaultConfig, notice bool) { // --config is specified. if configPath != "" { v.SetConfigFile(configPath) readConfigFile(false, notice) return } // no config start or base64 config mode. if !writeDefaultConfig { return } readConfigWithNotice(writeDefaultConfig, notice) } // readConfig calls readConfigWithNotice with notice set to false. func readConfig(writeDefaultConfig bool) { readConfigWithNotice(writeDefaultConfig, false) } // readConfigWithNotice calls readConfigFile with config file set to ctrld.toml // or config.toml for compatible with earlier versions of ctrld. func readConfigWithNotice(writeDefaultConfig, notice bool) { configs := []struct { name string written bool }{ // For compatibility, we check for config.toml first, but only read it if exists. {"config", false}, {"ctrld", writeDefaultConfig}, } dir, err := userHomeDir() if err != nil { mainLog.Load().Fatal().Msgf("failed to get user home dir: %v", err) } for _, config := range configs { ctrld.SetConfigNameWithPath(v, config.name, dir) v.SetConfigFile(configPath) if readConfigFile(config.written, notice) { break } } } func uninstall(p *prog, s service.Service) { if _, err := s.Status(); err != nil && errors.Is(err, service.ErrNotInstalled) { mainLog.Load().Error().Msg(err.Error()) return } tasks := []task{ {s.Stop, false, "Stop"}, {s.Uninstall, true, "Uninstall"}, } initInteractiveLogging() if doTasks(tasks) { if err := p.router.ConfigureService(svcConfig); err != nil { mainLog.Load().Fatal().Err(err).Msg("could not configure service") } if err := p.router.Uninstall(svcConfig); err != nil { mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error") return } // restore static DNS settings or DHCP p.resetDNS(false, true) // Iterate over all physical interfaces and restore DNS if a saved static config exists. withEachPhysicalInterfaces(p.runningIface, "restore static DNS", func(i *net.Interface) error { file := savedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) } else { mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) err = os.Remove(file) if err != nil { mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name) } } } return nil }) if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") } // Stop already did router.Cleanup and report any error if happens, // ignoring error here to prevent false positive. _ = p.router.Cleanup() mainLog.Load().Notice().Msg("Service uninstalled") return } } func validateConfig(cfg *ctrld.Config) error { if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil { var ve validator.ValidationErrors if errors.As(err, &ve) { for _, fe := range ve { mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) } } return err } return nil } // NOTE: Add more case here once new validation tag is used in ctrld.Config struct. func fieldErrorMsg(fe validator.FieldError) string { switch fe.Tag() { case "oneof": return fmt.Sprintf("must be one of: %q", fe.Param()) case "min": if fe.Kind() == reflect.Map || fe.Kind() == reflect.Slice { return fmt.Sprintf("must define at least %s element", fe.Param()) } return fmt.Sprintf("minimum value: %q", fe.Param()) case "max": if fe.Kind() == reflect.Map || fe.Kind() == reflect.Slice { return fmt.Sprintf("exceeded maximum number of elements: %s", fe.Param()) } return fmt.Sprintf("maximum value: %q", fe.Param()) case "len": if fe.Kind() == reflect.Slice { return fmt.Sprintf("must have at least %s element", fe.Param()) } return fmt.Sprintf("minimum len: %q", fe.Param()) case "gte": return fmt.Sprintf("must be greater than or equal to: %s", fe.Param()) case "cidr": return fmt.Sprintf("invalid value: %s", fe.Value()) case "required_unless", "required": return "value is required" case "dnsrcode": return fmt.Sprintf("invalid DNS rcode value: %s", fe.Value()) case "ipstack": ipStacks := []string{ctrld.IpStackV4, ctrld.IpStackV6, ctrld.IpStackSplit, ctrld.IpStackBoth} return fmt.Sprintf("must be one of: %q", strings.Join(ipStacks, " ")) case "iporempty": return fmt.Sprintf("invalid IP format: %s", fe.Value()) case "file": return fmt.Sprintf("filed does not exist: %s", fe.Value()) case "http_url": return fmt.Sprintf("invalid http/https url: %s", fe.Value()) } return "" } func isLoopback(ipStr string) bool { ip := net.ParseIP(ipStr) if ip == nil { return false } return ip.IsLoopback() } func shouldAllocateLoopbackIP(ipStr string) bool { ip := net.ParseIP(ipStr) if ip == nil || ip.To4() == nil { return false } return ip.IsLoopback() && ip.String() != "127.0.0.1" } type listenerConfigCheck struct { IP bool Port bool } // mobileListenerPort returns hardcoded port for mobile platforms. func mobileListenerPort() int { if isAndroid() { return 5354 } return 53 } // mobileListenerIp returns hardcoded listener ip for mobile platforms func mobileListenerIp() string { if isAndroid() { return "0.0.0.0" } return "127.0.0.1" } // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool { updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true) if addExtraSplitDnsRule(cfg) { updated = true } return updated } // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" // For Windows server with local Dns server running, we can only try on random local IP. hasLocalDnsServer := hasLocalDnsServerRunning() notRouter := router.Name() == "" isDesktop := ctrld.IsDesktopPlatform() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" // Windows Server lies to us that we could listen on 0.0.0.0:53 // even there's a process already done that, stick to local IP only. // // For desktop clients, also stick the listener to the local IP only. // Listening on 0.0.0.0 would expose it to the entire local network, potentially // creating security vulnerabilities (such as DNS amplification or abusing). if hasLocalDnsServer || isDesktop { listener.IP = "127.0.0.1" } lcc[n].IP = true } if listener.Port == 0 { listener.Port = 53 lcc[n].Port = true } // In cd mode, we always try to pick an ip:port pair to work. // Same if nextdns resolver is used. // // Except on Windows Server with local Dns running, // we could only listen on random local IP port 53. if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true if hasLocalDnsServer { lcc[n].Port = false } } updated = updated || lcc[n].IP || lcc[n].Port } il := mainLog.Load() if infoLogger != nil { il = infoLogger } if isMobile() { // On Mobile, only use first listener, ignore others. firstLn := cfg.FirstListener() for k := range cfg.Listener { if cfg.Listener[k] != firstLn { delete(cfg.Listener, k) } } if cdMode { firstLn.IP = mobileListenerIp() firstLn.Port = mobileListenerPort() clear(lcc) updated = true } } var closers []io.Closer defer func() { for _, closer := range closers { _ = closer.Close() } }() // tryListen attempts to listen on given udp and tcp address. // Created listeners will be kept in listeners slice above, and close // before function finished. tryListen := func(addr string) error { udpLn, udpErr := net.ListenPacket("udp", addr) if udpLn != nil { closers = append(closers, udpLn) } tcpLn, tcpErr := net.Listen("tcp", addr) if tcpLn != nil { closers = append(closers, tcpLn) } return errors.Join(udpErr, tcpErr) } logMsg := func(e *zerolog.Event, listenerNum int, format string, v ...any) { e.MsgFunc(func() string { return fmt.Sprintf("listener.%d %s", listenerNum, fmt.Sprintf(format, v...)) }) } listeners := make([]int, 0, len(cfg.Listener)) for k := range cfg.Listener { n, err := strconv.Atoi(k) if err != nil { continue } listeners = append(listeners, n) } sort.Ints(listeners) for _, n := range listeners { listener := cfg.Listener[strconv.Itoa(n)] check := lcc[strconv.Itoa(n)] oldIP := listener.IP oldPort := listener.Port isZeroIP := listener.IP == "0.0.0.0" || listener.IP == "::" // Check if we could listen on the current IP + Port, if not, try following thing, pick first one success: // - Try 127.0.0.1:53 // - Pick a random port until success. localhostIP := func(ipStr string) string { if ip := net.ParseIP(ipStr); ip != nil && ip.To4() == nil { return "::1" } return "127.0.0.1" } // On firewalla, we don't need to check localhost, because the lo interface is excluded in dnsmasq // config, so we can always listen on localhost port 53, but no traffic could be routed there. tryLocalhost := !isLoopback(listener.IP) && router.CanListenLocalhost() tryAllPort53 := true tryOldIPPort5354 := true tryPort5354 := true if hasLocalDnsServer { tryAllPort53 = false tryOldIPPort5354 = false tryPort5354 = false } // if not running on a router, we should not try to listen on any port other than 53 // if we do, this will break the dns resolution for the system. if notRouter { tryOldIPPort5354 = false tryPort5354 = false } attempts := 0 maxAttempts := 10 for { if attempts == maxAttempts { notifyFunc() logMsg(mainLog.Load().Fatal(), n, "could not find available listen ip and port") } addr := net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)) err := tryListen(addr) if err == nil { break } logMsg(il.Info(), n, "error listening on address: %s, error: %v", addr, err) if !check.IP && !check.Port { if fatal { notifyFunc() logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) } ok = false break } if tryAllPort53 { tryAllPort53 = false if check.IP { listener.IP = "0.0.0.0" } if check.Port { listener.Port = 53 } if check.IP { logMsg(il.Info(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } if tryLocalhost { tryLocalhost = false if check.IP { listener.IP = localhostIP(listener.IP) } if check.Port { listener.Port = 53 } if check.IP { logMsg(il.Info(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } if tryOldIPPort5354 { tryOldIPPort5354 = false if check.IP { listener.IP = oldIP } if check.Port { listener.Port = 5354 } logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr) continue } if tryPort5354 { tryPort5354 = false if check.IP { listener.IP = "0.0.0.0" } if check.Port { listener.Port = 5354 } logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) continue } if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port. listener.IP = randomLocalIP() } else { listener.IP = oldIP } // if we are not running on a router, we should not try to listen on any port other than 53 // if we do, this will break the dns resolution for the system. if check.Port && !notRouter { listener.Port = randomPort() } else { listener.Port = oldPort } if listener.IP == oldIP && listener.Port == oldPort { if fatal { notifyFunc() logMsg(mainLog.Load().Fatal(), n, "could not listen on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) } ok = false break } logMsg(il.Info(), n, "could not listen on address: %s, pick a random ip+port", addr) attempts++ } } if !ok { return } // Specific case for systemd-resolved. if useSystemdResolved { if listener := cfg.FirstListener(); listener != nil && listener.Port == 53 { n := listeners[0] // systemd-resolved does not allow forwarding DNS queries from 127.0.0.53 to loopback // ip address, other than "127.0.0.1", so trying to listen on default route interface // address instead. if ip := net.ParseIP(listener.IP); ip != nil && ip.IsLoopback() && ip.String() != "127.0.0.1" { logMsg(il.Info(), n, "using loopback interface do not work with systemd-resolved") found := false if netIface, _ := net.InterfaceByName(defaultIfaceName()); netIface != nil { addrs, _ := netIface.Addrs() for _, addr := range addrs { if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.To4() != nil { addr := net.JoinHostPort(netIP.IP.String(), strconv.Itoa(listener.Port)) if err := tryListen(addr); err == nil { found = true listener.IP = netIP.IP.String() logMsg(il.Info(), n, "use %s as listener address", listener.IP) break } } } } if !found { notifyFunc() logMsg(mainLog.Load().Fatal(), n, "could not use %q as DNS nameserver with systemd resolved", listener.IP) } } } } return } func dirWritable(dir string) (bool, error) { f, err := os.CreateTemp(dir, "") if err != nil { return false, err } defer os.Remove(f.Name()) return true, f.Close() } func osVersion() string { oi := osinfo.New() if runtime.GOOS == "freebsd" { if ver, _, found := strings.Cut(oi.String(), ":"); found { return ver } } return oi.String() } // cdUIDFromProvToken fetch UID from ControlD API using provision token. func cdUIDFromProvToken() string { // --cd flag supersedes --cd-org, ignore it if both are supplied. if cdUID != "" { return "" } // --cd-org is empty, nothing to do. if cdOrg == "" { return "" } // Validate custom hostname if provided. if customHostname != "" && !validHostname(customHostname) { mainLog.Load().Fatal().Msgf("invalid custom hostname: %q", customHostname) } req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev) if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) } return resolverConfig.UID } // removeOrgFlagsFromArgs removes organization flags from command line arguments. // The flags are: // // - "--cd-org" // - "--custom-hostname" // // This is necessary because "ctrld run" only need a valid UID, which could be fetched // using "--cd-org". So if "ctrld start" have already been called with "--cd-org", we // already have a valid UID to pass to "ctrld run", so we don't have to force "ctrld run" // to re-do the already done job. func removeOrgFlagsFromArgs(sc *service.Config) { a := sc.Arguments[:0] skip := false for _, x := range sc.Arguments { if skip { skip = false continue } // For "--cd-org XXX"/"--custom-hostname XXX", skip them and mark next arg skipped. if x == "--"+cdOrgFlagName || x == "--"+customHostnameFlagName { skip = true continue } // For "--cd-org=XXX"/"--custom-hostname=XXX", just skip them. if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") || strings.HasPrefix(x, "--"+customHostnameFlagName+"=") { continue } a = append(a, x) } sc.Arguments = a } // newSocketControlClient returns new control client after control server was started. func newSocketControlClient(ctx context.Context, s service.Service, dir string) *controlClient { return newSocketControlClientWithTimeout(ctx, s, dir, dialSocketControlServerTimeout) } // newSocketControlClientWithTimeout returns new control client after control server was started. // The timeoutDuration controls how long to wait for the server. func newSocketControlClientWithTimeout(ctx context.Context, s service.Service, dir string, timeoutDuration time.Duration) *controlClient { // Return early if service is not running. if status, err := s.Status(); err != nil || status != service.StatusRunning { return nil } bo := backoff.NewBackoff("self-check", logf, 10*time.Second) bo.LogLongerThan = 10 * time.Second cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) timeout := time.NewTimer(timeoutDuration) defer timeout.Stop() // The socket control server may not start yet, so attempt to ping // it until we got a response. for { _, err := cc.post(startedPath, nil) if err == nil { // Server was started, stop pinging. break } // The socket control server is not ready yet, backoff for waiting it to be ready. bo.BackOff(ctx, err) select { case <-timeout.C: return nil case <-ctx.Done(): return nil default: } } return cc } func newSocketControlClientMobile(dir string, stopCh chan struct{}) *controlClient { bo := backoff.NewBackoff("self-check", logf, 3*time.Second) bo.LogLongerThan = 3 * time.Second ctx := context.Background() cc := newControlClient(filepath.Join(dir, ControlSocketName())) for { select { case <-stopCh: return nil default: _, err := cc.post("/", nil) if err == nil { return cc } else { bo.BackOff(ctx, err) } } } } // checkStrFlagEmpty validates if a string flag was set to an empty string. // If yes, emitting a fatal error message. func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { fl := cmd.Flags().Lookup(flagName) if !fl.Changed || fl.Value.Type() != "string" { return } if fl.Value.String() == "" { mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name) } } func validateCdUpstreamProtocol() { if cdUID == "" { return } switch cdUpstreamProto { case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3: default: mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`) } } func validateCdAndNextDNSFlags() { if (cdUID != "" || cdOrg != "") && nextdns != "" { mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) } } // removeNextDNSFromArgs removes the --nextdns from command line arguments. func removeNextDNSFromArgs(sc *service.Config) { a := sc.Arguments[:0] skip := false for _, x := range sc.Arguments { if skip { skip = false continue } // For "--nextdns XXX", skip it and mark next arg skipped. if x == "--"+nextdnsFlagName { skip = true continue } // For "--nextdns=XXX", just skip it. if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") { continue } a = append(a, x) } sc.Arguments = a } // doGenerateNextDNSConfig generates a working config with nextdns resolver. func doGenerateNextDNSConfig(uid string) error { if uid == "" { return nil } mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) generateNextDNSConfig(uid) updateListenerConfig(&cfg, func() {}) return writeConfigFile(&cfg) } func noticeWritingControlDConfig() error { if cdUID != "" { mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile) } return nil } // deactivationPinInvalidExitCode indicates exit code due to invalid pin code. const deactivationPinInvalidExitCode = 126 // errInvalidDeactivationPin indicates that the deactivation pin is invalid. var errInvalidDeactivationPin = errors.New("deactivation pin is invalid") // errRequiredDeactivationPin indicates that the deactivation pin is required but not provided by users. var errRequiredDeactivationPin = errors.New("deactivation pin is required to stop or uninstall the service") // checkDeactivationPin validates if the deactivation pin matches one in ControlD config. func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { mainLog.Load().Debug().Msg("Checking deactivation pin") dir, err := socketDir() if err != nil { mainLog.Load().Err(err).Msg("could not check deactivation pin") return err } mainLog.Load().Debug().Msg("Creating control client") var cc *controlClient if s == nil { cc = newSocketControlClientMobile(dir, stopCh) } else { cc = newSocketControlClient(context.TODO(), s, dir) } mainLog.Load().Debug().Msg("Control client done") if cc == nil { return nil // ctrld is not running. } data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin}) mainLog.Load().Debug().Msg("Posting deactivation request") resp, err := cc.post(deactivationPath, bytes.NewReader(data)) mainLog.Load().Debug().Msg("Posting deactivation request done") if resp != nil { switch resp.StatusCode { case http.StatusBadRequest: mainLog.Load().Error().Msg(errRequiredDeactivationPin.Error()) return errRequiredDeactivationPin // pin is required case http.StatusOK: return nil // valid pin case http.StatusNotFound: return nil // the server is running older version of ctrld } } mainLog.Load().Error().Err(err).Msg(errInvalidDeactivationPin.Error()) return errInvalidDeactivationPin } // isCheckDeactivationPinErr reports whether there is an error during check deactivation pin process. func isCheckDeactivationPinErr(err error) bool { return errors.Is(err, errInvalidDeactivationPin) || errors.Is(err, errRequiredDeactivationPin) } // ensureUninstall ensures that s.Uninstall will remove ctrld service from system completely. func ensureUninstall(s service.Service) error { maxAttempts := 10 var err error for i := 0; i < maxAttempts; i++ { err = s.Uninstall() if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { return nil } time.Sleep(time.Second) } return errors.Join(err, errors.New("uninstall failed")) } // exchangeContextWithTimeout wraps c.ExchangeContext with the given timeout. func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return c.ExchangeContext(ctx, msg, addr) } // absHomeDir returns the absolute path to given filename using home directory as root dir. func absHomeDir(filename string) string { if homedir != "" { return filepath.Join(homedir, filename) } dir, err := userHomeDir() if err != nil { return filename } return filepath.Join(dir, filename) } // runInCdMode reports whether ctrld service is running in cd mode. func runInCdMode() bool { return curCdUID() != "" } // curCdUID returns the current ControlD UID used by running ctrld process. func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { // Configure Windows service failure actions if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) if cc != nil { resp, _ := cc.post(cdPath, nil) if resp != nil { defer resp.Body.Close() buf, _ := io.ReadAll(resp.Body) return string(buf) } } } } return "" } // goArm returns the GOARM value for the binary. func goArm() string { if runtime.GOARCH != "arm" { return "" } if bi, ok := debug.ReadBuildInfo(); ok { for _, setting := range bi.Settings { if setting.Key == "GOARM" { return setting.Value } } } // Use ARM v5 as a fallback, since it works on all others. return "5" } // upgradeUrl returns the url for downloading new ctrld binary. func upgradeUrl(baseUrl string) string { dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH) // Use arm version set during build time, v5 binary can be run on higher arm version system. if armVersion := goArm(); armVersion != "" { dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion) } // linux/amd64 has nocgo version, to support systems that missing some libc (like openwrt). if !cgoEnabled && runtime.GOOS == "linux" && runtime.GOARCH == "amd64" { dlPath = fmt.Sprintf("%s-%s-nocgo/ctrld", runtime.GOOS, runtime.GOARCH) } dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath) if runtime.GOOS == "windows" { dlUrl += ".exe" } return dlUrl } // runningIface returns the value of the iface variable used by ctrld process which is running. func runningIface(s service.Service) *ifaceResponse { if sockDir, err := socketDir(); err == nil { if cc := newSocketControlClient(context.TODO(), s, sockDir); cc != nil { resp, err := cc.post(ifacePath, nil) if err != nil { return nil } defer resp.Body.Close() res := &ifaceResponse{} if err := json.NewDecoder(resp.Body).Decode(res); err != nil { return nil } return res } } return nil } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) if err != nil { logger := mainLog.Load().Fatal() if !fatal { logger = mainLog.Load().Warn() } logger.Err(err).Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) if !fatal { return err } } // return earlier if there's no custom config. if rc.Ctrld.CustomConfig == "" { return nil } // validateCdRemoteConfig clobbers v, saving it here to restore later. oldV := v var cfgErr error remoteCfg := &ctrld.Config{} if cfgErr = validateCdRemoteConfig(rc, remoteCfg); cfgErr == nil { setListenerDefaultValue(remoteCfg) setNetworkDefaultValue(remoteCfg) cfgErr = validateConfig(remoteCfg) } else { if errors.As(cfgErr, &viper.ConfigParseError{}) { if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { tmpDir := os.TempDir() tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") errorLogged := false // Write remote config to a temporary file to get details error. if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { row, col := de.Position() mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) errorLogged = true } _ = os.Remove(tmpConfFile) } // If we could not log details error, emit what we have already got. if !errorLogged { mainLog.Load().Error().Msgf("failed to parse custom config: %v", cfgErr) } } } else { mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) } } if cfgErr != nil { mainLog.Load().Warn().Msg("disregarding invalid custom config") } v = oldV return nil } // uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { s, err := newService(p, svcConfig) if err != nil { logger.Warn().Err(err).Msg("failed to create new service") return false } // restore static DNS settings or DHCP p.resetDNS(false, true) tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { logger.Info().Msg("uninstalled service") if doStop { _ = s.Stop() } return true } return false }