diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 6d6360f..5c7795f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "os/exec" "path/filepath" @@ -23,11 +24,13 @@ import ( "sync" "time" + "github.com/Masterminds/semver" "github.com/cuonglm/osinfo" "github.com/fsnotify/fsnotify" "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/miekg/dns" + "github.com/minio/selfupdate" "github.com/olekukonko/tablewriter" "github.com/pelletier/go-toml/v2" "github.com/rs/zerolog" @@ -44,6 +47,9 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router" ) +// selfCheckInternalTestDomain is used for testing ctrld self response to clients. +const selfCheckInternalTestDomain = "ctrld" + loopTestDomain + var ( version = "dev" commit = "none" @@ -170,12 +176,63 @@ func initCLI() { } setDependencies(sc) sc.Arguments = append([]string{"run"}, osArgs...) + + p := &prog{ + router: router.New(&cfg, cdUID != ""), + cfg: &cfg, + } + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + + status, err := s.Status() + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + + // If pin code was set, do not allow running start command. + if status == service.StatusRunning { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + } + if cdUID != "" { - if _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); err != nil { + rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) } + // validateCdRemoteConfig clobbers v, saving it here to restore later. + oldV := v + if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil { + if errors.As(err, &viper.ConfigParseError{}) { + if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { + tmpDir := os.TempDir() + tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") + errorLogged := false + // Write remote config to a temporary file to get details error. + if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { + if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { + row, col := de.Position() + mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) + errorLogged = true + } + _ = os.Remove(tmpConfFile) + } + // If we could not log details error, emit what we have already got. + if !errorLogged { + mainLog.Load().Error().Msgf("failed to parse custom config: %v", err) + } + } + } else { + mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) + } + mainLog.Load().Warn().Msg("disregarding invalid custom config") + } + v = oldV } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid + mainLog.Load().Debug().Msg("using uid from provision token") removeProvTokenFromArgs(sc) // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. sc.Arguments = append(sc.Arguments, "--cd="+cdUID) @@ -184,10 +241,6 @@ func initCLI() { validateCdUpstreamProtocol() } - p := &prog{ - router: router.New(&cfg, cdUID != ""), - cfg: &cfg, - } if err := p.router.ConfigureService(sc); err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") } @@ -253,22 +306,6 @@ func initCLI() { sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) } - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - - status, err := s.Status() - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) - - // If pin code was set, do not allow running start command. - if status == service.StatusRunning { - if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - } - if router.Name() != "" && iface != "" { mainLog.Load().Debug().Msg("cleaning up router before installing") _ = p.router.Cleanup() @@ -334,7 +371,7 @@ func initCLI() { } // If ctrld service is running but selfCheckStatus failed, it could be related // to user's system firewall configuration, notice users about it. - if status == service.StatusRunning { + if status == service.StatusRunning && err == nil { _, _ = mainLog.Load().Write(marker) mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) @@ -406,14 +443,14 @@ func initCLI() { Run: func(cmd *cobra.Command, args []string) { readConfig(false) v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, cdUID != "")} + p := &prog{router: router.New(&cfg, runInCdMode())} s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) return } initLogging() - if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } if doTasks([]task{{s.Stop, true}}) { @@ -563,7 +600,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Run: func(cmd *cobra.Command, args []string) { readConfig(false) v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, cdUID != "")} + p := &prog{router: router.New(&cfg, runInCdMode())} s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) @@ -572,7 +609,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if iface == "" { iface = "auto" } - if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } uninstall(p, s) @@ -816,6 +853,117 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } clientsCmd.AddCommand(listClientsCmd) rootCmd.AddCommand(clientsCmd) + + const ( + upgradeChannelDev = "dev" + upgradeChannelProd = "prod" + upgradeChannelDefault = "default" + ) + upgradeChannel := map[string]string{ + upgradeChannelDefault: "https://dl.controld.dev", + upgradeChannelDev: "https://dl.controld.dev", + upgradeChannelProd: "https://dl.controld.com", + } + if isStableVersion(curVersion()) { + upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] + } + upgradeCmd := &cobra.Command{ + Use: "upgrade", + Short: "Upgrading ctrld to latest version", + ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, + Args: cobra.MaximumNArgs(1), + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + bin, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") + } + oldBin := bin + "_previous" + urlString := upgradeChannel[upgradeChannelDefault] + if len(args) > 0 { + channel := args[0] + switch channel { + case upgradeChannelProd, upgradeChannelDev: // ok + default: + mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) + } + urlString = upgradeChannel[channel] + } + dlUrl := fmt.Sprintf("%s/%s-%s/ctrld", urlString, runtime.GOOS, runtime.GOARCH) + if runtime.GOOS == "windows" { + dlUrl += ".exe" + } + mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) + resp, err := http.Get(dlUrl) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to download binary") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) + } + mainLog.Load().Debug().Msg("Updating current binary") + if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { + if rerr := selfupdate.RollbackError(err); rerr != nil { + mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") + } + mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") + } + + doRestart := func() bool { + tasks := []task{ + {s.Stop, false}, + {s.Start, false}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + return newSocketControlClient(s, dir) != nil + } + } + return false + } + mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + if doRestart() { + _ = os.Remove(oldBin) + _ = os.Chmod(bin, 0755) + ver := "unknown version" + out, err := exec.Command(bin, "--version").CombinedOutput() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") + } + if after, found := strings.CutPrefix(string(out), "ctrld version "); found { + ver = after + } + mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) + return + } + + mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) + if err := os.Remove(bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") + } + if err := os.Rename(oldBin, bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") + } + if doRestart() { + mainLog.Load().Notice().Msg("Restored previous binary successfully") + return + } + }, + } + rootCmd.AddCommand(upgradeCmd) } // isMobile reports whether the current OS is a mobile platform. @@ -828,6 +976,15 @@ func isAndroid() bool { return runtime.GOOS == "android" } +// isStableVersion reports whether vs is a stable semantic version. +func isStableVersion(vs string) bool { + v, err := semver.NewVersion(vs) + if err != nil { + return false + } + return v.Prerelease() == "" +} + // RunCobraCommand runs ctrld cli. func RunCobraCommand(cmd *cobra.Command) { noConfigStart = isNoConfigStart(cmd) @@ -852,9 +1009,9 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc } // CheckDeactivationPin checks if deactivation pin is valid -func CheckDeactivationPin(pin int64) int { +func CheckDeactivationPin(pin int64, stopCh chan struct{}) int { deactivationPin = pin - if err := checkDeactivationPin(nil); isCheckDeactivationPinErr(err) { + if err := checkDeactivationPin(nil, stopCh); isCheckDeactivationPinErr(err) { return deactivationPinInvalidExitCode } return 0 @@ -912,7 +1069,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { writeDefaultConfig := !noConfigStart && configBase64 == "" tryReadingConfig(writeDefaultConfig) - readBase64Config(configBase64) + if err := readBase64Config(configBase64); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") + } processNoConfigFlags(noConfigStart) p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { @@ -935,7 +1094,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } p.router = router.New(&cfg, cdUID != "") - cs, err := newControlServer(filepath.Join(sockDir, ctrldControlUnixSock)) + cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName())) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not create control server") } @@ -1141,7 +1300,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { } // If error is viper.ConfigFileNotFoundError, write default config. - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + if errors.As(err, &viper.ConfigFileNotFoundError{}) { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } @@ -1162,13 +1321,11 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { return false } - if _, ok := err.(viper.ConfigParseError); ok { - if f, _ := os.Open(v.ConfigFileUsed()); f != nil { - var i any - if err, ok := toml.NewDecoder(f).Decode(&i).(*toml.DecodeError); ok { - row, col := err.Position() - mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) - } + // If error is viper.ConfigParseError, emit details line and column number. + if errors.As(err, &viper.ConfigParseError{}) { + if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil { + row, col := de.Position() + mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) } } @@ -1177,13 +1334,27 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { return false } -func readBase64Config(configBase64 string) { +// decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error. +func decoderErrorFromTomlFile(cf string) *toml.DecodeError { + if f, _ := os.Open(cf); f != nil { + defer f.Close() + var i any + var de *toml.DecodeError + if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) { + return de + } + } + return nil +} + +// readBase64Config reads ctrld config from the base64 input string. +func readBase64Config(configBase64 string) error { if configBase64 == "" { - return + return nil } configStr, err := base64.StdEncoding.DecodeString(configBase64) if err != nil { - mainLog.Load().Fatal().Msgf("invalid base64 config: %v", err) + return fmt.Errorf("invalid base64 config: %w", err) } // readBase64Config is called when: @@ -1194,9 +1365,7 @@ func readBase64Config(configBase64 string) { // So we need to re-create viper instance to discard old one. v = viper.NewWithOptions(viper.KeyDelimiter("::")) v.SetConfigType("toml") - if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil { - mainLog.Load().Fatal().Msgf("failed to read base64 config: %v", err) - } + return v.ReadConfig(bytes.NewReader(configStr)) } func processNoConfigFlags(noConfigStart bool) { @@ -1286,42 +1455,76 @@ func processCDFlags(cfg *ctrld.Config) error { // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") - readBase64Config(resolverConfig.Ctrld.CustomConfig) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil { + setListenerDefaultValue(cfg) + return nil } - } else { - cfg.Network = make(map[string]*ctrld.NetworkConfig) - cfg.Network["0"] = &ctrld.NetworkConfig{ - Name: "Network 0", - Cidrs: []string{"0.0.0.0/0"}, - } - cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) - cfg.Upstream["0"] = &ctrld.UpstreamConfig{ - Endpoint: resolverConfig.DOH, - Type: cdUpstreamProto, - Timeout: 5000, - } - rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) - for _, domain := range resolverConfig.Exclude { - rules = append(rules, ctrld.Rule{domain: []string{}}) - } - cfg.Listener = make(map[string]*ctrld.ListenerConfig) - lc := &ctrld.ListenerConfig{ - Policy: &ctrld.ListenerPolicyConfig{ - Name: "My Policy", - Rules: rules, - }, - } - cfg.Listener["0"] = lc + mainLog.Load().Err(err).Msg("disregarding invalid custom config") } + + bootstrapIP := func(endpoint string) string { + u, err := url.Parse(endpoint) + if err != nil { + logger.Warn().Err(err).Msgf("no bootstrap IP for invalid endpoint: %s", endpoint) + return "" + } + switch { + case dns.IsSubDomain(ctrld.FreeDnsDomain, u.Host): + return ctrld.FreeDNSBoostrapIP + case dns.IsSubDomain(ctrld.PremiumDnsDomain, u.Host): + return ctrld.PremiumDNSBoostrapIP + } + return "" + } + cfg.Network = make(map[string]*ctrld.NetworkConfig) + cfg.Network["0"] = &ctrld.NetworkConfig{ + Name: "Network 0", + Cidrs: []string{"0.0.0.0/0"}, + } + cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) + cfg.Upstream["0"] = &ctrld.UpstreamConfig{ + BootstrapIP: bootstrapIP(resolverConfig.DOH), + Endpoint: resolverConfig.DOH, + Type: cdUpstreamProto, + Timeout: 5000, + } + rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) + for _, domain := range resolverConfig.Exclude { + rules = append(rules, ctrld.Rule{domain: []string{}}) + } + cfg.Listener = make(map[string]*ctrld.ListenerConfig) + lc := &ctrld.ListenerConfig{ + Policy: &ctrld.ListenerPolicyConfig{ + Name: "My Policy", + Rules: rules, + }, + } + cfg.Listener["0"] = lc + // Set default value. + setListenerDefaultValue(cfg) + + return nil +} + +// setListenerDefaultValue sets the default value for cfg.Listener if none existed. +func setListenerDefaultValue(cfg *ctrld.Config) { if len(cfg.Listener) == 0 { cfg.Listener = map[string]*ctrld.ListenerConfig{ "0": {IP: "", Port: 0}, } } - return nil +} + +// validateCdRemoteConfig validates the custom config from ControlD if defined. +func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error { + if rc.Ctrld.CustomConfig == "" { + return nil + } + if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil { + return err + } + return v.Unmarshal(&cfg) } func processListenFlag() { @@ -1398,6 +1601,13 @@ func defaultIfaceName() string { // selfCheckStatus performs the end-to-end DNS test by sending query to ctrld listener. // It returns a boolean to indicate whether the check is succeeded, the actual status // of ctrld service, and an additional error if any. +// +// We perform two tests: +// +// - Internal testing, ensuring query could be sent from client -> ctrld. +// - External testing, ensuring query could be sent from ctrld -> upstream. +// +// Self-check is considered success only if both tests are ok. func selfCheckStatus(s service.Service) (bool, service.Status, error) { status, err := s.Status() if err != nil { @@ -1480,8 +1690,9 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { }) v.WatchConfig() var ( - lastAnswer *dns.Msg - lastErr error + lastAnswer *dns.Msg + lastErr error + internalTested bool ) for i := 0; i < maxAttempts; i++ { mu.Lock() @@ -1494,6 +1705,9 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { mu.Unlock() lc := cfg.FirstListener() domain = cfg.FirstUpstream().VerifyDomain() + if !internalTested { + domain = selfCheckInternalTestDomain + } if domain == "" { continue } @@ -1503,7 +1717,13 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { m.RecursionDesired = true r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { - mainLog.Load().Debug().Msgf("self-check against %q succeeded", domain) + internalTested = domain == selfCheckInternalTestDomain + if internalTested { + mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain) + continue // internal domain test ok, continue with external test. + } else { + mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain) + } return true, status, nil } // Return early if this is a connection refused. @@ -1515,6 +1735,12 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } mainLog.Load().Debug().Msgf("self-check against %q failed", domain) + // Ping all upstreams to provide better error message to users. + for name, uc := range cfg.Upstream { + if err := uc.ErrorPing(); err != nil { + mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) + } + } lc := cfg.FirstListener() addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) marker := strings.Repeat("=", 32) @@ -1629,6 +1855,10 @@ func readConfigWithNotice(writeDefaultConfig, notice bool) { } func uninstall(p *prog, s service.Service) { + if _, err := s.Status(); err != nil && errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Error().Msg(err.Error()) + return + } tasks := []task{ {s.Stop, false}, {s.Uninstall, true}, @@ -1677,6 +1907,11 @@ func fieldErrorMsg(fe validator.FieldError) string { return fmt.Sprintf("must define at least %s element", fe.Param()) } return fmt.Sprintf("minimum value: %q", fe.Param()) + case "max": + if fe.Kind() == reflect.Map || fe.Kind() == reflect.Slice { + return fmt.Sprintf("exceeded maximum number of elements: %s", fe.Param()) + } + return fmt.Sprintf("maximum value: %q", fe.Param()) case "len": if fe.Kind() == reflect.Slice { return fmt.Sprintf("must have at least %s element", fe.Param()) @@ -1802,10 +2037,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata if cdMode { firstLn.IP = mobileListenerIp() firstLn.Port = mobileListenerPort() - // TODO: use clear(lcc) once upgrading to go 1.21 - for k := range lcc { - delete(lcc, k) - } + clear(lcc) updated = true } } @@ -2027,6 +2259,7 @@ func cdUIDFromProvToken() string { if cdOrg == "" { return "" } + // Process provision token if provided. resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev) if err != nil { @@ -2065,6 +2298,8 @@ func newSocketControlClient(s service.Service, dir string) *controlClient { ctx := context.Background() cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + timeout := time.NewTimer(30 * time.Second) + defer timeout.Stop() // The socket control server may not start yet, so attempt to ping // it until we got a response. For each iteration, check ctrld status @@ -2072,7 +2307,6 @@ func newSocketControlClient(s service.Service, dir string) *controlClient { for { curStatus, err := s.Status() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not get service status while doing self-check") return nil } if curStatus != service.StatusRunning { @@ -2084,12 +2318,37 @@ func newSocketControlClient(s service.Service, dir string) *controlClient { } // The socket control server is not ready yet, backoff for waiting it to be ready. bo.BackOff(ctx, err) + select { + case <-timeout.C: + return nil + default: + } continue } return cc } +func newSocketControlClientMobile(dir string, stopCh chan struct{}) *controlClient { + bo := backoff.NewBackoff("self-check", logf, 3*time.Second) + bo.LogLongerThan = 3 * time.Second + ctx := context.Background() + cc := newControlClient(filepath.Join(dir, ControlSocketName())) + for { + select { + case <-stopCh: + return nil + default: + _, err := cc.post("/", nil) + if err == nil { + return cc + } else { + bo.BackOff(ctx, err) + } + } + } +} + // checkStrFlagEmpty validates if a string flag was set to an empty string. // If yes, emitting a fatal error message. func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { @@ -2170,7 +2429,7 @@ var errInvalidDeactivationPin = errors.New("deactivation pin is invalid") var errRequiredDeactivationPin = errors.New("deactivation pin is required to stop or uninstall the service") // checkDeactivationPin validates if the deactivation pin matches one in ControlD config. -func checkDeactivationPin(s service.Service) error { +func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { dir, err := socketDir() if err != nil { mainLog.Load().Err(err).Msg("could not check deactivation pin") @@ -2178,7 +2437,7 @@ func checkDeactivationPin(s service.Service) error { } var cc *controlClient if s == nil { - cc = newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + cc = newSocketControlClientMobile(dir, stopCh) } else { cc = newSocketControlClient(s, dir) } @@ -2259,3 +2518,20 @@ func absHomeDir(filename string) string { } return filepath.Join(dir, filename) } + +// runInCdMode reports whether ctrld service is running in cd mode. +func runInCdMode() bool { + if s, _ := newService(&prog{}, svcConfig); s != nil { + if dir, _ := socketDir(); dir != "" { + cc := newSocketControlClient(s, dir) + if cc != nil { + resp, _ := cc.post(cdPath, nil) + if resp != nil { + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK + } + } + } + } + return false +} diff --git a/cmd/cli/cli_test.go b/cmd/cli/cli_test.go index 01f2586..fcede32 100644 --- a/cmd/cli/cli_test.go +++ b/cmd/cli/cli_test.go @@ -21,3 +21,26 @@ func Test_writeConfigFile(t *testing.T) { _, err = os.Stat(configPath) require.NoError(t, err) } + +func Test_isStableVersion(t *testing.T) { + tests := []struct { + name string + ver string + isStable bool + }{ + {"stable", "v1.3.5", true}, + {"pre", "v1.3.5-next", false}, + {"pre with commit hash", "v1.3.5-next-asdf", false}, + {"dev", "dev", false}, + {"empty", "dev", false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isStableVersion(tc.ver); got != tc.isStable { + t.Errorf("unexpected result for %s, want: %v, got: %v", tc.ver, tc.isStable, got) + } + }) + } +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 28c20a6..4d243bf 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -21,6 +21,7 @@ const ( startedPath = "/started" reloadPath = "/reload" deactivationPath = "/deactivation" + cdPath = "/cd" ) type controlServer struct { @@ -171,6 +172,13 @@ func (p *prog) registerControlServerHandler() { } w.WriteHeader(code) })) + p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + if cdUID != "" { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusBadRequest) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0b6282e..a5242c5 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -101,6 +101,15 @@ func (p *prog) serveDNS(listenerNum string) error { go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) + if domain == selfCheckInternalTestDomain { + answer := resolveInternalDomainTestQuery(ctx, domain, m) + _ = w.WriteMsg(answer) + return + } + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { + p.cache.Purge() + ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) + } remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) ci.ClientIDPref = p.cfg.Service.ClientIDPref @@ -282,7 +291,7 @@ networkRules: macRules: for _, rule := range lc.Policy.Macs { for source, targets := range rule { - if source != "" && strings.EqualFold(source, srcMac) { + if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) { matchedPolicy = lc.Policy.Name matchedNetwork = source networkTargets = targets @@ -590,7 +599,8 @@ func canonicalName(fqdn string) string { return q } -func wildcardMatches(wildcard, domain string) bool { +// wildcardMatches reports whether string str matches the wildcard pattern. +func wildcardMatches(wildcard, str string) bool { // Wildcard match. wildCardParts := strings.Split(wildcard, "*") if len(wildCardParts) != 2 { @@ -600,15 +610,15 @@ func wildcardMatches(wildcard, domain string) bool { switch { case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0: // Domain must match both prefix and suffix. - return strings.HasPrefix(domain, wildCardParts[0]) && strings.HasSuffix(domain, wildCardParts[1]) + return strings.HasPrefix(str, wildCardParts[0]) && strings.HasSuffix(str, wildCardParts[1]) case len(wildCardParts[1]) > 0: // Only suffix must match. - return strings.HasSuffix(domain, wildCardParts[1]) + return strings.HasSuffix(str, wildCardParts[1]) case len(wildCardParts[0]) > 0: // Only prefix must match. - return strings.HasPrefix(domain, wildCardParts[0]) + return strings.HasPrefix(str, wildCardParts[0]) } return false @@ -806,6 +816,13 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) } ci.Self = queryFromSelf(ci.IP) + // If this is a query from self, but ci.IP is not loopback IP, + // try using hostname mapping for lookback IP if presents. + if ci.Self { + if name := p.ciTable.LocalHostname(); name != "" { + ci.Hostname = name + } + } p.spoofLoopbackIpInClientInfo(ci) return ci } @@ -936,3 +953,21 @@ func isWanClient(na net.Addr) bool { !ip.IsLinkLocalMulticast() && !tsaddr.CGNATRange().Contains(ip) } + +// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller. +func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg { + ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query") + + q := m.Question[0] + answer := new(dns.Msg) + rrStr := fmt.Sprintf("%s A %s", domain, net.IPv4zero) + if q.Qtype == dns.TypeAAAA { + rrStr = fmt.Sprintf("%s AAAA %s", domain, net.IPv6zero) + } + rr, err := dns.NewRR(rrStr) + if err == nil { + answer.Answer = append(answer.Answer, rr) + } + answer.SetReply(m) + return answer +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 52d3edb..cb2e459 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -22,14 +22,21 @@ func Test_wildcardMatches(t *testing.T) { domain string match bool }{ - {"prefix parent should not match", "*.windscribe.com", "windscribe.com", false}, - {"prefix", "*.windscribe.com", "anything.windscribe.com", true}, - {"prefix not match other domain", "*.windscribe.com", "example.com", false}, - {"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false}, - {"suffix", "suffix.*", "suffix.windscribe.com", true}, - {"suffix not match other", "suffix.*", "suffix1.windscribe.com", false}, - {"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true}, - {"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false}, + {"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false}, + {"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true}, + {"domain - prefix not match other s", "*.windscribe.com", "example.com", false}, + {"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false}, + {"domain - suffix", "suffix.*", "suffix.windscribe.com", true}, + {"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false}, + {"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true}, + {"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false}, + {"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true}, + {"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false}, + {"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false}, + {"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true}, + {"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false}, + {"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true}, + {"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false}, } for _, tc := range tests { diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 7ce4aa1..f319056 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os/exec" + "strings" "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) @@ -30,6 +31,18 @@ func deAllocateIP(ip string) error { return nil } +// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable. +func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error { + if err := setDNS(iface, nameservers); err != nil { + // TODO: investiate whether we can detect this without relying on error message. + if strings.Contains(err.Error(), " is not a recognized network service") { + return nil + } + return err + } + return nil +} + // set the dns server for the provided network interface // networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1 // TODO(cuonglm): use system API @@ -43,6 +56,18 @@ func setDNS(iface *net.Interface, nameservers []string) error { return nil } +// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable. +func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { + if err := resetDNS(iface); err != nil { + // TODO: investiate whether we can detect this without relying on error message. + if strings.Contains(err.Error(), " is not a recognized network service") { + return nil + } + return err + } + return nil +} + // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { if ns := savedStaticNameservers(iface); len(ns) > 0 { diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index 216b36f..cc5ff92 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -29,6 +29,11 @@ func deAllocateIP(ip string) error { return nil } +// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable. +func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error { + return setDNS(iface, nameservers) +} + // set the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { r, err := dns.NewOSConfigurator(logf, iface.Name) @@ -49,6 +54,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { return nil } +// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable. +func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { + return resetDNS(iface) +} + func resetDNS(iface *net.Interface) error { r, err := dns.NewOSConfigurator(logf, iface.Name) if err != nil { diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 3d9bffd..a36311d 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -9,6 +9,7 @@ import ( "net" "net/netip" "os/exec" + "slices" "strings" "syscall" "time" @@ -45,7 +46,11 @@ func deAllocateIP(ip string) error { const maxSetDNSAttempts = 5 -// set the dns server for the provided network interface +// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable. +func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error { + return setDNS(iface, nameservers) +} + func setDNS(iface *net.Interface, nameservers []string) error { r, err := dns.NewOSConfigurator(logf, iface.Name) if err != nil { @@ -115,6 +120,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { return nil } +// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable. +func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { + return resetDNS(iface) +} + func resetDNS(iface *net.Interface) (err error) { defer func() { if err == nil { @@ -276,8 +286,7 @@ func ignoringEINTR(fn func() error) error { func isSubSet(s1, s2 []string) bool { ok := true for _, ns := range s1 { - // TODO(cuonglm): use slices.Contains once upgrading to go1.21 - if sliceContains(s2, ns) { + if slices.Contains(s2, ns) { continue } ok = false @@ -285,19 +294,3 @@ func isSubSet(s1, s2 []string) bool { } return ok } - -// sliceContains reports whether v is present in s. -func sliceContains[S ~[]E, E comparable](s S, v E) bool { - return sliceIndex(s, v) >= 0 -} - -// sliceIndex returns the index of the first occurrence of v in s, -// or -1 if not present. -func sliceIndex[S ~[]E, E comparable](s S, v E) int { - for i := range s { - if v == s[i] { - return i - } - } - return -1 -} diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 56097f8..d2f1dd2 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -26,6 +26,12 @@ var ( resetDNSOnce sync.Once ) +// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable. +func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error { + return setDNS(iface, nameservers) +} + +// setDNS sets the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { return errors.New("empty DNS nameservers") @@ -61,6 +67,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { return nil } +// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable. +func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { + return resetDNS(iface) +} + // TODO(cuonglm): should we use system API? func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 6febff8..b3f3abf 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -33,11 +33,22 @@ const ( defaultSemaphoreCap = 256 ctrldLogUnixSock = "ctrld_start.sock" ctrldControlUnixSock = "ctrld_control.sock" - upstreamPrefix = "upstream." - upstreamOS = upstreamPrefix + "os" - upstreamPrivate = upstreamPrefix + "private" + // iOS unix socket name max length is 11. + ctrldControlUnixSockMobile = "cd.sock" + upstreamPrefix = "upstream." + upstreamOS = upstreamPrefix + "os" + upstreamPrivate = upstreamPrefix + "private" ) +// ControlSocketName returns name for control unix socket. +func ControlSocketName() string { + if isMobile() { + return ctrldControlUnixSockMobile + } else { + return ctrldControlUnixSock + } +} + var logf = func(format string, args ...any) { mainLog.Load().Debug().Msgf(format, args...) } @@ -59,17 +70,18 @@ type prog struct { logConn net.Conn cs *controlServer - cfg *ctrld.Config - localUpstreams []string - ptrNameservers []string - appCallback *AppCallback - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrLoopGuard *loopGuard - lanLoopGuard *loopGuard + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + cacheFlushDomainsMap map[string]struct{} + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard loopMu sync.Mutex loop map[string]bool @@ -242,12 +254,17 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.loop = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() + p.cacheFlushDomainsMap = nil if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled") } else { p.cache = cacher + p.cacheFlushDomainsMap = make(map[string]struct{}, 256) + for _, domain := range p.cfg.Service.CacheFlushDomains { + p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{} + } } } @@ -477,7 +494,7 @@ func (p *prog) setDNS() { } if allIfaces { withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error { - return setDNS(i, nameservers) + return setDnsIgnoreUnusableInterface(i, nameservers) }) } } @@ -509,7 +526,7 @@ func (p *prog) resetDNS() { } logger.Debug().Msg("Restoring DNS successfully") if allIfaces { - withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDNS) + withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface) } } diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index 2b9c69d..cdb3c0e 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -1,6 +1,8 @@ package cli import ( + "os" + "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld/internal/dns" @@ -10,6 +12,10 @@ func init() { if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo"); err == nil { useSystemdResolved = r.Mode() == "systemd-resolved" } + // Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911 + if os.Getenv("QUIC_GO_DISABLE_ECN") == "" { + os.Setenv("QUIC_GO_DISABLE_ECN", "true") + } } func setDependencies(svc *service.Config) { diff --git a/cmd/cli/service.go b/cmd/cli/service.go index c6ed68c..ef37796 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -20,7 +20,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error) return nil, err } switch { - case router.IsOldOpenwrt(): + case router.IsOldOpenwrt(), router.IsNetGearOrbi(): return &procd{&sysV{s}}, nil case router.IsGLiNet(): return &sysV{s}, nil diff --git a/cmd/ctrld_library/main.go b/cmd/ctrld_library/main.go index ec42b9c..49f5b26 100644 --- a/cmd/ctrld_library/main.go +++ b/cmd/ctrld_library/main.go @@ -61,8 +61,13 @@ func mapCallback(callback AppCallback) cli.AppCallback { } } -func (c *Controller) Stop(Pin int64) int { - errorCode := cli.CheckDeactivationPin(Pin) +func (c *Controller) Stop(restart bool, pin int64) int { + var errorCode = 0 + // Force disconnect without checking pin. + // In iOS restart is required if vpn detects no connectivity after network change. + if !restart { + errorCode = cli.CheckDeactivationPin(pin, c.stopCh) + } if errorCode == 0 && c.stopCh != nil { close(c.stopCh) c.stopCh = nil diff --git a/config.go b/config.go index cb38096..8c99a8e 100644 --- a/config.go +++ b/config.go @@ -46,6 +46,15 @@ const ( // depending on the record type of the DNS query. IpStackSplit = "split" + // FreeDnsDomain is the domain name of free ControlD service. + FreeDnsDomain = "freedns.controld.com" + // FreeDNSBoostrapIP is the IP address of freedns.controld.com. + FreeDNSBoostrapIP = "76.76.2.11" + // PremiumDnsDomain is the domain name of premium ControlD service. + PremiumDnsDomain = "dns.controld.com" + // PremiumDNSBoostrapIP is the IP address of dns.controld.com. + PremiumDNSBoostrapIP = "76.76.2.22" + controlDComDomain = "controld.com" controlDNetDomain = "controld.net" controlDDevDomain = "controld.dev" @@ -104,14 +113,14 @@ func InitConfig(v *viper.Viper, name string) { }) v.SetDefault("upstream", map[string]*UpstreamConfig{ "0": { - BootstrapIP: "76.76.2.11", + BootstrapIP: FreeDNSBoostrapIP, Name: "Control D - Anti-Malware", Type: ResolverTypeDOH, Endpoint: "https://freedns.controld.com/p1", Timeout: 5000, }, "1": { - BootstrapIP: "76.76.2.11", + BootstrapIP: FreeDNSBoostrapIP, Name: "Control D - No Ads", Type: ResolverTypeDOQ, Endpoint: "p2.freedns.controld.com", @@ -179,26 +188,27 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` - ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` - MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` - MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"` + ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` + MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"` + MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. @@ -285,6 +295,7 @@ type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. func (uc *UpstreamConfig) Init() { + uc.initDoHScheme() uc.uid = upstreamUID() if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Host @@ -510,35 +521,55 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { // Ping warms up the connection to DoH/DoH3 upstream. func (uc *UpstreamConfig) Ping() { + _ = uc.ping() +} + +// ErrorPing is like Ping, but return an error if any. +func (uc *UpstreamConfig) ErrorPing() error { + return uc.ping() +} + +func (uc *UpstreamConfig) ping() error { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: - return + return nil } - ping := func(t http.RoundTripper) { + ping := func(t http.RoundTripper) error { if t == nil { - return + return nil } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - req, _ := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil) - resp, _ := t.RoundTrip(req) - if resp == nil { - return + req, err := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil) + if err != nil { + return err + } + resp, err := t.RoundTrip(req) + if err != nil { + return err } defer resp.Body.Close() _, _ = io.Copy(io.Discard, resp.Body) + return nil } for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { switch uc.Type { case ResolverTypeDOH: - ping(uc.dohTransport(typ)) + + if err := ping(uc.dohTransport(typ)); err != nil { + return err + } case ResolverTypeDOH3: - ping(uc.doh3Transport(typ)) + if err := ping(uc.doh3Transport(typ)); err != nil { + return err + } } } + + return nil } func (uc *UpstreamConfig) isControlD() bool { @@ -631,6 +662,18 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { return "tcp-tls", "udp" } +// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present. +func (uc *UpstreamConfig) initDoHScheme() { + switch uc.Type { + case ResolverTypeDOH, ResolverTypeDOH3: + default: + return + } + if !strings.HasPrefix(uc.Endpoint, "https://") { + uc.Endpoint = "https://" + uc.Endpoint + } +} + // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { if lc.Policy != nil { @@ -683,6 +726,7 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) { return } + uc.initDoHScheme() // DoH/DoH3 requires endpoint is an HTTP url. if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 { u, err := url.Parse(uc.Endpoint) @@ -690,10 +734,6 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) { sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") return } - if u.Scheme != "http" && u.Scheme != "https" { - sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") - return - } } } diff --git a/config_test.go b/config_test.go index d66556f..83a1e13 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package ctrld_test import ( + "fmt" "os" "strings" "testing" @@ -102,6 +103,8 @@ func TestConfigValidation(t *testing.T) { {"invalid lease file format", configWithInvalidLeaseFileFormat(t), true}, {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, {"invalid client id pref", configWithInvalidClientIDPref(t), true}, + {"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false}, + {"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true}, } for _, tc := range tests { @@ -167,6 +170,12 @@ func invalidUpstreamType(t *testing.T) *ctrld.Config { return cfg } +func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "freedns.controld.com/p1" + return cfg +} + func invalidUpstreamTimeout(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) cfg.Upstream["0"].Timeout = -1 @@ -258,7 +267,7 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config { func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) - cfg.Upstream["0"].Endpoint = "1.1.1.1" + cfg.Upstream["0"].Endpoint = "/1.1.1.1" cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH return cfg } @@ -268,3 +277,12 @@ func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config { cfg.Service.ClientIDPref = "foo" return cfg } + +func configWithInvalidFlushCacheDomain(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.CacheFlushDomains = make([]string, 257) + for i := range cfg.Service.CacheFlushDomains { + cfg.Service.CacheFlushDomains[i] = fmt.Sprintf("%d.com", i) + } + return cfg +} diff --git a/docs/config.md b/docs/config.md index 5d099ea..d9c1dae 100644 --- a/docs/config.md +++ b/docs/config.md @@ -157,6 +157,13 @@ stale cached records (regardless of their TTLs) until upstream comes online. - Required: no - Default: false +### cache_flush_domains +When `ctrld` receives query with domain name in `cache_flush_domains`, the local cache will be discarded +before serving the query. + +- Type: array of strings +- Required: no + ### max_concurrent_requests The number of concurrent requests that will be handled, must be a non-negative integer. Tweaking this value depends on the capacity of your system. @@ -220,7 +227,7 @@ DHCP leases file format. - Type: string - Required: no -- Valid values: `dnsmasq`, `isc-dhcp` +- Valid values: `dnsmasq`, `isc-dhcp`, `kea-dhcp4` - Default: "" ### client_id_preference @@ -531,7 +538,7 @@ And within each policy, the rules are processed from top to bottom. ### failover_rcodes For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. -- Type: array of string +- Type: array of strings - Required: no - Default: [] - diff --git a/doh.go b/doh.go index 239fd6f..bddc583 100644 --- a/doh.go +++ b/doh.go @@ -60,17 +60,10 @@ func init() { } } -// TODO: use sync.OnceValue when upgrading to go1.21 -var xCdOsValueOnce sync.Once -var xCdOsValue string - -func dohOsHeaderValue() string { - xCdOsValueOnce.Do(func() { - oi := osinfo.New() - xCdOsValue = strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-") - }) - return xCdOsValue -} +var dohOsHeaderValue = sync.OnceValue(func() string { + oi := osinfo.New() + return strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-") +})() func newDohResolver(uc *UpstreamConfig) *dohResolver { r := &dohResolver{ @@ -172,7 +165,6 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { // newControlDHeaders returns DoH/Doh3 HTTP request headers for ControlD upstream. func newControlDHeaders(ci *ClientInfo) http.Header { header := make(http.Header) - header.Set(dohOsHeader, dohOsHeaderValue()) if ci.Mac != "" { header.Set(dohMacHeader, ci.Mac) } @@ -183,7 +175,7 @@ func newControlDHeaders(ci *ClientInfo) http.Header { header.Set(dohHostHeader, ci.Hostname) } if ci.Self { - header.Set(dohOsHeader, dohOsHeaderValue()) + header.Set(dohOsHeader, dohOsHeaderValue) } switch ci.ClientIDPref { case "mac": diff --git a/doh_test.go b/doh_test.go index d233498..8d3e011 100644 --- a/doh_test.go +++ b/doh_test.go @@ -6,7 +6,7 @@ import ( ) func Test_dohOsHeaderValue(t *testing.T) { - val := dohOsHeaderValue() + val := dohOsHeaderValue if val == "" { t.Fatalf("empty %s", dohOsHeader) } diff --git a/go.mod b/go.mod index 0476717..bfe6060 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Control-D-Inc/ctrld go 1.21 require ( + github.com/Masterminds/semver v1.5.0 github.com/coreos/go-systemd/v22 v22.5.0 github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf github.com/frankban/quicktest v1.14.5 @@ -17,26 +18,28 @@ require ( github.com/kardianos/service v1.2.1 github.com/mdlayher/ndp v1.0.1 github.com/miekg/dns v1.1.55 + github.com/minio/selfupdate v0.6.0 github.com/olekukonko/tablewriter v0.0.5 github.com/pelletier/go-toml/v2 v2.0.8 github.com/prometheus/client_golang v1.15.1 github.com/prometheus/client_model v0.4.0 github.com/prometheus/prom2json v1.3.3 - github.com/quic-go/quic-go v0.41.0 + github.com/quic-go/quic-go v0.42.0 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.3 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.17.0 + golang.org/x/net v0.23.0 golang.org/x/sync v0.2.0 - golang.org/x/sys v0.13.0 + golang.org/x/sys v0.18.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.44.0 ) require ( + aead.dev/minisign v0.2.0 // indirect github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect @@ -77,14 +80,14 @@ require ( github.com/subosito/gotenv v1.4.2 // indirect github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect github.com/vishvananda/netns v0.0.4 // indirect - go.uber.org/mock v0.3.0 // indirect + go.uber.org/mock v0.4.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - golang.org/x/crypto v0.14.0 // indirect + golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect golang.org/x/mod v0.11.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.9.1 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6ab5340..22f00e9 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk= +aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -38,6 +40,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik= github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA= @@ -222,6 +226,8 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= +github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= +github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= @@ -253,8 +259,8 @@ github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcET github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= -github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= +github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= +github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -310,8 +316,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= -go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -319,11 +325,13 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -394,8 +402,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -429,6 +437,7 @@ golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -454,6 +463,7 @@ golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -465,8 +475,9 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -476,11 +487,13 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -626,8 +639,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 1a775ea..780334b 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -14,6 +14,11 @@ import ( "github.com/Control-D-Inc/ctrld/internal/controld" ) +const ( + ipV4Loopback = "127.0.0.1" + ipv6Loopback = "::1" +) + // IpResolver is the interface for retrieving IP from Mac. type IpResolver interface { fmt.Stringer @@ -224,6 +229,7 @@ func (t *Table) init() { cancel() }() go t.ndp.listen(ctx) + go t.ndp.subscribe(ctx) } // PTR lookup. if t.discoverPTR() { @@ -321,6 +327,16 @@ func (t *Table) LookupRFC1918IPv4(mac string) string { return "" } +// LocalHostname returns the localhost hostname associated with loopback IP. +func (t *Table) LocalHostname() string { + for _, ip := range []string{ipV4Loopback, ipv6Loopback} { + if name := t.LookupHostname(ip, ""); name != "" { + return name + } + } + return "" +} + type macEntry struct { mac string src string diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 9d1f339..147ad29 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -353,8 +353,8 @@ func (d *dhcp) addSelf() { return } hostname = normalizeHostname(hostname) - d.ip2name.Store("127.0.0.1", hostname) - d.ip2name.Store("::1", hostname) + d.ip2name.Store(ipV4Loopback, hostname) + d.ip2name.Store(ipv6Loopback, hostname) found := false interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { mac := i.HardwareAddr.String() @@ -375,15 +375,17 @@ func (d *dhcp) addSelf() { d.mac.Store(ip.String(), mac) d.ip.Store(mac, ip.String()) if ip.To4() != nil { - d.mac.Store("127.0.0.1", mac) + d.mac.Store(ipV4Loopback, mac) } else { - d.mac.Store("::1", mac) + d.mac.Store(ipv6Loopback, mac) } d.mac2name.Store(mac, hostname) d.ip2name.Store(ip.String(), hostname) // If we have self IP set, and this IP is it, use this IP only. if ip.String() == d.selfIP { found = true + d.mac.Store(ipV4Loopback, mac) + d.mac.Store(ipv6Loopback, mac) } } }) diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index c758f3b..d96229d 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -95,7 +95,7 @@ func (hf *hostsFile) LookupHostnameByIP(ip string) string { hf.mu.Lock() defer hf.mu.Unlock() if names := hf.m[ip]; len(names) > 0 { - isLoopback := ip == "127.0.0.1" || ip == "::1" + isLoopback := ip == ipV4Loopback || ip == ipv6Loopback for _, hostname := range names { name := normalizeHostname(hostname) // Ignoring ipv4/ipv6 loopback entry. diff --git a/internal/clientinfo/ndp.go b/internal/clientinfo/ndp.go index 600b54c..9d9155d 100644 --- a/internal/clientinfo/ndp.go +++ b/internal/clientinfo/ndp.go @@ -15,6 +15,7 @@ import ( "github.com/mdlayher/ndp" "github.com/Control-D-Inc/ctrld" + ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) // ndpDiscover provides client discovery functionality using NDP protocol. @@ -69,15 +70,45 @@ func (nd *ndpDiscover) List() []string { return ips } +// saveInfo saves ip and mac info to mapping table. +func (nd *ndpDiscover) saveInfo(ip, mac string) { + ip = normalizeIP(ip) + // Store ip => map mapping, + nd.mac.Store(ip, mac) + + // Do not store mac => ip mapping if new ip is a link local unicast. + if ctrldnet.IsLinkLocalUnicastIPv6(ip) { + return + } + + // If there is old ip => mac mapping, delete it. + if old, existed := nd.ip.Load(mac); existed { + oldIP := old.(string) + if oldIP != ip { + nd.mac.Delete(oldIP) + } + } + // Store mac => ip mapping. + nd.ip.Store(mac, ip) +} + // listen listens on ipv6 link local for Neighbor Solicitation message // to update new neighbors information to ndp table. func (nd *ndpDiscover) listen(ctx context.Context) { - ifi, err := firstInterfaceWithV6LinkLocal() + ifis, err := allInterfacesWithV6LinkLocal() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6") + ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces") return } - c, ip, err := ndp.Listen(ifi, ndp.LinkLocal) + for _, ifi := range ifis { + go func(ifi *net.Interface) { + nd.listenOnInterface(ctx, ifi) + }(ifi) + } +} + +func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) { + c, ip, err := ndp.Listen(ifi, ndp.Unspecified) if err != nil { ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed") return @@ -111,8 +142,7 @@ func (nd *ndpDiscover) listen(ctx context.Context) { for _, opt := range am.Options { if lla, ok := opt.(*ndp.LinkLayerAddress); ok { mac := lla.Addr.String() - nd.mac.Store(fromIP, mac) - nd.ip.Store(mac, fromIP) + nd.saveInfo(fromIP, mac) } } } @@ -127,8 +157,7 @@ func (nd *ndpDiscover) scanWindows(r io.Reader) { continue } if mac := parseMAC(fields[1]); mac != "" { - nd.mac.Store(fields[0], mac) - nd.ip.Store(mac, fields[0]) + nd.saveInfo(fields[0], mac) } } } @@ -147,8 +176,7 @@ func (nd *ndpDiscover) scanUnix(r io.Reader) { if idx := strings.IndexByte(ip, '%'); idx != -1 { ip = ip[:idx] } - nd.mac.Store(ip, mac) - nd.ip.Store(mac, ip) + nd.saveInfo(ip, mac) } } } @@ -183,14 +211,15 @@ func parseMAC(mac string) string { return hw.String() } -// firstInterfaceWithV6LinkLocal returns the first interface which is capable of using NDP. -func firstInterfaceWithV6LinkLocal() (*net.Interface, error) { +// allInterfacesWithV6LinkLocal returns all interfaces which is capable of using NDP. +func allInterfacesWithV6LinkLocal() ([]*net.Interface, error) { ifis, err := net.Interfaces() if err != nil { return nil, err } - + res := make([]*net.Interface, 0, len(ifis)) for _, ifi := range ifis { + ifi := ifi // Skip if iface is down/loopback/non-multicast. if ifi.Flags&net.FlagUp == 0 || ifi.Flags&net.FlagLoopback != 0 || ifi.Flags&net.FlagMulticast == 0 { continue @@ -211,9 +240,10 @@ func firstInterfaceWithV6LinkLocal() (*net.Interface, error) { return nil, fmt.Errorf("invalid ip address: %s", ipNet.String()) } if ip.Is6() && !ip.Is4In6() { - return &ifi, nil + res = append(res, &ifi) + break } } } - return nil, errors.New("no interface can be used") + return res, nil } diff --git a/internal/clientinfo/ndp_linux.go b/internal/clientinfo/ndp_linux.go index 713a7e3..ebd416f 100644 --- a/internal/clientinfo/ndp_linux.go +++ b/internal/clientinfo/ndp_linux.go @@ -1,7 +1,10 @@ package clientinfo import ( + "context" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" "github.com/Control-D-Inc/ctrld" ) @@ -15,10 +18,47 @@ func (nd *ndpDiscover) scan() { } for _, n := range neighs { + // Skipping non-reachable neighbors. + if n.State&netlink.NUD_REACHABLE == 0 { + continue + } ip := n.IP.String() mac := n.HardwareAddr.String() - nd.mac.Store(ip, mac) - nd.ip.Store(mac, ip) + nd.saveInfo(ip, mac) + } +} + +// subscribe watches NDP table changes and update new information to local table. +func (nd *ndpDiscover) subscribe(ctx context.Context) { + ch := make(chan netlink.NeighUpdate) + done := make(chan struct{}) + defer close(done) + if err := netlink.NeighSubscribe(ch, done); err != nil { + ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing") + return + } + for { + select { + case <-ctx.Done(): + return + case nu := <-ch: + if nu.Family != netlink.FAMILY_V6 { + continue + } + ip := normalizeIP(nu.IP.String()) + if nu.Type == unix.RTM_DELNEIGH { + ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip) + nd.mac.Delete(ip) + continue + } + mac := nu.HardwareAddr.String() + switch nu.State { + case netlink.NUD_REACHABLE: + nd.saveInfo(ip, mac) + case netlink.NUD_FAILED: + ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip) + nd.mac.Delete(ip) + } + } } - } diff --git a/internal/clientinfo/ndp_others.go b/internal/clientinfo/ndp_others.go index 05ac322..007407b 100644 --- a/internal/clientinfo/ndp_others.go +++ b/internal/clientinfo/ndp_others.go @@ -4,6 +4,7 @@ package clientinfo import ( "bytes" + "context" "os/exec" "runtime" @@ -29,3 +30,7 @@ func (nd *ndpDiscover) scan() { nd.scanUnix(bytes.NewReader(data)) } } + +// subscribe watches NDP table changes and update new information to local table. +// This is a stub method, and only works on Linux at this moment. +func (nd *ndpDiscover) subscribe(ctx context.Context) {} diff --git a/internal/clientinfo/ndp_test.go b/internal/clientinfo/ndp_test.go index c8cd398..ca924b9 100644 --- a/internal/clientinfo/ndp_test.go +++ b/internal/clientinfo/ndp_test.go @@ -45,20 +45,22 @@ ff02::c 33-33-00-00-00-0c Permanent nd.scanWindows(r) count := 0 + expectedCount := 6 nd.mac.Range(func(key, value any) bool { count++ return true }) - if count != 6 { - t.Errorf("unexpected count, want 6, got: %d", count) + if count != expectedCount { + t.Errorf("unexpected count, want %d, got: %d", expectedCount, count) } count = 0 + expectedCount = 4 nd.ip.Range(func(key, value any) bool { count++ return true }) - if count != 5 { - t.Errorf("unexpected count, want 5, got: %d", count) + if count != expectedCount { + t.Errorf("unexpected count, want %d, got: %d", expectedCount, count) } } diff --git a/internal/dnscache/cache.go b/internal/dnscache/cache.go index 4aa7f69..af8883e 100644 --- a/internal/dnscache/cache.go +++ b/internal/dnscache/cache.go @@ -12,6 +12,7 @@ import ( type Cacher interface { Get(Key) *Value Add(Key, *Value) + Purge() } // Key is the caching key for DNS message. @@ -34,15 +35,22 @@ type LRUCache struct { cacher *lru.ARCCache[Key, *Value] } +// Get looks up key's value from cache. func (l *LRUCache) Get(key Key) *Value { v, _ := l.cacher.Get(key) return v } +// Add adds a value to cache. func (l *LRUCache) Add(key Key, value *Value) { l.cacher.Add(key, value) } +// Purge clears the cache. +func (l *LRUCache) Purge() { + l.cacher.Purge() +} + // NewLRUCache creates a new LRUCache instance with given size. func NewLRUCache(size int) (*LRUCache, error) { cacher, err := lru.NewARC[Key, *Value](size) diff --git a/internal/net/net.go b/internal/net/net.go index 770c3db..3a81849 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -115,6 +115,15 @@ func IsIPv6(ip string) bool { return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil } +// IsLinkLocalUnicastIPv6 checks if the provided IP is a link local unicast v6 address. +func IsLinkLocalUnicastIPv6(ip string) bool { + parsedIP := net.ParseIP(ip) + if parsedIP == nil || parsedIP.To4() != nil || parsedIP.To16() == nil { + return false + } + return parsedIP.To16().IsLinkLocalUnicast() +} + type parallelDialerResult struct { conn net.Conn err error diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index c2f8845..55c62e8 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -10,6 +10,8 @@ import ( "github.com/Control-D-Inc/ctrld" ) +const CtrldMarker = `# GENERATED BY ctrld - DO NOT MODIFY` + const ConfigContentTmpl = `# GENERATED BY ctrld - DO NOT MODIFY no-resolv {{- range .Upstreams}} diff --git a/internal/router/netgear_orbi_voxel/procd.go b/internal/router/netgear_orbi_voxel/procd.go new file mode 100644 index 0000000..750a17d --- /dev/null +++ b/internal/router/netgear_orbi_voxel/procd.go @@ -0,0 +1,22 @@ +package netgear + +const openWrtScript = `#!/bin/sh /etc/rc.common +USE_PROCD=1 +# After dnsmasq starts +START=61 +# Before network stops +STOP=89 +cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}" +name="{{.Name}}" +pid_file="/var/run/${name}.pid" + +start_service() { + echo "Starting ${name}" + procd_open_instance + procd_set_param command ${cmd} + procd_set_param respawn # respawn automatically if something died + procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop + procd_close_instance + echo "${name} has been started" +} +` diff --git a/internal/router/netgear_orbi_voxel/voxel.go b/internal/router/netgear_orbi_voxel/voxel.go new file mode 100644 index 0000000..4338f9c --- /dev/null +++ b/internal/router/netgear_orbi_voxel/voxel.go @@ -0,0 +1,220 @@ +package netgear + +import ( + "bufio" + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/kardianos/service" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" + "github.com/Control-D-Inc/ctrld/internal/router/nvram" +) + +const ( + Name = "netgear_orbi_voxel" + netgearOrbiVoxelDNSMasqConfigPath = "/etc/dnsmasq.conf" + netgearOrbiVoxelHomedir = "/mnt/bitdefender" + netgearOrbiVoxelStartupScript = "/mnt/bitdefender/rc.user" + netgearOrbiVoxelStartupScriptBackup = "/mnt/bitdefender/rc.user.bak" + netgearOrbiVoxelStartupScriptMarker = "\n# GENERATED BY ctrld" +) + +var nvramKvMap = map[string]string{ + "dns_hijack": "0", // Disable dns hijacking +} + +type NetgearOrbiVoxel struct { + cfg *ctrld.Config +} + +// New returns a router.Router for configuring/setup/run ctrld on ddwrt routers. +func New(cfg *ctrld.Config) *NetgearOrbiVoxel { + return &NetgearOrbiVoxel{cfg: cfg} +} + +func (d *NetgearOrbiVoxel) ConfigureService(svc *service.Config) error { + if err := d.checkInstalledDir(); err != nil { + return err + } + svc.Option["SysvScript"] = openWrtScript + return nil +} + +func (d *NetgearOrbiVoxel) Install(_ *service.Config) error { + // Ignoring error here at this moment is ok, since everything will be wiped out on reboot. + _ = exec.Command("/etc/init.d/ctrld", "enable").Run() + if err := d.checkInstalledDir(); err != nil { + return err + } + if err := backupVoxelStartupScript(); err != nil { + return fmt.Errorf("backup startup script: %w", err) + } + if err := writeVoxelStartupScript(); err != nil { + return fmt.Errorf("writing startup script: %w", err) + } + return nil +} + +func (d *NetgearOrbiVoxel) Uninstall(_ *service.Config) error { + if err := os.Remove(netgearOrbiVoxelStartupScript); err != nil && !os.IsNotExist(err) { + return err + } + err := os.Rename(netgearOrbiVoxelStartupScriptBackup, netgearOrbiVoxelStartupScript) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +func (d *NetgearOrbiVoxel) PreRun() error { + return nil +} + +func (d *NetgearOrbiVoxel) Setup() error { + if d.cfg.FirstListener().IsDirectDnsListener() { + return nil + } + // Already setup. + if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { + return nil + } + + data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, d.cfg, false) + if err != nil { + return err + } + currentConfig, _ := os.ReadFile(netgearOrbiVoxelDNSMasqConfigPath) + configContent := append(currentConfig, data...) + if err := os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, configContent, 0600); err != nil { + return err + } + // Restart dnsmasq service. + if err := restartDNSMasq(); err != nil { + return err + } + + if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil { + return err + } + + return nil +} + +func (d *NetgearOrbiVoxel) Cleanup() error { + if d.cfg.FirstListener().IsDirectDnsListener() { + return nil + } + if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" { + return nil // was restored, nothing to do. + } + + // Restore old configs. + if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil { + return err + } + + // Restore dnsmasq config. + if err := restoreDnsmasqConf(); err != nil { + return err + } + + // Restart dnsmasq service. + if err := restartDNSMasq(); err != nil { + return err + } + return nil +} + +// checkInstalledDir checks that ctrld binary was installed in the correct directory. +func (d *NetgearOrbiVoxel) checkInstalledDir() error { + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("checkHomeDir: failed to get binary path %w", err) + } + if !strings.HasSuffix(filepath.Dir(exePath), netgearOrbiVoxelHomedir) { + return fmt.Errorf("checkHomeDir: could not install service outside %s", netgearOrbiVoxelHomedir) + } + return nil +} + +// backupVoxelStartupScript creates a backup of original startup script if existed. +func backupVoxelStartupScript() error { + // Do nothing if the startup script was modified by ctrld. + script, _ := os.ReadFile(netgearOrbiVoxelStartupScript) + if bytes.Contains(script, []byte(netgearOrbiVoxelStartupScriptMarker)) { + return nil + } + err := os.Rename(netgearOrbiVoxelStartupScript, netgearOrbiVoxelStartupScriptBackup) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("backupVoxelStartupScript: %w", err) + } + return nil +} + +// writeVoxelStartupScript writes startup script to re-install ctrld upon reboot. +// See: https://github.com/SVoxel/ORBI-RBK50/pull/7 +func writeVoxelStartupScript() error { + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("configure service: failed to get binary path %w", err) + } + // This is called when "ctrld start ..." runs, so recording + // the same command line arguments to use in startup script. + argStr := strings.Join(os.Args[1:], " ") + script, _ := os.ReadFile(netgearOrbiVoxelStartupScriptBackup) + script = append(script, fmt.Sprintf("%s\n%q %s\n", netgearOrbiVoxelStartupScriptMarker, exe, argStr)...) + f, err := os.Create(netgearOrbiVoxelStartupScript) + if err != nil { + return fmt.Errorf("failed to create startup script: %w", err) + } + defer f.Close() + + if _, err := f.Write(script); err != nil { + return fmt.Errorf("failed to write startup script: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("failed to save startup script: %w", err) + } + return nil +} + +// restoreDnsmasqConf restores original dnsmasq configuration. +func restoreDnsmasqConf() error { + f, err := os.Open(netgearOrbiVoxelDNSMasqConfigPath) + if err != nil { + return err + } + defer f.Close() + + var bs []byte + buf := bytes.NewBuffer(bs) + + removed := false + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if line == dnsmasq.CtrldMarker { + removed = true + } + if !removed { + _, err := buf.WriteString(line + "\n") + if err != nil { + return err + } + } + } + return os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, buf.Bytes(), 0644) +} + +func restartDNSMasq() error { + if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil { + return fmt.Errorf("restartDNSMasq: %s, %w", string(out), err) + } + return nil +} diff --git a/internal/router/os_config_freebsd.go b/internal/router/os_config_freebsd.go new file mode 100644 index 0000000..9066191 --- /dev/null +++ b/internal/router/os_config_freebsd.go @@ -0,0 +1,40 @@ +package router + +import ( + "encoding/xml" + "os" +) + +// Config represents /conf/config.xml file found on pfsense/opnsense. +type Config struct { + PfsenseUnbound *string `xml:"unbound>enable,omitempty"` + OPNsenseUnbound *string `xml:"OPNsense>unboundplus>general>enabled,omitempty"` + Dnsmasq *string `xml:"dnsmasq>enable,omitempty"` +} + +// DnsmasqEnabled reports whether dnsmasq is enabled. +func (c *Config) DnsmasqEnabled() bool { + if isPfsense() { // pfsense only set the attribute if dnsmasq is enabled. + return c.Dnsmasq != nil + } + return c.Dnsmasq != nil && *c.Dnsmasq == "1" +} + +// UnboundEnabled reports whether unbound is enabled. +func (c *Config) UnboundEnabled() bool { + if isPfsense() { // pfsense only set the attribute if unbound is enabled. + return c.PfsenseUnbound != nil + } + return c.OPNsenseUnbound != nil && *c.OPNsenseUnbound == "1" +} + +// currentConfig does unmarshalling /conf/config.xml file, +// return the corresponding *Config represent it. +func currentConfig() (*Config, error) { + buf, _ := os.ReadFile("/conf/config.xml") + c := Config{} + if err := xml.Unmarshal(buf, &c); err != nil { + return nil, err + } + return &c, nil +} diff --git a/internal/router/os_freebsd.go b/internal/router/os_freebsd.go index c38eebc..9a79188 100644 --- a/internal/router/os_freebsd.go +++ b/internal/router/os_freebsd.go @@ -111,8 +111,16 @@ func (or *osRouter) Setup() error { func (or *osRouter) Cleanup() error { if or.cdMode { - _ = exec.Command(unboundRcPath, "onerestart").Run() - _ = exec.Command(dnsmasqRcPath, "onerestart").Run() + c, err := currentConfig() + if err != nil { + return err + } + if c.UnboundEnabled() { + _ = exec.Command(unboundRcPath, "onerestart").Run() + } + if c.DnsmasqEnabled() { + _ = exec.Command(dnsmasqRcPath, "onerestart").Run() + } } return nil } diff --git a/internal/router/router.go b/internal/router/router.go index 2990cd7..18b7a90 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -18,6 +18,7 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router/edgeos" "github.com/Control-D-Inc/ctrld/internal/router/firewalla" "github.com/Control-D-Inc/ctrld/internal/router/merlin" + netgear "github.com/Control-D-Inc/ctrld/internal/router/netgear_orbi_voxel" "github.com/Control-D-Inc/ctrld/internal/router/openwrt" "github.com/Control-D-Inc/ctrld/internal/router/synology" "github.com/Control-D-Inc/ctrld/internal/router/tomato" @@ -66,10 +67,17 @@ func New(cfg *ctrld.Config, cdMode bool) Router { return tomato.New(cfg) case firewalla.Name: return firewalla.New(cfg) + case netgear.Name: + return netgear.New(cfg) } return newOsRouter(cfg, cdMode) } +// IsNetGearOrbi reports whether the router is a Netgear Orbi router. +func IsNetGearOrbi() bool { + return Name() == netgear.Name +} + // IsGLiNet reports whether the router is an GL.iNet router. func IsGLiNet() bool { if Name() != openwrt.Name { @@ -145,7 +153,7 @@ func LocalResolverIP() string { // HomeDir returns the home directory of ctrld on current router. func HomeDir() (string, error) { switch Name() { - case ddwrt.Name, merlin.Name, tomato.Name: + case ddwrt.Name, firewalla.Name, merlin.Name, netgear.Name, tomato.Name: exe, err := os.Executable() if err != nil { return "", err @@ -198,6 +206,9 @@ func distroName() string { case bytes.HasPrefix(unameO(), []byte("ASUSWRT-Merlin")): return merlin.Name case haveFile("/etc/openwrt_version"): + if haveFile("/bin/config") { // TODO: is there any more reliable way? + return netgear.Name + } return openwrt.Name case isUbios(): return ubios.Name