diff --git a/client_info.go b/client_info.go index f32526a..05d2910 100644 --- a/client_info.go +++ b/client_info.go @@ -5,10 +5,11 @@ type ClientInfoCtxKey struct{} // ClientInfo represents ctrld's clients information. type ClientInfo struct { - Mac string - IP string - Hostname string - Self bool + Mac string + IP string + Hostname string + Self bool + ClientIDPref string } // LeaseFileFormat specifies the format of DHCP lease file. diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index b67504c..3f76c80 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -144,6 +144,7 @@ func initCLI() { _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) rootCmd.AddCommand(runCmd) @@ -158,6 +159,7 @@ func initCLI() { Run: func(cmd *cobra.Command, args []string) { checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] @@ -176,6 +178,9 @@ func initCLI() { // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. sc.Arguments = append(sc.Arguments, "--cd="+cdUID) } + if cdUID != "" { + validateCdUpstreamProtocol() + } p := &prog{ router: router.New(&cfg, cdUID != ""), @@ -223,7 +228,7 @@ func initCLI() { }() } - tryReadingConfig(writeDefaultConfig) + tryReadingConfigWithNotice(writeDefaultConfig, true) if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) @@ -231,6 +236,10 @@ func initCLI() { initLogging() + if nextdns != "" { + removeNextDNSFromArgs(sc) + } + // Explicitly passing config, so on system where home directory could not be obtained, // or sub-process env is different with the parent, we still behave correctly and use // the expected config file. @@ -244,16 +253,20 @@ func initCLI() { return } - if router.Name() != "" { + if router.Name() != "" && iface != "" { mainLog.Load().Debug().Msg("cleaning up router before installing") _ = p.router.Cleanup() } tasks := []task{ {s.Stop, false}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, {s.Uninstall, false}, {s.Install, false}, {s.Start, true}, + // Note that startCmd do not actually write ControlD config, but the config file was + // generated after s.Start, so we notice users here for consistent with nextdns mode. + {noticeWritingControlDConfig, false}, } mainLog.Load().Notice().Msg("Starting service") if doTasks(tasks) { @@ -281,7 +294,7 @@ func initCLI() { } }, } - // Keep these flags in sync with runCmd above, except for "-d". + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") @@ -295,6 +308,8 @@ func initCLI() { startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) routerCmd := &cobra.Command{ Use: "setup", @@ -367,6 +382,10 @@ func initCLI() { mainLog.Load().Error().Msg(err.Error()) return } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } initLogging() tasks := []task{ @@ -388,6 +407,50 @@ func initCLI() { }, } + reloadCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := userHomeDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("Service reloaded") + case http.StatusCreated: + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") + mainLog.Load().Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return + } + restartCmd.Run(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") + } + mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) + } + }, + } statusCmd := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -503,9 +566,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Short: "Manage ctrld service", Args: cobra.OnlyValidArgs, ValidArgs: []string{ - statusCmd.Use, + startCmd.Use, stopCmd.Use, restartCmd.Use, + reloadCmd.Use, statusCmd.Use, uninstallCmd.Use, interfacesCmd.Use, @@ -514,6 +578,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, serviceCmd.AddCommand(startCmd) serviceCmd.AddCommand(stopCmd) serviceCmd.AddCommand(restartCmd) + serviceCmd.AddCommand(reloadCmd) serviceCmd.AddCommand(statusCmd) serviceCmd.AddCommand(uninstallCmd) serviceCmd.AddCommand(interfacesCmd) @@ -568,6 +633,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } rootCmd.AddCommand(restartCmdAlias) + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + reloadCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + statusCmdAlias := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -688,6 +766,7 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc homedir = appConfig.HomeDir verbose = appConfig.Verbose cdUID = appConfig.CdUID + cdUpstreamProto = ctrld.ResolverTypeDOH logPath = appConfig.LogPath run(appCallback, stopCh) } @@ -699,10 +778,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } waitCh := make(chan struct{}) p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - cfg: &cfg, - appCallback: appCallback, + waitCh: waitCh, + stopCh: stopCh, + reloadCh: make(chan struct{}), + reloadDoneCh: make(chan struct{}), + cfg: &cfg, + appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { @@ -740,9 +821,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { readBase64Config(configBase64) processNoConfigFlags(noConfigStart) + p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } + p.mu.Unlock() processLogAndCacheFlags() @@ -777,14 +860,47 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cdUID = uid } if cdUID != "" { - err := processCDFlags() - if err != nil { - appCallback.Exit(err.Error()) - return + validateCdUpstreamProtocol() + if err := processCDFlags(&cfg); err != nil { + if isMobile() { + appCallback.Exit(err.Error()) + return + } + + uninstallIfInvalidCdUID := func() { + cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { + s, err := newService(&prog{}, svcConfig) + if err != nil { + cdLogger.Warn().Err(err).Msg("failed to create new service") + return + } + if netIface, _ := netInterface(iface); netIface != nil { + if err := restoreNetworkManager(); err != nil { + cdLogger.Error().Err(err).Msg("could not restore NetworkManager") + return + } + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") + if err := resetDNS(netIface); err != nil { + cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS") + } else { + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") + } + } + + tasks := []task{{s.Uninstall, true}} + if doTasks(tasks) { + cdLogger.Info().Msg("uninstalled service") + } + cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config") + return + } + } + uninstallIfInvalidCdUID() } } - updated := updateListenerConfig() + updated := updateListenerConfig(&cfg) if cdUID != "" { processLogAndCacheFlags() @@ -812,7 +928,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { initLoggingWithBackup(false) } - validateConfig(&cfg) + if err := validateConfig(&cfg); err != nil { + os.Exit(1) + } initCache() if daemon { @@ -859,19 +977,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if cp := router.CertPool(); cp != nil { rootCertPool = cp } - p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") - if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") - } - }) - p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") - if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") - } - p.resetDNS() - }) + if iface != "" { + p.onStarted = append(p.onStarted, func() { + mainLog.Load().Debug().Msg("router setup on start") + if err := p.router.Setup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not configure router") + } + }) + p.onStopped = append(p.onStopped, func() { + mainLog.Load().Debug().Msg("router cleanup on stop") + if err := p.router.Cleanup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not cleanup router") + } + p.resetDNS() + }) + } } close(waitCh) @@ -907,10 +1027,17 @@ func writeConfigFile() error { return nil } -func readConfigFile(writeDefaultConfig bool) bool { +// readConfigFile reads in config file. +// +// - It writes default config file if config file not found if writeDefaultConfig is true. +// - It emits notice message to user if notice is true. +func readConfigFile(writeDefaultConfig, notice bool) bool { // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { + if notice { + mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) + } mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed()) defaultConfigFile = v.ConfigFileUsed() return true @@ -925,7 +1052,8 @@ func readConfigFile(writeDefaultConfig bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - _ = updateListenerConfig() + nop := zerolog.Nop() + _, _ = tryUpdateListenerConfig(&cfg, &nop, true) if err := writeConfigFile(); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { @@ -933,6 +1061,9 @@ func readConfigFile(writeDefaultConfig bool) bool { if err != nil { mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err) } + if cdUID == "" && nextdns == "" { + mainLog.Load().Notice().Msg("Generating controld default config: " + fp) + } mainLog.Load().Info().Msg("writing default config file to: " + fp) } return false @@ -1015,7 +1146,7 @@ func processNoConfigFlags(noConfigStart bool) { v.Set("upstream", upstream) } -func processCDFlags() error { +func processCDFlags(cfg *ctrld.Config) error { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -1031,44 +1162,17 @@ func processCDFlags() error { } break } - if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { - s, err := newService(&prog{}, svcConfig) - if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") - return nil - } - if netIface, _ := netInterface(iface); netIface != nil { - if err := restoreNetworkManager(); err != nil { - logger.Error().Err(err).Msg("could not restore NetworkManager") - return nil - } - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - logger.Warn().Err(err).Msg("something went wrong while restoring DNS") - } else { - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") - } - } - - tasks := []task{{s.Uninstall, true}} - if doTasks(tasks) { - logger.Info().Msg("uninstalled service") - } - event := logger.Fatal() - if isMobile() { - event = logger.Warn() - } - event.Err(uer).Msg("failed to fetch resolver config") - return uer - } if err != nil { + if isMobile() { + return err + } logger.Warn().Err(err).Msg("could not fetch resolver config") - return nil + return err } logger.Info().Msg("generating ctrld config from Control-D configuration") - cfg = ctrld.Config{} + *cfg = ctrld.Config{} // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") @@ -1085,7 +1189,7 @@ func processCDFlags() error { cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ Endpoint: resolverConfig.DOH, - Type: ctrld.ResolverTypeDOH, + Type: cdUpstreamProto, Timeout: 5000, } rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) @@ -1213,6 +1317,11 @@ func selfCheckStatus(s service.Service) service.Status { return service.StatusUnknown } + // Not a ctrld upstream, return status as-is. + if cfg.FirstUpstream().VerifyDomain() == "" { + return status + } + mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("performing self-check") bo := backoff.NewBackoff("self-check", logf, 10*time.Second) @@ -1338,21 +1447,35 @@ func userHomeDir() (string, error) { return dir, nil } +// tryReadingConfig is like tryReadingConfigWithNotice, with notice set to false. func tryReadingConfig(writeDefaultConfig bool) { + tryReadingConfigWithNotice(writeDefaultConfig, false) +} + +// tryReadingConfigWithNotice tries reading in config files, either specified by user or from default +// locations. If notice is true, emitting a notice message to user which config file was read. +func tryReadingConfigWithNotice(writeDefaultConfig, notice bool) { // --config is specified. if configPath != "" { v.SetConfigFile(configPath) - readConfigFile(false) + readConfigFile(false, notice) return } // no config start or base64 config mode. if !writeDefaultConfig { return } - readConfig(writeDefaultConfig) + readConfigWithNotice(writeDefaultConfig, notice) } +// readConfig calls readConfigWithNotice with notice set to false. func readConfig(writeDefaultConfig bool) { + readConfigWithNotice(writeDefaultConfig, false) +} + +// readConfigWithNotice calls readConfigFile with config file set to ctrld.toml +// or config.toml for compatible with earlier versions of ctrld. +func readConfigWithNotice(writeDefaultConfig, notice bool) { configs := []struct { name string written bool @@ -1369,7 +1492,7 @@ func readConfig(writeDefaultConfig bool) { for _, config := range configs { ctrld.SetConfigNameWithPath(v, config.name, dir) v.SetConfigFile(configPath) - if readConfigFile(config.written) { + if readConfigFile(config.written, notice) { break } } @@ -1401,18 +1524,17 @@ func uninstall(p *prog, s service.Service) { } } -func validateConfig(cfg *ctrld.Config) { - err := ctrld.ValidateConfig(validator.New(), cfg) - if err == nil { - return - } - var ve validator.ValidationErrors - if errors.As(err, &ve) { - for _, fe := range ve { - mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) +func validateConfig(cfg *ctrld.Config) error { + if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil { + var ve validator.ValidationErrors + if errors.As(err, &ve) { + for _, fe := range ve { + mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) + } } + return err } - os.Exit(1) + return nil } // NOTE: Add more case here once new validation tag is used in ctrld.Config struct. @@ -1483,9 +1605,19 @@ func mobileListenerPort() int { // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. -func updateListenerConfig() (updated bool) { +func updateListenerConfig(cfg *ctrld.Config) bool { + updated, _ := tryUpdateListenerConfig(cfg, nil, true) + return updated +} + +// tryUpdateListenerConfig tries updating listener config with a working one. +// If fatal is true, and there's listen address conflicted, the function do +// fatal error. +func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) { + ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" + nextdnsMode := nextdns != "" for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -1497,12 +1629,17 @@ func updateListenerConfig() (updated bool) { lcc[n].Port = true } // In cd mode, we always try to pick an ip:port pair to work. - if cdMode { + // Same if nextdns resolver is used. + if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true } updated = updated || lcc[n].IP || lcc[n].Port } + il := mainLog.Load() + if infoLogger != nil { + il = infoLogger + } if isMobile() { // On Mobile, only use first listener, ignore others. firstLn := cfg.FirstListener() @@ -1594,7 +1731,11 @@ func updateListenerConfig() (updated bool) { break } if !check.IP && !check.Port { - logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + } + ok = false + break } if tryAllPort53 { tryAllPort53 = false @@ -1605,7 +1746,7 @@ func updateListenerConfig() (updated bool) { listener.Port = 53 } if check.IP { - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) + logMsg(il.Info(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } @@ -1618,7 +1759,7 @@ func updateListenerConfig() (updated bool) { listener.Port = 53 } if check.IP { - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) + logMsg(il.Info(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } @@ -1630,7 +1771,7 @@ func updateListenerConfig() (updated bool) { if check.Port { listener.Port = 5354 } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying current ip with port 5354", addr) + logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr) continue } if tryPort5354 { @@ -1641,7 +1782,7 @@ func updateListenerConfig() (updated bool) { if check.Port { listener.Port = 5354 } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) + logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) continue } if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port. @@ -1655,12 +1796,19 @@ func updateListenerConfig() (updated bool) { listener.Port = oldPort } if listener.IP == oldIP && listener.Port == oldPort { - logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + } + ok = false + break } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr) + logMsg(il.Info(), n, "could not listen on address: %s, pick a random ip+port", addr) attempts++ } } + if !ok { + return + } // Specific case for systemd-resolved. if useSystemdResolved { @@ -1670,7 +1818,7 @@ func updateListenerConfig() (updated bool) { // ip address, other than "127.0.0.1", so trying to listen on default route interface // address instead. if ip := net.ParseIP(listener.IP); ip != nil && ip.IsLoopback() && ip.String() != "127.0.0.1" { - logMsg(mainLog.Load().Warn(), n, "using loopback interface do not work with systemd-resolved") + logMsg(il.Info(), n, "using loopback interface do not work with systemd-resolved") found := false if netIface, _ := net.InterfaceByName(defaultIfaceName()); netIface != nil { addrs, _ := netIface.Addrs() @@ -1680,7 +1828,7 @@ func updateListenerConfig() (updated bool) { if err := tryListen(addr); err == nil { found = true listener.IP = netIP.IP.String() - logMsg(mainLog.Load().Warn(), n, "use %s as listener address", listener.IP) + logMsg(il.Info(), n, "use %s as listener address", listener.IP) break } } @@ -1742,12 +1890,12 @@ func removeProvTokenFromArgs(sc *service.Config) { continue } // For "--cd-org XXX", skip it and mark next arg skipped. - if x == cdOrgFlagName { + if x == "--"+cdOrgFlagName { skip = true continue } // For "--cd-org=XXX", just skip it. - if strings.HasPrefix(x, cdOrgFlagName+"=") { + if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") { continue } a = append(a, x) @@ -1795,6 +1943,64 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { return } if fl.Value.String() == "" { - mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) + mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name) } } + +func validateCdUpstreamProtocol() { + if cdUID == "" { + return + } + switch cdUpstreamProto { + case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3: + default: + mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`) + } +} + +func validateCdAndNextDNSFlags() { + if (cdUID != "" || cdOrg != "") && nextdns != "" { + mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) + } +} + +// removeNextDNSFromArgs removes the --nextdns from command line arguments. +func removeNextDNSFromArgs(sc *service.Config) { + a := sc.Arguments[:0] + skip := false + for _, x := range sc.Arguments { + if skip { + skip = false + continue + } + // For "--nextdns XXX", skip it and mark next arg skipped. + if x == "--"+nextdnsFlagName { + skip = true + continue + } + // For "--nextdns=XXX", just skip it. + if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") { + continue + } + a = append(a, x) + } + sc.Arguments = a +} + +// doGenerateNextDNSConfig generates a working config with nextdns resolver. +func doGenerateNextDNSConfig(uid string) error { + if uid == "" { + return nil + } + mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) + generateNextDNSConfig(uid) + updateListenerConfig(&cfg) + return writeConfigFile() +} + +func noticeWritingControlDConfig() error { + if cdUID != "" { + mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile) + } + return nil +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 5f5ac51..5ee7112 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -6,14 +6,18 @@ import ( "net" "net/http" "os" + "reflect" "sort" "time" + + "github.com/Control-D-Inc/ctrld" ) const ( contentTypeJson = "application/json" listClientsPath = "/clients" startedPath = "/started" + reloadPath = "/reload" ) type controlServer struct { @@ -75,6 +79,52 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusRequestTimeout) } })) + p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + listeners := make(map[string]*ctrld.ListenerConfig) + p.mu.Lock() + for k, v := range p.cfg.Listener { + listeners[k] = &ctrld.ListenerConfig{ + IP: v.IP, + Port: v.Port, + } + } + oldSvc := p.cfg.Service + p.mu.Unlock() + if err := p.sendReloadSignal(); err != nil { + mainLog.Load().Err(err).Msg("could not send reload signal") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + select { + case <-p.reloadDoneCh: + case <-time.After(5 * time.Second): + http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError) + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + // Checking for cases that we could not do a reload. + + // 1. Listener config ip or port changes. + for k, v := range p.cfg.Listener { + l := listeners[k] + if l == nil || l.IP != v.IP || l.Port != v.Port { + w.WriteHeader(http.StatusCreated) + return + } + } + + // 2. Service config changes. + if !reflect.DeepEqual(oldSvc, p.cfg.Service) { + w.WriteHeader(http.StatusCreated) + return + } + + // Otherwise, reload is done. + w.WriteHeader(http.StatusOK) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 12cf781..2b0f94d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -17,6 +18,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/interfaces" "tailscale.com/net/netaddr" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" @@ -25,6 +27,7 @@ import ( const ( staleTTL = 60 * time.Second + localTTL = 3600 * time.Second // EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option. // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. @@ -37,6 +40,29 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } +var privateUpstreamConfig = &ctrld.UpstreamConfig{ + Name: "Private resolver", + Type: ctrld.ResolverTypePrivate, + Timeout: 2000, +} + +// proxyRequest contains data for proxying a DNS query to upstream. +type proxyRequest struct { + msg *dns.Msg + ci *ctrld.ClientInfo + failoverRcodes []int + ufr *upstreamForResult +} + +// upstreamForResult represents the result of processing rules for a request. +type upstreamForResult struct { + upstreams []string + matchedPolicy string + matchedNetwork string + matchedRule string + matched bool +} + func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated @@ -44,36 +70,58 @@ func (p *prog) serveDNS(listenerNum string) error { mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } - var failoverRcodes []int - if listenerConfig.Policy != nil { - failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers - } + handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() + if len(m.Question) == 0 { + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeFormatError) + _ = w.WriteMsg(answer) + return + } + reqId := requestID() + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) + if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { + ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeRefused) + _ = w.WriteMsg(answer) + return + } go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) - reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) + ci.ClientIDPref = p.cfg.Service.ClientIDPref + stripClientSubnet(m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() - ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) - upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain) + res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) var answer *dns.Msg - if !matched && listenerConfig.Restricted { + if !res.matched && listenerConfig.Restricted { + ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String()) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) } else { - answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci) + var failoverRcode []int + if listenerConfig.Policy != nil { + failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers + } + answer = p.proxy(ctx, &proxyRequest{ + msg: m, + ci: ci, + failoverRcodes: failoverRcode, + ufr: res, + }) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveUDP: failed to send DNS response to client") + ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") } }) @@ -99,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error { // addresses of the machine. So ctrld could receive queries from LAN clients. if needRFC1918Listeners(listenerConfig) { g.Go(func() error { - for _, addr := range rfc1918Addresses() { + for _, addr := range ctrld.Rfc1918Addresses() { func() { listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(listenAddr, proto, handler) @@ -146,27 +194,24 @@ func (p *prog) serveDNS(listenerNum string) error { // Though domain policy has higher priority than network policy, it is still // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. -func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { +func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) { upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" matchedRule := "no rule" matched := false + res = &upstreamForResult{} defer func() { - if !matched && lc.Restricted { - ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String()) - return - } - if matched { - ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams) - } else { - ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams) - } + res.upstreams = upstreams + res.matched = matched + res.matchedPolicy = matchedPolicy + res.matchedNetwork = matchedNetwork + res.matchedRule = matchedRule }() if lc.Policy == nil { - return upstreams, false + return } do := func(policyUpstreams []string) { @@ -202,6 +247,19 @@ networkRules: } } +macRules: + for _, rule := range lc.Policy.Macs { + for source, targets := range rule { + if source != "" && strings.EqualFold(source, srcMac) { + matchedPolicy = lc.Policy.Name + matchedNetwork = source + networkTargets = targets + matched = true + break macRules + } + } + } + for _, rule := range lc.Policy.Rules { // There's only one entry per rule, config validation ensures this. for source, targets := range rule { @@ -213,7 +271,7 @@ networkRules: matchedRule = source do(targets) matched = true - return upstreams, matched + return } } } @@ -222,26 +280,134 @@ networkRules: do(networkTargets) } - return upstreams, matched + return } -func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg { +func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { + cDomainName := msg.Question[0].Name + locked := p.ptrLoopGuard.TryLock(cDomainName) + defer p.ptrLoopGuard.Unlock(cDomainName) + if !locked { + return nil + } + ip := ipFromARPA(cDomainName) + if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + answer.Answer = []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + }, + Ptr: dns.Fqdn(name), + }} + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: name, + }) + return answer + } + return nil +} + +func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { + q := msg.Question[0] + hostname := strings.TrimSuffix(q.Name, ".") + locked := p.lanLoopGuard.TryLock(hostname) + defer p.lanLoopGuard.Unlock(hostname) + if !locked { + return nil + } + if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + switch { + case ip.Is4(): + answer.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + A: ip.AsSlice(), + }} + case ip.Is6(): + answer.Answer = []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + AAAA: ip.AsSlice(), + }} + } + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: hostname, + }) + return answer + } + return nil +} + +func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg { var staleAnswer *dns.Msg + upstreams := req.ufr.upstreams serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} } + + // LAN/PTR lookup flow: + // + // 1. If there's matching rule, follow it. + // 2. Try from client info table. + // 3. Try private resolver. + // 4. Try remote upstream. + isLanOrPtrQuery := false + if req.ufr.matched { + ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) + } else { + switch { + case isPrivatePtrLookup(req.msg): + isLanOrPtrQuery = true + if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams) + case isLanHostnameQuery(req.msg): + isLanOrPtrQuery = true + if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams) + default: + ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams) + } + } + // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 - if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { for _, upstream := range upstreams { - cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream)) + cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) if cachedValue == nil { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(msg, answer.Rcode) + answer.SetRcode(req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") @@ -268,9 +434,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i return dnsResolver.Resolve(resolveCtx, msg) } resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { - if upstreamConfig.UpstreamSendClientInfo() && ci != nil { + if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") - ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci) + ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } answer, err := resolve1(n, upstreamConfig, msg) if err != nil { @@ -281,6 +447,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i go p.um.checkUpstream(upstreams[n], upstreamConfig) } } + // For timeout error (i.e: context deadline exceed), force re-bootstrapping. + var e net.Error + if errors.As(err, &e) && e.Timeout() { + upstreamConfig.ReBootstrap() + } return nil } return answer @@ -297,7 +468,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) continue } - answer := resolve(n, upstreamConfig, msg) + answer := resolve(n, upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") @@ -307,7 +478,13 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } continue } - if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { + // We are doing LAN/PTR lookup using private resolver, so always process next one. + // Except for the last, we want to send response instead of saying all upstream failed. + if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { + ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n]) + continue + } + if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") continue } @@ -315,7 +492,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // set compression, as it is not set by default when unpacking answer.Compress = true - if p.cache != nil { + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { ttl := ttlFromMsg(answer) now := time.Now() expired := now.Add(time.Duration(ttl) * time.Second) @@ -323,17 +500,27 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i expired = now.Add(time.Duration(cachedTTL) * time.Second) } setCachedAnswerTTL(answer, now, expired) - p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired)) + p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response") } return answer } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) answer := new(dns.Msg) - answer.SetRcode(msg, dns.RcodeServerFailure) + answer.SetRcode(req.msg, dns.RcodeServerFailure) return answer } +func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { + if len(p.localUpstreams) > 0 { + tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) + tmp = append(tmp, p.localUpstreams...) + tmp = append(tmp, upstreams...) + return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp) + } + return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...) +} + func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { @@ -454,6 +641,23 @@ func ipAndMacFromMsg(msg *dns.Msg) (string, string) { return ip, mac } +// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address, +// passing them to upstream is pointless, these cannot be used by anything on the WAN. +func stripClientSubnet(msg *dns.Msg) { + if opt := msg.IsEdns0(); opt != nil { + opts := make([]dns.EDNS0, 0, len(opt.Option)) + for _, s := range opt.Option { + if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) { + continue + } + opts = append(opts, s) + } + if len(opts) != len(opt.Option) { + opt.Option = opts + } + } +} + func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { if ci != nil && ci.IP != "" { switch addr := addr.(type) { @@ -531,23 +735,44 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { } // If MAC is still empty here, that mean the requests are made from virtual interface, - // like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP - // (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients. + // like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients. if ci.Mac == "" { - ci.Mac = p.ciTable.LookupMac(remoteIP) if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" { ci.Hostname = hostname } else { - ci.Hostname = ci.IP - p.ciTable.StoreVPNClient(ci) + // Only use IP as hostname for IPv4 clients. + // For Android devices, when it joins the network, it uses ctrld to resolve + // its private DNS once and never reaches ctrld again. For each time, it uses + // a different IPv6 address, which causes hundreds/thousands different client + // IDs created for the same device, which is pointless. + // + // TODO(cuonglm): investigate whether this can be a false positive for other clients? + if !ctrldnet.IsIPv6(ci.IP) { + ci.Hostname = ci.IP + p.ciTable.StoreVPNClient(ci) + } } } else { ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) } ci.Self = queryFromSelf(ci.IP) + p.spoofLoopbackIpInClientInfo(ci) return ci } +// spoofLoopbackIpInClientInfo replaces loopback IPs in client info. +// +// - Preference IPv4. +// - Preference RFC1918. +func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { + if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() { + return + } + if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" { + ci.IP = ip + } +} + // queryFromSelf reports whether the input IP is from device running ctrld. func queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) @@ -578,17 +803,86 @@ func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool { return lc.IP == "127.0.0.1" && lc.Port == 53 } -func rfc1918Addresses() []string { - var res []string - interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { - addrs, _ := i.Addrs() - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if !ok || !ipNet.IP.IsPrivate() { - continue - } - res = append(res, ipNet.IP.String()) +// ipFromARPA parses a FQDN arpa domain and return the IP address if valid. +func ipFromARPA(arpa string) net.IP { + if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok { + if ptrIP := net.ParseIP(arpa); ptrIP != nil { + return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]} } - }) - return res + } + if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok { + l := net.IPv6len * 2 + base := 16 + ip := make(net.IP, net.IPv6len) + for i := 0; i < l && arpa != ""; i++ { + idx := strings.LastIndexByte(arpa, '.') + off := idx + 1 + if idx == -1 { + idx = 0 + off = 0 + } else if idx == len(arpa)-1 { + return nil + } + n, err := strconv.ParseUint(arpa[off:], base, 8) + if err != nil { + return nil + } + b := byte(n) + ii := i / 2 + if i&1 == 1 { + b |= ip[ii] << 4 + } + ip[ii] = b + arpa = arpa[:idx] + } + return ip + } + return nil +} + +// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network. +func isPrivatePtrLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + if ip := ipFromARPA(q.Name); ip != nil { + if addr, ok := netip.AddrFromSlice(ip); ok { + return addr.IsPrivate() || + addr.IsLoopback() || + addr.IsLinkLocalUnicast() || + tsaddr.CGNATRange().Contains(addr) + } + } + return false +} + +// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname. +func isLanHostnameQuery(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + switch q.Qtype { + case dns.TypeA, dns.TypeAAAA: + default: + return false + } + name := strings.TrimSuffix(q.Name, ".") + return !strings.Contains(name, ".") || + strings.HasSuffix(name, ".domain") || + strings.HasSuffix(name, ".lan") +} + +// isWanClient reports whether the input is a WAN address. +func isWanClient(na net.Addr) bool { + var ip netip.Addr + if ap, err := netip.ParseAddrPort(na.String()); err == nil { + ip = ap.Addr() + } + return !ip.IsLoopback() && + !ip.IsPrivate() && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !tsaddr.CGNATRange().Contains(ip) } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 674d486..bd73d17 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -67,8 +67,11 @@ func Test_canonicalName(t *testing.T) { func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) - prog := &prog{cfg: cfg} - for _, nc := range prog.cfg.Network { + p := &prog{cfg: cfg} + p.um = newUpstreamMonitor(p.cfg) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() + for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { @@ -81,6 +84,7 @@ func Test_prog_upstreamFor(t *testing.T) { tests := []struct { name string ip string + mac string defaultUpstreamNum string lc *ctrld.ListenerConfig domain string @@ -88,11 +92,14 @@ func Test_prog_upstreamFor(t *testing.T) { matched bool testLogMsg string }{ - {"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, - {"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, - {"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, - {"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, - {"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, + {"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, + {"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, + {"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, + {"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, + {"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, + {"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"}, + {"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, + {"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, } for _, tc := range tests { @@ -111,9 +118,13 @@ func Test_prog_upstreamFor(t *testing.T) { require.NoError(t, err) require.NotNil(t, addr) ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID()) - upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain) - assert.Equal(t, tc.matched, matched) - assert.Equal(t, tc.upstreams, upstreams) + ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain) + p.proxy(ctx, &proxyRequest{ + msg: newDnsMsgWithHostname("foo", dns.TypeA), + ufr: ufr, + }) + assert.Equal(t, tc.matched, ufr.matched) + assert.Equal(t, tc.upstreams, ufr.upstreams) if tc.testLogMsg != "" { assert.Contains(t, logOutput.String(), tc.testLogMsg) } @@ -149,8 +160,32 @@ func TestCache(t *testing.T) { answer2.SetRcode(msg, dns.RcodeRefused) prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute))) - got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil) - got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil) + req1 := &proxyRequest{ + msg: msg, + ci: nil, + failoverRcodes: nil, + ufr: &upstreamForResult{ + upstreams: []string{"upstream.1"}, + matchedPolicy: "", + matchedNetwork: "", + matchedRule: "", + matched: false, + }, + } + req2 := &proxyRequest{ + msg: msg, + ci: nil, + failoverRcodes: nil, + ufr: &upstreamForResult{ + upstreams: []string{"upstream.0"}, + matchedPolicy: "", + matchedNetwork: "", + matchedRule: "", + matched: false, + }, + } + got1 := prog.proxy(context.Background(), req1) + got2 := prog.proxy(context.Background(), req2) assert.NotSame(t, got1, got2) assert.Equal(t, answer1.Rcode, got1.Rcode) assert.Equal(t, answer2.Rcode, got2.Rcode) @@ -234,3 +269,165 @@ func Test_remoteAddrFromMsg(t *testing.T) { }) } } + +func Test_ipFromARPA(t *testing.T) { + tests := []struct { + IP string + ARPA string + }{ + {"1.2.3.4", "4.3.2.1.in-addr.arpa."}, + {"245.110.36.114", "114.36.110.245.in-addr.arpa."}, + {"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."}, + {"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."}, + {"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."}, + {"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."}, + {"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."}, + {"", "asd.in-addr.arpa."}, + {"", "asd.ip6.arpa."}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.IP, func(t *testing.T) { + t.Parallel() + if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) { + t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got) + } + }) + } +} + +func newDnsMsgWithClientIP(ip string) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)}) + m.Extra = append(m.Extra, o) + return m +} + +func Test_stripClientSubnet(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantSubnet bool + }{ + {"no edns0", new(dns.Msg), false}, + {"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false}, + {"loopback IP v6", newDnsMsgWithClientIP("::1"), false}, + {"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false}, + {"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false}, + {"public IP", newDnsMsgWithClientIP("1.1.1.1"), true}, + {"invalid IP", newDnsMsgWithClientIP(""), true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + stripClientSubnet(tc.msg) + hasSubnet := false + if opt := tc.msg.IsEdns0(); opt != nil { + for _, s := range opt.Option { + if _, ok := s.(*dns.EDNS0_SUBNET); ok { + hasSubnet = true + } + } + } + if tc.wantSubnet != hasSubnet { + t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet) + } + }) + } +} + +func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(hostname, typ) + return m +} + +func Test_isLanHostnameQuery(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isLanHostnameQuery bool + }{ + {"A", newDnsMsgWithHostname("foo", dns.TypeA), true}, + {"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true}, + {"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false}, + {"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false}, + {"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got) + } + }) + } +} + +func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg { + t.Helper() + m := new(dns.Msg) + ptr, err := dns.ReverseAddr(ip) + if err != nil { + t.Fatal(err) + } + m.SetQuestion(ptr, dns.TypePTR) + return m +} + +func Test_isPrivatePtrLookup(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isPrivatePtrLookup bool + }{ + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + {"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true}, + {"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true}, + {"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true}, + {"CGNAT", newDnsMsgPtr("100.66.27.28", t), true}, + {"Loopback", newDnsMsgPtr("127.0.0.1", t), true}, + {"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true}, + {"Public IP", newDnsMsgPtr("8.8.8.8", t), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got) + } + }) + } +} + +func Test_isWanClient(t *testing.T) { + tests := []struct { + name string + addr net.Addr + isWanClient bool + }{ + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + {"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false}, + {"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false}, + {"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false}, + {"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false}, + {"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false}, + {"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false}, + {"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isWanClient(tc.addr); tc.isWanClient != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got) + } + }) + } +} diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 87dabf8..06a7e03 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -3,6 +3,7 @@ package cli import ( "context" "strings" + "sync" "time" "github.com/miekg/dns" @@ -15,6 +16,36 @@ const ( loopTestQtype = dns.TypeTXT ) +// newLoopGuard returns new loopGuard. +func newLoopGuard() *loopGuard { + return &loopGuard{inflight: make(map[string]struct{})} +} + +// loopGuard guards against DNS loop, ensuring only one query +// for a given domain is processed at a time. +type loopGuard struct { + mu sync.Mutex + inflight map[string]struct{} +} + +// TryLock marks the domain as being processed. +func (lg *loopGuard) TryLock(domain string) bool { + lg.mu.Lock() + defer lg.mu.Unlock() + if _, inflight := lg.inflight[domain]; !inflight { + lg.inflight[domain] = struct{}{} + return true + } + return false +} + +// Unlock marks the domain as being done. +func (lg *loopGuard) Unlock(domain string) { + lg.mu.Lock() + defer lg.mu.Unlock() + delete(lg.inflight, domain) +} + // isLoop reports whether the given upstream config is detected as having DNS loop. func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { p.loopMu.Lock() @@ -56,7 +87,15 @@ func (p *prog) checkDnsLoop() { mainLog.Load().Debug().Msg("start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() - for _, uc := range p.cfg.Upstream { + for n, uc := range p.cfg.Upstream { + if p.um.isDown("upstream." + n) { + continue + } + // Do not send test query to external upstream. + if !canBeLocalUpstream(uc.Domain) { + mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n) + continue + } uid := uc.UID() p.loop[uid] = false upstream[uid] = uc @@ -79,13 +118,15 @@ func (p *prog) checkDnsLoop() { } // checkDnsLoopTicker performs p.checkDnsLoop every minute. -func (p *prog) checkDnsLoopTicker() { +func (p *prog) checkDnsLoopTicker(ctx context.Context) { timer := time.NewTicker(time.Minute) defer timer.Stop() for { select { case <-p.stopCh: return + case <-ctx.Done(): + return case <-timer.C: p.checkDnsLoop() } diff --git a/cmd/cli/loop_test.go b/cmd/cli/loop_test.go new file mode 100644 index 0000000..b2c8404 --- /dev/null +++ b/cmd/cli/loop_test.go @@ -0,0 +1,42 @@ +package cli + +import ( + "sync" + "sync/atomic" + "testing" +) + +func Test_loopGuard(t *testing.T) { + lg := newLoopGuard() + key := "foo" + + var i atomic.Int64 + var started atomic.Int64 + n := 1000 + do := func() { + locked := lg.TryLock(key) + defer lg.Unlock(key) + started.Add(1) + for started.Load() < 2 { + // Wait until at least 2 goroutines started, otherwise, on system with heavy load, + // or having only 1 CPU, all goroutines can be scheduled to run consequently. + } + if locked { + i.Add(1) + } + } + + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + do() + }() + } + wg.Wait() + + if i.Load() == int64(n) { + t.Fatalf("i must not be increased %d times", n) + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f4439a5..3f1ef8b 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -32,6 +32,8 @@ var ( cdDev bool iface string ifaceStartStop string + nextdns string + cdUpstreamProto string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter @@ -39,8 +41,9 @@ var ( ) const ( - cdUidFlagName = "cd" - cdOrgFlagName = "cd-org" + cdUidFlagName = "cd" + cdOrgFlagName = "cd-org" + nextdnsFlagName = "nextdns" ) func init() { @@ -93,6 +96,7 @@ func initConsoleLogging() { // initLogging initializes global logging setup. func initLogging() { + zerolog.TimeFieldFormat = time.RFC3339 + ".000" initLoggingWithBackup(true) } @@ -131,7 +135,7 @@ func initLoggingWithBackup(doBackup bool) { } writers = append(writers, consoleWriter) multi := zerolog.MultiLevelWriter(writers...) - l := mainLog.Load().Output(multi).With().Timestamp().Logger() + l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) // TODO: find a better way. ctrld.ProxyLogger.Store(&l) diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index 0faae84..d757f8b 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -1,11 +1,13 @@ package cli import ( + "context" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -func (p *prog) watchLinkState() { +func (p *prog) watchLinkState(ctx context.Context) { ch := make(chan netlink.LinkUpdate) done := make(chan struct{}) defer close(done) @@ -13,14 +15,19 @@ func (p *prog) watchLinkState() { mainLog.Load().Warn().Err(err).Msg("could not subscribe link") return } - for lu := range ch { - if lu.Change == 0xFFFFFFFF { - continue - } - if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") - for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + for { + select { + case <-ctx.Done(): + return + case lu := <-ch: + if lu.Change == 0xFFFFFFFF { + continue + } + if lu.Change&unix.IFF_UP != 0 { + mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + for _, uc := range p.cfg.Upstream { + uc.ReBootstrap() + } } } } diff --git a/cmd/cli/netlink_others.go b/cmd/cli/netlink_others.go index f0afd21..5a298b9 100644 --- a/cmd/cli/netlink_others.go +++ b/cmd/cli/netlink_others.go @@ -2,4 +2,6 @@ package cli -func (p *prog) watchLinkState() {} +import "context" + +func (p *prog) watchLinkState(ctx context.Context) {} diff --git a/cmd/cli/nextdns.go b/cmd/cli/nextdns.go new file mode 100644 index 0000000..f4fed47 --- /dev/null +++ b/cmd/cli/nextdns.go @@ -0,0 +1,31 @@ +package cli + +import ( + "fmt" + + "github.com/Control-D-Inc/ctrld" +) + +const nextdnsURL = "https://dns.nextdns.io" + +func generateNextDNSConfig(uid string) { + if uid == "" { + return + } + mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver") + cfg = ctrld.Config{ + Listener: map[string]*ctrld.ListenerConfig{ + "0": { + IP: "0.0.0.0", + Port: 53, + }, + }, + Upstream: map[string]*ctrld.UpstreamConfig{ + "0": { + Type: ctrld.ResolverTypeDOH3, + Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid), + Timeout: 5000, + }, + }, + } +} diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 7fb692c..3036d03 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -9,10 +9,12 @@ import ( "net" "net/netip" "os/exec" + "path/filepath" "strings" "syscall" "time" + "github.com/fsnotify/fsnotify" "github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/client6" @@ -23,7 +25,10 @@ import ( "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) -const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" +const ( + resolvConfPath = "/etc/resolv.conf" + resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" +) // allocate loopback ip // sudo ip a add 127.0.0.2/24 dev lo @@ -64,6 +69,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + defer func() { + if r.Mode() == "direct" { + go watchResolveConf(osConfig) + } + }() trySystemdResolve := false for i := 0; i < maxSetDNSAttempts; i++ { @@ -299,3 +309,59 @@ func sliceIndex[S ~[]E, E comparable](s S, v E) int { } return -1 } + +// watchResolveConf watches any changes to /etc/resolv.conf file, +// and reverting to the original config set by ctrld. +func watchResolveConf(oc dns.OSConfig) { + mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file") + watcher, err := fsnotify.NewWatcher() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") + return + } + + // We watch /etc instead of /etc/resolv.conf directly, + // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well + watchDir := filepath.Dir(resolvConfPath) + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list") + return + } + + r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter. + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator") + return + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes. + continue + } + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting") + if err := watcher.Remove(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + continue + } + if err := r.SetDNS(oc); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + } + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + return + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf") + } + } +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index e30a03d..878681e 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -2,6 +2,7 @@ package cli import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -16,7 +17,9 @@ import ( "syscall" "github.com/kardianos/service" + "github.com/spf13/viper" "tailscale.com/net/interfaces" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" @@ -30,6 +33,7 @@ const ( ctrldControlUnixSock = "ctrld_control.sock" upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" + upstreamPrivate = upstreamPrefix + "private" ) var logf = func(format string, args ...any) { @@ -45,19 +49,25 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - logConn net.Conn - cs *controlServer + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + logConn net.Conn + cs *controlServer - cfg *ctrld.Config - appCallback *AppCallback - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard loopMu sync.Mutex loop map[string]bool @@ -69,11 +79,106 @@ type prog struct { } func (p *prog) Start(s service.Service) error { - p.cfg = &cfg - go p.run() + go p.runWait() return nil } +// runWait runs ctrld components, waiting for signal to reload. +func (p *prog) runWait() { + p.mu.Lock() + p.cfg = &cfg + p.mu.Unlock() + reloadSigCh := make(chan os.Signal, 1) + notifyReloadSigCh(reloadSigCh) + + reload := false + logger := mainLog.Load() + for { + reloadCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + p.run(reload, reloadCh) + reload = true + }() + select { + case sig := <-reloadSigCh: + logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + case <-p.reloadCh: + logger.Notice().Msg("reloading...") + case <-p.stopCh: + close(reloadCh) + return + } + + waitOldRunDone := func() { + close(reloadCh) + <-done + } + newCfg := &ctrld.Config{} + v := viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.InitConfig(v, "ctrld") + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + logger.Err(err).Msg("could not read new config") + waitOldRunDone() + continue + } + if err := v.Unmarshal(&newCfg); err != nil { + logger.Err(err).Msg("could not unmarshal new config") + waitOldRunDone() + continue + } + if cdUID != "" { + if err := processCDFlags(newCfg); err != nil { + logger.Err(err).Msg("could not fetch ControlD config") + waitOldRunDone() + continue + } + } + + waitOldRunDone() + + p.mu.Lock() + curListener := p.cfg.Listener + p.mu.Unlock() + + for n, lc := range newCfg.Listener { + curLc := curListener[n] + if curLc == nil { + continue + } + if lc.IP == "" { + lc.IP = curLc.IP + } + if lc.Port == 0 { + lc.Port = curLc.Port + } + } + if err := validateConfig(newCfg); err != nil { + logger.Err(err).Msg("invalid config") + continue + } + + // This needs to be done here, otherwise, the DNS handler may observe an invalid + // upstream config because its initialization function have not been called yet. + mainLog.Load().Debug().Msg("setup upstream with new config") + p.setupUpstream(newCfg) + + p.mu.Lock() + *p.cfg = *newCfg + p.mu.Unlock() + + logger.Notice().Msg("reloading config successfully") + select { + case p.reloadDoneCh <- struct{}{}: + default: + } + } +} + func (p *prog) preRun() { if !service.Interactive() { p.setDNS() @@ -87,14 +192,54 @@ func (p *prog) preRun() { } } -func (p *prog) run() { +func (p *prog) setupUpstream(cfg *ctrld.Config) { + localUpstreams := make([]string, 0, len(cfg.Upstream)) + ptrNameservers := make([]string, 0, len(cfg.Upstream)) + for n := range cfg.Upstream { + uc := cfg.Upstream[n] + uc.Init() + if uc.BootstrapIP == "" { + uc.SetupBootstrapIP() + mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + } else { + mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + } + uc.SetCertPool(rootCertPool) + go uc.Ping() + + if canBeLocalUpstream(uc.Domain) { + localUpstreams = append(localUpstreams, upstreamPrefix+n) + } + if uc.IsDiscoverable() { + ptrNameservers = append(ptrNameservers, uc.Endpoint) + } + } + p.localUpstreams = localUpstreams + p.ptrNameservers = ptrNameservers +} + +// run runs the ctrld main components. +// +// The reload boolean indicates that the function is run when ctrld first start +// or when ctrld receive reloading signal. Platform specifics setup is only done +// on started, mean reload is "false". +// +// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded, +// so all listeners could be terminated and re-spawned again. +func (p *prog) run(reload bool, reloadCh chan struct{}) { // Wait the caller to signal that we can do our logic. <-p.waitCh - p.preRun() + if !reload { + p.preRun() + } numListeners := len(p.cfg.Listener) - p.started = make(chan struct{}, numListeners) + if !reload { + p.started = make(chan struct{}, numListeners) + } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -103,15 +248,7 @@ func (p *prog) run() { p.cache = cacher } } - p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} - if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { - n := *mcr - if n == 0 { - p.sema = &noopSemaphore{} - } else { - p.sema = &chanSemaphore{ready: make(chan struct{}, n)} - } - } + var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -127,74 +264,102 @@ func (p *prog) run() { } p.um = newUpstreamMonitor(p.cfg) - for n := range p.cfg.Upstream { - uc := p.cfg.Upstream[n] - uc.Init() - if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) - } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + + if !reload { + p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} + if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { + n := *mcr + if n == 0 { + p.sema = &noopSemaphore{} + } else { + p.sema = &chanSemaphore{ready: make(chan struct{}, n)} + } + } + p.setupUpstream(p.cfg) + p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) } - uc.SetCertPool(rootCertPool) - go uc.Ping() } - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) - } + // context for managing spawn goroutines. + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + // Newer versions of android and iOS denies permission which breaks connectivity. - if !isMobile() { + if !isMobile() && !reload { + wg.Add(1) go func() { + defer wg.Done() p.ciTable.Init() - p.ciTable.RefreshLoop(p.stopCh) + p.ciTable.RefreshLoop(ctx) }() - go p.watchLinkState() + go p.watchLinkState(ctx) } for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() - go func(listenerNum string) { - defer wg.Done() - listenerConfig := p.cfg.Listener[listenerNum] - upstreamConfig := p.cfg.Upstream[listenerNum] - if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + if !reload { + go func(listenerNum string) { + listenerConfig := p.cfg.Listener[listenerNum] + upstreamConfig := p.cfg.Upstream[listenerNum] + if upstreamConfig == nil { + mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + } + addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) + if err := p.serveDNS(listenerNum); err != nil { + mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + } + }(listenerNum) + } + go func() { + defer func() { + cancelFunc() + wg.Done() + }() + select { + case <-p.stopCh: + case <-ctx.Done(): + case <-reloadCh: } - addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) - } - }(listenerNum) + }() } - for i := 0; i < numListeners; i++ { - <-p.started + if !reload { + for i := 0; i < numListeners; i++ { + <-p.started + } + for _, f := range p.onStarted { + f() + } } - for _, f := range p.onStarted { - f() - } - // Check for possible DNS loop. - p.checkDnsLoop() + close(p.onStartedDone) - // Start check DNS loop ticker. - go p.checkDnsLoopTicker() + wg.Add(1) + go func() { + defer wg.Done() + // Check for possible DNS loop. + p.checkDnsLoop() + // Start check DNS loop ticker. + p.checkDnsLoopTicker(ctx) + }() - // Stop writing log to unix socket. - consoleWriter.Out = os.Stdout - initLoggingWithBackup(false) - if p.logConn != nil { - _ = p.logConn.Close() - } - if p.cs != nil { - p.registerControlServerHandler() - if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + if !reload { + // Stop writing log to unix socket. + consoleWriter.Out = os.Stdout + initLoggingWithBackup(false) + if p.logConn != nil { + _ = p.logConn.Close() + } + if p.cs != nil { + p.registerControlServerHandler() + if err := p.cs.start(); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not start control server") + } } } wg.Wait() @@ -276,7 +441,7 @@ func (p *prog) setDNS() { nameservers := []string{ns} if needRFC1918Listeners(lc) { - nameservers = append(nameservers, rfc1918Addresses()...) + nameservers = append(nameservers, ctrld.Rfc1918Addresses()...) } if err := setDNS(netIface, nameservers); err != nil { logger.Error().Err(err).Msgf("could not set DNS for interface") @@ -462,3 +627,11 @@ func defaultRouteIP() string { mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") return ip } + +// canBeLocalUpstream reports whether the IP address can be used as a local upstream. +func canBeLocalUpstream(addr string) bool { + if ip, err := netip.ParseAddr(addr); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + return false +} diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index ed28561..2b9c69d 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -19,7 +19,6 @@ func setDependencies(svc *service.Config) { "Wants=NetworkManager-wait-online.service", "After=NetworkManager-wait-online.service", "Wants=systemd-networkd-wait-online.service", - "After=systemd-networkd-wait-online.service", "Wants=nss-lookup.target", "After=nss-lookup.target", } diff --git a/cmd/cli/reload_others.go b/cmd/cli/reload_others.go new file mode 100644 index 0000000..0977af9 --- /dev/null +++ b/cmd/cli/reload_others.go @@ -0,0 +1,17 @@ +//go:build !windows + +package cli + +import ( + "os" + "os/signal" + "syscall" +) + +func notifyReloadSigCh(ch chan os.Signal) { + signal.Notify(ch, syscall.SIGUSR1) +} + +func (p *prog) sendReloadSignal() error { + return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) +} diff --git a/cmd/cli/reload_windows.go b/cmd/cli/reload_windows.go new file mode 100644 index 0000000..0e817e4 --- /dev/null +++ b/cmd/cli/reload_windows.go @@ -0,0 +1,18 @@ +package cli + +import ( + "errors" + "os" + "time" +) + +func notifyReloadSigCh(ch chan os.Signal) {} + +func (p *prog) sendReloadSignal() error { + select { + case p.reloadCh <- struct{}{}: + return nil + case <-time.After(5 * time.Second): + } + return errors.New("timeout while sending reload signal") +} diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 4b3ee69..83087a4 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -7,7 +7,6 @@ import ( "time" "github.com/miekg/dns" - "tailscale.com/logtail/backoff" "github.com/Control-D-Inc/ctrld" ) @@ -15,8 +14,8 @@ import ( const ( // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. maxFailureRequest = 100 - // checkUpstreamMaxBackoff is the max backoff time when checking upstream status. - checkUpstreamMaxBackoff = 2 * time.Minute + // checkUpstreamBackoffSleep is the time interval between each upstream checks. + checkUpstreamBackoffSleep = 2 * time.Second ) // upstreamMonitor performs monitoring upstreams health. @@ -76,7 +75,6 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf um.checking[upstream] = true um.mu.Unlock() - bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff) resolver, err := ctrld.NewResolver(uc) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not check upstream") @@ -84,15 +82,20 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf } msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) - ctx := context.Background() - for { + check := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + uc.ReBootstrap() _, err := resolver.Resolve(ctx, msg) - if err == nil { + return err + } + for { + if err := check(); err == nil { mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) um.reset(upstream) return } - bo.BackOff(ctx, err) + time.Sleep(checkUpstreamBackoffSleep) } } diff --git a/config.go b/config.go index 21d636c..5baa10d 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/http" + "net/netip" "net/url" "os" "runtime" @@ -26,6 +27,7 @@ import ( "github.com/spf13/viper" "golang.org/x/sync/singleflight" "tailscale.com/logtail/backoff" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld/internal/dnsrcode" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -82,6 +84,16 @@ func InitConfig(v *viper.Viper, name string) { "0": { IP: "", Port: 0, + Policy: &ListenerPolicyConfig{ + Name: "Main Policy", + Networks: []Rule{ + {"network.0": []string{"upstream.0"}}, + }, + Rules: []Rule{ + {"example.com": []string{"upstream.0"}}, + {"*.ads.com": []string{"upstream.1"}}, + }, + }, }, }) v.SetDefault("network", map[string]*NetworkConfig{ @@ -181,6 +193,7 @@ type ServiceConfig struct { DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` } @@ -204,6 +217,9 @@ type UpstreamConfig struct { // The caller should not access this field directly. // Use UpstreamSendClientInfo instead. SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` + // The caller should not access this field directly. + // Use IsDiscoverable instead. + Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group rebootstrap atomic.Bool @@ -224,10 +240,11 @@ type UpstreamConfig struct { // ListenerConfig specifies the networks configuration that ctrld will run on. type ListenerConfig struct { - IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"` - Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"` - Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"` - Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"` + IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"` + Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"` + Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"` + AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"` + Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"` } // IsDirectDnsListener reports whether ctrld can be a direct listener on port 53. @@ -253,6 +270,7 @@ type ListenerPolicyConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` + Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` } @@ -322,13 +340,28 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { } switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: - if uc.isControlD() { + if uc.isControlD() || uc.isNextDNS() { return true } } return false } +// IsDiscoverable reports whether the upstream can be used for PTR discovery. +// The caller must ensure uc.Init() was called before calling this. +func (uc *UpstreamConfig) IsDiscoverable() bool { + if uc.Discoverable != nil { + return *uc.Discoverable + } + switch uc.Type { + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: + if ip, err := netip.ParseAddr(uc.Domain); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + } + return false +} + // BootstrapIPs returns the bootstrap IPs list of upstreams. func (uc *UpstreamConfig) BootstrapIPs() []string { return uc.bootstrapIPs @@ -394,8 +427,9 @@ func (uc *UpstreamConfig) ReBootstrap() { return } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { - ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") - uc.rebootstrap.Store(true) + if uc.rebootstrap.CompareAndSwap(false, true) { + ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") + } return true, nil }) } @@ -519,6 +553,16 @@ func (uc *UpstreamConfig) isControlD() bool { return false } +func (uc *UpstreamConfig) isNextDNS() bool { + domain := uc.Domain + if domain == "" { + if u, err := url.Parse(uc.Endpoint); err == nil { + domain = u.Hostname() + } + } + return domain == "dns.nextdns.io" +} + func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { uc.SetupTransport() diff --git a/config_internal_test.go b/config_internal_test.go index 89cec19..96beddc 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) { } } +func TestUpstreamConfig_IsDiscoverable(t *testing.T) { + tests := []struct { + name string + uc *UpstreamConfig + discoverable bool + }{ + { + "loopback", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy}, + true, + }, + { + "rfc1918", + &UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy}, + true, + }, + { + "CGNAT", + &UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy}, + true, + }, + { + "Public IP", + &UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy}, + false, + }, + { + "override discoverable", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)}, + false, + }, + { + "override non-public", + &UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)}, + true, + }, + { + "non-legacy upstream", + &UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH}, + false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + if got := tc.uc.IsDiscoverable(); got != tc.discoverable { + t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) + } + }) + } +} + func ptrBool(b bool) *bool { return &b } diff --git a/config_quic.go b/config_quic.go index cd3eaee..5103231 100644 --- a/config_quic.go +++ b/config_quic.go @@ -10,13 +10,10 @@ import ( "net/http" "runtime" "sync" - "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" - - ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) func (uc *UpstreamConfig) setupDOH3Transport() { @@ -29,9 +26,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) case IpStackSplit: uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if ctrldnet.IPv6Available(ctx) { + if hasIPv6() { uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) } else { uc.http3RoundTripper6 = uc.http3RoundTripper4 @@ -127,11 +122,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t close(ch) }() - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - for _, addr := range addrs { go func(addr string) { defer wg.Done() @@ -140,6 +130,11 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t ch <- ¶llelDialerResult{conn: nil, err: err} return } + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + ch <- ¶llelDialerResult{conn: nil, err: err} + return + } conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) select { case ch <- ¶llelDialerResult{conn: conn, err: err}: @@ -147,6 +142,9 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t if conn != nil { conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "") } + if udpConn != nil { + udpConn.Close() + } } }(addr) } diff --git a/config_test.go b/config_test.go index ca57372..ff20bc2 100644 --- a/config_test.go +++ b/config_test.go @@ -54,7 +54,12 @@ func TestLoadDefaultConfig(t *testing.T) { cfg := defaultConfig(t) validate := validator.New() require.NoError(t, ctrld.ValidateConfig(validate, cfg)) - assert.Len(t, cfg.Listener, 1) + if assert.Len(t, cfg.Listener, 1) { + l0 := cfg.Listener["0"] + require.NotNil(t, l0.Policy) + assert.Len(t, l0.Policy.Networks, 1) + assert.Len(t, l0.Policy.Rules, 2) + } assert.Len(t, cfg.Upstream, 2) } @@ -96,6 +101,7 @@ func TestConfigValidation(t *testing.T) { {"lease file format required if lease file exist", configWithExistedLeaseFile(t), true}, {"invalid lease file format", configWithInvalidLeaseFileFormat(t), true}, {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, + {"invalid client id pref", configWithInvalidClientIDPref(t), true}, } for _, tc := range tests { @@ -233,3 +239,9 @@ func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH return cfg } + +func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.ClientIDPref = "foo" + return cfg +} diff --git a/docs/config.md b/docs/config.md index 35fbda5..e5b3945 100644 --- a/docs/config.md +++ b/docs/config.md @@ -215,6 +215,17 @@ DHCP leases file format. - Valid values: `dnsmasq`, `isc-dhcp` - Default: "" +### client_id_preference +Decide how the client ID is generated + +If `host` -> client id will only use the hostname i.e.`hash(hostname)`. +If `mac` -> client id will only use the MAC address `hash(mac)`. +Else -> client ID will use both Mac and Hostname i.e. `hash(mac + host) +- Type: string +- Required: no +- Valid values: `mac`, `host` +- Default: "" + ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to. @@ -319,6 +330,24 @@ If `ip_stack` is empty, or undefined: - Default value is `both` for non-Control D resolvers. - Default value is `split` for Control D resolvers. +### send_client_info +Specifying whether to include client info when sending query to upstream. + +- Type: boolean +- Required: no +- Default: + - `true` for ControlD upstreams. + - `false` for other upstreams. + +### discoverable +Specifying whether the upstream can be used for PTR discovery. + +- Type: boolean +- Required: no +- Default: + - `true` for loopback/RFC1918/CGNAT IP address. + - `false` for public IP address. + ## Network The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs. @@ -376,7 +405,14 @@ Port number that the listener will listen on for incoming requests. If `port` is - Default: 0 or 53 or 5354 (depending on platform) ### restricted -If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`. +If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`. + +- Type: bool +- Required: no +- Default: false + +### allow_wan_clients +The listener `REFUSED` DNS queries from WAN clients by default. If set to `true`, makes the listener replies to them. - Type: bool - Required: no @@ -386,7 +422,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to. If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`. -The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example: +The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either: + + - Network. + - Domain. + - Mac Address. + +Value is the list of the upstreams. + +For example: ```toml [listener.0.policy] @@ -400,12 +444,18 @@ rules = [ {"*.local" = ["upstream.1"]}, {"test.com" = ["upstream.2", "upstream.1"]}, ] + +macs = [ + {"14:54:4a:8e:08:2d" = ["upstream.3"]}, +] ``` Above policy will: -- Forward requests on `listener.0` from `network.0` to `upstream.1`. + - Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`. - Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`. +- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`. +- Forward requests on `listener.0` from `network.0` to `upstream.1`. - All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`. An empty upstream would not route the request to any defined upstreams, and use the OS default resolver. @@ -419,6 +469,18 @@ rules = [ ] ``` +--- + +Note that the order of matching preference: + +``` +rules => macs => networks +``` + +And within each policy, the rules are processed from top to bottom. + +--- + #### name `name` is the name for the policy. @@ -440,6 +502,13 @@ rules = [ - Required: no - Default: [] +### macs: +`macs` is the list of mac rules within the policy. Mac address value is case-insensitive. + +- Type: array of macs +- Required: no +- Default: [] + ### failover_rcodes For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. diff --git a/doh.go b/doh.go index e0aa363..25ed2cb 100644 --- a/doh.go +++ b/doh.go @@ -18,11 +18,12 @@ import ( ) const ( - dohMacHeader = "x-cd-mac" - dohIPHeader = "x-cd-ip" - dohHostHeader = "x-cd-host" - dohOsHeader = "x-cd-os" - headerApplicationDNS = "application/dns-message" + dohMacHeader = "x-cd-mac" + dohIPHeader = "x-cd-ip" + dohHostHeader = "x-cd-host" + dohOsHeader = "x-cd-os" + dohClientIDPrefHeader = "x-cd-cpref" + headerApplicationDNS = "application/dns-message" ) // EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value. @@ -76,7 +77,6 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { endpoint: uc.u, isDoH3: uc.Type == ResolverTypeDOH3, http3RoundTripper: uc.http3RoundTripper, - sendClientInfo: uc.UpstreamSendClientInfo(), uc: uc, } return r @@ -87,9 +87,9 @@ type dohResolver struct { endpoint *url.URL isDoH3 bool http3RoundTripper http.RoundTripper - sendClientInfo bool } +// Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { data, err := msg.Pack() if err != nil { @@ -106,7 +106,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if err != nil { return nil, fmt.Errorf("could not create request: %w", err) } - addHeader(ctx, req, r.sendClientInfo) + addHeader(ctx, req, r.uc) dnsTyp := uint16(0) if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -146,26 +146,19 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, nil } -func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { +func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) - req.Header.Set(dohOsHeader, dohOsHeaderValue()) printed := false - if sendClientInfo { + if uc.UpstreamSendClientInfo() { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" - if ci.Mac != "" { - req.Header.Set(dohMacHeader, ci.Mac) - } - if ci.IP != "" { - req.Header.Set(dohIPHeader, ci.IP) - } - if ci.Hostname != "" { - req.Header.Set(dohHostHeader, ci.Hostname) - } - if ci.Self { - req.Header.Set(dohOsHeader, dohOsHeaderValue()) + switch { + case uc.isControlD(): + addControlDHeaders(req, ci) + case uc.isNextDNS(): + addNextDNSHeaders(req, ci) } } } @@ -173,3 +166,41 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") } } + +// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream. +func addControlDHeaders(req *http.Request, ci *ClientInfo) { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + if ci.Mac != "" { + req.Header.Set(dohMacHeader, ci.Mac) + } + if ci.IP != "" { + req.Header.Set(dohIPHeader, ci.IP) + } + if ci.Hostname != "" { + req.Header.Set(dohHostHeader, ci.Hostname) + } + if ci.Self { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + } + switch ci.ClientIDPref { + case "mac": + req.Header.Set(dohClientIDPrefHeader, "1") + case "host": + req.Header.Set(dohClientIDPrefHeader, "2") + } +} + +// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream. +// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100 +func addNextDNSHeaders(req *http.Request, ci *ClientInfo) { + if ci.Mac != "" { + // https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543 + req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8]) + } + if ci.IP != "" { + req.Header.Set("X-Device-Ip", ci.IP) + } + if ci.Hostname != "" { + req.Header.Set("X-Device-Name", ci.Hostname) + } +} diff --git a/go.mod b/go.mod index 58ba1e4..fec32ef 100644 --- a/go.mod +++ b/go.mod @@ -25,9 +25,9 @@ require ( github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.3 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.10.0 + golang.org/x/net v0.17.0 golang.org/x/sync v0.2.0 - golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a + golang.org/x/sys v0.13.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.44.0 ) @@ -70,11 +70,10 @@ require ( github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect github.com/vishvananda/netns v0.0.4 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - golang.org/x/crypto v0.9.0 // indirect + golang.org/x/crypto v0.14.0 // indirect golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect - golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 409133a..c792103 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,6 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs= -github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -302,8 +300,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -331,8 +329,6 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34= -golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -378,8 +374,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -452,8 +448,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a h1:qMsju+PNttu/NMbq8bQ9waDdxgJMu9QNoUDuhnBaYt0= -golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -463,8 +459,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 3e92fd1..07e4cf0 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -1,8 +1,11 @@ package clientinfo import ( + "context" "fmt" + "net" "net/netip" + "strconv" "strings" "sync" "time" @@ -68,25 +71,27 @@ type Table struct { refreshers []refresher initOnce sync.Once - dhcp *dhcp - merlin *merlinDiscover - arp *arpDiscover - ptr *ptrDiscover - mdns *mdns - hf *hostsFile - vni *virtualNetworkIface - cfg *ctrld.Config - quitCh chan struct{} - selfIP string - cdUID string + dhcp *dhcp + merlin *merlinDiscover + arp *arpDiscover + ptr *ptrDiscover + mdns *mdns + hf *hostsFile + vni *virtualNetworkIface + svcCfg ctrld.ServiceConfig + quitCh chan struct{} + selfIP string + cdUID string + ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ - cfg: cfg, - quitCh: make(chan struct{}), - selfIP: selfIP, - cdUID: cdUID, + svcCfg: cfg.Service, + quitCh: make(chan struct{}), + selfIP: selfIP, + cdUID: cdUID, + ptrNameservers: ns, } } @@ -97,7 +102,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { clientInfoFiles[name] = format } -func (t *Table) RefreshLoop(stopCh chan struct{}) { +func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Minute * 5) defer timer.Stop() for { @@ -106,7 +111,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) { for _, r := range t.refreshers { _ = r.refresh() } - case <-stopCh: + case <-ctx.Done(): close(t.quitCh) return } @@ -182,6 +187,26 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} + if len(t.ptrNameservers) > 0 { + nss := make([]string, 0, len(t.ptrNameservers)) + for _, ns := range t.ptrNameservers { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + } + } + if len(nss) > 0 { + t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + } + + } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover") @@ -240,6 +265,21 @@ func (t *Table) LookupHostname(ip, mac string) string { return "" } +// LookupRFC1918IPv4 returns the RFC1918 IPv4 address for the given MAC address, if any. +func (t *Table) LookupRFC1918IPv4(mac string) string { + t.initOnce.Do(t.init) + for _, r := range t.ipResolvers { + ip, err := netip.ParseAddr(r.LookupIP(mac)) + if err != nil || ip.Is6() { + continue + } + if ip.IsPrivate() { + return ip.String() + } + } + return "" +} + type macEntry struct { mac string src string @@ -338,39 +378,60 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { t.vni.ip2name.Store(ci.IP, ci.Hostname) } +// ipFinder is the interface for retrieving IP address from hostname. +type ipFinder interface { + lookupIPByHostname(name string, v6 bool) string +} + +// LookupIPByHostname returns the ip address of given hostname. +// If v6 is true, return IPv6 instead of default IPv4. +func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr { + if t == nil { + return nil + } + for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} { + if addr := finder.lookupIPByHostname(hostname, v6); addr != "" { + if ip, err := netip.ParseAddr(addr); err == nil { + return &ip + } + } + } + return nil +} + func (t *Table) discoverDHCP() bool { - if t.cfg.Service.DiscoverDHCP == nil { + if t.svcCfg.DiscoverDHCP == nil { return true } - return *t.cfg.Service.DiscoverDHCP + return *t.svcCfg.DiscoverDHCP } func (t *Table) discoverARP() bool { - if t.cfg.Service.DiscoverARP == nil { + if t.svcCfg.DiscoverARP == nil { return true } - return *t.cfg.Service.DiscoverARP + return *t.svcCfg.DiscoverARP } func (t *Table) discoverMDNS() bool { - if t.cfg.Service.DiscoverMDNS == nil { + if t.svcCfg.DiscoverMDNS == nil { return true } - return *t.cfg.Service.DiscoverMDNS + return *t.svcCfg.DiscoverMDNS } func (t *Table) discoverPTR() bool { - if t.cfg.Service.DiscoverPtr == nil { + if t.svcCfg.DiscoverPtr == nil { return true } - return *t.cfg.Service.DiscoverPtr + return *t.svcCfg.DiscoverPtr } func (t *Table) discoverHosts() bool { - if t.cfg.Service.DiscoverHosts == nil { + if t.svcCfg.DiscoverHosts == nil { return true } - return *t.cfg.Service.DiscoverHosts + return *t.svcCfg.DiscoverHosts } // normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file. diff --git a/internal/clientinfo/client_info_test.go b/internal/clientinfo/client_info_test.go index 79e5912..e6575f2 100644 --- a/internal/clientinfo/client_info_test.go +++ b/internal/clientinfo/client_info_test.go @@ -25,3 +25,22 @@ func Test_normalizeIP(t *testing.T) { }) } } + +func TestTable_LookupRFC1918IPv4(t *testing.T) { + table := &Table{ + dhcp: &dhcp{}, + arp: &arpDiscover{}, + } + + table.ipResolvers = append(table.ipResolvers, table.dhcp) + table.ipResolvers = append(table.ipResolvers, table.arp) + + macAddress := "cc:19:f9:8a:49:e6" + rfc1918IPv4 := "10.0.10.245" + table.dhcp.ip.Store(macAddress, "127.0.0.1") + table.arp.ip.Store(macAddress, rfc1918IPv4) + + if got := table.LookupRFC1918IPv4(macAddress); got != rfc1918IPv4 { + t.Fatalf("unexpected result, want: %s, got: %s", rfc1918IPv4, got) + } +} diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 7c1b2cf..ebbeb77 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -8,6 +8,7 @@ import ( "net" "net/netip" "os" + "sort" "strings" "sync" @@ -134,6 +135,39 @@ func (d *dhcp) List() []string { return ips } +func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { + if d == nil { + return "" + } + var ( + rfc1918Addrs []netip.Addr + others []netip.Addr + ) + d.ip2name.Range(func(key, value any) bool { + if value != name { + return true + } + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + if addr.IsPrivate() { + rfc1918Addrs = append(rfc1918Addrs, addr) + } else { + others = append(others, addr) + } + } + return true + }) + result := [][]netip.Addr{rfc1918Addrs, others} + for _, addrs := range result { + if len(addrs) > 0 { + sort.Slice(addrs, func(i, j int) bool { + return addrs[i].Less(addrs[j]) + }) + return addrs[0].String() + } + } + return "" +} + // AddLeaseFile adds given lease file for reading/watching clients info. func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error { if d.watcher == nil { diff --git a/internal/clientinfo/dhcp_test.go b/internal/clientinfo/dhcp_test.go index af3a168..359f441 100644 --- a/internal/clientinfo/dhcp_test.go +++ b/internal/clientinfo/dhcp_test.go @@ -86,3 +86,15 @@ lease 192.168.1.2 { }) } } + +func Test_dhcp_lookupIPByHostname(t *testing.T) { + d := &dhcp{} + want := "192.168.1.123" + d.ip2name.Store(want, "foo") + d.ip2name.Store("127.0.0.1", "foo") + d.ip2name.Store("169.254.123.123", "foo") + + if got := d.lookupIPByHostname("foo", false); got != want { + t.Fatalf("unexpected result, want: %s, got: %s", want, got) + } +} diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index baf05fb..8c86987 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -1,6 +1,7 @@ package clientinfo import ( + "net/netip" "os" "sync" @@ -109,6 +110,24 @@ func (hf *hostsFile) String() string { return "hosts" } +func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string { + if hf == nil { + return "" + } + hf.mu.Lock() + defer hf.mu.Unlock() + for addr, names := range hf.m { + if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() { + for _, n := range names { + if n == name && ip.Is6() == v6 { + return ip.String() + } + } + } + } + return "" +} + // isLocalhostName reports whether the given hostname represents localhost. func isLocalhostName(hostname string) bool { switch hostname { diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 5875b69..3f0a311 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "os" "sync" "syscall" @@ -59,6 +60,27 @@ func (m *mdns) List() []string { return ips } +func (m *mdns) lookupIPByHostname(name string, v6 bool) string { + if m == nil { + return "" + } + var ip string + m.name.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + //lint:ignore S1008 This is used for readable. + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } + return false + } + } + return true + }) + return ip +} + func (m *mdns) init(quitCh chan struct{}) error { ifaces, err := multicastInterfaces() if err != nil { @@ -123,6 +145,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) { continue } + // Do not complain about use of closed network connection. + if errors.Is(err, net.ErrClosed) { + return + } ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error") return } diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 6a9d99b..8e6b3f7 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -2,6 +2,7 @@ package clientinfo import ( "context" + "net/netip" "sync" "sync/atomic" "time" @@ -72,15 +73,16 @@ func (p *ptrDiscover) lookupHostname(ip string) string { msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address") + ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { - ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") - p.serverDown.Store(true) - go p.checkServer() + if p.serverDown.CompareAndSwap(false, true) { + ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + go p.checkServer() + } return "" } for _, rr := range ans.Answer { @@ -93,6 +95,27 @@ func (p *ptrDiscover) lookupHostname(ip string) string { return "" } +func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { + if p == nil { + return "" + } + var ip string + p.hostname.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + //lint:ignore S1008 This is used for readable. + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } + return false + } + } + return true + }) + return ip +} + // checkServer monitors if the resolver can reach its nameserver. When the nameserver // is reachable, set p.serverDown to false, so p.lookupHostname can continue working. func (p *ptrDiscover) checkServer() { diff --git a/internal/controld/config.go b/internal/controld/config.go index 4e4bc2e..4cc6770 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "runtime" "strings" "time" @@ -119,7 +120,7 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig return d.DialContext(ctx, network, addrs) } - if router.Name() == ddwrt.Name { + if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } client := http.Client{ diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index 54ba8fd..9fab8b6 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -1,9 +1,12 @@ package dnsmasq import ( + "bytes" "errors" + "fmt" "html/template" "net" + "os" "path/filepath" "strings" @@ -19,6 +22,9 @@ server={{ .IP }}#{{ .Port }} add-mac add-subnet=32,128 {{- end}} +{{- if .CacheDisabled}} +cache-size=0 +{{- end}} ` const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" @@ -47,6 +53,8 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then {{- end}} pc_delete "dnssec" "$config_file" # disable DNSSEC pc_delete "trust-anchor=" "$config_file" # disable DNSSEC + pc_delete "cache-size=" "$config_file" + pc_append "cache-size=0" "$config_file" # disable cache # For John fork pc_delete "resolv-file" "$config_file" # no WAN DNS settings @@ -65,6 +73,10 @@ type Upstream struct { } func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { + return ConfTmplWitchCacheDisabled(tmplText, cfg, true) +} + +func ConfTmplWitchCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) { listener := cfg.FirstListener() if listener == nil { return "", errors.New("missing listener") @@ -74,24 +86,26 @@ func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { ip = "127.0.0.1" } upstreams := []Upstream{{IP: ip, Port: listener.Port}} - return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo()) + return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo(), cacheDisabled) } func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") { - return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo()) + return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo(), true) } return ConfTmpl(tmplText, cfg) } -func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo bool) (string, error) { +func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo, cacheDisabled bool) (string, error) { tmpl := template.Must(template.New("").Parse(tmplText)) var to = &struct { SendClientInfo bool Upstreams []Upstream + CacheDisabled bool }{ SendClientInfo: sendClientInfo, Upstreams: upstreams, + CacheDisabled: cacheDisabled, } var sb strings.Builder if err := tmpl.Execute(&sb, to); err != nil { @@ -117,9 +131,28 @@ func firewallaUpstreams(port int) []Upstream { return upstreams } +// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces. +func firewallaDnsmasqConfFiles() ([]string, error) { + return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf") +} + +// firewallUpdateConf updates all firewall config files using given function. +func firewallUpdateConf(update func(conf string) error) error { + confFiles, err := firewallaDnsmasqConfFiles() + if err != nil { + return err + } + for _, conf := range confFiles { + if err := update(conf); err != nil { + return fmt.Errorf("%s: %w", conf, err) + } + } + return nil +} + // FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla. func FirewallaSelfInterfaces() []*net.Interface { - matches, err := filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf") + matches, err := firewallaDnsmasqConfFiles() if err != nil { return nil } @@ -133,3 +166,32 @@ func FirewallaSelfInterfaces() []*net.Interface { } return ifaces } + +// FirewallaDisableCache comments out "cache-size" line in all firewalla dnsmasq config files. +func FirewallaDisableCache() error { + return firewallUpdateConf(DisableCache) +} + +// FirewallaEnableCache un-comments out "cache-size" line in all firewalla dnsmasq config files. +func FirewallaEnableCache() error { + return firewallUpdateConf(EnableCache) +} + +// DisableCache comments out "cache-size" line in dnsmasq config file. +func DisableCache(conf string) error { + return replaceFileContent(conf, "\ncache-size=", "\n#cache-size=") +} + +// EnableCache un-comments "cache-size" line in dnsmasq config file. +func EnableCache(conf string) error { + return replaceFileContent(conf, "\n#cache-size=", "\ncache-size=") +} + +func replaceFileContent(filename, old, new string) error { + content, err := os.ReadFile(filename) + if err != nil { + return err + } + content = bytes.ReplaceAll(content, []byte(old), []byte(new)) + return os.WriteFile(filename, content, 0644) +} diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index f50f610..3e7003b 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -8,10 +8,10 @@ import ( "os/exec" "strings" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" + "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" - "github.com/kardianos/service" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -95,7 +95,7 @@ func (e *EdgeOS) setupUSG() error { return fmt.Errorf("setupUSG: backup current config: %w", err) } - // Removing all configured upstreams. + // Removing all configured upstreams and cache config. var sb strings.Builder scanner := bufio.NewScanner(bytes.NewReader(buf)) for scanner.Scan() { @@ -109,7 +109,7 @@ func (e *EdgeOS) setupUSG() error { sb.WriteString(line) } - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg) + data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) if err != nil { return err } @@ -127,7 +127,7 @@ func (e *EdgeOS) setupUSG() error { } func (e *EdgeOS) setupUDM() error { - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg) + data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) if err != nil { return err } diff --git a/internal/router/firewalla/firewalla.go b/internal/router/firewalla/firewalla.go index cdf6586..66cd15e 100644 --- a/internal/router/firewalla/firewalla.go +++ b/internal/router/firewalla/firewalla.go @@ -65,6 +65,11 @@ func (f *Firewalla) Setup() error { return fmt.Errorf("writing ctrld config: %w", err) } + // Disable dnsmasq cache. + if err := dnsmasq.FirewallaDisableCache(); err != nil { + return err + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return fmt.Errorf("restartDNSMasq: %w", err) @@ -82,6 +87,11 @@ func (f *Firewalla) Cleanup() error { return fmt.Errorf("removing ctrld config: %w", err) } + // Enable dnsmasq cache. + if err := dnsmasq.FirewallaEnableCache(); err != nil { + return err + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return fmt.Errorf("restartDNSMasq: %w", err) diff --git a/internal/router/merlin/merlin.go b/internal/router/merlin/merlin.go index 84ebd1c..8b6a0fc 100644 --- a/internal/router/merlin/merlin.go +++ b/internal/router/merlin/merlin.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "strings" + "time" "unicode" "github.com/kardianos/service" @@ -44,8 +45,24 @@ func (m *Merlin) Uninstall(_ *service.Config) error { } func (m *Merlin) PreRun() error { + // Wait NTP ready. _ = m.Cleanup() - return ntp.WaitNvram() + if err := ntp.WaitNvram(); err != nil { + return err + } + // Wait until directories mounted. + for _, dir := range []string{"/tmp", "/proc"} { + waitDirExists(dir) + } + // Wait dnsmasq started. + for { + out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput() + if len(bytes.TrimSpace(out)) > 0 { + break + } + time.Sleep(time.Second) + } + return nil } func (m *Merlin) Setup() error { @@ -56,9 +73,6 @@ func (m *Merlin) Setup() error { if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { return nil } - if _, err := nvram.Run("set", nvram.CtrldSetupKey+"=1"); err != nil { - return err - } buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) // Already setup. if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) { @@ -140,3 +154,12 @@ func merlinParsePostConf(buf []byte) []byte { } return buf } + +func waitDirExists(dir string) { + for { + if _, err := os.Stat(dir); !os.IsNotExist(err) { + return + } + time.Sleep(time.Second) + } +} diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go index 83ea884..ad98db9 100644 --- a/internal/router/openwrt/openwrt.go +++ b/internal/router/openwrt/openwrt.go @@ -8,11 +8,10 @@ import ( "os/exec" "strings" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -20,10 +19,9 @@ const ( openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf" ) -var errUCIEntryNotFound = errors.New("uci: Entry not found") - type Openwrt struct { - cfg *ctrld.Config + cfg *ctrld.Config + dnsmasqCacheSize string } // New returns a router.Router for configuring/setup/run ctrld on Openwrt routers. @@ -52,6 +50,19 @@ func (o *Openwrt) Setup() error { if o.cfg.FirstListener().IsDirectDnsListener() { return nil } + + // Save current dnsmasq config cache size if present. + if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil { + o.dnsmasqCacheSize = cs + if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg) if err != nil { return err @@ -59,10 +70,6 @@ func (o *Openwrt) Setup() error { if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil { return err } - // Commit. - if _, err := uci("commit"); err != nil { - return err - } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -78,6 +85,18 @@ func (o *Openwrt) Cleanup() error { if err := os.Remove(openwrtDNSMasqConfigPath); err != nil { return err } + + // Restore original value if present. + if o.dnsmasqCacheSize != "" { + if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -92,6 +111,8 @@ func restartDNSMasq() error { return nil } +var errUCIEntryNotFound = errors.New("uci: Entry not found") + func uci(args ...string) (string, error) { cmd := exec.Command("uci", args...) var stdout, stderr bytes.Buffer diff --git a/internal/router/ubios/ubios.go b/internal/router/ubios/ubios.go index b0762db..32c7576 100644 --- a/internal/router/ubios/ubios.go +++ b/internal/router/ubios/ubios.go @@ -5,16 +5,17 @@ import ( "os" "strconv" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" + "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" "github.com/Control-D-Inc/ctrld/internal/router/edgeos" - "github.com/kardianos/service" ) const ( - Name = "ubios" - ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" + Name = "ubios" + ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" + ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf" ) type Ubios struct { @@ -57,6 +58,10 @@ func (u *Ubios) Setup() error { if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil { return err } + // Disable dnsmasq cache. + if err := dnsmasq.DisableCache(ubiosDNSMasqDnsConfigPath); err != nil { + return err + } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -72,6 +77,10 @@ func (u *Ubios) Cleanup() error { if err := os.Remove(ubiosDNSMasqConfigPath); err != nil { return err } + // Enable dnsmasq cache. + if err := dnsmasq.EnableCache(ubiosDNSMasqDnsConfigPath); err != nil { + return err + } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err diff --git a/net.go b/net.go index 110d67e..3ae3bb5 100644 --- a/net.go +++ b/net.go @@ -2,13 +2,10 @@ package ctrld import ( "context" - "errors" "sync" "sync/atomic" "time" - "tailscale.com/logtail/backoff" - ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) @@ -17,30 +14,36 @@ var ( ipv6Available atomic.Bool ) +const ipv6ProbingInterval = 10 * time.Second + func hasIPv6() bool { hasIPv6Once.Do(func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) ipv6Available.Store(val) - go probingIPv6(val) + go probingIPv6(context.TODO(), val) }) return ipv6Available.Load() } // TODO(cuonglm): doing poll check natively for supported platforms. -func probingIPv6(old bool) { - b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second) - bCtx := context.Background() +func probingIPv6(ctx context.Context, old bool) { + ticker := time.NewTicker(ipv6ProbingInterval) + defer ticker.Stop() for { - func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - cur := ctrldnet.IPv6Available(ctx) - if ipv6Available.CompareAndSwap(old, cur) { - old = cur - } - }() - b.BackOff(bCtx, errors.New("no change")) + select { + case <-ctx.Done(): + return + case <-ticker.C: + func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + cur := ctrldnet.IPv6Available(ctx) + if ipv6Available.CompareAndSwap(old, cur) { + old = cur + } + }() + } } } diff --git a/resolver.go b/resolver.go index 969da86..750679c 100644 --- a/resolver.go +++ b/resolver.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "time" "github.com/miekg/dns" + "tailscale.com/net/interfaces" ) const ( @@ -24,6 +26,8 @@ const ( ResolverTypeOS = "os" // ResolverTypeLegacy specifies legacy resolver. ResolverTypeLegacy = "legacy" + // ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only. + ResolverTypePrivate = "private" ) var bootstrapDNS = "76.76.2.0" @@ -61,6 +65,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil + case ResolverTypePrivate: + return NewPrivateResolver(), nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } @@ -74,8 +80,9 @@ type osResolverResult struct { err error } -// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from -// available nameservers with a roundrobin algorithm. +// Resolve resolves DNS queries using pre-configured nameservers. +// Query is sent to all nameservers concurrently, and the first +// success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { numServers := len(o.nameservers) if numServers == 0 { @@ -240,12 +247,16 @@ func NewBootstrapResolver(servers ...string) Resolver { } // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, -// excluding nameservers from /etc/resolv.conf file. +// excluding: +// +// - Nameservers from /etc/resolv.conf file. +// - Nameservers which is local RFC1918 addresses. // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { nss := nameservers() resolveConfNss := nameserversFromResolvconf() + localRfc1918Addrs := Rfc1918Addresses() n := 0 for _, ns := range nss { host, _, _ := net.SplitHostPort(ns) @@ -258,6 +269,10 @@ func NewPrivateResolver() Resolver { if sliceContains(resolveConfNss, host) { continue } + // Ignoring local RFC 1918 addresses. + if sliceContains(localRfc1918Addrs, host) { + continue + } ip := net.ParseIP(host) if ip != nil && ip.IsPrivate() && !ip.IsLoopback() { nss[n] = ns @@ -265,11 +280,35 @@ func NewPrivateResolver() Resolver { } } nss = nss[:n] - if len(nss) == 0 { + return NewResolverWithNameserver(nss) +} + +// NewResolverWithNameserver returns an OS resolver which uses the given nameservers +// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned. +// +// Each nameserver must be form "host:port". It's the caller responsibility to ensure all +// nameservers are well formatted by using net.JoinHostPort function. +func NewResolverWithNameserver(nameservers []string) Resolver { + if len(nameservers) == 0 { return &dummyResolver{} } - resolver := &osResolver{nameservers: nss} - return resolver + return &osResolver{nameservers: nameservers} +} + +// Rfc1918Addresses returns the list of local interfaces private IP addresses +func Rfc1918Addresses() []string { + var res []string + interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { + addrs, _ := i.Addrs() + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok || !ipNet.IP.IsPrivate() { + continue + } + res = append(res, ipNet.IP.String()) + } + }) + return res } func newDialer(dnsAddress string) *net.Dialer { diff --git a/testhelper/config.go b/testhelper/config.go index 5c2e5f4..6199424 100644 --- a/testhelper/config.go +++ b/testhelper/config.go @@ -82,4 +82,8 @@ rules = [ {"*.ru" = ["upstream.1"]}, {"*.local.host" = ["upstream.2", "upstream.0"]}, ] +macs = [ + {"14:45:A0:67:83:0A" = ["upstream.2"]}, + {"14:54:4a:8e:08:2d" = ["upstream.2"]}, +] `