diff --git a/README.md b/README.md index eddcb41..ef98c38 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/Control-D-Inc/ctrld.svg)](https://pkg.go.dev/github.com/Control-D-Inc/ctrld) [![Go Report Card](https://goreportcard.com/badge/github.com/Control-D-Inc/ctrld)](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld) +![ctrld spash image](/docs/ctrldsplash.png) + A highly configurable DNS forwarding proxy with support for: - Multiple listeners for incoming queries - Multiple upstreams with fallbacks diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5c7795f..bdae37a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -18,15 +18,14 @@ import ( "path/filepath" "reflect" "runtime" + "runtime/debug" "sort" "strconv" "strings" - "sync" "time" "github.com/Masterminds/semver" "github.com/cuonglm/osinfo" - "github.com/fsnotify/fsnotify" "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/miekg/dns" @@ -86,7 +85,7 @@ var rootCmd = &cobra.Command{ Use: "ctrld", Short: strings.TrimLeft(rootShortDesc, "\n"), Version: curVersion(), - PreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRun: func(cmd *cobra.Command, args []string) { initConsoleLogging() }, } @@ -127,9 +126,6 @@ func initCLI() { Use: "run", Short: "Run the DNS proxy server", Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, Run: func(cmd *cobra.Command, args []string) { RunCobraCommand(cmd) }, @@ -158,7 +154,6 @@ func initCLI() { startCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "start", @@ -187,11 +182,11 @@ func initCLI() { return } - status, err := s.Status() - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + status, _ := s.Status() + isCtrldRunning := status == service.StatusRunning // If pin code was set, do not allow running start command. - if status == service.StatusRunning { + if isCtrldRunning { if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } @@ -255,13 +250,14 @@ func initCLI() { // A buffer channel to gather log output from runCmd and report // to user in case self-check process failed. runCmdLogCh := make(chan string, 256) - if dir, err := userHomeDir(); err == nil { - setWorkingDirectory(sc, dir) + ud, err := userHomeDir() + sockDir := ud + if err == nil { + setWorkingDirectory(sc, ud) if configPath == "" && writeDefaultConfig { - defaultConfigFile = filepath.Join(dir, defaultConfigFile) + defaultConfigFile = filepath.Join(ud, defaultConfigFile) } - sc.Arguments = append(sc.Arguments, "--homedir="+dir) - sockDir := dir + sc.Arguments = append(sc.Arguments, "--homedir="+ud) if d, err := socketDir(); err == nil { sockDir = d } @@ -312,18 +308,11 @@ func initCLI() { } tasks := []task{ + resetDnsTask(p, s), {s.Stop, false}, {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, {func() error { return ensureUninstall(s) }, false}, {func() error { - // If ctrld is installed, we should not save current DNS settings, because: - // - // - The DNS settings was being set by ctrld already. - // - We could not determine the state of DNS settings before installing ctrld. - if isCtrldInstalled { - return nil - } - // Save current DNS so we can restore later. withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error { return saveCurrentStaticDNS(i) @@ -343,7 +332,7 @@ func initCLI() { return } - ok, status, err := selfCheckStatus(s) + ok, status, err := selfCheckStatus(s, ud, sockDir) switch { case ok && status == service.StatusRunning: mainLog.Load().Notice().Msg("Service started") @@ -381,7 +370,15 @@ func initCLI() { uninstall(p, s) os.Exit(1) } - p.setDNS() + if cc := newSocketControlClient(s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + logger := mainLog.Load().With().Str("iface", iface).Logger() + logger.Debug().Msg("setting DNS successfully") + } + } } }, } @@ -401,12 +398,10 @@ func initCLI() { startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) routerCmd := &cobra.Command{ Use: "setup", - PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, Run: func(cmd *cobra.Command, _ []string) { exe, err := os.Executable() if err != nil { @@ -434,7 +429,6 @@ func initCLI() { stopCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "stop", @@ -456,6 +450,23 @@ func initCLI() { if doTasks([]task{{s.Stop, true}}) { p.router.Cleanup() p.resetDNS() + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + return + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } mainLog.Load().Notice().Msg("Service stopped") } }, @@ -466,14 +477,16 @@ func initCLI() { restartCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "restart", Short: "Restart the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) return @@ -484,6 +497,7 @@ func initCLI() { } initLogging() + iface = runningIface(s) tasks := []task{ {s.Stop, false}, {s.Start, true}, @@ -494,10 +508,12 @@ func initCLI() { mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") return } - if cc := newSocketControlClient(s, dir); cc == nil { + cc := newSocketControlClient(s, dir) + if cc == nil { mainLog.Load().Notice().Msg("Service was not restarted") os.Exit(1) } + _, _ = cc.post(ifacePath, nil) mainLog.Load().Notice().Msg("Service restarted") } }, @@ -505,7 +521,6 @@ func initCLI() { reloadCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "reload", @@ -551,9 +566,6 @@ func initCLI() { Use: "status", Short: "Show status of the ctrld service", Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, Run: func(cmd *cobra.Command, args []string) { s, err := newService(&prog{}, svcConfig) if err != nil { @@ -581,14 +593,12 @@ func initCLI() { if runtime.GOOS == "darwin" { // On darwin, running status command without privileges may return wrong information. statusCmd.PreRun = func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() } } uninstallCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "uninstall", @@ -623,9 +633,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Use: "list", Short: "List network interfaces of the host", Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, Run: func(cmd *cobra.Command, args []string) { err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { fmt.Printf("Index : %d\n", i.Index) @@ -686,7 +693,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, rootCmd.AddCommand(serviceCmd) startCmdAlias := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "start", @@ -704,7 +710,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, rootCmd.AddCommand(startCmdAlias) stopCmdAlias := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "stop", @@ -723,7 +728,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, restartCmdAlias := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "restart", @@ -736,7 +740,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, reloadCmdAlias := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "reload", @@ -751,16 +754,12 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Use: "status", Short: "Show status of the ctrld service", Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, - Run: statusCmd.Run, + Run: statusCmd.Run, } rootCmd.AddCommand(statusCmdAlias) uninstallCmdAlias := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Use: "uninstall", @@ -785,7 +784,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Short: "List clients that ctrld discovered", Args: cobra.NoArgs, PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { @@ -873,25 +871,31 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, Args: cobra.MaximumNArgs(1), PreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } bin, err := os.Executable() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") } + sc := &service.Config{} + *sc = *svcConfig + sc.Executable = bin + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + + svcInstalled := true + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + svcInstalled = false + } oldBin := bin + "_previous" - urlString := upgradeChannel[upgradeChannelDefault] + baseUrl := upgradeChannel[upgradeChannelDefault] if len(args) > 0 { channel := args[0] switch channel { @@ -899,12 +903,9 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, default: mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) } - urlString = upgradeChannel[channel] - } - dlUrl := fmt.Sprintf("%s/%s-%s/ctrld", urlString, runtime.GOOS, runtime.GOARCH) - if runtime.GOOS == "windows" { - dlUrl += ".exe" + baseUrl = upgradeChannel[channel] } + dlUrl := upgradeUrl(baseUrl) mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) resp, err := http.Get(dlUrl) if err != nil { @@ -923,18 +924,26 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } doRestart := func() bool { + if !svcInstalled { + return true + } tasks := []task{ {s.Stop, false}, {s.Start, false}, } if doTasks(tasks) { if dir, err := socketDir(); err == nil { - return newSocketControlClient(s, dir) != nil + if cc := newSocketControlClient(s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } } } return false } - mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + if svcInstalled { + mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + } if doRestart() { _ = os.Remove(oldBin) _ = os.Chmod(bin, 0755) @@ -1608,7 +1617,7 @@ func defaultIfaceName() string { // - External testing, ensuring query could be sent from ctrld -> upstream. // // Self-check is considered success only if both tests are ok. -func selfCheckStatus(s service.Service) (bool, service.Status, error) { +func selfCheckStatus(s service.Service, homedir, sockDir string) (bool, service.Status, error) { status, err := s.Status() if err != nil { mainLog.Load().Warn().Err(err).Msg("could not get service status") @@ -1618,117 +1627,80 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { if status != service.StatusRunning { return false, status, nil } - dir, err := socketDir() - if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to check ctrld listener status: could not get home directory") - return false, status, err + // Skip self checks if set. + if skipSelfChecks { + return true, status, nil } + mainLog.Load().Debug().Msg("waiting for ctrld listener to be ready") - cc := newSocketControlClient(s, dir) + cc := newSocketControlClient(s, sockDir) if cc == nil { return false, status, errors.New("could not connect to control server") } - resp, err := cc.post(startedPath, nil) - if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to connect to control server") + v = viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.SetConfigNameWithPath(v, "ctrld", homedir) + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed()) return false, status, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - mainLog.Load().Error().Msg("ctrld listener is not ready") - return false, status, errors.New("ctrld listener is not ready") + + cfg = ctrld.Config{} + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to update new config") + return false, status, err } - // Not a ctrld upstream, return status as-is. - if cfg.FirstUpstream().VerifyDomain() == "" { + selfCheckExternalDomain := cfg.FirstUpstream().VerifyDomain() + if selfCheckExternalDomain == "" { + // Nothing to do, return the status as-is. return true, status, nil } mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("performing self-check") - bo := backoff.NewBackoff("self-check", logf, 10*time.Second) - bo.LogLongerThan = 500 * time.Millisecond - ctx := context.Background() - maxAttempts := 20 - c := new(dns.Client) - var ( - lcChanged map[string]*ctrld.ListenerConfig - ucChanged map[string]*ctrld.UpstreamConfig - mu sync.Mutex - ) - if err := v.ReadInConfig(); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read new config") - } - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to update new config") - } - domain := cfg.FirstUpstream().VerifyDomain() - if domain == "" { - // Nothing to do, return the status as-is. - return true, status, nil - } - watcher, err := fsnotify.NewWatcher() - if err != nil { - mainLog.Load().Error().Err(err).Msg("could not watch config change") + lc := cfg.FirstListener() + addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) + if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil { return false, status, err } - defer watcher.Close() + if err := selfCheckResolveDomain(context.TODO(), addr, "external", selfCheckExternalDomain); err != nil { + return false, status, err + } + return true, status, nil +} + +// selfCheckResolveDomain performs DNS test query against ctrld listener. +func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error { + bo := backoff.NewBackoff("self-check", logf, 10*time.Second) + bo.LogLongerThan = 500 * time.Millisecond + maxAttempts := 20 + c := new(dns.Client) - v.OnConfigChange(func(in fsnotify.Event) { - mu.Lock() - defer mu.Unlock() - if err := v.UnmarshalKey("listener", &lcChanged); err != nil { - mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err) - return - } - if err := v.UnmarshalKey("upstream", &ucChanged); err != nil { - mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err) - return - } - }) - v.WatchConfig() var ( - lastAnswer *dns.Msg - lastErr error - internalTested bool + lastAnswer *dns.Msg + lastErr error ) - for i := 0; i < maxAttempts; i++ { - mu.Lock() - if lcChanged != nil { - cfg.Listener = lcChanged - } - if ucChanged != nil { - cfg.Upstream = ucChanged - } - mu.Unlock() - lc := cfg.FirstListener() - domain = cfg.FirstUpstream().VerifyDomain() - if !internalTested { - domain = selfCheckInternalTestDomain - } - if domain == "" { - continue - } + for i := 0; i < maxAttempts; i++ { + if domain == "" { + return errors.New("empty test domain") + } m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) m.RecursionDesired = true - r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) + r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { - internalTested = domain == selfCheckInternalTestDomain - if internalTested { - mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain) - continue // internal domain test ok, continue with external test. - } else { - mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain) - } - return true, status, nil + mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain) + return nil } // Return early if this is a connection refused. if errConnectionRefused(exErr) { - return false, status, exErr + return exErr } lastAnswer = r lastErr = exErr @@ -1741,8 +1713,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } - lc := cfg.FirstListener() - addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) marker := strings.Repeat("=", 32) mainLog.Load().Debug().Msg(marker) mainLog.Load().Debug().Msgf("listener address : %s", addr) @@ -1753,9 +1723,8 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { for _, s := range strings.Split(lastAnswer.String(), "\n") { mainLog.Load().Debug().Msgf("%s", s) } - return false, status, errSelfCheckNoAnswer } - return false, status, lastErr + return errSelfCheckNoAnswer } func userHomeDir() (string, error) { @@ -2293,6 +2262,10 @@ func removeProvTokenFromArgs(sc *service.Config) { // newSocketControlClient returns new control client after control server was started. func newSocketControlClient(s service.Service, dir string) *controlClient { + // Return early if service is not running. + if status, err := s.Status(); err != nil || status != service.StatusRunning { + return nil + } bo := backoff.NewBackoff("self-check", logf, 10*time.Second) bo.LogLongerThan = 10 * time.Second ctx := context.Background() @@ -2302,28 +2275,21 @@ func newSocketControlClient(s service.Service, dir string) *controlClient { defer timeout.Stop() // The socket control server may not start yet, so attempt to ping - // it until we got a response. For each iteration, check ctrld status - // to make sure ctrld is still running. + // it until we got a response. for { - curStatus, err := s.Status() - if err != nil { - return nil - } - if curStatus != service.StatusRunning { - return nil - } - if _, err := cc.post("/", nil); err == nil { + _, err := cc.post(startedPath, nil) + if err == nil { // Server was started, stop pinging. break } // The socket control server is not ready yet, backoff for waiting it to be ready. bo.BackOff(ctx, err) + select { case <-timeout.C: return nil default: } - continue } return cc @@ -2496,11 +2462,6 @@ func powershell(cmd string) ([]byte, error) { // windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. func windowsHasLocalDnsServerRunning() bool { if runtime.GOOS == "windows" { - out, _ := powershell("Get-WindowsFeature -Name DNS") - if !bytes.Contains(bytes.ToLower(out), []byte("installed")) { - return false - } - _, err := powershell("Get-Process -Name DNS") return err == nil } @@ -2535,3 +2496,79 @@ func runInCdMode() bool { } return false } + +// goArm returns the GOARM value for the binary. +func goArm() string { + if runtime.GOARCH != "arm" { + return "" + } + if bi, ok := debug.ReadBuildInfo(); ok { + for _, setting := range bi.Settings { + if setting.Key == "GOARM" { + return setting.Value + } + } + } + // Use ARM v5 as a fallback, since it works on all others. + return "5" +} + +// upgradeUrl returns the url for downloading new ctrld binary. +func upgradeUrl(baseUrl string) string { + dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH) + if armVersion := goArm(); armVersion != "" { + dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion) + } + dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath) + if runtime.GOOS == "windows" { + dlUrl += ".exe" + } + return dlUrl +} + +// runningIface returns the value of the iface variable used by ctrld process which is running. +func runningIface(s service.Service) string { + if sockDir, err := socketDir(); err == nil { + if cc := newSocketControlClient(s, sockDir); cc != nil { + resp, err := cc.post(ifacePath, nil) + if err != nil { + return "" + } + defer resp.Body.Close() + if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 { + return string(buf) + } + } + } + return "" +} + +// resetDnsNoLog performs resetting DNS with logging disable. +func resetDnsNoLog(p *prog) { + lvl := zerolog.GlobalLevel() + zerolog.SetGlobalLevel(zerolog.Disabled) + p.resetDNS() + zerolog.SetGlobalLevel(lvl) +} + +// resetDnsTask returns a task which perform reset DNS operation. +func resetDnsTask(p *prog, s service.Service) task { + status, err := s.Status() + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + isCtrldRunning := status == service.StatusRunning + return task{func() error { + // Always reset DNS first, ensuring DNS setting is in a good state. + // resetDNS must use the "iface" value of current running ctrld + // process to reset what setDNS has done properly. + oldIface := iface + iface = "auto" + if isCtrldRunning { + iface = runningIface(s) + } + if isCtrldInstalled { + resetDnsNoLog(p) + } + iface = oldIface + return nil + }, false} +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 4d243bf..66a38a3 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -10,6 +10,8 @@ import ( "sort" "time" + "github.com/kardianos/service" + dto "github.com/prometheus/client_model/go" "github.com/Control-D-Inc/ctrld" @@ -22,6 +24,7 @@ const ( reloadPath = "/reload" deactivationPath = "/deactivation" cdPath = "/cd" + ifacePath = "/iface" ) type controlServer struct { @@ -179,6 +182,17 @@ func (p *prog) registerControlServerHandler() { } w.WriteHeader(http.StatusBadRequest) })) + p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + // p.setDNS is only called when running as a service + if !service.Interactive() { + <-p.csSetDnsDone + if p.csSetDnsOk { + w.Write([]byte(iface)) + return + } + } + w.WriteHeader(http.StatusBadRequest) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a5242c5..9f95812 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -210,12 +210,9 @@ func (p *prog) serveDNS(listenerNum string) error { addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - select { - case err := <-errCh: - return err - case <-time.After(5 * time.Second): - p.started <- struct{}{} - } + + p.started <- struct{}{} + select { case <-p.stopCh: case <-ctx.Done(): @@ -752,20 +749,19 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha Handler: handler, } - waitLock := sync.Mutex{} - waitLock.Lock() - s.NotifyStartedFunc = waitLock.Unlock + startedCh := make(chan struct{}) + s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() } errCh := make(chan error) go func() { defer close(errCh) if err := s.ListenAndServe(); err != nil { - waitLock.Unlock() + s.NotifyStartedFunc() mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr) errCh <- err } }() - waitLock.Lock() + <-startedCh return s, errCh } diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 06a7e03..3504bc3 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -105,6 +105,10 @@ func (p *prog) checkDnsLoop() { for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] + // Skipping upstream which is being marked as down. + if uc == nil { + continue + } resolver, err := ctrld.NewResolver(uc) if err != nil { mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 9c64aa0..279f5f2 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -35,6 +35,7 @@ var ( nextdns string cdUpstreamProto string deactivationPin int64 + skipSelfChecks bool mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index d2f1dd2..6441e05 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -41,17 +41,12 @@ func setDNS(iface *net.Interface, nameservers []string) error { // Configuring the Dns server to forward queries to ctrld instead. if windowsHasLocalDnsServerRunning() { file := absHomeDir(forwardersFilename) - if data, _ := os.ReadFile(file); len(data) > 0 { - if err := removeDnsServerForwarders(strings.Split(string(data), ",")); err != nil { - mainLog.Load().Error().Err(err).Msg("could not remove current forwarders settings") - } else { - mainLog.Load().Debug().Msg("removed current forwarders settings.") - } - } + oldForwardersContent, _ := os.ReadFile(file) if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil { mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") } - if err := addDnsServerForwarders(nameservers); err != nil { + oldForwarders := strings.Split(string(oldForwardersContent), ",") + if err := addDnsServerForwarders(nameservers, oldForwarders); err != nil { mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") } } @@ -213,14 +208,32 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) { return ns, nil } -// addDnsServerForwarders adds given nameservers to DNS server forwarders list. -func addDnsServerForwarders(nameservers []string) error { - for _, ns := range nameservers { - cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns) - if out, err := powershell(cmd); err != nil { - return fmt.Errorf("%w: %s", err, string(out)) +// addDnsServerForwarders adds given nameservers to DNS server forwarders list, +// and also removing old forwarders if provided. +func addDnsServerForwarders(nameservers, old []string) error { + newForwardersMap := make(map[string]struct{}) + newForwarders := make([]string, len(nameservers)) + for i := range nameservers { + newForwardersMap[nameservers[i]] = struct{}{} + newForwarders[i] = fmt.Sprintf("%q", nameservers[i]) + } + oldForwarders := old[:0] + for _, fwd := range old { + if _, ok := newForwardersMap[fwd]; !ok { + oldForwarders = append(oldForwarders, fwd) } } + // NOTE: It is important to add new forwarder before removing old one. + // Testing on Windows Server 2022 shows that removing forwarder1 + // then adding forwarder2 sometimes ends up adding both of them + // to the forwarders list. + cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ",")) + if len(oldForwarders) > 0 { + cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ",")) + } + if out, err := powershell(cmd); err != nil { + return fmt.Errorf("%w: %s", err, string(out)) + } return nil } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index b3f3abf..8e35575 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -69,6 +69,8 @@ type prog struct { reloadDoneCh chan struct{} logConn net.Conn cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool cfg *ctrld.Config localUpstreams []string @@ -194,9 +196,6 @@ func (p *prog) runWait() { } func (p *prog) preRun() { - if !service.Interactive() { - p.setDNS() - } if runtime.GOOS == "darwin" { p.onStopped = append(p.onStopped, func() { if !service.Interactive() { @@ -206,6 +205,15 @@ func (p *prog) preRun() { } } +func (p *prog) postRun() { + if !service.Interactive() { + p.resetDNS() + ns := ctrld.InitializeOsResolver() + mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) + p.setDNS() + } +} + func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) @@ -249,6 +257,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { numListeners := len(p.cfg.Listener) if !reload { p.started = make(chan struct{}, numListeners) + if p.cs != nil { + p.csSetDnsDone = make(chan struct{}, 1) + p.registerControlServerHandler() + if err := p.cs.start(); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not start control server") + } + mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr) + } } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) @@ -381,12 +397,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.logConn != nil { _ = p.logConn.Close() } - if p.cs != nil { - p.registerControlServerHandler() - if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") - } - } + p.postRun() } wg.Wait() } @@ -430,17 +441,25 @@ func (p *prog) deAllocateIP() error { } func (p *prog) setDNS() { + setDnsOK := false + defer func() { + p.csSetDnsOk = setDnsOK + p.csSetDnsDone <- struct{}{} + close(p.csSetDnsDone) + }() + if cfg.Listener == nil { return } if iface == "" { return } + runningIface := iface // allIfaces tracks whether we should set DNS for all physical interfaces. allIfaces := false - if iface == "auto" { - iface = defaultIfaceName() - // If iface is "auto", it means user does not specify "--iface" flag. + if runningIface == "auto" { + runningIface = defaultIfaceName() + // If runningIface is "auto", it means user does not specify "--iface" flag. // In this case, ctrld has to set DNS for all physical interfaces, so // thing will still work when user switch from one to the other. allIfaces = requiredMultiNICsConfig() @@ -449,8 +468,8 @@ func (p *prog) setDNS() { if lc == nil { return } - logger := mainLog.Load().With().Str("iface", iface).Logger() - netIface, err := netInterface(iface) + logger := mainLog.Load().With().Str("iface", runningIface).Logger() + netIface, err := netInterface(runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return @@ -484,6 +503,7 @@ func (p *prog) setDNS() { logger.Error().Err(err).Msgf("could not set DNS for interface") return } + setDnsOK = true logger.Debug().Msg("setting DNS successfully") if shouldWatchResolvconf() { servers := make([]netip.Addr, len(nameservers)) @@ -503,14 +523,15 @@ func (p *prog) resetDNS() { if iface == "" { return } + runningIface := iface allIfaces := false - if iface == "auto" { - iface = defaultIfaceName() + if runningIface == "auto" { + runningIface = defaultIfaceName() // See corresponding comments in (*prog).setDNS function. allIfaces = requiredMultiNICsConfig() } - logger := mainLog.Load().With().Str("iface", iface).Logger() - netIface, err := netInterface(iface) + logger := mainLog.Load().With().Str("iface", runningIface).Logger() + netIface, err := netInterface(runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index cdb3c0e..0af906d 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -1,7 +1,12 @@ package cli import ( + "bufio" + "bytes" + "io" "os" + "os/exec" + "strings" "github.com/kardianos/service" @@ -24,12 +29,34 @@ func setDependencies(svc *service.Config) { "After=network-online.target", "Wants=NetworkManager-wait-online.service", "After=NetworkManager-wait-online.service", - "Wants=systemd-networkd-wait-online.service", "Wants=nss-lookup.target", "After=nss-lookup.target", } + if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 { + if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) { + svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service") + } + } } func setWorkingDirectory(svc *service.Config, dir string) { svc.WorkingDirectory = dir } + +// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service +// is required to be added to ctrld dependencies services. +// The input reader r is the output of "networkctl --no-pager" command. +func wantsSystemDNetworkdWaitOnline(r io.Reader) bool { + scanner := bufio.NewScanner(r) + // Skip header + scanner.Scan() + configured := false + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) > 0 && fields[len(fields)-1] == "configured" { + configured = true + break + } + } + return configured +} diff --git a/cmd/cli/prog_linux_test.go b/cmd/cli/prog_linux_test.go new file mode 100644 index 0000000..ecc4cd5 --- /dev/null +++ b/cmd/cli/prog_linux_test.go @@ -0,0 +1,48 @@ +package cli + +import ( + "io" + "strings" + "testing" +) + +const ( + networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP + 1 lo loopback carrier unmanaged + 2 wlp0s20f3 wlan routable unmanaged + 3 tailscale0 none routable unmanaged + 4 br-9ac33145e060 bridge no-carrier unmanaged + 5 docker0 bridge no-carrier unmanaged + +5 links listed. +` + networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP + 1 lo loopback carrier unmanaged + 2 wlp0s20f3 wlan routable configured + 3 tailscale0 none routable unmanaged + 4 br-9ac33145e060 bridge no-carrier unmanaged + 5 docker0 bridge no-carrier unmanaged + +5 links listed. +` +) + +func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) { + tests := []struct { + name string + r io.Reader + required bool + }{ + {"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false}, + {"managed", strings.NewReader(networkctlManagedOutput), true}, + {"empty", strings.NewReader(""), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required { + t.Errorf("wants %v got %v", tc.required, required) + } + }) + } +} diff --git a/cmd/cli/service.go b/cmd/cli/service.go index ef37796..1de206f 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -21,7 +21,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error) } switch { case router.IsOldOpenwrt(), router.IsNetGearOrbi(): - return &procd{&sysV{s}}, nil + return &procd{sysV: &sysV{s}, svcConfig: c}, nil case router.IsGLiNet(): return &sysV{s}, nil case s.Platform() == "unix-systemv": @@ -89,18 +89,24 @@ func (s *sysV) Status() (service.Status, error) { // like old GL.iNET Opal router. type procd struct { *sysV + svcConfig *service.Config } func (s *procd) Status() (service.Status, error) { if !s.installed() { return service.StatusUnknown, service.ErrNotInstalled } - exe, err := os.Executable() - if err != nil { - return service.StatusUnknown, nil + bin := s.svcConfig.Executable + if bin == "" { + exe, err := os.Executable() + if err != nil { + return service.StatusUnknown, nil + } + bin = exe } + // Looking for something like "/sbin/ctrld run ". - shellCmd := fmt.Sprintf("ps | grep -q %q", exe+" [r]un ") + shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ") if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil { return service.StatusStopped, nil } diff --git a/docs/config.md b/docs/config.md index d9c1dae..5615d30 100644 --- a/docs/config.md +++ b/docs/config.md @@ -336,7 +336,7 @@ The protocol that `ctrld` will use to send DNS requests to upstream. - Type: string - Required: yes - - Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os` + - Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy` ### ip_stack Specifying what kind of ip stack that `ctrld` will use to connect to upstream. diff --git a/docs/ctrldsplash.png b/docs/ctrldsplash.png new file mode 100644 index 0000000..9de7346 Binary files /dev/null and b/docs/ctrldsplash.png differ diff --git a/internal/router/openwrt/procd.go b/internal/router/openwrt/procd.go index 8e74461..bf7253e 100644 --- a/internal/router/openwrt/procd.go +++ b/internal/router/openwrt/procd.go @@ -18,6 +18,7 @@ start_service() { procd_set_param stdout 1 # forward stdout of the command to logd procd_set_param stderr 1 # same for stderr procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop + procd_set_param term_timeout 10 procd_close_instance echo "${name} has been started" } diff --git a/internal/router/router.go b/internal/router/router.go index 18b7a90..bf65e6e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -98,6 +98,11 @@ func IsOldOpenwrt() bool { return cmd == "" } +// WaitProcessExited reports whether the "ctrld stop" command have to wait until ctrld process exited. +func WaitProcessExited() bool { + return Name() == openwrt.Name +} + var routerPlatform atomic.Pointer[router] type router struct { diff --git a/internal/router/service_merlin.go b/internal/router/service_merlin.go index 76ea938..8ab6d6a 100644 --- a/internal/router/service_merlin.go +++ b/internal/router/service_merlin.go @@ -49,11 +49,15 @@ func (s *merlinSvc) Platform() string { } func (s *merlinSvc) configPath() string { - path, err := os.Executable() - if err != nil { - return "" + bin := s.Config.Executable + if bin == "" { + path, err := os.Executable() + if err != nil { + return "" + } + bin = path } - return path + ".startup" + return bin + ".startup" } func (s *merlinSvc) template() *template.Template { diff --git a/nameservers_windows.go b/nameservers_windows.go index ea9b347..150f252 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -1,12 +1,9 @@ package ctrld import ( - "net" "syscall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - - "golang.org/x/sys/windows" ) func dnsFns() []dnsFn { @@ -20,40 +17,23 @@ func dnsFromAdapter() []string { } ns := make([]string, 0, len(aas)*2) seen := make(map[string]bool) - do := func(addr windows.SocketAddress) { - sa, err := addr.Sockaddr.Sockaddr() - if err != nil { - return + addressMap := make(map[string]struct{}) + for _, aa := range aas { + for a := aa.FirstUnicastAddress; a != nil; a = a.Next { + addressMap[a.Address.IP().String()] = struct{}{} } - var ip net.IP - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) - case *syscall.SockaddrInet6: - ip = make(net.IP, net.IPv6len) - copy(ip, sa.Addr[:]) - if ip[0] == 0xfe && ip[1] == 0xc0 { - // Ignore these fec0/10 ones. Windows seems to - // populate them as defaults on its misc rando - // interfaces. - return - } - default: - return - - } - if ip.IsLoopback() || seen[ip.String()] { - return - } - seen[ip.String()] = true - ns = append(ns, ip.String()) } for _, aa := range aas { for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { - do(dns.Address) - } - for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next { - do(gw.Address) + ip := dns.Address.IP() + if ip == nil || ip.IsLoopback() || seen[ip.String()] { + continue + } + if _, ok := addressMap[ip.String()]; ok { + continue + } + seen[ip.String()] = true + ns = append(ns, ip.String()) } } return ns diff --git a/resolver.go b/resolver.go index 0a4569e..49ac652 100644 --- a/resolver.go +++ b/resolver.go @@ -35,13 +35,26 @@ const bootstrapDNS = "76.76.2.22" // or is the Resolver used for ResolverTypeOS. var or = &osResolver{nameservers: defaultNameservers()} -// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver. +// defaultNameservers returns nameservers used by the OS. +// If no nameservers can be found, ctrld bootstrap nameserver will be used. func defaultNameservers() []string { ns := nameservers() - ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) + if len(ns) == 0 { + ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) + } return ns } +// InitializeOsResolver initializes OS resolver using the current system DNS settings. +// It returns the nameservers that is going to be used by the OS resolver. +// +// It's the caller's responsibility to ensure the system DNS is in a clean state before +// calling this function. +func InitializeOsResolver() []string { + or.nameservers = defaultNameservers() + return or.nameservers +} + // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error.