diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 45292b4..c0ef45c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -377,7 +377,15 @@ func initCLI() { uninstall(p, s) os.Exit(1) } - p.setDNS() + if cc := newSocketControlClient(s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + logger := mainLog.Load().With().Str("iface", iface).Logger() + logger.Debug().Msg("setting DNS successfully") + } + } } }, } @@ -482,7 +490,10 @@ func initCLI() { Short: "Restart the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) if err != nil { mainLog.Load().Error().Msg(err.Error()) return @@ -493,8 +504,10 @@ func initCLI() { } initLogging() + iface = runningIface(s) tasks := []task{ {s.Stop, false}, + {func() error { p.resetDNS(); return nil }, false}, {s.Start, true}, } if doTasks(tasks) { @@ -2511,3 +2524,20 @@ func upgradeUrl(baseUrl string) string { } return dlUrl } + +// runningIface returns the value of the iface variable used by ctrld process which is running. +func runningIface(s service.Service) string { + if sockDir, err := socketDir(); err == nil { + if cc := newSocketControlClient(s, sockDir); cc != nil { + resp, err := cc.post(ifacePath, nil) + if err != nil { + return "" + } + defer resp.Body.Close() + if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 { + return string(buf) + } + } + } + return "" +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 4d243bf..66a38a3 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -10,6 +10,8 @@ import ( "sort" "time" + "github.com/kardianos/service" + dto "github.com/prometheus/client_model/go" "github.com/Control-D-Inc/ctrld" @@ -22,6 +24,7 @@ const ( reloadPath = "/reload" deactivationPath = "/deactivation" cdPath = "/cd" + ifacePath = "/iface" ) type controlServer struct { @@ -179,6 +182,17 @@ func (p *prog) registerControlServerHandler() { } w.WriteHeader(http.StatusBadRequest) })) + p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + // p.setDNS is only called when running as a service + if !service.Interactive() { + <-p.csSetDnsDone + if p.csSetDnsOk { + w.Write([]byte(iface)) + return + } + } + w.WriteHeader(http.StatusBadRequest) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2e74a98..d2ea7a9 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -69,6 +69,8 @@ type prog struct { reloadDoneCh chan struct{} logConn net.Conn cs *controlServer + csSetDnsDone chan struct{} + csSetDnsOk bool cfg *ctrld.Config localUpstreams []string @@ -204,6 +206,7 @@ func (p *prog) preRun() { } func (p *prog) postRun() { + mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ctrld.OsNameservers) if !service.Interactive() { p.setDNS() } @@ -253,6 +256,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { p.started = make(chan struct{}, numListeners) if p.cs != nil { + p.csSetDnsDone = make(chan struct{}, 1) p.registerControlServerHandler() if err := p.cs.start(); err != nil { mainLog.Load().Warn().Err(err).Msg("could not start control server") @@ -435,6 +439,13 @@ func (p *prog) deAllocateIP() error { } func (p *prog) setDNS() { + setDnsOK := false + defer func() { + p.csSetDnsOk = setDnsOK + p.csSetDnsDone <- struct{}{} + close(p.csSetDnsDone) + }() + if cfg.Listener == nil { return } @@ -489,6 +500,7 @@ func (p *prog) setDNS() { logger.Error().Err(err).Msgf("could not set DNS for interface") return } + setDnsOK = true logger.Debug().Msg("setting DNS successfully") if shouldWatchResolvconf() { servers := make([]netip.Addr, len(nameservers)) diff --git a/nameservers_windows.go b/nameservers_windows.go index ea9b347..ded4658 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -4,9 +4,8 @@ import ( "net" "syscall" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) func dnsFns() []dnsFn { @@ -52,9 +51,6 @@ func dnsFromAdapter() []string { for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { do(dns.Address) } - for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next { - do(gw.Address) - } } return ns } diff --git a/resolver.go b/resolver.go index 0a4569e..f9951d7 100644 --- a/resolver.go +++ b/resolver.go @@ -32,8 +32,12 @@ const ( const bootstrapDNS = "76.76.2.22" +// OsNameservers is the list of DNS nameservers used by OS resolver. +// This reads OS settings at the time ctrld process starts. +var OsNameservers = defaultNameservers() + // or is the Resolver used for ResolverTypeOS. -var or = &osResolver{nameservers: defaultNameservers()} +var or = &osResolver{nameservers: OsNameservers} // defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver. func defaultNameservers() []string {