diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3a989ad..b4b44d4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: with: go-version: ${{ matrix.go }} - run: "go test -race ./..." - - uses: dominikh/staticcheck-action@v1.2.0 + - uses: dominikh/staticcheck-action@v1.3.1 with: version: "2024.1.1" install-go: false diff --git a/cmd/cli/cgo.go b/cmd/cli/cgo.go new file mode 100644 index 0000000..9979523 --- /dev/null +++ b/cmd/cli/cgo.go @@ -0,0 +1,5 @@ +//go:build cgo + +package cli + +const cgoEnabled = true diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 9d01206..eb3b910 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -61,7 +61,7 @@ var ( v = viper.NewWithOptions(viper.KeyDelimiter("::")) defaultConfigFile = "ctrld.toml" rootCertPool *x509.CertPool - errSelfCheckNoAnswer = errors.New("no answer from ctrld listener") + errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks") ) var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} @@ -222,10 +222,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { lc := &logConn{conn: conn} consoleWriter.Out = io.MultiWriter(os.Stdout, lc) p.logConn = lc + } else { + mainLog.Load().Warn().Err(err).Msgf("unable to create log ipc connection") } + } else { + mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) } notifyExitToLogServer := func() { - _, _ = p.logConn.Write([]byte(msgExit)) + if p.logConn != nil { + _, _ = p.logConn.Write([]byte(msgExit)) + } } if daemon && runtime.GOOS == "windows" { @@ -266,10 +272,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. - logWriters := initLogging() - - // Initializing internal logging after global logging. - p.initInternalLogging(logWriters) + p.initLogging(true) mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) mainLog.Load().Info().Msgf("os: %s", osVersion()) @@ -322,7 +325,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } } - updated := updateListenerConfig(&cfg) + updated := updateListenerConfig(&cfg, notifyExitToLogServer) if cdUID != "" { processLogAndCacheFlags() @@ -418,7 +421,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := p.router.Cleanup(); err != nil { mainLog.Load().Error().Err(err).Msg("could not cleanup router") } - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) }) } } @@ -484,7 +488,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } nop := zerolog.Nop() - _, _ = tryUpdateListenerConfig(&cfg, &nop, true) + _, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true) addExtraSplitDnsRule(&cfg) if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) @@ -645,11 +649,15 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") - if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil { + var cfgErr error + if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil { setListenerDefaultValue(cfg) - return resolverConfig, nil + setNetworkDefaultValue(cfg) + if cfgErr = validateConfig(cfg); cfgErr == nil { + return resolverConfig, nil + } } - mainLog.Load().Err(err).Msg("disregarding invalid custom config") + mainLog.Load().Warn().Err(err).Msg("disregarding invalid custom config") } bootstrapIP := func(endpoint string) string { @@ -666,11 +674,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { } return "" } - cfg.Network = make(map[string]*ctrld.NetworkConfig) - cfg.Network["0"] = &ctrld.NetworkConfig{ - Name: "Network 0", - Cidrs: []string{"0.0.0.0/0"}, - } + cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ BootstrapIP: bootstrapIP(resolverConfig.DOH), @@ -693,6 +697,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { // Set default value. setListenerDefaultValue(cfg) + setNetworkDefaultValue(cfg) return resolverConfig, nil } @@ -706,7 +711,21 @@ func setListenerDefaultValue(cfg *ctrld.Config) { } } +// setListenerDefaultValue sets the default value for cfg.Listener if none existed. +func setNetworkDefaultValue(cfg *ctrld.Config) { + if len(cfg.Network) == 0 { + cfg.Network = map[string]*ctrld.NetworkConfig{ + "0": { + Name: "Network 0", + Cidrs: []string{"0.0.0.0/0"}, + }, + } + } +} + // validateCdRemoteConfig validates the custom config from ControlD if defined. +// This only validate the config syntax. To validate the config rules, calling +// validateConfig with the cfg after calling this function. func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error { if rc.Ctrld.CustomConfig == "" { return nil @@ -783,7 +802,13 @@ func defaultIfaceName() string { if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") { return "lo" } - mainLog.Load().Fatal().Err(err).Msg("failed to get default route interface") + // On linux, it could be either resolvconf or systemd which is managing DNS settings, + // so the interface name does not matter if there's no default route interface. + if runtime.GOOS == "linux" { + return "lo" + } + mainLog.Load().Debug().Err(err).Msg("no default route interface found") + return "" } return dri } @@ -843,10 +868,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo } mainLog.Load().Debug().Msg("ctrld listener is ready") - mainLog.Load().Debug().Msg("performing self-check") lc := cfg.FirstListener() addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) + + mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr) + if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil { return false, status, err } @@ -860,7 +887,7 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error { bo := backoff.NewBackoff("self-check", logf, 10*time.Second) bo.LogLongerThan = 500 * time.Millisecond - maxAttempts := 20 + maxAttempts := 10 c := new(dns.Client) var ( @@ -876,7 +903,7 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) m.RecursionDesired = true - r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr) + r, _, exErr := exchangeContextWithTimeout(c, 5*time.Second, m, addr) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain) return nil @@ -1030,16 +1057,25 @@ func uninstall(p *prog, s service.Service) { mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error") return } - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) - // if present restore the original DNS settings - if netIface, err := netInterface(p.runningIface); err == nil { - if err := restoreDNS(netIface); err != nil { - mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") - } else { - mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + // Iterate over all physical interfaces and restore DNS if a saved static config exists. + withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + if err := restoreDNS(i); err != nil { + mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + } else { + mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + err = os.Remove(file) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name) + } + } } - } + return nil + }) if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") @@ -1146,8 +1182,8 @@ func mobileListenerIp() string { // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. -func updateListenerConfig(cfg *ctrld.Config) bool { - updated, _ := tryUpdateListenerConfig(cfg, nil, true) +func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool { + updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true) if addExtraSplitDnsRule(cfg) { updated = true } @@ -1157,13 +1193,14 @@ func updateListenerConfig(cfg *ctrld.Config) bool { // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. -func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) { +func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" // For Windows server with local Dns server running, we can only try on random local IP. hasLocalDnsServer := hasLocalDnsServerRunning() + notRouter := router.Name() == "" for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -1193,6 +1230,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata } updated = updated || lcc[n].IP || lcc[n].Port } + il := mainLog.Load() if infoLogger != nil { il = infoLogger @@ -1277,10 +1315,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata tryOldIPPort5354 = false tryPort5354 = false } + // if not running on a router, we should not try to listen on any port other than 53 + // if we do, this will break the dns resolution for the system. + if notRouter { + tryOldIPPort5354 = false + tryPort5354 = false + } attempts := 0 maxAttempts := 10 for { if attempts == maxAttempts { + notifyFunc() logMsg(mainLog.Load().Fatal(), n, "could not find available listen ip and port") } addr := net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)) @@ -1288,8 +1333,12 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata if err == nil { break } + + logMsg(il.Info(), n, "error listening on address: %s, error: %v", addr, err) + if !check.IP && !check.Port { if fatal { + notifyFunc() logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) } ok = false @@ -1348,14 +1397,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata } else { listener.IP = oldIP } - if check.Port { + // if we are not running on a router, we should not try to listen on any port other than 53 + // if we do, this will break the dns resolution for the system. + if check.Port && !notRouter { listener.Port = randomPort() } else { listener.Port = oldPort } if listener.IP == oldIP && listener.Port == oldPort { if fatal { - logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + notifyFunc() + logMsg(mainLog.Load().Fatal(), n, "could not listen on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) } ok = false break @@ -1393,6 +1445,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata } } if !found { + notifyFunc() logMsg(mainLog.Load().Fatal(), n, "could not use %q as DNS nameserver with systemd resolved", listener.IP) } } @@ -1597,7 +1650,7 @@ func doGenerateNextDNSConfig(uid string) error { } mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) generateNextDNSConfig(uid) - updateListenerConfig(&cfg) + updateListenerConfig(&cfg, func() {}) return writeConfigFile(&cfg) } @@ -1739,9 +1792,14 @@ func goArm() string { // upgradeUrl returns the url for downloading new ctrld binary. func upgradeUrl(baseUrl string) string { dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH) + // Use arm version set during build time, v5 binary can be run on higher arm version system. if armVersion := goArm(); armVersion != "" { dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion) } + // linux/amd64 has nocgo version, to support systems that missing some libc (like openwrt). + if !cgoEnabled && runtime.GOOS == "linux" && runtime.GOARCH == "amd64" { + dlPath = fmt.Sprintf("%s-%s-nocgo/ctrld", runtime.GOOS, runtime.GOARCH) + } dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath) if runtime.GOOS == "windows" { dlUrl += ".exe" @@ -1768,50 +1826,6 @@ func runningIface(s service.Service) *ifaceResponse { return nil } -// resetDnsNoLog performs resetting DNS with logging disable. -func resetDnsNoLog(p *prog) { - // Normally, disable log to prevent annoying users. - if verbose < 3 { - lvl := zerolog.GlobalLevel() - zerolog.SetGlobalLevel(zerolog.Disabled) - p.resetDNS() - zerolog.SetGlobalLevel(lvl) - return - } - // For debugging purpose, still emit log. - p.resetDNS() -} - -// resetDnsTask returns a task which perform reset DNS operation. -func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task { - return task{func() error { - if iface == "" { - mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask") - return nil - } - // Always reset DNS first, ensuring DNS setting is in a good state. - // resetDNS must use the "iface" value of current running ctrld - // process to reset what setDNS has done properly. - oldIface := iface - iface = "auto" - p.requiredMultiNICsConfig = requiredMultiNICsConfig() - if ir != nil { - iface = ir.Name - p.requiredMultiNICsConfig = ir.All - } - p.runningIface = iface - if isCtrldInstalled { - mainLog.Load().Debug().Msg("restore system DNS settings") - if status, _ := s.Status(); status == service.StatusRunning { - mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe") - } - resetDnsNoLog(p) - } - iface = oldIface - return nil - }, false, "Reset DNS"} -} - // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) @@ -1825,10 +1839,22 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error { return err } } + + // return earlier if there's no custom config. + if rc.Ctrld.CustomConfig == "" { + return nil + } + // validateCdRemoteConfig clobbers v, saving it here to restore later. oldV := v - if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil { - if errors.As(err, &viper.ConfigParseError{}) { + var cfgErr error + remoteCfg := &ctrld.Config{} + if cfgErr = validateCdRemoteConfig(rc, remoteCfg); cfgErr == nil { + setListenerDefaultValue(remoteCfg) + setNetworkDefaultValue(remoteCfg) + cfgErr = validateConfig(remoteCfg) + } else { + if errors.As(cfgErr, &viper.ConfigParseError{}) { if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { tmpDir := os.TempDir() tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") @@ -1844,12 +1870,14 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error { } // If we could not log details error, emit what we have already got. if !errorLogged { - mainLog.Load().Error().Msgf("failed to parse custom config: %v", err) + mainLog.Load().Error().Msgf("failed to parse custom config: %v", cfgErr) } } } else { mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) } + } + if cfgErr != nil { mainLog.Load().Warn().Msg("disregarding invalid custom config") } v = oldV @@ -1863,8 +1891,8 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { logger.Warn().Err(err).Msg("failed to create new service") return false } - - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 49dfb8f..6c0c202 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -205,7 +205,13 @@ func initStartCmd() *cobra.Command { Long: `Install and start the ctrld service NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Args: cobra.NoArgs, + Args: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") + } + return nil + }, Run: func(cmd *cobra.Command, args []string) { checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) @@ -242,6 +248,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c os.Exit(deactivationPinInvalidExitCode) } currentIface = runningIface(s) + mainLog.Load().Debug().Msgf("current interface on start: %v", currentIface) } ctx, cancel := context.WithCancel(context.Background()) @@ -339,13 +346,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } + // if already running, dont restart + if isCtrldRunning { + mainLog.Load().Notice().Msg("service is already running") + return + } + initInteractiveLogging() tasks := []task{ - {s.Stop, false, "Stop"}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { return err } @@ -355,7 +366,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c }, false, "Save current DNS"}, {func() error { return ConfigureWindowsServiceFailureActions(ctrldServiceName) - }, false, "Configure Windows service failure actions"}, + }, false, "Configure service failure actions"}, {s.Start, true, "Start"}, {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } @@ -424,10 +435,10 @@ NOTE: running "ctrld start" without any arguments will start already installed c {s.Stop, false, "Stop"}, {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), + //resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { return err } @@ -451,6 +462,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c return } + // add a small delay to ensure the service is started and did not crash + time.Sleep(1 * time.Second) + ok, status, err := selfCheckStatus(ctx, s, sockDir) switch { case ok && status == service.StatusRunning: @@ -550,6 +564,13 @@ NOTE: running "ctrld start" without any arguments will start already installed c Long: `Quick start service and configure DNS on interface NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") + } + return nil + }, Run: func(cmd *cobra.Command, args []string) { if len(os.Args) == 2 { startOnly = true @@ -608,16 +629,21 @@ func initStopCmd() *cobra.Command { } if doTasks([]task{{s.Stop, true, "Stop"}}) { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) - // restore DNS settings - if netIface, err := netInterface(p.runningIface); err == nil { - if err := restoreDNS(netIface); err != nil { - mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") - } else { - mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + // Iterate over all physical interfaces and restore static DNS if a saved static config exists. + withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + if err := restoreDNS(i); err != nil { + mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + } else { + mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + } } - } + return nil + }) if router.WaitProcessExited() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) @@ -714,7 +740,8 @@ func initRestartCmd() *cobra.Command { {s.Stop, true, "Stop"}, {func() error { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) return nil }, false, "Cleanup"}, {func() error { @@ -994,13 +1021,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if os.IsNotExist(err) { continue } - mainLog.Load().Warn().Err(err).Msg("failed to remove file") + mainLog.Load().Warn().Err(err).Msgf("failed to remove file: %s", file) } else { mainLog.Load().Debug().Msgf("file removed: %s", file) } } if err := selfDeleteExe(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to remove file") + mainLog.Load().Warn().Err(err).Msg("failed to delete ctrld binary") } else { if !supportedSelfDelete { mainLog.Load().Debug().Msgf("file removed: %s", bin) @@ -1044,9 +1071,16 @@ func initInterfacesCmd() *cobra.Command { Short: "List network interfaces of the host", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error { fmt.Printf("Index : %d\n", i.Index) fmt.Printf("Name : %s\n", i.Name) + var status string + if i.Flags&net.FlagUp != 0 { + status = "Up" + } else { + status = "Down" + } + fmt.Printf("Status: %s\n", status) addrs, _ := i.Addrs() for i, ipaddr := range addrs { if i == 0 { @@ -1242,7 +1276,8 @@ func initUpgradeCmd() *cobra.Command { } dlUrl := upgradeUrl(baseUrl) mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) - resp, err := http.Get(dlUrl) + + resp, err := getWithRetry(dlUrl) if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to download binary") } @@ -1266,7 +1301,8 @@ func initUpgradeCmd() *cobra.Command { {s.Stop, true, "Stop"}, {func() error { p.router.Cleanup() - p.resetDNS() + // restore static DNS settings or DHCP + p.resetDNS(false, true) return nil }, false, "Cleanup"}, {func() error { diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 1ea1693..17f585d 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -79,33 +79,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) { func (p *prog) registerControlServerHandler() { p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + mainLog.Load().Debug().Msg("handling list clients request") + clients := p.ciTable.ListClients() + mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list") + sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) + mainLog.Load().Debug().Msg("sorted clients by IP address") + if p.metricsQueryStats.Load() { - for _, client := range clients { + mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts") + + for idx, client := range clients { + mainLog.Load().Debug(). + Int("index", idx). + Str("ip", client.IP.String()). + Str("mac", client.Mac). + Str("hostname", client.Hostname). + Msg("processing client metrics") + client.IncludeQueryCount = true dm := &dto.Metric{} + + if statsClientQueriesCount.MetricVec == nil { + mainLog.Load().Debug(). + Str("client_ip", client.IP.String()). + Msg("skipping metrics collection: MetricVec is nil") + continue + } + m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues( client.IP.String(), client.Mac, client.Hostname, ) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client) + mainLog.Load().Debug(). + Err(err). + Str("client_ip", client.IP.String()). + Str("mac", client.Mac). + Str("hostname", client.Hostname). + Msg("failed to get metrics for client") continue } - if err := m.Write(dm); err == nil { + + if err := m.Write(dm); err == nil && dm.Counter != nil { client.QueryCount = int64(dm.Counter.GetValue()) + mainLog.Load().Debug(). + Str("client_ip", client.IP.String()). + Int64("query_count", client.QueryCount). + Msg("successfully collected query count") + } else if err != nil { + mainLog.Load().Debug(). + Err(err). + Str("client_ip", client.IP.String()). + Msg("failed to write metric") } } + } else { + mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts") } if err := json.NewEncoder(w).Encode(&clients); err != nil { + mainLog.Load().Error(). + Err(err). + Int("client_count", len(clients)). + Msg("failed to encode clients response") http.Error(w, err.Error(), http.StatusInternalServerError) return } + + mainLog.Load().Debug(). + Int("client_count", len(clients)). + Msg("successfully sent clients list response") })) p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { select { @@ -250,7 +298,7 @@ func (p *prog) registerControlServerHandler() { } })) p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - if time.Since(p.internalLogSent) < logSentInterval { + if time.Since(p.internalLogSent) < logWriterSentInterval { w.WriteHeader(http.StatusServiceUnavailable) return } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0bc042e..e3dbc26 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -20,12 +20,12 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" + "github.com/Control-D-Inc/ctrld/internal/router" ) const ( @@ -435,14 +435,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} - } - - if p.isAdDomainQuery(req.msg) { - ctrld.Log(ctx, mainLog.Load().Debug(), - "AD domain query detected for %s in domain %s", - req.msg.Question[0].Name, p.adDomain) - upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} - upstreams = []string{upstreamOS} + // For OS resolver, local addresses are ignored to prevent possible looping. + // However, on Active Directory Domain Controller, where it has local DNS server + // running and listening on local addresses, these local addresses must be used + // as nameservers, so queries for ADDC could be resolved as expected. + if p.isAdDomainQuery(req.msg) { + ctrld.Log(ctx, mainLog.Load().Debug(), + "AD domain query detected for %s in domain %s", + req.msg.Question[0].Name, p.adDomain) + upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} + upstreams = []string{upstreamOSLocal} + } } res := &proxyResponse{} @@ -458,7 +461,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { - case isSrvLookup(req.msg): + case isSrvLanLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) @@ -620,7 +623,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) // if we have no healthy upstreams, trigger recovery flow - if p.recoverOnUpstreamFailure() { + if p.leakOnUpstreamFailure() { if p.um.countHealthy(upstreams) == 0 { p.recoveryCancelMu.Lock() if p.recoveryCancel == nil { @@ -639,19 +642,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } else { mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } - } - // attempt query to OS resolver while as a retry catch all - if upstreams[0] != upstreamOS { - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") - answer := resolve(upstreamOS, osUpstreamConfig, req.msg) - if answer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") - res.answer = answer - res.upstream = osUpstreamConfig.Endpoint - return res + // attempt query to OS resolver while as a retry catch all + // we dont want this to happen if leakOnUpstreamFailure is false + if upstreams[0] != upstreamOS { + ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") + answer := resolve(upstreamOS, osUpstreamConfig, req.msg) + if answer != nil { + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") + res.answer = answer + res.upstream = osUpstreamConfig.Endpoint + return res + } + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") } - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") } answer := new(dns.Msg) @@ -1108,21 +1112,27 @@ func isLanHostnameQuery(m *dns.Msg) bool { default: return false } - name := strings.TrimSuffix(q.Name, ".") + return isLanHostname(q.Name) +} + +// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname. +func isSrvLanLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + return q.Qtype == dns.TypeSRV && isLanHostname(q.Name) +} + +// isLanHostname reports whether name is a LAN hostname. +func isLanHostname(name string) bool { + name = strings.TrimSuffix(name, ".") return !strings.Contains(name, ".") || strings.HasSuffix(name, ".domain") || strings.HasSuffix(name, ".lan") || strings.HasSuffix(name, ".local") } -// isSrvLookup reports whether DNS message is a SRV query. -func isSrvLookup(m *dns.Msg) bool { - if m == nil || len(m.Question) == 0 { - return false - } - return m.Question[0].Qtype == dns.TypeSRV -} - // isWanClient reports whether the input is a WAN address. func isWanClient(na net.Addr) bool { var ip netip.Addr @@ -1177,7 +1187,10 @@ func FlushDNSCache() error { // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges(ctx context.Context) error { - mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) + mon, err := netmon.New(func(format string, args ...any) { + // Always fetch the latest logger (and inject the prefix) + mainLog.Load().Printf("netmon: "+format, args...) + }) if err != nil { return fmt.Errorf("creating network monitor: %w", err) } @@ -1248,8 +1261,16 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } + // if the default route changed, set changed to true + if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface { + changed = true + mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) + } + if !changed { mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + // check if the default IPs are still on an interface that is up + ValidateDefaultLocalIPsFromDelta(delta.New) return } @@ -1260,6 +1281,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Get IPs from default route interface in new state selfIP := defaultRouteIP() + + // Ensure that selfIP is an IPv4 address. + // If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it + if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil { + mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) + selfIP = "" + } var ipv6 string if delta.New.DefaultRouteInterface != "" { @@ -1295,7 +1323,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } - if ip := net.ParseIP(selfIP); ip != nil { + // Only set the IPv4 default if selfIP is a valid IPv4 address. + if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { ctrld.SetDefaultLocalIPv4(ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) @@ -1306,7 +1335,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) - if p.recoverOnUpstreamFailure() { + // we only trigger recovery flow for network changes on non router devices + if router.Name() == "" { p.handleRecovery(RecoveryReasonNetworkChange) } }) @@ -1438,7 +1468,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // Immediately remove our DNS settings from the interface. // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface p.recoveryRunning.Store(true) - p.resetDNS() + // we do not want to restore any static DNS settings + // we must try to get the DHCP values, any static DNS settings + // will be appended to nameservers from the saved interface values + p.resetDNS(false, false) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { @@ -1504,12 +1537,14 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) + attempts := 0 for { select { case <-ctx.Done(): mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) return default: + attempts++ // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) @@ -1523,6 +1558,18 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string } mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) + + // if this is the upstreamOS and it's the 3rd attempt (or multiple of 3), + // we should try to reinit the OS resolver to ensure we can recover + if name == upstreamOS && attempts%3 == 0 { + mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } } } }(name, uc) @@ -1556,3 +1603,32 @@ func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.U } return upstreams } + +// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored +// are still present in the new network state (provided by delta.New). +// If a stored default IP is no longer active, it resets that default (sets it to nil) +// so that it won't be used in subsequent custom dialer contexts. +func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) { + currentIPv4 := ctrld.GetDefaultLocalIPv4() + currentIPv6 := ctrld.GetDefaultLocalIPv6() + + // Build a map of active IP addresses from the new state. + activeIPs := make(map[string]bool) + for _, prefixes := range newState.InterfaceIPs { + for _, prefix := range prefixes { + activeIPs[prefix.Addr().String()] = true + } + } + + // Check if the default IPv4 is still active. + if currentIPv4 != nil && !activeIPs[currentIPv4.String()] { + mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4) + ctrld.SetDefaultLocalIPv4(nil) + } + + // Check if the default IPv6 is still active. + if currentIPv6 != nil && !activeIPs[currentIPv6.String()] { + mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6) + ctrld.SetDefaultLocalIPv6(nil) + } +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index eae3dfa..4a4e5b4 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -418,20 +418,21 @@ func Test_isPrivatePtrLookup(t *testing.T) { } } -func Test_isSrvLookup(t *testing.T) { +func Test_isSrvLanLookup(t *testing.T) { tests := []struct { name string msg *dns.Msg isSrvLookup bool }{ - {"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, + {"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, {"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false}, + {"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - if got := isSrvLookup(tc.msg); tc.isSrvLookup != got { + if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got { t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got) } }) diff --git a/cmd/cli/library.go b/cmd/cli/library.go index d302644..a5ba389 100644 --- a/cmd/cli/library.go +++ b/cmd/cli/library.go @@ -1,5 +1,12 @@ package cli +import ( + "fmt" + "net" + "net/http" + "time" +) + // AppCallback provides hooks for injecting certain functionalities // from mobile platforms to main ctrld cli. type AppCallback struct { @@ -17,3 +24,55 @@ type AppConfig struct { Verbose int LogPath string } + +const ( + defaultHTTPTimeout = 30 * time.Second + defaultMaxRetries = 3 +) + +// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback +func httpClientWithFallback(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + // Prefer IPv4 over IPv6 + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4 + }).DialContext, + }, + } +} + +// doWithRetry performs an HTTP request with retries +func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) { + var lastErr error + client := httpClientWithFallback(defaultHTTPTimeout) + + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff + } + + resp, err := client.Do(req) + if err == nil { + return resp, nil + } + lastErr = err + mainLog.Load().Debug().Err(err). + Str("method", req.Method). + Str("url", req.URL.String()). + Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries) + } + return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr) +} + +// Helper for making GET requests with retries +func getWithRetry(url string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + return doWithRetry(req, defaultMaxRetries) +} diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 339d984..ab6b855 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -16,13 +16,12 @@ import ( ) const ( - logWriterSize = 1024 * 1024 * 5 // 5 MB - logWriterSmallSize = 1024 * 1024 * 1 // 1 MB - logWriterInitialSize = 32 * 1024 // 32 KB - logSentInterval = time.Minute - logStartEndMarker = "\n\n=== INIT_END ===\n\n" - logLogEndMarker = "\n\n=== LOG_END ===\n\n" - logWarnEndMarker = "\n\n=== WARN_END ===\n\n" + logWriterSize = 1024 * 1024 * 5 // 5 MB + logWriterSmallSize = 1024 * 1024 * 1 // 1 MB + logWriterInitialSize = 32 * 1024 // 32 KB + logWriterSentInterval = time.Minute + logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n" + logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n" ) type logViewResponse struct { @@ -69,13 +68,23 @@ func (lw *logWriter) Write(p []byte) (int, error) { // If writing p causes overflows, discard old data. if lw.buf.Len()+len(p) > lw.size { buf := lw.buf.Bytes() - buf = buf[:logWriterInitialSize] - if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 { - buf = buf[:idx] + haveEndMarker := false + // If there's init end marker already, preserve the data til the marker. + if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 { + buf = buf[:idx+len(logWriterInitEndMarker)] + haveEndMarker = true + } else { + // Otherwise, preserve the initial size data. + buf = buf[:logWriterInitialSize] + if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 { + buf = buf[:idx] + } } lw.buf.Reset() lw.buf.Write(buf) - lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated. + if !haveEndMarker { + lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated. + } } // If p is bigger than buffer size, truncate p by half until its size is smaller. for len(p)+lw.buf.Len() > lw.size { @@ -84,6 +93,15 @@ func (lw *logWriter) Write(p []byte) (int, error) { return lw.buf.Write(p) } +// initLogging initializes global logging setup. +func (p *prog) initLogging(backup bool) { + zerolog.TimeFieldFormat = time.RFC3339 + ".000" + logWriters := initLoggingWithBackup(backup) + + // Initializing internal logging after global logging. + p.initInternalLogging(logWriters) +} + // initInternalLogging performs internal logging if there's no log enabled. func (p *prog) initInternalLogging(writers []io.Writer) { if !p.needInternalLogging() { @@ -92,7 +110,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) { p.initInternalLogWriterOnce.Do(func() { mainLog.Load().Notice().Msg("internal logging enabled") p.internalLogWriter = newLogWriter() - p.internalLogSent = time.Now().Add(-logSentInterval) + p.internalLogSent = time.Now().Add(-logWriterSentInterval) p.internalWarnLogWriter = newSmallLogWriter() }) p.mu.Lock() @@ -158,7 +176,7 @@ func (p *prog) logReader() (*logReader, error) { wlwReader := bytes.NewReader(wlw.buf.Bytes()) wlwSize := wlw.buf.Len() wlw.mu.Unlock() - reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader) + reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader) lr := &logReader{r: io.NopCloser(reader)} lr.size = int64(lwSize + wlwSize) if lr.size == 0 { diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go index bd48785..5336d4e 100644 --- a/cmd/cli/log_writer_test.go +++ b/cmd/cli/log_writer_test.go @@ -16,7 +16,7 @@ func Test_logWriter_Write(t *testing.T) { t.Fatalf("unexpected buf content: %v", lw.buf.String()) } newData := "B" - halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker + halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker lw.Write([]byte(newData)) if lw.buf.String() != halfData+newData { t.Fatalf("unexpected new buf content: %v", lw.buf.String()) @@ -47,3 +47,39 @@ func Test_logWriter_ConcurrentWrite(t *testing.T) { t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String()) } } + +func Test_logWriter_MarkerInitEnd(t *testing.T) { + size := 64 * 1024 + lw := &logWriter{size: size} + lw.buf.Grow(lw.size) + + paddingSize := 10 + // Writing half of the size, minus len(end marker) and padding size. + dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize + data := strings.Repeat("A", dataSize) + // Inserting newline for making partial init data + data += "\n" + // Filling left over buffer to make the log full. + // The data length: len(end marker) + padding size - 1 (for newline above) + size/2 + data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2)) + lw.Write([]byte(data)) + if lw.buf.String() != data { + t.Fatalf("unexpected buf content: %v", lw.buf.String()) + } + lw.Write([]byte("B")) + lw.Write([]byte(strings.Repeat("B", 256*1024))) + firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker) + lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker) + // Check if init end marker present. + if firstIdx == -1 || lastIdx == -1 { + t.Fatalf("missing init end marker: %s", lw.buf.String()) + } + // Check if init end marker appears only once. + if firstIdx != lastIdx { + t.Fatalf("log init end marker appears more than once: %s", lw.buf.String()) + } + // Ensure that we have the correct init log data. + if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) { + t.Fatalf("unexpected log content: %s", lw.buf.String()) + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 73a601d..6a8cb62 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -88,24 +88,21 @@ func initConsoleLogging() { multi := zerolog.MultiLevelWriter(consoleWriter) l := mainLog.Load().Output(multi).With().Timestamp().Logger() mainLog.Store(&l) + switch { case silent: zerolog.SetGlobalLevel(zerolog.NoLevel) case verbose == 1: + ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.InfoLevel) case verbose > 1: + ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.DebugLevel) default: zerolog.SetGlobalLevel(zerolog.NoticeLevel) } } -// initLogging initializes global logging setup. -func initLogging() []io.Writer { - zerolog.TimeFieldFormat = time.RFC3339 + ".000" - return initLoggingWithBackup(true) -} - // initInteractiveLogging is like initLogging, but the ProxyLogger is discarded // to be used for all interactive commands. // diff --git a/cmd/cli/nocgo.go b/cmd/cli/nocgo.go new file mode 100644 index 0000000..2596d09 --- /dev/null +++ b/cmd/cli/nocgo.go @@ -0,0 +1,5 @@ +//go:build !cgo + +package cli + +const cgoEnabled = false diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 3f815e8..e2302a3 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -72,34 +72,25 @@ func setDNS(iface *net.Interface, nameservers []string) error { SearchDomains: []dnsname.FQDN{}, } trySystemdResolve := false - for i := 0; i < maxSetDNSAttempts; i++ { - if err := r.SetDNS(osConfig); err != nil { - if strings.Contains(err.Error(), "Rejected send message") && - strings.Contains(err.Error(), "org.freedesktop.network1.Manager") { - mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS") - trySystemdResolve = true - break - } - // This error happens on read-only file system, which causes ctrld failed to create backup - // for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore - // DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file. - // The error format is controlled by us, so checking for error string is fine. - // See: ../../internal/dns/direct.go:L278 - if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) { - return nil - } - return err + if err := r.SetDNS(osConfig); err != nil { + if strings.Contains(err.Error(), "Rejected send message") && + strings.Contains(err.Error(), "org.freedesktop.network1.Manager") { + mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS") + trySystemdResolve = true + goto systemdResolve } - if useSystemdResolved { - if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out)) - } - } - currentNS := currentDNS(iface) - if isSubSet(nameservers, currentNS) { + // This error happens on read-only file system, which causes ctrld failed to create backup + // for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore + // DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file. + // The error format is controlled by us, so checking for error string is fine. + // See: ../../internal/dns/direct.go:L278 + if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) { return nil } + return err } + +systemdResolve: if trySystemdResolve { // Stop systemd-networkd and retry setting DNS. if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil { @@ -119,8 +110,8 @@ func setDNS(iface *net.Interface, nameservers []string) error { } time.Sleep(time.Second) } + mainLog.Load().Debug().Msg("DNS was not set for some reason") } - mainLog.Load().Debug().Msg("DNS was not set for some reason") return nil } @@ -169,6 +160,7 @@ func resetDNS(iface *net.Interface) (err error) { } // TODO(cuonglm): handle DHCPv6 properly. + mainLog.Load().Debug().Msg("checking for IPv6 availability") if ctrldnet.IPv6Available(ctx) { c := client6.NewClient() conversation, err := c.Exchange(iface.Name) @@ -188,6 +180,8 @@ func resetDNS(iface *net.Interface) (err error) { } } } + } else { + mainLog.Load().Debug().Msg("IPv6 is not available") } return ignoringEINTR(func() error { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 990cc57..e1bcd9a 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. if hasLocalDnsServerRunning() { + mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders") + file := absHomeDir(windowsForwardersFilename) - oldForwardersContent, _ := os.ReadFile(file) + mainLog.Load().Debug().Msgf("Using forwarders file: %s", file) + + oldForwardersContent, err := os.ReadFile(file) + if err != nil { + mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file") + } else { + mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent)) + } + hasLocalIPv6Listener := needLocalIPv6Listener() + mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status") + forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool { if !hasLocalIPv6Listener { return false } return s == "::1" }) + mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list") + if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil { mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") + } else { + mainLog.Load().Debug().Msg("Successfully wrote new forwarders file") } + oldForwarders := strings.Split(string(oldForwardersContent), ",") + mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders") + if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil { mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") + } else { + mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders") } } }) @@ -147,15 +168,32 @@ func restoreDNS(iface *net.Interface) (err error) { } } - for _, ns := range [][]string{v4ns, v6ns} { - if len(ns) == 0 { - continue - } - mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) - err = setDNS(iface, ns) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("restoreDNS: %w", err) + } - if err != nil { - return err + if len(v4ns) > 0 { + mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns) + if err := setDNS(iface, v4ns); err != nil { + return fmt.Errorf("restoreDNS (IPv4): %w", err) + } + } else { + mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name) + if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { + return fmt.Errorf("restoreDNS (IPv4 clear): %w", err) + } + } + + if len(v6ns) > 0 { + mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns) + if err := setDNS(iface, v6ns); err != nil { + return fmt.Errorf("restoreDNS (IPv6): %w", err) + } + } else { + mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name) + if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { + return fmt.Errorf("restoreDNS (IPv6 clear): %w", err) } } } @@ -180,43 +218,69 @@ func currentDNS(iface *net.Interface) []string { return ns } -// currentStaticDNS returns the current static DNS settings of given interface. +// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys +// like "NameServer" and "ProfileNameServer". func currentStaticDNS(iface *net.Interface) ([]string, error) { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err) + return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err) } guid, err := luid.GUID() if err != nil { - return nil, fmt.Errorf("luid.GUID: %w", err) + return nil, fmt.Errorf("fallback luid.GUID: %w", err) } + var ns []string - for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { - found := false + keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} + for _, path := range keyPaths { interfaceKeyPath := path + guid.String() k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) if err != nil { - return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err) + mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name) + continue } - for _, key := range []string{"NameServer", "ProfileNameServer"} { - if found { - continue - } - value, _, err := k.GetStringValue(key) - if err != nil && !errors.Is(err, registry.ErrNotExist) { - return nil, fmt.Errorf("%s: %w", key, err) - } - if len(value) > 0 { - found = true - for _, e := range strings.Split(value, ",") { - ns = append(ns, strings.TrimRight(e, "\x00")) + func() { + defer k.Close() + for _, keyName := range []string{"NameServer", "ProfileNameServer"} { + value, _, err := k.GetStringValue(keyName) + if err != nil && !errors.Is(err, registry.ErrNotExist) { + mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName) + continue + } + if len(value) > 0 { + mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value) + parsed := parseDNSServers(value) + for _, pns := range parsed { + if !slices.Contains(ns, pns) { + ns = append(ns, pns) + } + } } } - } + }() + } + if len(ns) == 0 { + mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name) } return ns, nil } +// parseDNSServers splits a DNS server string that may be comma- or space-separated, +// and trims any extraneous whitespace or null characters. +func parseDNSServers(val string) []string { + fields := strings.FieldsFunc(val, func(r rune) bool { + return r == ' ' || r == ',' + }) + var servers []string + for _, f := range fields { + trimmed := strings.TrimSpace(f) + if len(trimmed) > 0 { + servers = append(servers, trimmed) + } + } + return servers +} + // addDnsServerForwarders adds given nameservers to DNS server forwarders list, // and also removing old forwarders if provided. func addDnsServerForwarders(nameservers, old []string) error { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 8a86bcf..be9b0ae 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -43,7 +43,7 @@ const ( ctrldControlUnixSockMobile = "cd.sock" upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" - upstreamPrivate = upstreamPrefix + "private" + upstreamOSLocal = upstreamOS + ".local" dnsWatchdogDefaultInterval = 20 * time.Second ctrldServiceName = "ctrld" ) @@ -120,6 +120,7 @@ type prog struct { runningIface string requiredMultiNICsConfig bool adDomain string + runningOnDomainController bool selfUninstallMu sync.Mutex refusedQueryCount int @@ -268,7 +269,7 @@ func (p *prog) preRun() { if runtime.GOOS == "darwin" { p.onStopped = append(p.onStopped, func() { if !service.Interactive() { - p.resetDNS() + p.resetDNS(false, true) } }) } @@ -276,7 +277,12 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { - p.resetDNS() + if runtime.GOOS == "windows" { + isDC, roleInt := isRunningOnDomainController() + p.runningOnDomainController = isDC + mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) + } + p.resetDNS(false, false) ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() @@ -345,14 +351,19 @@ func (p *prog) apiConfigReload() { if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced { lastUpdated = time.Now().Unix() cfg := &ctrld.Config{} - if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil { + var cfgErr error + if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil { + setListenerDefaultValue(cfg) + setNetworkDefaultValue(cfg) + cfgErr = validateConfig(cfg) + } + if cfgErr != nil { logger.Warn().Err(err).Msg("skipping invalid custom config") if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { logger.Error().Err(err).Msg("could not mark custom last update failed") } return } - setListenerDefaultValue(cfg) logger.Debug().Msg("custom config changes detected, reloading...") p.apiReloadCh <- cfg } else { @@ -560,13 +571,12 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { // Stop writing log to unix socket. consoleWriter.Out = os.Stdout - logWriters := initLoggingWithBackup(false) + p.initLogging(false) if p.logConn != nil { _ = p.logConn.Close() } go p.apiConfigReload() p.postRun() - p.initInternalLogging(logWriters) } wg.Wait() } @@ -647,16 +657,74 @@ func (p *prog) setDNS() { if cfg.Listener == nil { return } - if p.runningIface == "" { - return - } - - // allIfaces tracks whether we should set DNS for all physical interfaces. - allIfaces := p.requiredMultiNICsConfig lc := cfg.FirstListener() if lc == nil { return } + ns := lc.IP + switch { + case lc.IsDirectDnsListener(): + // If ctrld is direct listener, use 127.0.0.1 as nameserver. + ns = "127.0.0.1" + case lc.Port != 53: + ns = "127.0.0.1" + if resolver := router.LocalResolverIP(); resolver != "" { + ns = resolver + } + default: + // If we ever reach here, it means ctrld is running on lc.IP port 53, + // so we could just use lc.IP as nameserver. + } + + nameservers := []string{ns} + if needRFC1918Listeners(lc) { + nameservers = append(nameservers, ctrld.Rfc1918Addresses()...) + } + if needLocalIPv6Listener() { + nameservers = append(nameservers, "::1") + } + + slices.Sort(nameservers) + + netIfaceName := "" + netIface := p.setDnsForRunningIface(nameservers) + if netIface != nil { + netIfaceName = netIface.Name + } + setDnsOK = true + + if p.requiredMultiNICsConfig { + withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error { + return setDnsIgnoreUnusableInterface(i, nameservers) + }) + } + // resolvconf file is only useful when we have default route interface, + // then set DNS on this interface will push change to /etc/resolv.conf file. + if netIface != nil && shouldWatchResolvconf() { + servers := make([]netip.Addr, len(nameservers)) + for i := range nameservers { + servers[i] = netip.MustParseAddr(nameservers[i]) + } + p.dnsWg.Add(1) + go func() { + defer p.dnsWg.Done() + p.watchResolvConf(netIface, servers, setResolvConf) + }() + } + if p.dnsWatchdogEnabled() { + p.dnsWg.Add(1) + go func() { + defer p.dnsWg.Done() + p.dnsWatchdog(netIface, nameservers) + }() + } +} + +func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) { + if p.runningIface == "" { + return + } + logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() const maxDNSRetryAttempts = 3 @@ -690,59 +758,14 @@ func (p *prog) setDNS() { return } + runningIface = netIface logger.Debug().Msg("setting DNS for interface") - ns := lc.IP - switch { - case lc.IsDirectDnsListener(): - // If ctrld is direct listener, use 127.0.0.1 as nameserver. - ns = "127.0.0.1" - case lc.Port != 53: - ns = "127.0.0.1" - if resolver := router.LocalResolverIP(); resolver != "" { - ns = resolver - } - default: - // If we ever reach here, it means ctrld is running on lc.IP port 53, - // so we could just use lc.IP as nameserver. - } - - nameservers := []string{ns} - if needRFC1918Listeners(lc) { - nameservers = append(nameservers, ctrld.Rfc1918Addresses()...) - } - if needLocalIPv6Listener() { - nameservers = append(nameservers, "::1") - } - slices.Sort(nameservers) if err := setDNS(netIface, nameservers); err != nil { logger.Error().Err(err).Msgf("could not set DNS for interface") return } - setDnsOK = true logger.Debug().Msg("setting DNS successfully") - if allIfaces { - withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error { - return setDnsIgnoreUnusableInterface(i, nameservers) - }) - } - if shouldWatchResolvconf() { - servers := make([]netip.Addr, len(nameservers)) - for i := range nameservers { - servers[i] = netip.MustParseAddr(nameservers[i]) - } - p.dnsWg.Add(1) - go func() { - defer p.dnsWg.Done() - p.watchResolvConf(netIface, servers, setResolvConf) - }() - } - if p.dnsWatchdogEnabled() { - p.dnsWg.Add(1) - go func() { - defer p.dnsWg.Done() - p.dnsWatchdog(netIface, nameservers, allIfaces) - }() - } + return } // dnsWatchdogEnabled reports whether DNS watchdog is enabled. @@ -765,12 +788,12 @@ func (p *prog) dnsWatchdogDuration() time.Duration { // dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings. // This is only works when deactivation pin set. -func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) { +func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { if !requiredMultiNICsConfig() { return } - logger := mainLog.Load().With().Str("iface", iface.Name).Logger() - logger.Debug().Msg("start DNS settings watchdog") + + mainLog.Load().Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) @@ -788,14 +811,56 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces return } if dnsChanged(iface, ns) { - logger.Debug().Msg("DNS settings were changed, re-applying settings") + mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings") + // Check if the interface already has static DNS servers configured. + // currentStaticDNS is an OS-dependent helper that returns the current static DNS. + staticDNS, err := currentStaticDNS(iface) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) + } else if len(staticDNS) > 0 { + //filter out loopback addresses + staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + // if we have a static config and no saved IPs already, save them + if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 { + // Save these static DNS values so that they can be restored later. + if err := saveCurrentStaticDNS(iface); err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) + } + } + } if err := setDNS(iface, ns); err != nil { mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") } } - if allIfaces { - withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error { + if p.requiredMultiNICsConfig { + ifaceName := "" + if iface != nil { + ifaceName = iface.Name + } + withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error { if dnsChanged(i, ns) { + + // Check if the interface already has static DNS servers configured. + // currentStaticDNS is an OS-dependent helper that returns the current static DNS. + staticDNS, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) + } else if len(staticDNS) > 0 { + //filter out loopback addresses + staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + // if we have a static config and no saved IPs already, save them + if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 { + // Save these static DNS values so that they can be restored later. + if err := saveCurrentStaticDNS(i); err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) + } + } + } + if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") } else { @@ -809,33 +874,78 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces } } -func (p *prog) resetDNS() { +// resetDNS performs a DNS reset for all interfaces. +func (p *prog) resetDNS(isStart bool, restoreStatic bool) { + netIfaceName := "" + if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil { + netIfaceName = netIface.Name + } + // See corresponding comments in (*prog).setDNS function. + if p.requiredMultiNICsConfig { + withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface) + } +} + +// resetDNSForRunningIface performs a DNS reset on the running interface. +// The parameter isStart indicates whether this is being called as part of a start (or restart) +// command. When true, we check if the current static DNS configuration already differs from the +// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's +// static DNS settings and we do not override them using the potentially out-of-date saved file. +// Otherwise, we restore the saved configuration (if any) or reset to DHCP. +func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) { if p.runningIface == "" { mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") return } - // See corresponding comments in (*prog).setDNS function. - allIfaces := p.requiredMultiNICsConfig logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return } - + runningIface = netIface if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return } - logger.Debug().Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - logger.Error().Err(err).Msgf("could not reset DNS") - return + + // If starting, check the current static DNS configuration. + if isStart { + current, err := currentStaticDNS(netIface) + if err != nil { + logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config") + } else if len(current) > 0 { + // If any static DNS value is not our own listener, assume an admin override. + hasManualConfig := false + for _, ns := range current { + if ns != "127.0.0.1" && ns != "::1" { + hasManualConfig = true + break + } + } + if hasManualConfig { + logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current) + return + } + } } - logger.Debug().Msg("Restoring DNS successfully") - if allIfaces { - withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface) + + // Default logic: if there is a saved static DNS configuration, restore it. + saved := savedStaticNameservers(netIface) + if len(saved) > 0 && restoreStatic { + logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved) + if err := setDNS(netIface, saved); err != nil { + logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name) + return + } + } else { + logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name) + if err := resetDNS(netIface); err != nil { + logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name) + return + } } + return } func (p *prog) logInterfacesState() { @@ -985,12 +1095,6 @@ func findWorkingInterface(currentIface string) string { return currentIface } -// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure. -func (p *prog) recoverOnUpstreamFailure() bool { - // Default is false on routers, since this recovery flow is only useful for devices that move between networks. - return router.Name() == "" -} - func randomLocalIP() string { n := rand.Intn(254-2) + 2 return fmt.Sprintf("127.0.0.%d", n) @@ -1192,7 +1296,7 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. // TODO: investigate whether we should report this error? if err := f(netIface); err == nil { if context != "" { - mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name) + mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name) } } else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) { mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name) @@ -1215,19 +1319,38 @@ var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not // saveCurrentStaticDNS saves the current static DNS settings for restoring later. // Only works on Windows and Mac. func saveCurrentStaticDNS(iface *net.Interface) error { + if iface == nil { + mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface") + return nil + } switch runtime.GOOS { case "windows", "darwin": default: return errSaveCurrentStaticDNSNotSupported } file := savedStaticDnsSettingsFilePath(iface) - ns, _ := currentStaticDNS(iface) + ns, err := currentStaticDNS(iface) + if err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name) + return err + } if len(ns) == 0 { + mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name) _ = os.Remove(file) // removing old static DNS settings return nil } + //filter out loopback addresses + ns = slices.DeleteFunc(ns, func(s string) bool { + return net.ParseIP(s).IsLoopback() + }) + //if we now have no static DNS settings and the file already exists + // return and do not save the file + if len(ns) == 0 { + mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name) + return nil + } if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file") + mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file) } nss := strings.Join(ns, ",") mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss) @@ -1241,6 +1364,9 @@ func saveCurrentStaticDNS(iface *net.Interface) error { // savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface. func savedStaticDnsSettingsFilePath(iface *net.Interface) string { + if iface == nil { + return "" + } return absHomeDir(".dns_" + iface.Name) } @@ -1248,6 +1374,10 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string { // //lint:ignore U1000 use in os_windows.go and os_darwin.go func savedStaticNameservers(iface *net.Interface) []string { + if iface == nil { + mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface") + return nil + } file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { saveValues := strings.Split(string(data), ",") @@ -1265,8 +1395,13 @@ func savedStaticNameservers(iface *net.Interface) []string { } // dnsChanged reports whether DNS settings for given interface was changed. +// It returns false for a nil iface. +// // The caller must sort the nameservers before calling this function. func dnsChanged(iface *net.Interface, nameservers []string) bool { + if iface == nil { + return false + } curNameservers, _ := currentStaticDNS(iface) slices.Sort(curNameservers) if !slices.Equal(curNameservers, nameservers) { @@ -1286,3 +1421,36 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { selfUninstall(p, logger) } } + +// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow +// when upstream failures occur. +func (p *prog) leakOnUpstreamFailure() bool { + if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { + return *ptr + } + // Default is false on routers, since this leaking is only useful for devices that move between networks. + if router.Name() != "" { + return false + } + // if we are running on ADDC, we should not leak on upstream failure + if p.runningOnDomainController { + return false + } + return true +} + +// Domain controller role values from Win32_ComputerSystem +// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem +const ( + BackupDomainController = 4 + PrimaryDomainController = 5 +) + +// isRunningOnDomainController checks if the current machine is a domain controller +// by querying the DomainRole property from Win32_ComputerSystem via WMI. +func isRunningOnDomainController() (bool, int) { + if runtime.GOOS != "windows" { + return false, 0 + } + return isRunningOnDomainControllerWindows() +} diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index b987ed3..cc0046b 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -13,6 +13,7 @@ import ( "tailscale.com/health" "github.com/Control-D-Inc/ctrld/internal/dns" + "github.com/Control-D-Inc/ctrld/internal/router" ) func init() { @@ -39,6 +40,9 @@ func setDependencies(svc *service.Config) { svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service") } } + if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 { + svc.Dependencies = append(svc.Dependencies, routerDeps...) + } } func setWorkingDirectory(svc *service.Config, dir string) { diff --git a/cmd/cli/prog_others.go b/cmd/cli/prog_others.go index 92f3a9f..9026318 100644 --- a/cmd/cli/prog_others.go +++ b/cmd/cli/prog_others.go @@ -1,4 +1,4 @@ -//go:build !linux && !freebsd && !darwin +//go:build !linux && !freebsd && !darwin && !windows package cli diff --git a/cmd/cli/prog_windows.go b/cmd/cli/prog_windows.go new file mode 100644 index 0000000..e448625 --- /dev/null +++ b/cmd/cli/prog_windows.go @@ -0,0 +1,14 @@ +package cli + +import "github.com/kardianos/service" + +func setDependencies(svc *service.Config) { + if hasLocalDnsServerRunning() { + svc.Dependencies = []string{"DNS"} + } +} + +func setWorkingDirectory(svc *service.Config, dir string) { + // WorkingDirectory is not supported on Windows. + svc.WorkingDirectory = dir +} diff --git a/cmd/cli/service.go b/cmd/cli/service.go index 82f144c..f03146d 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -6,10 +6,12 @@ import ( "fmt" "os" "os/exec" + "runtime" "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld/internal/router" + "github.com/Control-D-Inc/ctrld/internal/router/openwrt" ) // newService wraps service.New call to return service.Service @@ -167,7 +169,11 @@ func doTasks(tasks []task) bool { mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err) return false } - mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) + // if this is darwin stop command, dont print debug + // since launchctl complains on every start + if runtime.GOOS != "darwin" || task.Name != "Stop" { + mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) + } } } return true @@ -188,6 +194,13 @@ func checkHasElevatedPrivilege() { func unixSystemVServiceStatus() (service.Status, error) { out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput() if err != nil { + // Specific case for openwrt >= 24.10, it returns non-success code + // for above status command, which may not right. + if router.Name() == openwrt.Name { + if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" { + return service.StatusStopped, nil + } + } return service.StatusUnknown, nil } diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index 056903c..954b228 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) { func hasLocalDnsServerRunning() bool { return false } func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } + +func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index c4df5a5..fddb0ef 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,12 +2,18 @@ package cli import ( "os" + "reflect" "runtime" + "strconv" "strings" "syscall" "time" "unsafe" + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc/mgr" ) @@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool { } } } + +func isRunningOnDomainControllerWindows() (bool, int) { + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("Win32_ComputerSystem") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) + if err != nil { + mainLog.Load().Debug().Err(err).Msg("WMI query failed") + return false, 0 + } + if instances == nil { + mainLog.Load().Debug().Msg("WMI query returned nil instances") + return false, 0 + } + defer instances.Close() + + if len(instances) == 0 { + mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem") + return false, 0 + } + + val, err := instances[0].GetProperty("DomainRole") + if err != nil { + mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property") + return false, 0 + } + if val == nil { + mainLog.Load().Debug().Msg("DomainRole property is nil") + return false, 0 + } + + // Safely handle varied types: string or integer + var roleInt int + switch v := val.(type) { + case string: + // "4", "5", etc. + parsed, parseErr := strconv.Atoi(v) + if parseErr != nil { + mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v) + return false, 0 + } + roleInt = parsed + case int8, int16, int32, int64: + roleInt = int(reflect.ValueOf(v).Int()) + case uint8, uint16, uint32, uint64: + roleInt = int(reflect.ValueOf(v).Uint()) + default: + mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v) + return false, 0 + } + + // Check if role indicates a domain controller + isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController + return isDC, roleInt +} diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index af204ad..1f761e6 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -1,7 +1,13 @@ package main -import "github.com/Control-D-Inc/ctrld/cmd/cli" +import ( + "os" + + "github.com/Control-D-Inc/ctrld/cmd/cli" +) func main() { cli.Main() + // make sure we exit with 0 if there are no errors + os.Exit(0) } diff --git a/config.go b/config.go index e1454f9..2e85e76 100644 --- a/config.go +++ b/config.go @@ -529,7 +529,7 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } - conn, err := pd.DialContext(ctx, network, dialAddrs) + conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load()) if err != nil { return nil, err } diff --git a/docs/config.md b/docs/config.md index c4090ce..4f50af1 100644 --- a/docs/config.md +++ b/docs/config.md @@ -166,7 +166,6 @@ before serving the query. ### max_concurrent_requests The number of concurrent requests that will be handled, must be a non-negative integer. -Tweaking this value depends on the capacity of your system. - Type: number - Required: no @@ -253,9 +252,7 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus - Default: "" ### dns_watchdog_enabled -Checking DNS changes to network interfaces and reverting to ctrld's own settings. - -The DNS watchdog process only runs on Windows and MacOS, while in `--cd` mode. +Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS. - Type: boolean - Required: no @@ -274,7 +271,7 @@ If the time duration is non-positive, default value will be used. - Default: 20s ### refetch_time -Time in seconds between each iteration that reloads custom config if changed. +Time in seconds between each iteration that reloads custom config from the API. The value must be a positive number, any invalid value will be ignored and default value will be used. - Type: number @@ -282,7 +279,7 @@ The value must be a positive number, any invalid value will be ignored and defau - Default: 3600 ### leak_on_upstream_failure -Once ctrld is "offline", mean ctrld could not connect to any upstream, next queries will be leaked to OS resolver. +If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored. - Type: boolean - Required: no @@ -531,6 +528,15 @@ rules = [ ] ``` +If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver. + +These following domains are considered LAN queries: + +- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1) +- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2) +- All `SRV` queries of LAN hostname (1) + (2). +- `PTR` queries with private IPs. + --- Note that the order of matching preference: diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 35d5dbb..06449e1 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -207,11 +207,10 @@ func (t *Table) init() { } for platform, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", platform) - } else { - t.hostnameResolvers = append(t.hostnameResolvers, discover) - t.refreshers = append(t.refreshers, discover) + ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform) } + t.hostnameResolvers = append(t.hostnameResolvers, discover) + t.refreshers = append(t.refreshers, discover) } } // Hosts file mapping. @@ -423,17 +422,27 @@ func (t *Table) ListClients() []*Client { t.Refresh() ipMap := make(map[string]*Client) il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni} + for _, ir := range il { + if ir == nil { + continue + } + for _, ip := range ir.List() { - c, ok := ipMap[ip] - if !ok { - c = &Client{ - IP: netip.MustParseAddr(ip), - Source: map[string]struct{}{ir.String(): {}}, + // Validate IP before using MustParseAddr + if addr, err := netip.ParseAddr(ip); err == nil { + c, ok := ipMap[ip] + if !ok { + c = &Client{ + IP: addr, + Source: map[string]struct{}{}, + } + ipMap[ip] = c + } + // Safely get source name + if src := ir.String(); src != "" { + c.Source[src] = struct{}{} } - ipMap[ip] = c - } else { - c.Source[ir.String()] = struct{}{} } } } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 3c8af6e..e009e01 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -92,6 +92,11 @@ func (m *mdns) init(quitCh chan struct{}) error { return err } + // Check if IPv6 is available once and use the result for the rest of the function. + ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init") + ipv6 := ctrldnet.IPv6Available(context.Background()) + ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6) + v4ConnList := make([]*net.UDPConn, 0, len(ifaces)) v6ConnList := make([]*net.UDPConn, 0, len(ifaces)) for _, iface := range ifaces { @@ -102,7 +107,8 @@ func (m *mdns) init(quitCh chan struct{}) error { v4ConnList = append(v4ConnList, conn) go m.readLoop(conn) } - if ctrldnet.IPv6Available(context.Background()) { + + if ipv6 { if conn, err := net.ListenMulticastUDP("udp6", &iface, mdnsV6Addr); err == nil { v6ConnList = append(v6ConnList, conn) go m.readLoop(conn) diff --git a/internal/clientinfo/mdns_services.go b/internal/clientinfo/mdns_services.go index d7869c8..e9e30e8 100644 --- a/internal/clientinfo/mdns_services.go +++ b/internal/clientinfo/mdns_services.go @@ -67,4 +67,16 @@ var services = [...]string{ // Merlin "_alexa._tcp", + + // Newer Android TV devices + "_androidtvremote2._tcp.local.", + + // https://esphome.io/ + "_esphomelib._tcp.local.", + + // https://www.home-assistant.io/ + "_home-assistant._tcp.local.", + + // https://kno.wled.ge/ + "_wled._tcp.local.", } diff --git a/internal/clientinfo/ubios.go b/internal/clientinfo/ubios.go index 1a60de0..0ffd6e5 100644 --- a/internal/clientinfo/ubios.go +++ b/internal/clientinfo/ubios.go @@ -3,6 +3,7 @@ package clientinfo import ( "bytes" "encoding/json" + "fmt" "io" "os/exec" "strings" @@ -44,9 +45,9 @@ func (u *ubiosDiscover) refreshDevices() error { cmd := exec.Command("/usr/bin/mongo", "localhost:27117/ace", "--quiet", "--eval", ` DBQuery.shellBatchSize = 256; db.user.find({name: {$exists: true, $ne: ""}}, {_id:0, mac:1, name:1});`) - b, err := cmd.Output() + b, err := cmd.CombinedOutput() if err != nil { - return err + return fmt.Errorf("out: %s, err: %w", string(b), err) } return u.storeDevices(bytes.NewReader(b)) } diff --git a/internal/controld/config.go b/internal/controld/config.go index fbbd9d4..5e65fdb 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -32,6 +32,8 @@ const ( logURLCom = apiURLCom + "/logs" logURLDev = apiURLDev + "/logs" InvalidConfigCode = 40402 + defaultTimeout = 20 * time.Second + sendLogTimeout = 300 * time.Second ) // ResolverConfig represents Control D resolver data. @@ -135,7 +137,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade req.Header.Add("Content-Type", "application/json") transport := apiTransport(cdDev) client := http.Client{ - Timeout: 10 * time.Second, + Timeout: defaultTimeout, Transport: transport, } resp, err := client.Do(req) @@ -176,7 +178,7 @@ func SendLogs(lr *LogsRequest, cdDev bool) error { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") transport := apiTransport(cdDev) client := http.Client{ - Timeout: 300 * time.Second, + Timeout: sendLogTimeout, Transport: transport, } resp, err := client.Do(req) @@ -214,19 +216,55 @@ func apiTransport(cdDev bool) *http.Transport { if cdDev { apiDomain = apiDomainDev } + + // First try IPv4 + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + ips := ctrld.LookupIP(apiDomain) if len(ips) == 0 { - ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr) - return ctrldnet.Dialer.DialContext(ctx, network, addr) + ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, falling back to direct connection to %s", apiDomain, addr) + return dialer.DialContext(ctx, network, addr) } - ctrld.ProxyLogger.Load().Debug().Msgf("API IPs: %v", ips) + + // Separate IPv4 and IPv6 addresses + var ipv4s, ipv6s []string + for _, ip := range ips { + if strings.Contains(ip, ":") { + ipv6s = append(ipv6s, ip) + } else { + ipv4s = append(ipv4s, ip) + } + } + _, port, _ := net.SplitHostPort(addr) - addrs := make([]string, len(ips)) - for i := range ips { - addrs[i] = net.JoinHostPort(ips[i], port) + + // Try IPv4 first + if len(ipv4s) > 0 { + addrs := make([]string, len(ipv4s)) + for i, ip := range ipv4s { + addrs[i] = net.JoinHostPort(ip, port) + } + d := &ctrldnet.ParallelDialer{} + if conn, err := d.DialContext(ctx, "tcp4", addrs, ctrld.ProxyLogger.Load()); err == nil { + return conn, nil + } } - d := &ctrldnet.ParallelDialer{} - return d.DialContext(ctx, network, addrs) + + // Fall back to IPv6 if available + if len(ipv6s) > 0 { + addrs := make([]string, len(ipv6s)) + for i, ip := range ipv6s { + addrs[i] = net.JoinHostPort(ip, port) + } + d := &ctrldnet.ParallelDialer{} + return d.DialContext(ctx, "tcp6", addrs, ctrld.ProxyLogger.Load()) + } + + // Final fallback to direct connection + return dialer.DialContext(ctx, network, addr) } if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} diff --git a/internal/net/net.go b/internal/net/net.go index 3a81849..2693fbf 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -3,6 +3,7 @@ package net import ( "context" "errors" + "io" "net" "os" "os/signal" @@ -11,6 +12,7 @@ import ( "syscall" "time" + "github.com/rs/zerolog" "tailscale.com/logtail/backoff" ) @@ -26,7 +28,8 @@ var Dialer = &net.Dialer{ Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := ParallelDialer{} d.Timeout = 10 * time.Second - return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}) + l := zerolog.New(io.Discard) + return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, &l) }, }, } @@ -49,8 +52,12 @@ func init() { } func supportIPv6(ctx context.Context) bool { - _, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443")) - return err == nil + c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS) + if err != nil { + return false + } + c.Close() + return true } func supportListenIPv6Local() bool { @@ -133,7 +140,7 @@ type ParallelDialer struct { net.Dialer } -func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) { +func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zerolog.Logger) (net.Conn, error) { if len(addrs) == 0 { return nil, errors.New("empty addresses") } @@ -153,11 +160,16 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs for _, addr := range addrs { go func(addr string) { defer wg.Done() + logger.Debug().Msgf("dialing to %s", addr) conn, err := d.Dialer.DialContext(ctx, network, addr) + if err != nil { + logger.Debug().Msgf("failed to dial %s: %v", addr, err) + } select { case ch <- ¶llelDialerResult{conn: conn, err: err}: case <-done: if conn != nil { + logger.Debug().Msgf("connection closed: %s", conn.RemoteAddr()) conn.Close() } } @@ -168,6 +180,7 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs for res := range ch { if res.err == nil { cancel() + logger.Debug().Msgf("connected to %s", res.conn.RemoteAddr()) return res.conn, res.err } errs = append(errs, res.err) diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go index ad98db9..73f5a06 100644 --- a/internal/router/openwrt/openwrt.go +++ b/internal/router/openwrt/openwrt.go @@ -2,10 +2,13 @@ package openwrt import ( "bytes" + "encoding/json" "errors" "fmt" + "io" "os" "os/exec" + "path/filepath" "strings" "github.com/kardianos/service" @@ -15,10 +18,13 @@ import ( ) const ( - Name = "openwrt" - openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf" + Name = "openwrt" + openwrtDNSMasqConfigName = "ctrld.conf" + openwrtDNSMasqDefaultConfigDir = "/tmp/dnsmasq.d" ) +var openwrtDnsmasqDefaultConfigPath = filepath.Join(openwrtDNSMasqDefaultConfigDir, openwrtDNSMasqConfigName) + type Openwrt struct { cfg *ctrld.Config dnsmasqCacheSize string @@ -67,7 +73,7 @@ func (o *Openwrt) Setup() error { if err != nil { return err } - if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil { + if err := os.WriteFile(dnsmasqConfPathFromUbus(), []byte(data), 0600); err != nil { return err } // Restart dnsmasq service. @@ -82,7 +88,7 @@ func (o *Openwrt) Cleanup() error { return nil } // Remove the custom dnsmasq config - if err := os.Remove(openwrtDNSMasqConfigPath); err != nil { + if err := os.Remove(dnsmasqConfPathFromUbus()); err != nil { return err } @@ -126,3 +132,60 @@ func uci(args ...string) (string, error) { } return strings.TrimSpace(stdout.String()), nil } + +// openwrtServiceList represents openwrt services config. +type openwrtServiceList struct { + Dnsmasq dnsmasqConf `json:"dnsmasq"` +} + +// dnsmasqConf represents dnsmasq config. +type dnsmasqConf struct { + Instances map[string]confInstances `json:"instances"` +} + +// confInstances represents an instance config of a service. +type confInstances struct { + Mount map[string]string `json:"mount"` +} + +// dnsmasqConfPath returns the dnsmasq config path. +// +// Since version 24.10, openwrt makes some changes to dnsmasq to support +// multiple instances of dnsmasq. This change causes breaking changes to +// software which depends on the default dnsmasq path. +// +// There are some discussion/PRs in openwrt repo to address this: +// +// - https://github.com/openwrt/openwrt/pull/16806 +// - https://github.com/openwrt/openwrt/pull/16890 +// +// In the meantime, workaround this problem by querying the actual config path +// by querying ubus service list. +func dnsmasqConfPath(r io.Reader) string { + var svc openwrtServiceList + if err := json.NewDecoder(r).Decode(&svc); err != nil { + return openwrtDnsmasqDefaultConfigPath + } + for _, inst := range svc.Dnsmasq.Instances { + for mount := range inst.Mount { + dirName := filepath.Base(mount) + parts := strings.Split(dirName, ".") + if len(parts) < 2 { + continue + } + if parts[0] == "dnsmasq" && parts[len(parts)-1] == "d" { + return filepath.Join(mount, openwrtDNSMasqConfigName) + } + } + } + return openwrtDnsmasqDefaultConfigPath +} + +// dnsmasqConfPathFromUbus get dnsmasq config path from ubus service list. +func dnsmasqConfPathFromUbus() string { + output, err := exec.Command("ubus", "call", "service", "list").Output() + if err != nil { + return openwrtDnsmasqDefaultConfigPath + } + return dnsmasqConfPath(bytes.NewReader(output)) +} diff --git a/internal/router/openwrt/openwrt_test.go b/internal/router/openwrt/openwrt_test.go new file mode 100644 index 0000000..8b260e8 --- /dev/null +++ b/internal/router/openwrt/openwrt_test.go @@ -0,0 +1,58 @@ +package openwrt + +import ( + "io" + "path/filepath" + "strings" + "testing" +) + +// Sample output from https://github.com/openwrt/openwrt/pull/16806#issuecomment-2448255734 +const ubusDnsmasqBefore2410 = `{ + "dnsmasq": { + "instances": { + "guest_dns": { + "mount": { + "/tmp/dnsmasq.d": "0", + "/var/run/dnsmasq/": "1" + } + } + } + } +}` + +const ubusDnsmasq2410 = `{ + "dnsmasq": { + "instances": { + "guest_dns": { + "mount": { + "/tmp/dnsmasq.guest_dns.d": "0", + "/var/run/dnsmasq/": "1" + } + } + } + } +}` + +func Test_dnsmasqConfPath(t *testing.T) { + var dnsmasq2410expected = filepath.Join("/tmp/dnsmasq.guest_dns.d", openwrtDNSMasqConfigName) + tests := []struct { + name string + in io.Reader + expected string + }{ + {"empty", strings.NewReader(""), openwrtDnsmasqDefaultConfigPath}, + {"invalid", strings.NewReader("}}"), openwrtDnsmasqDefaultConfigPath}, + {"before 24.10", strings.NewReader(ubusDnsmasqBefore2410), openwrtDnsmasqDefaultConfigPath}, + {"24.10", strings.NewReader(ubusDnsmasq2410), dnsmasq2410expected}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := dnsmasqConfPath(tc.in); got != tc.expected { + t.Errorf("dnsmasqConfPath() = %v, want %v", got, tc.expected) + } + }) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 4b335a6..2d8c462 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -215,6 +215,20 @@ func LeaseFilesDir() string { return "" } +// ServiceDependencies returns list of dependencies that ctrld services needs on this router. +// See https://pkg.go.dev/github.com/kardianos/service#Config for list format. +func ServiceDependencies() []string { + if Name() == ubios.Name { + // On Ubios, ctrld needs to start after unifi-mongodb, + // so it can query custom client info mapping. + return []string{ + "Wants=unifi-mongodb.service", + "After=unifi-mongodb.service", + } + } + return nil +} + func distroName() string { switch { case bytes.HasPrefix(unameO(), []byte("DD-WRT")): diff --git a/internal/router/service_tomato.go b/internal/router/service_tomato.go index 1a7151a..2cf5939 100644 --- a/internal/router/service_tomato.go +++ b/internal/router/service_tomato.go @@ -45,11 +45,15 @@ func (s *tomatoSvc) Platform() string { } func (s *tomatoSvc) configPath() string { - path, err := os.Executable() - if err != nil { - return "" + bin := s.Config.Executable + if bin == "" { + path, err := os.Executable() + if err != nil { + return "" + } + bin = path } - return path + ".startup" + return bin + ".startup" } func (s *tomatoSvc) template() *template.Template { diff --git a/internal/router/service_ubios.go b/internal/router/service_ubios.go index 0b49cd2..8077c07 100644 --- a/internal/router/service_ubios.go +++ b/internal/router/service_ubios.go @@ -219,6 +219,8 @@ const ubiosBootSystemdService = `[Unit] Description=Run ctrld On Startup UDM Wants=network-online.target After=network-online.target +Wants=unifi-mongodb +After=unifi-mongodb StartLimitIntervalSec=500 StartLimitBurst=5 diff --git a/log.go b/log.go index c521163..14c82e8 100644 --- a/log.go +++ b/log.go @@ -9,11 +9,6 @@ import ( "github.com/rs/zerolog" ) -func init() { - l := zerolog.New(io.Discard) - ProxyLogger.Store(&l) -} - // ProxyLog emits the log record for proxy operations. // The caller should set it only once. // DEPRECATED: use ProxyLogger instead. diff --git a/nameservers_darwin.go b/nameservers_darwin.go index d536d78..b6b1543 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "fmt" + "io" "net" "os/exec" "regexp" @@ -155,6 +156,8 @@ func getDHCPNameservers(iface string) ([]string, error) { } func getAllDHCPNameservers() []string { + logger := *ProxyLogger.Load() + interfaces, err := net.Interfaces() if err != nil { return nil @@ -213,5 +216,67 @@ func getAllDHCPNameservers() []string { } } + // if we have static DNS servers saved for the current default route, we should add them to the list + drIfaceName, err := netmon.DefaultRouteInterface() + Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get default route interface: %v", err) + } else { + drIface, err := net.InterfaceByName(drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get interface by name %s: %v", drIfaceName, err) + } else if drIface != nil { + if _, err := patchNetIfaceName(drIface); err != nil { + Log(context.Background(), logger.Debug(), + "Failed to patch interface name %s: %v", drIfaceName, err) + } + staticNs, file := SavedStaticNameservers(drIface) + Log(context.Background(), logger.Debug(), + "static dns servers from %s: %v", file, staticNs) + if len(staticNs) > 0 { + Log(context.Background(), logger.Debug(), + "Adding static DNS servers from %s: %v", drIface.Name, staticNs) + allNameservers = append(allNameservers, staticNs...) + } + } + } + return allNameservers } + +func patchNetIfaceName(iface *net.Interface) (bool, error) { + b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output() + if err != nil { + return false, err + } + + patched := false + if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { + patched = true + iface.Name = name + } + return patched, nil +} + +func networkServiceName(ifaceName string, r io.Reader) string { + scanner := bufio.NewScanner(r) + prevLine := "" + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "*") { + // Network services is disabled. + continue + } + if !strings.Contains(line, "Device: "+ifaceName) { + prevLine = line + continue + } + parts := strings.SplitN(prevLine, " ", 2) + if len(parts) == 2 { + return strings.TrimSpace(parts[1]) + } + } + return "" +} diff --git a/nameservers_windows.go b/nameservers_windows.go index 54fb8b6..0c47e58 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -17,9 +17,9 @@ import ( "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/constant" "github.com/microsoft/wmi/pkg/hardware/network/netadapter" - "github.com/rs/zerolog" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/net/netmon" ) const ( @@ -62,10 +62,7 @@ func dnsFromAdapter() []string { var ns []string var err error - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { @@ -111,10 +108,8 @@ func dnsFromAdapter() []string { } func getDNSServers(ctx context.Context) ([]string, error) { - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() + // Check context before making the call if ctx.Err() != nil { return nil, ctx.Err() @@ -303,6 +298,28 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } + // if we have static DNS servers saved for the current default route, we should add them to the list + drIfaceName, err := netmon.DefaultRouteInterface() + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get default route interface: %v", err) + } else { + drIface, err := net.InterfaceByName(drIfaceName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get interface by name %s: %v", drIfaceName, err) + } else { + staticNs, file := SavedStaticNameservers(drIface) + Log(context.Background(), logger.Debug(), + "static dns servers from %s: %v", file, staticNs) + if len(staticNs) > 0 { + Log(context.Background(), logger.Debug(), + "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + ns = append(ns, staticNs...) + } + } + } + if len(ns) == 0 { return nil, fmt.Errorf("no valid DNS servers found") } @@ -320,10 +337,8 @@ func nameserversFromResolvconf() []string { // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available func checkDomainJoined() bool { - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() + var domain *uint16 var status uint32 @@ -400,10 +415,7 @@ func validInterfaces() map[string]struct{} { defer log.SetOutput(os.Stderr) //load the logger - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") diff --git a/net.go b/net.go index 3ae3bb5..449620d 100644 --- a/net.go +++ b/net.go @@ -18,6 +18,7 @@ const ipv6ProbingInterval = 10 * time.Second func hasIPv6() bool { hasIPv6Once.Do(func() { + Log(context.Background(), ProxyLogger.Load().Debug(), "checking for IPv6 availability once") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) @@ -43,6 +44,7 @@ func probingIPv6(ctx context.Context, old bool) { if ipv6Available.CompareAndSwap(old, cur) { old = cur } + Log(ctx, ProxyLogger.Load().Debug(), "IPv6 availability: %v", cur) }() } } diff --git a/resolver.go b/resolver.go index 19ebc1f..677738b 100644 --- a/resolver.go +++ b/resolver.go @@ -48,7 +48,15 @@ const ( var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") -var localResolver = newLocalResolver() +var localResolver Resolver + +func init() { + // Initializing ProxyLogger here, so other places don't have to do nil check. + l := zerolog.New(io.Discard) + ProxyLogger.Store(&l) + + localResolver = newLocalResolver() +} var ( resolverMutex sync.Mutex @@ -91,10 +99,8 @@ func availableNameservers() []string { machineIPsMap := make(map[string]struct{}, len(regularIPs)) //load the logger - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() + Log(context.Background(), logger.Debug(), "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) @@ -193,9 +199,12 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeDOQ: return &doqResolver{uc: uc}, nil case ResolverTypeOS: + resolverMutex.Lock() if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver") or = newResolverWithNameserver(defaultNameservers()) } + resolverMutex.Unlock() return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil @@ -277,14 +286,29 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) + // If this is a LAN query, skip public DNS. lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + + // remove controldPublicDnsWithPort from publicServers for LAN queries + // this is to prevent DoS for high frequency local requests if ok && lan { - numServers -= len(publicServers) + if index := slices.Index(publicServers, controldPublicDnsWithPort); index != -1 { + publicServers = slices.Delete(publicServers, index, index+1) + numServers-- + } } + question := "" + if msg != nil && len(msg.Question) > 0 { + question = msg.Question[0].Name + } + Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) + + // New check: If no resolvers are available, return an error. if numServers == 0 { - return nil, errors.New("no nameservers available") + return nil, errors.New("no nameservers available for query") } + ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -320,10 +344,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }(server) } } - do(nss, true) - if !lan { - do(publicServers, false) - } logAnswer := func(server string) { host, _, err := net.SplitHostPort(server) @@ -333,6 +353,18 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) } + + // try local nameservers + if len(nss) > 0 { + do(nss, true) + } + + // we must always try the public servers too, since DCHP may have only public servers + // this is okay to do since we always prefer LAN nameserver responses + if len(publicServers) > 0 { + do(publicServers, false) + } + var ( nonSuccessAnswer *dns.Msg nonSuccessServer string @@ -353,33 +385,49 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.server == controldPublicDnsWithPort: controldSuccessAnswer = res.answer case !res.lan: + // if there are no LAN nameservers, we should not wait + // just use the first response + if len(nss) == 0 { + Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server) + cancel() + logAnswer(res.server) + return res.answer, nil + } publicResponses = append(publicResponses, publicResponse{ answer: res.answer, server: res.server, }) } case res.answer != nil: - nonSuccessAnswer = res.answer - nonSuccessServer = res.server Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", res.server, res.answer.Rcode) + // When there are no LAN nameservers, we should not wait + // for other nameservers to respond. + if len(nss) == 0 { + Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer") + cancel() + logAnswer(res.server) + return res.answer, nil + } + nonSuccessAnswer = res.answer + nonSuccessServer = res.server } errs = append(errs, res.err) } if len(publicResponses) > 0 { resp := publicResponses[0] - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server) + Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server) logAnswer(resp.server) return resp.answer, nil } if controldSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) + Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer) + Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } @@ -428,9 +476,13 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { + resolverMutex.Lock() if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize OS resolver in lookupIP") or = newResolverWithNameserver(defaultNameservers()) } + resolverMutex.Unlock() + nss := *or.lanServers.Load() nss = append(nss, *or.publicServers.Load()...) if withBootstrapDNS { @@ -510,6 +562,9 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) // - Gateway IP address (depends on OS). // - Input servers. func NewBootstrapResolver(servers ...string) Resolver { + logger := *ProxyLogger.Load() + + Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers) nss := defaultNameservers() nss = append([]string{controldPublicDnsWithPort}, nss...) for _, ns := range servers { @@ -526,6 +581,11 @@ func NewBootstrapResolver(servers ...string) Resolver { // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { + + logger := *ProxyLogger.Load() + + Log(context.Background(), logger.Debug(), "NewPrivateResolver called") + nss := defaultNameservers() resolveConfNss := nameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() @@ -570,6 +630,9 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // newResolverWithNameserver returns an OS resolver from given nameservers list. // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { + logger := *ProxyLogger.Load() + + Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers) r := &osResolver{} var publicNss []string var lanNss []string diff --git a/resolver_test.go b/resolver_test.go index e96e875..fb6831b 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -70,41 +70,59 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { } func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { - ns := make([]string, 0, 2) - servers := make([]*dns.Server, 0, 2) - handlers := []dns.Handler{ + // Set up a LAN nameserver that returns a success response. + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback) + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, successHandler()) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + // Set up two public nameservers that return non-success responses. + publicHandlers := []dns.Handler{ nonSuccessHandlerWithRcode(dns.RcodeRefused), nonSuccessHandlerWithRcode(dns.RcodeNameError), - successHandler(), } - for i := range handlers { + var publicNS []string + var publicServers []*dns.Server + for _, handler := range publicHandlers { pc, err := net.ListenPacket("udp", ":0") if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatalf("failed to listen on public address: %v", err) } - - s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i]) + s, addr, err := runLocalPacketConnTestServer(t, pc, handler) if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatalf("failed to run public test server: %v", err) } - ns = append(ns, addr) - servers = append(servers, s) + publicNS = append(publicNS, addr) + publicServers = append(publicServers, s) } defer func() { - for _, server := range servers { - server.Shutdown() + for _, s := range publicServers { + s.Shutdown() } }() + + // We now create an osResolver which has both a LAN and public nameserver. resolver := &osResolver{} - resolver.publicServers.Store(&ns) + // Explicitly store the LAN nameserver. + resolver.lanServers.Store(&[]string{lanAddr}) + // And store the public nameservers. + resolver.publicServers.Store(&publicNS) + msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) if err != nil { t.Fatal(err) } + + // Since a LAN nameserver is available and returns a success answer, we expect RcodeSuccess. if answer.Rcode != dns.RcodeSuccess { - t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode]) + t.Errorf("expected a success answer from LAN nameserver (RcodeSuccess) but got: %s", dns.RcodeToString[answer.Rcode]) } } diff --git a/staticdns.go b/staticdns.go new file mode 100644 index 0000000..1bfd556 --- /dev/null +++ b/staticdns.go @@ -0,0 +1,79 @@ +package ctrld + +import ( + "net" + "os" + "path/filepath" + "runtime" + "strings" +) + +var homedir string + +// absHomeDir returns the absolute path to given filename using home directory as root dir. +func absHomeDir(filename string) string { + if homedir != "" { + return filepath.Join(homedir, filename) + } + dir, err := userHomeDir() + if err != nil { + return filename + } + return filepath.Join(dir, filename) +} + +func dirWritable(dir string) (bool, error) { + f, err := os.CreateTemp(dir, "") + if err != nil { + return false, err + } + defer os.Remove(f.Name()) + return true, f.Close() +} + +func userHomeDir() (string, error) { + // viper will expand for us. + if runtime.GOOS == "windows" { + // If we're on windows, use the install path for this. + exePath, err := os.Executable() + if err != nil { + return "", err + } + + return filepath.Dir(exePath), nil + } + dir := "/etc/controld" + if err := os.MkdirAll(dir, 0750); err != nil { + return os.UserHomeDir() // fallback to user home directory + } + if ok, _ := dirWritable(dir); !ok { + return os.UserHomeDir() + } + return dir, nil +} + +// SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings +// for the provided interface are saved. +func SavedStaticDnsSettingsFilePath(iface *net.Interface) string { + // The file is stored in the user home directory under a hidden file. + return absHomeDir(".dns_" + iface.Name) +} + +// SavedStaticNameservers returns the stored static nameservers for the given interface. +func SavedStaticNameservers(iface *net.Interface) ([]string, string) { + file := SavedStaticDnsSettingsFilePath(iface) + data, err := os.ReadFile(file) + if err != nil || len(data) == 0 { + return nil, file + } + saveValues := strings.Split(string(data), ",") + var ns []string + for _, v := range saveValues { + // Skip any IP that is loopback + if ip := net.ParseIP(v); ip != nil && ip.IsLoopback() { + continue + } + ns = append(ns, v) + } + return ns, file +}