From 96085147ffb9f852207c69b493caf8f22cb4f495 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 21 May 2024 17:08:18 +0700 Subject: [PATCH] all: preserve DNS settings when running "ctrld restart" By attempting to reset DNS before starting new ctrld process. This way, ctrld will read the correct system DNS settings before changing itself. While at it, some optimizations are made: - "ctrld start" won't set DNS anymore, since "ctrld run" has already did this, start command could just query socket control server and emittin proper message to users. - The gateway won't be included as nameservers on Windows anymore, since the GetAdaptersAddresses Windows API always returns the correct DNS servers of the interfaces. - The nameservers list that OS resolver is using will be shown during ctrld startup, making it easier for debugging. --- cmd/cli/cli.go | 34 ++++++++++++++++++++++++++++++++-- cmd/cli/control_server.go | 14 ++++++++++++++ cmd/cli/prog.go | 12 ++++++++++++ nameservers_windows.go | 6 +----- resolver.go | 6 +++++- 5 files changed, 64 insertions(+), 8 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 45292b4..c0ef45c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -377,7 +377,15 @@ func initCLI() { uninstall(p, s) os.Exit(1) } - p.setDNS() + if cc := newSocketControlClient(s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + logger := mainLog.Load().With().Str("iface", iface).Logger() + logger.Debug().Msg("setting DNS successfully") + } + } } }, } @@ -482,7 +490,10 @@ func initCLI() { Short: "Restart the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) return @@ -493,8 +504,10 @@ func initCLI() { } initLogging() + iface = runningIface(s) tasks := []task{ {s.Stop, false}, + {func() error { p.resetDNS(); return nil }, false}, {s.Start, true}, } if doTasks(tasks) { @@ -2511,3 +2524,20 @@ func upgradeUrl(baseUrl string) string { } return dlUrl } + +// runningIface returns the value of the iface variable used by ctrld process which is running. +func runningIface(s service.Service) string { + if sockDir, err := socketDir(); err == nil { + if cc := newSocketControlClient(s, sockDir); cc != nil { + resp, err := cc.post(ifacePath, nil) + if err != nil { + return "" + } + defer resp.Body.Close() + if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 { + return string(buf) + } + } + } + return "" +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 4d243bf..66a38a3 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -10,6 +10,8 @@ import ( "sort" "time" + "github.com/kardianos/service" + dto "github.com/prometheus/client_model/go" "github.com/Control-D-Inc/ctrld" @@ -22,6 +24,7 @@ const ( reloadPath = "/reload" deactivationPath = "/deactivation" cdPath = "/cd" + ifacePath = "/iface" ) type controlServer struct { @@ -179,6 +182,17 @@ func (p *prog) registerControlServerHandler() { } w.WriteHeader(http.StatusBadRequest) })) + p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + // p.setDNS is only called when running as a service + if !service.Interactive() { + <-p.csSetDnsDone + if p.csSetDnsOk { + w.Write([]byte(iface)) + return + } + } + w.WriteHeader(http.StatusBadRequest) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2e74a98..d2ea7a9 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -69,6 +69,8 @@ type prog struct { reloadDoneCh chan struct{} logConn net.Conn cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool cfg *ctrld.Config localUpstreams []string @@ -204,6 +206,7 @@ func (p *prog) preRun() { } func (p *prog) postRun() { + mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ctrld.OsNameservers) if !service.Interactive() { p.setDNS() } @@ -253,6 +256,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { p.started = make(chan struct{}, numListeners) if p.cs != nil { + p.csSetDnsDone = make(chan struct{}, 1) p.registerControlServerHandler() if err := p.cs.start(); err != nil { mainLog.Load().Warn().Err(err).Msg("could not start control server") @@ -435,6 +439,13 @@ func (p *prog) deAllocateIP() error { } func (p *prog) setDNS() { + setDnsOK := false + defer func() { + p.csSetDnsOk = setDnsOK + p.csSetDnsDone <- struct{}{} + close(p.csSetDnsDone) + }() + if cfg.Listener == nil { return } @@ -489,6 +500,7 @@ func (p *prog) setDNS() { logger.Error().Err(err).Msgf("could not set DNS for interface") return } + setDnsOK = true logger.Debug().Msg("setting DNS successfully") if shouldWatchResolvconf() { servers := make([]netip.Addr, len(nameservers)) diff --git a/nameservers_windows.go b/nameservers_windows.go index ea9b347..ded4658 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -4,9 +4,8 @@ import ( "net" "syscall" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) func dnsFns() []dnsFn { @@ -52,9 +51,6 @@ func dnsFromAdapter() []string { for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { do(dns.Address) } - for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next { - do(gw.Address) - } } return ns } diff --git a/resolver.go b/resolver.go index 0a4569e..f9951d7 100644 --- a/resolver.go +++ b/resolver.go @@ -32,8 +32,12 @@ const ( const bootstrapDNS = "76.76.2.22" +// OsNameservers is the list of DNS nameservers used by OS resolver. +// This reads OS settings at the time ctrld process starts. +var OsNameservers = defaultNameservers() + // or is the Resolver used for ResolverTypeOS. -var or = &osResolver{nameservers: defaultNameservers()} +var or = &osResolver{nameservers: OsNameservers} // defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver. func defaultNameservers() []string {