all: preserve DNS settings when running "ctrld restart"

By attempting to reset DNS before starting new ctrld process. This way,
ctrld will read the correct system DNS settings before changing itself.

While at it, some optimizations are made:

 - "ctrld start" won't set DNS anymore, since "ctrld run" has already did
   this, start command could just query socket control server and emittin
   proper message to users.

 - The gateway won't be included as nameservers on Windows anymore,
   since the GetAdaptersAddresses Windows API always returns the correct
   DNS servers of the interfaces.

 - The nameservers list that OS resolver is using will be shown during
   ctrld startup, making it easier for debugging.
This commit is contained in:
Cuong Manh Le
2024-05-21 17:08:18 +07:00
committed by Cuong Manh Le
parent f3dd344026
commit 96085147ff
5 changed files with 64 additions and 8 deletions

View File

@@ -377,7 +377,15 @@ func initCLI() {
uninstall(p, s)
os.Exit(1)
}
p.setDNS()
if cc := newSocketControlClient(s, sockDir); cc != nil {
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
if iface == "auto" {
iface = defaultIfaceName()
}
logger := mainLog.Load().With().Str("iface", iface).Logger()
logger.Debug().Msg("setting DNS successfully")
}
}
}
},
}
@@ -482,7 +490,10 @@ func initCLI() {
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := newService(&prog{}, svcConfig)
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, runInCdMode())}
s, err := newService(p, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
@@ -493,8 +504,10 @@ func initCLI() {
}
initLogging()
iface = runningIface(s)
tasks := []task{
{s.Stop, false},
{func() error { p.resetDNS(); return nil }, false},
{s.Start, true},
}
if doTasks(tasks) {
@@ -2511,3 +2524,20 @@ func upgradeUrl(baseUrl string) string {
}
return dlUrl
}
// runningIface returns the value of the iface variable used by ctrld process which is running.
func runningIface(s service.Service) string {
if sockDir, err := socketDir(); err == nil {
if cc := newSocketControlClient(s, sockDir); cc != nil {
resp, err := cc.post(ifacePath, nil)
if err != nil {
return ""
}
defer resp.Body.Close()
if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 {
return string(buf)
}
}
}
return ""
}

View File

@@ -10,6 +10,8 @@ import (
"sort"
"time"
"github.com/kardianos/service"
dto "github.com/prometheus/client_model/go"
"github.com/Control-D-Inc/ctrld"
@@ -22,6 +24,7 @@ const (
reloadPath = "/reload"
deactivationPath = "/deactivation"
cdPath = "/cd"
ifacePath = "/iface"
)
type controlServer struct {
@@ -179,6 +182,17 @@ func (p *prog) registerControlServerHandler() {
}
w.WriteHeader(http.StatusBadRequest)
}))
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
// p.setDNS is only called when running as a service
if !service.Interactive() {
<-p.csSetDnsDone
if p.csSetDnsOk {
w.Write([]byte(iface))
return
}
}
w.WriteHeader(http.StatusBadRequest)
}))
}
func jsonResponse(next http.Handler) http.Handler {

View File

@@ -69,6 +69,8 @@ type prog struct {
reloadDoneCh chan struct{}
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
cfg *ctrld.Config
localUpstreams []string
@@ -204,6 +206,7 @@ func (p *prog) preRun() {
}
func (p *prog) postRun() {
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ctrld.OsNameservers)
if !service.Interactive() {
p.setDNS()
}
@@ -253,6 +256,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if !reload {
p.started = make(chan struct{}, numListeners)
if p.cs != nil {
p.csSetDnsDone = make(chan struct{}, 1)
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
@@ -435,6 +439,13 @@ func (p *prog) deAllocateIP() error {
}
func (p *prog) setDNS() {
setDnsOK := false
defer func() {
p.csSetDnsOk = setDnsOK
p.csSetDnsDone <- struct{}{}
close(p.csSetDnsDone)
}()
if cfg.Listener == nil {
return
}
@@ -489,6 +500,7 @@ func (p *prog) setDNS() {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
setDnsOK = true
logger.Debug().Msg("setting DNS successfully")
if shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))

View File

@@ -4,9 +4,8 @@ import (
"net"
"syscall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func dnsFns() []dnsFn {
@@ -52,9 +51,6 @@ func dnsFromAdapter() []string {
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
do(dns.Address)
}
for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next {
do(gw.Address)
}
}
return ns
}

View File

@@ -32,8 +32,12 @@ const (
const bootstrapDNS = "76.76.2.22"
// OsNameservers is the list of DNS nameservers used by OS resolver.
// This reads OS settings at the time ctrld process starts.
var OsNameservers = defaultNameservers()
// or is the Resolver used for ResolverTypeOS.
var or = &osResolver{nameservers: defaultNameservers()}
var or = &osResolver{nameservers: OsNameservers}
// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver.
func defaultNameservers() []string {