mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
cmd/cli: avoid accessing mainLog when possible
By adding a logger field to "prog" struct, and use this field inside its method instead of always accessing global mainLog variable. This at least ensure more consistent usage of the logger during ctrld prog runtime, and also help refactoring the code more easily in the future (like replacing the logger library).
This commit is contained in:
committed by
Cuong Manh Le
parent
fc527dbdfb
commit
b9b9cfcade
@@ -211,6 +211,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
}
|
||||
p.logger.Store(mainLog.Load())
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
homedir = dir
|
||||
@@ -228,11 +229,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
p.logConn = lc
|
||||
} else {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
mainLog.Load().Warn().Err(err).Msg("unable to create log ipc connection")
|
||||
p.Warn().Err(err).Msg("unable to create log ipc connection")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath)
|
||||
p.Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath)
|
||||
}
|
||||
notifyExitToLogServer := func() {
|
||||
if p.logConn != nil {
|
||||
@@ -241,7 +242,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
p.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
if !daemon {
|
||||
@@ -250,10 +251,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
go func() {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed create new service")
|
||||
p.Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||
p.Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -261,7 +262,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
if err := readBase64Config(configBase64); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config")
|
||||
p.Fatal().Err(err).Msg("failed to read base64 config")
|
||||
}
|
||||
processNoConfigFlags(noConfigStart)
|
||||
|
||||
@@ -270,7 +271,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
p.mu.Lock()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
notifyExitToLogServer()
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
p.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
@@ -280,19 +281,19 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
// so it's able to log information in processCDFlags.
|
||||
p.initLogging(true)
|
||||
|
||||
mainLog.Load().Info().Msgf("starting ctrld %s", curVersion())
|
||||
mainLog.Load().Info().Msgf("os: %s", osVersion())
|
||||
p.Info().Msgf("starting ctrld %s", curVersion())
|
||||
p.Info().Msgf("os: %s", osVersion())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
notifyExitToLogServer()
|
||||
mainLog.Load().Fatal().Msg("network is not up yet")
|
||||
p.Fatal().Msg("network is not up yet")
|
||||
}
|
||||
|
||||
p.router = router.New(&cfg, cdUID != "")
|
||||
cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName()))
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create control server")
|
||||
p.Warn().Err(err).Msg("could not create control server")
|
||||
}
|
||||
p.cs = cs
|
||||
|
||||
@@ -301,7 +302,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
// to set the current time, so this check must happen before processCDFlags.
|
||||
if err := p.router.PreRun(); err != nil {
|
||||
notifyExitToLogServer()
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check")
|
||||
p.Fatal().Err(err).Msg("failed to perform router pre-run check")
|
||||
}
|
||||
|
||||
oldLogPath := cfg.Service.LogPath
|
||||
@@ -316,7 +317,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
return
|
||||
}
|
||||
|
||||
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
cdLogger := p.logger.Load().With().Str("mode", "cd").Logger()
|
||||
// Performs self-uninstallation if the ControlD device does not exist.
|
||||
var uer *controld.ErrorResponse
|
||||
if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
@@ -340,9 +341,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if updated {
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
notifyExitToLogServer()
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to write config file")
|
||||
p.Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
p.Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,10 +355,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
// Copy logs written so far to new log file if possible.
|
||||
if buf, err := os.ReadFile(oldLogPath); err == nil {
|
||||
if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not copy old log file")
|
||||
p.Warn().Err(err).Msg("could not copy old log file")
|
||||
}
|
||||
}
|
||||
initLoggingWithBackup(false)
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
if err := validateConfig(&cfg); err != nil {
|
||||
@@ -369,13 +371,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to find the binary")
|
||||
p.Error().Err(err).Msg("failed to find the binary")
|
||||
notifyExitToLogServer()
|
||||
os.Exit(1)
|
||||
}
|
||||
curDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get current working directory")
|
||||
p.Error().Err(err).Msg("failed to get current working directory")
|
||||
notifyExitToLogServer()
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -383,11 +385,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
|
||||
cmd.Dir = curDir
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start process as daemon")
|
||||
p.Error().Err(err).Msg("failed to start process as daemon")
|
||||
notifyExitToLogServer()
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
p.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -395,7 +397,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := allocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP)
|
||||
p.Error().Err(err).Msgf("could not allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -406,7 +408,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := deAllocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP)
|
||||
p.Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -417,15 +419,15 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
if iface != "" {
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
mainLog.Load().Debug().Msg("router setup on start")
|
||||
p.Debug().Msg("router setup on start")
|
||||
if err := p.router.Setup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not configure router")
|
||||
p.Error().Err(err).Msg("could not configure router")
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
mainLog.Load().Debug().Msg("router cleanup on stop")
|
||||
p.Debug().Msg("router cleanup on stop")
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
p.Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -438,9 +440,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
file := ctrld.SavedStaticDnsSettingsFilePath(i)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
if err := restoreDNS(i); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
|
||||
p.Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
|
||||
p.Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -79,21 +79,21 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
mainLog.Load().Debug().Msg("handling list clients request")
|
||||
p.Debug().Msg("handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
p.Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
mainLog.Load().Debug().Msg("sorted clients by IP address")
|
||||
p.Debug().Msg("sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
p.Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
@@ -104,7 +104,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
@@ -116,7 +116,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
@@ -127,23 +127,23 @@ func (p *prog) registerControlServerHandler() {
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("successfully collected query count")
|
||||
} else if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
p.Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
mainLog.Load().Error().
|
||||
p.Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("failed to encode clients response")
|
||||
@@ -151,7 +151,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("successfully sent clients list response")
|
||||
}))
|
||||
@@ -175,7 +175,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
p.Error().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
return
|
||||
}
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
@@ -225,7 +225,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code")
|
||||
p.Warn().Err(err).Msg("could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
@@ -237,7 +237,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
mainLog.Load().Err(err).Msg("invalid deactivation request")
|
||||
p.Error().Err(err).Msg("invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -320,15 +320,15 @@ func (p *prog) registerControlServerHandler() {
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
p.Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
p.Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("sending log file successfully")
|
||||
p.Debug().Msg("sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
|
||||
@@ -87,14 +87,14 @@ type upstreamForResult struct {
|
||||
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
// Start network monitoring
|
||||
if err := p.monitorNetworkChanges(mainCtx); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
|
||||
p.Error().Err(err).Msg("Failed to start network monitoring")
|
||||
// Don't return here as we still want DNS service to run
|
||||
}
|
||||
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
|
||||
@@ -110,9 +110,9 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctx = ctrld.LoggerCtx(ctx, mainLog.Load())
|
||||
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
_ = w.WriteMsg(answer)
|
||||
@@ -135,7 +135,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
|
||||
ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain)
|
||||
}
|
||||
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ci := p.getClientInfo(remoteIP, m)
|
||||
@@ -144,7 +144,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String())
|
||||
t := time.Now()
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
|
||||
|
||||
labelValues := make([]string, 0, len(statsQueriesCountLabels))
|
||||
@@ -155,7 +155,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
|
||||
var answer *dns.Msg
|
||||
if !ur.matched && listenerConfig.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
labelValues = append(labelValues, "") // no upstream
|
||||
@@ -174,7 +174,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
|
||||
answer = pr.answer
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
upstream := pr.upstream
|
||||
switch {
|
||||
case pr.cached:
|
||||
@@ -192,7 +192,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
p.forceFetchingAPI(domain)
|
||||
}()
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -209,7 +209,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
case err := <-errCh:
|
||||
// Local ipv6 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on Windows.
|
||||
mainLog.Load().Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
p.Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -229,7 +229,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
case err := <-errCh:
|
||||
// RFC1918 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on system with systemd-resolved.
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr)
|
||||
p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -371,8 +371,8 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg
|
||||
},
|
||||
Ptr: dns.Fqdn(name),
|
||||
}}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
ctrld.Log(ctx, p.Info(), "private PTR lookup, using client info table")
|
||||
ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: name,
|
||||
@@ -416,8 +416,8 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg
|
||||
AAAA: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
ctrld.Log(ctx, p.Info(), "lan hostname lookup, using client info table")
|
||||
ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: hostname,
|
||||
@@ -441,7 +441,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// running and listening on local addresses, these local addresses must be used
|
||||
// as nameservers, so queries for ADDC could be resolved as expected.
|
||||
if p.isAdDomainQuery(req.msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(),
|
||||
ctrld.Log(ctx, p.Debug(),
|
||||
"AD domain query detected for %s in domain %s",
|
||||
req.msg.Question[0].Name, p.adDomain)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
||||
@@ -459,14 +459,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if req.ufr.matched {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
ctrld.Log(ctx, p.Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isSrvLanLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
ctrld.Log(ctx, p.Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
@@ -476,7 +476,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
ctrld.Log(ctx, p.Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
|
||||
@@ -487,9 +487,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
ctrld.Log(ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
ctrld.Log(ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,7 +504,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
||||
ctrld.Log(ctx, p.Debug(), "hit cached response")
|
||||
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
||||
res.answer = answer
|
||||
res.cached = true
|
||||
@@ -514,10 +514,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
}
|
||||
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
}
|
||||
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
||||
@@ -526,7 +526,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
||||
ctrld.Log(ctx, p.Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(upstream, upstreamConfig, msg)
|
||||
@@ -540,7 +540,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
return answer
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query")
|
||||
|
||||
// increase failure count when there is no answer
|
||||
// rehardless of what kind of error we get
|
||||
@@ -564,7 +564,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
logger := mainLog.Load().Debug().
|
||||
logger := p.Debug().
|
||||
Str("upstream", upstreamConfig.String()).
|
||||
Str("query", req.msg.Question[0].Name).
|
||||
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
|
||||
@@ -577,7 +577,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
answer := resolve(upstreams[n], upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
||||
ctrld.Log(ctx, p.Debug(), "serving stale cached response")
|
||||
now := time.Now()
|
||||
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
||||
res.answer = staleAnswer
|
||||
@@ -589,11 +589,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
||||
// Except for the last, we want to send response instead of saying all upstream failed.
|
||||
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
|
||||
ctrld.Log(ctx, p.Debug(), "no response from %s, process to next upstream", upstreams[n])
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
|
||||
ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -609,18 +609,18 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
|
||||
ctrld.Log(ctx, p.Debug(), "add cached response")
|
||||
}
|
||||
hostname := ""
|
||||
if req.ci != nil {
|
||||
hostname = req.ci.Hostname
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
||||
ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
||||
res.answer = answer
|
||||
res.upstream = upstreamConfig.Endpoint
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams)
|
||||
|
||||
// if we have no healthy upstreams, trigger recovery flow
|
||||
if p.leakOnUpstreamFailure() {
|
||||
@@ -633,28 +633,28 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
} else {
|
||||
reason = RecoveryReasonRegularFailure
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
go p.handleRecovery(reason)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
p.Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
}
|
||||
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
// we dont want this to happen if leakOnUpstreamFailure is false
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -958,10 +958,10 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
|
||||
logger := p.logger.Load().With().Str("mode", "self-uninstall").Logger()
|
||||
if p.refusedQueryCount > selfUninstallMaxQueries {
|
||||
p.checkingSelfUninstall = true
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
_, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev)
|
||||
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
|
||||
selfUninstallCheck(err, p, logger)
|
||||
@@ -1031,7 +1031,7 @@ func (p *prog) queryFromSelf(ip string) bool {
|
||||
netIP := netip.MustParseAddr(ip)
|
||||
regularIPs, loopbackIPs, err := netmon.LocalAddresses()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not get local addresses")
|
||||
p.Warn().Err(err).Msg("could not get local addresses")
|
||||
return false
|
||||
}
|
||||
for _, localIP := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
@@ -1151,7 +1151,8 @@ func isWanClient(na net.Addr) bool {
|
||||
|
||||
// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller.
|
||||
func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query")
|
||||
logger := ctrld.LoggerFromCtx(ctx)
|
||||
ctrld.Log(ctx, logger.Debug(), "internal domain test query")
|
||||
|
||||
q := m.Question[0]
|
||||
answer := new(dns.Msg)
|
||||
@@ -1192,7 +1193,7 @@ func FlushDNSCache() error {
|
||||
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
mon, err := netmon.New(func(format string, args ...any) {
|
||||
// Always fetch the latest logger (and inject the prefix)
|
||||
mainLog.Load().Printf("netmon: "+format, args...)
|
||||
p.logger.Load().Printf("netmon: "+format, args...)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network monitor: %w", err)
|
||||
@@ -1204,7 +1205,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
||||
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Interface("old_state", delta.Old).
|
||||
Interface("new_state", delta.New).
|
||||
Bool("is_major_change", isMajorChange).
|
||||
@@ -1232,7 +1233,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
Msg("Interface newly appeared (was not present in old state)")
|
||||
@@ -1254,7 +1255,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("old_ips", oldIPs).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
@@ -1267,39 +1268,39 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
// if the default route changed, set changed to true
|
||||
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
||||
changed = true
|
||||
mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
||||
p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
||||
}
|
||||
|
||||
if !changed {
|
||||
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
p.Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
// check if the default IPs are still on an interface that is up
|
||||
ValidateDefaultLocalIPsFromDelta(delta.New)
|
||||
return
|
||||
}
|
||||
|
||||
if !activeInterfaceExists {
|
||||
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
|
||||
p.Debug().Msg("No active interfaces found, skipping reinitialization")
|
||||
return
|
||||
}
|
||||
|
||||
// Get IPs from default route interface in new state
|
||||
selfIP := defaultRouteIP()
|
||||
selfIP := p.defaultRouteIP()
|
||||
|
||||
// Ensure that selfIP is an IPv4 address.
|
||||
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
||||
mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
|
||||
p.Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
|
||||
selfIP = ""
|
||||
}
|
||||
var ipv6 string
|
||||
|
||||
if delta.New.DefaultRouteInterface != "" {
|
||||
mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
||||
p.Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
||||
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
p.Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
@@ -1310,12 +1311,12 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
}
|
||||
} else {
|
||||
// If no default route interface is set yet, use the changed IPs
|
||||
mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
|
||||
p.Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
|
||||
for _, ip := range changeIPs {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
p.Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
@@ -1328,15 +1329,15 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip)
|
||||
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip)
|
||||
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
// we only trigger recovery flow for network changes on non router devices
|
||||
if router.Name() == "" {
|
||||
@@ -1345,7 +1346,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
mainLog.Load().Debug().Msg("Network monitor started")
|
||||
p.Debug().Msg("Network monitor started")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1400,11 +1401,11 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
||||
// checkUpstreamOnce sends a test query to the specified upstream.
|
||||
// Returns nil if the upstream responds successfully.
|
||||
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
||||
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
p.Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
|
||||
resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc)
|
||||
resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1415,22 +1416,22 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro
|
||||
if uc.Timeout > 0 {
|
||||
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
||||
p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load()))
|
||||
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
|
||||
start := time.Now()
|
||||
_, err = resolver.Resolve(ctx, msg)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
||||
p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
||||
p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1440,13 +1441,13 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro
|
||||
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
|
||||
// and then re-applying the DNS settings.
|
||||
func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
|
||||
p.Debug().Msg("Starting recovery process: removing DNS settings")
|
||||
|
||||
// For network changes, cancel any existing recovery check because the network state has changed.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
|
||||
p.Debug().Msg("Cancelling existing recovery check (network change)")
|
||||
p.recoveryCancel()
|
||||
p.recoveryCancel = nil
|
||||
}
|
||||
@@ -1455,7 +1456,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
// For upstream failures, if a recovery is already in progress, do nothing new.
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
||||
p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
@@ -1476,15 +1477,15 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
// will be appended to nameservers from the saved interface values
|
||||
p.resetDNS(false, false)
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
||||
if reason == RecoveryReasonOSFailure {
|
||||
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
||||
p.Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
||||
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1494,13 +1495,13 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
// Wait indefinitely until one of the upstreams recovers.
|
||||
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
|
||||
p.Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
|
||||
p.Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
|
||||
|
||||
// reset the upstream failure count and down state
|
||||
p.um.reset(recovered)
|
||||
@@ -1509,9 +1510,9 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
||||
p.Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1534,44 +1535,44 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
|
||||
recoveredCh := make(chan string, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
||||
p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
||||
|
||||
for name, uc := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
defer wg.Done()
|
||||
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
||||
p.Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
||||
attempts := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
|
||||
p.Debug().Msgf("Context canceled for upstream %s", name)
|
||||
return
|
||||
default:
|
||||
attempts++
|
||||
// checkUpstreamOnce will reset any failure counters on success.
|
||||
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
|
||||
p.Debug().Msgf("Upstream %s recovered successfully", name)
|
||||
select {
|
||||
case recoveredCh <- name:
|
||||
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
|
||||
p.Debug().Msgf("Sent recovery notification for upstream %s", name)
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
|
||||
p.Debug().Msg("Recovery channel full, another upstream already recovered")
|
||||
}
|
||||
return
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
||||
p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
|
||||
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
||||
// we should try to reinit the OS resolver to ensure we can recover
|
||||
if name == upstreamOS && attempts%3 == 0 {
|
||||
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true)
|
||||
p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
p.Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,8 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
@@ -145,6 +146,7 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
|
||||
@@ -100,6 +100,7 @@ func (p *prog) initLogging(backup bool) {
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
@@ -108,7 +109,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
|
||||
@@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
p.Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
@@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
p.Debug().Msgf("skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
@@ -102,7 +102,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
@@ -112,14 +112,14 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(loggerCtx, uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
p.Debug().Msg("end checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
|
||||
@@ -14,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
p.Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for {
|
||||
@@ -26,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
p.Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load()))
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,61 +28,61 @@ func hasNetworkManager() bool {
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
p.Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
}
|
||||
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("setup NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("setup NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("restore NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("restore NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {
|
||||
func (p *prog) reloadNetworkManager() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not create new system connection")
|
||||
p.Error().Err(err).Msg("could not create new system connection")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
p.Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
package cli
|
||||
|
||||
func setupNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {}
|
||||
func (p *prog) reloadNetworkManager() {}
|
||||
|
||||
177
cmd/cli/prog.go
177
cmd/cli/prog.go
@@ -102,6 +102,7 @@ type prog struct {
|
||||
apiForceReloadGroup singleflight.Group
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
logger atomic.Pointer[ctrld.Logger]
|
||||
csSetDnsDone chan struct{}
|
||||
csSetDnsOk bool
|
||||
dnsWg sync.WaitGroup
|
||||
@@ -150,7 +151,7 @@ type prog struct {
|
||||
onStopped []func()
|
||||
}
|
||||
|
||||
func (p *prog) Start(s service.Service) error {
|
||||
func (p *prog) Start(_ service.Service) error {
|
||||
go p.runWait()
|
||||
return nil
|
||||
}
|
||||
@@ -164,7 +165,6 @@ func (p *prog) runWait() {
|
||||
notifyReloadSigCh(reloadSigCh)
|
||||
|
||||
reload := false
|
||||
logger := mainLog.Load()
|
||||
for {
|
||||
reloadCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
@@ -177,9 +177,9 @@ func (p *prog) runWait() {
|
||||
var newCfg *ctrld.Config
|
||||
select {
|
||||
case sig := <-reloadSigCh:
|
||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
p.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
case <-p.reloadCh:
|
||||
logger.Notice().Msg("reloading...")
|
||||
p.Notice().Msg("reloading...")
|
||||
case apiCfg := <-p.apiReloadCh:
|
||||
newCfg = apiCfg
|
||||
case <-p.stopCh:
|
||||
@@ -202,18 +202,18 @@ func (p *prog) runWait() {
|
||||
}
|
||||
v.SetConfigFile(confFile)
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
p.Error().Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
p.Error().Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if rc, err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
p.Error().Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
} else {
|
||||
@@ -243,25 +243,25 @@ func (p *prog) runWait() {
|
||||
}
|
||||
}
|
||||
if err := validateConfig(newCfg); err != nil {
|
||||
logger.Err(err).Msg("invalid config")
|
||||
p.Error().Err(err).Msg("invalid config")
|
||||
continue
|
||||
}
|
||||
|
||||
addExtraSplitDnsRule(newCfg)
|
||||
if err := writeConfigFile(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not write new config")
|
||||
p.Error().Err(err).Msg("could not write new config")
|
||||
}
|
||||
|
||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||
// upstream config because its initialization function have not been called yet.
|
||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||
p.Debug().Msg("setup upstream with new config")
|
||||
p.setupUpstream(newCfg)
|
||||
|
||||
p.mu.Lock()
|
||||
*p.cfg = *newCfg
|
||||
p.mu.Unlock()
|
||||
|
||||
logger.Notice().Msg("reloading config successfully")
|
||||
p.Notice().Msg("reloading config successfully")
|
||||
|
||||
select {
|
||||
case p.reloadDoneCh <- struct{}{}:
|
||||
@@ -276,6 +276,7 @@ func (p *prog) preRun() {
|
||||
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
|
||||
}
|
||||
p.runningIface = iface
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
func (p *prog) postRun() {
|
||||
@@ -283,11 +284,11 @@ func (p *prog) postRun() {
|
||||
if runtime.GOOS == "windows" {
|
||||
isDC, roleInt := isRunningOnDomainController()
|
||||
p.runningOnDomainController = isDC
|
||||
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
|
||||
p.Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
|
||||
}
|
||||
p.resetDNS(false, false)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false)
|
||||
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), false)
|
||||
p.Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
p.setDNS()
|
||||
p.csSetDnsDone <- struct{}{}
|
||||
close(p.csSetDnsDone)
|
||||
@@ -304,14 +305,14 @@ func (p *prog) apiConfigReload() {
|
||||
ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
|
||||
logger := p.logger.Load().With().Str("mode", "api-reload").Logger()
|
||||
logger.Debug().Msg("starting custom config reload timer")
|
||||
lastUpdated := time.Now().Unix()
|
||||
curVerStr := curVersion()
|
||||
curVer, err := semver.NewVersion(curVerStr)
|
||||
isStable := curVer != nil && curVer.Prerelease() == ""
|
||||
if err != nil || !isStable {
|
||||
l := mainLog.Load().Warn()
|
||||
l := p.Warn()
|
||||
if err != nil {
|
||||
l = l.Err(err)
|
||||
}
|
||||
@@ -319,7 +320,7 @@ func (p *prog) apiConfigReload() {
|
||||
}
|
||||
|
||||
doReloadApiConfig := func(forced bool, logger zerolog.Logger) {
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev)
|
||||
selfUninstallCheck(err, p, logger)
|
||||
if err != nil {
|
||||
@@ -405,20 +406,20 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
isControlDUpstream := false
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
sdns := uc.Type == ctrld.ResolverTypeSDNS
|
||||
uc.Init(loggerCtx)
|
||||
if sdns {
|
||||
mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
|
||||
p.Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
|
||||
}
|
||||
isControlDUpstream = isControlDUpstream || uc.IsControlD()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), p.logger.Load()))
|
||||
p.Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
p.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping(loggerCtx)
|
||||
@@ -459,9 +460,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
p.csSetDnsDone = make(chan struct{}, 1)
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
p.Warn().Err(err).Msg("could not start control server")
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
|
||||
p.Debug().Msgf("control server started: %s", p.cs.addr)
|
||||
}
|
||||
}
|
||||
p.onStartedDone = make(chan struct{})
|
||||
@@ -473,7 +474,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
|
||||
p.Error().Err(err).Msg("failed to create cacher, caching is disabled")
|
||||
} else {
|
||||
p.cache = cacher
|
||||
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
|
||||
@@ -483,7 +484,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msgf("active directory domain: %s", domain)
|
||||
p.Debug().Msgf("active directory domain: %s", domain)
|
||||
p.adDomain = domain
|
||||
}
|
||||
|
||||
@@ -494,14 +495,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
|
||||
p.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
|
||||
continue
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.um = newUpstreamMonitor(p.cfg, p.logger.Load())
|
||||
|
||||
if !reload {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
||||
@@ -514,7 +515,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
p.setupUpstream(p.cfg)
|
||||
p.setupClientInfoDiscover(defaultRouteIP())
|
||||
p.setupClientInfoDiscover()
|
||||
}
|
||||
|
||||
// context for managing spawn goroutines.
|
||||
@@ -538,14 +539,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
p.Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
p.Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(ctx, listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
p.Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
|
||||
p.Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
@@ -602,10 +603,11 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
|
||||
// setupClientInfoDiscover performs necessary works for running client info discover.
|
||||
func (p *prog) setupClientInfoDiscover(selfIP string) {
|
||||
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load())
|
||||
func (p *prog) setupClientInfoDiscover() {
|
||||
selfIP := p.defaultRouteIP()
|
||||
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, p.logger.Load())
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
p.Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
@@ -622,18 +624,18 @@ func (p *prog) metricsEnabled() bool {
|
||||
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
|
||||
}
|
||||
|
||||
func (p *prog) Stop(s service.Service) error {
|
||||
func (p *prog) Stop(_ service.Service) error {
|
||||
p.stopDnsWatchers()
|
||||
mainLog.Load().Debug().Msg("dns watchers stopped")
|
||||
p.Debug().Msg("dns watchers stopped")
|
||||
for _, f := range p.onStopped {
|
||||
f()
|
||||
}
|
||||
mainLog.Load().Debug().Msg("finish running onStopped functions")
|
||||
p.Debug().Msg("finish running onStopped functions")
|
||||
defer func() {
|
||||
mainLog.Load().Info().Msg("Service stopped")
|
||||
p.Info().Msg("Service stopped")
|
||||
}()
|
||||
if err := p.deAllocateIP(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
|
||||
p.Error().Err(err).Msg("de-allocate ip failed")
|
||||
return err
|
||||
}
|
||||
if deactivationPinSet() {
|
||||
@@ -645,16 +647,16 @@ func (p *prog) Stop(s service.Service) error {
|
||||
// No valid pin code was checked, that mean we are stopping
|
||||
// because of OS signal sent directly from someone else.
|
||||
// In this case, restarting ctrld service by ourselves.
|
||||
mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code")
|
||||
mainLog.Load().Debug().Msgf("self restarting ctrld service")
|
||||
p.Debug().Msgf("receiving stopping signal without valid pin code")
|
||||
p.Debug().Msgf("self restarting ctrld service")
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
cmd := exec.Command(exe, "restart")
|
||||
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to run self restart command")
|
||||
p.Error().Err(err).Msg("failed to run self restart command")
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service")
|
||||
p.Error().Err(err).Msg("failed to self restart ctrld service")
|
||||
}
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
@@ -755,7 +757,7 @@ func (p *prog) setDNS() {
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||
p.watchResolvConf(netIface, servers, p.setResolvConf)
|
||||
}()
|
||||
}
|
||||
if p.dnsWatchdogEnabled() {
|
||||
@@ -772,7 +774,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In
|
||||
return
|
||||
}
|
||||
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger := p.logger.Load().With().Str("iface", p.runningIface).Logger()
|
||||
|
||||
const maxDNSRetryAttempts = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
@@ -785,10 +787,10 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In
|
||||
}
|
||||
if attempt < maxDNSRetryAttempts {
|
||||
// Try to find a different working interface
|
||||
newIface := findWorkingInterface(p.runningIface)
|
||||
newIface := p.findWorkingInterface()
|
||||
if newIface != p.runningIface {
|
||||
p.runningIface = newIface
|
||||
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger = p.logger.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger.Info().Msg("switched to new interface")
|
||||
continue
|
||||
}
|
||||
@@ -800,7 +802,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In
|
||||
logger.Error().Err(err).Msg("could not get interface after all attempts")
|
||||
return
|
||||
}
|
||||
if err := setupNetworkManager(); err != nil {
|
||||
if err := p.setupNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not patch NetworkManager")
|
||||
return
|
||||
}
|
||||
@@ -840,7 +842,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
p.Debug().Msg("start DNS settings watchdog")
|
||||
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
@@ -851,19 +853,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
p.Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if dnsChanged(iface, ns) {
|
||||
mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
if p.dnsChanged(iface, ns) {
|
||||
p.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
// Check if the interface already has static DNS servers configured.
|
||||
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
||||
staticDNS, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
|
||||
p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
|
||||
} else if len(staticDNS) > 0 {
|
||||
//filter out loopback addresses
|
||||
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
||||
@@ -873,12 +875,12 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 {
|
||||
// Save these static DNS values so that they can be restored later.
|
||||
if err := saveCurrentStaticDNS(iface); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
|
||||
p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
p.Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
}
|
||||
}
|
||||
if p.requiredMultiNICsConfig {
|
||||
@@ -887,13 +889,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
ifaceName = iface.Name
|
||||
}
|
||||
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
|
||||
if dnsChanged(i, ns) {
|
||||
if p.dnsChanged(i, ns) {
|
||||
|
||||
// Check if the interface already has static DNS servers configured.
|
||||
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
||||
staticDNS, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
|
||||
p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
|
||||
} else if len(staticDNS) > 0 {
|
||||
//filter out loopback addresses
|
||||
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
||||
@@ -903,15 +905,15 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 {
|
||||
// Save these static DNS values so that they can be restored later.
|
||||
if err := saveCurrentStaticDNS(i); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
|
||||
p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
p.Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||
p.Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -941,17 +943,17 @@ func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
|
||||
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
|
||||
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
|
||||
if p.runningIface == "" {
|
||||
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
|
||||
p.Debug().Msg("no running interface, skipping resetDNS")
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger := p.logger.Load().With().Str("iface", p.runningIface).Logger()
|
||||
netIface, err := netInterface(p.runningIface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
runningIface = netIface
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
if err := p.restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
@@ -999,16 +1001,16 @@ func (p *prog) logInterfacesState() {
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
addrs, err := i.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
|
||||
p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
|
||||
}
|
||||
nss, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
|
||||
p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
|
||||
}
|
||||
if len(nss) == 0 {
|
||||
nss = currentDNS(i)
|
||||
}
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Any("addrs", addrs).
|
||||
Strs("nameservers", nss).
|
||||
Int("index", i.Index).
|
||||
@@ -1018,7 +1020,8 @@ func (p *prog) logInterfacesState() {
|
||||
}
|
||||
|
||||
// findWorkingInterface looks for a network interface with a valid IP configuration
|
||||
func findWorkingInterface(currentIface string) string {
|
||||
func (p *prog) findWorkingInterface() string {
|
||||
currentIface := p.runningIface
|
||||
// Helper to check if IP is valid (not link-local)
|
||||
isValidIP := func(ip net.IP) bool {
|
||||
return ip != nil &&
|
||||
@@ -1036,7 +1039,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("interface", iface.Name).
|
||||
Err(err).
|
||||
Msg("failed to get interface addresses")
|
||||
@@ -1057,11 +1060,11 @@ func findWorkingInterface(currentIface string) string {
|
||||
// Get default route interface
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Msg("failed to get default route")
|
||||
} else {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("default_route_iface", defaultRoute.InterfaceName).
|
||||
Msg("found default route")
|
||||
}
|
||||
@@ -1069,7 +1072,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
// Get all interfaces
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
|
||||
p.Error().Err(err).Msg("failed to list network interfaces")
|
||||
return currentIface // Return current interface as fallback
|
||||
}
|
||||
|
||||
@@ -1099,7 +1102,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
// Found working physical interface
|
||||
if err == nil && defaultRoute.InterfaceName == iface.Name {
|
||||
// Found interface with default route - use it immediately
|
||||
mainLog.Load().Info().
|
||||
p.Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", iface.Name).
|
||||
Msg("switching to interface with default route")
|
||||
@@ -1120,7 +1123,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
// Return interfaces in order of preference:
|
||||
// 1. Current interface if it's still valid
|
||||
if currentIfaceValid {
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Str("interface", currentIface).
|
||||
Msg("keeping current interface")
|
||||
return currentIface
|
||||
@@ -1128,7 +1131,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
|
||||
// 2. First working interface found
|
||||
if firstWorkingIface != "" {
|
||||
mainLog.Load().Info().
|
||||
p.Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", firstWorkingIface).
|
||||
Msg("switching to first working physical interface")
|
||||
@@ -1136,7 +1139,7 @@ func findWorkingInterface(currentIface string) string {
|
||||
}
|
||||
|
||||
// 3. Fall back to current interface if nothing else works
|
||||
mainLog.Load().Warn().
|
||||
p.Warn().
|
||||
Str("current_iface", currentIface).
|
||||
Msg("no working physical interface found, keeping current")
|
||||
return currentIface
|
||||
@@ -1258,7 +1261,7 @@ func ifaceFirstPrivateIP(iface *net.Interface) string {
|
||||
}
|
||||
|
||||
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
|
||||
func defaultRouteIP() string {
|
||||
func (p *prog) defaultRouteIP() string {
|
||||
dr, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -1267,9 +1270,9 @@ func defaultRouteIP() string {
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
|
||||
p.Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
|
||||
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
|
||||
p.Debug().Str("ip", ip).Msg("found ip with default route interface")
|
||||
return ip
|
||||
}
|
||||
|
||||
@@ -1294,7 +1297,7 @@ func defaultRouteIP() string {
|
||||
})
|
||||
|
||||
if len(addrs) == 0 {
|
||||
mainLog.Load().Warn().Msg("no default route IP found")
|
||||
p.Warn().Msg("no default route IP found")
|
||||
return ""
|
||||
}
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
@@ -1302,7 +1305,7 @@ func defaultRouteIP() string {
|
||||
})
|
||||
|
||||
ip := addrs[0].String()
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
p.Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
return ip
|
||||
}
|
||||
|
||||
@@ -1413,14 +1416,14 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
|
||||
// It returns false for a nil iface.
|
||||
//
|
||||
// The caller must sort the nameservers before calling this function.
|
||||
func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
func (p *prog) dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
if iface == nil {
|
||||
return false
|
||||
}
|
||||
curNameservers, _ := currentStaticDNS(iface)
|
||||
slices.Sort(curNameservers)
|
||||
if !slices.Equal(curNameservers, nameservers) {
|
||||
mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers)
|
||||
p.Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -1465,16 +1468,16 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) {
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade")
|
||||
logger.Error().Err(err).Msg("failed to get executable path, skipped self-upgrade")
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(exe, "upgrade", "prod", "-vv")
|
||||
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade")
|
||||
logger.Error().Err(err).Msg("failed to start self-upgrade")
|
||||
return
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts)
|
||||
logger.Debug().Msgf("self-upgrade triggered, version target: %s", vts)
|
||||
}
|
||||
|
||||
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
|
||||
|
||||
33
cmd/cli/prog_log.go
Normal file
33
cmd/cli/prog_log.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cli
|
||||
|
||||
import "github.com/rs/zerolog"
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
func (p *prog) Debug() *zerolog.Event {
|
||||
return p.logger.Load().Debug()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
func (p *prog) Warn() *zerolog.Event {
|
||||
return p.logger.Load().Warn()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
func (p *prog) Info() *zerolog.Event {
|
||||
return p.logger.Load().Info()
|
||||
}
|
||||
|
||||
// Fatal starts a new message with fatal level.
|
||||
func (p *prog) Fatal() *zerolog.Event {
|
||||
return p.logger.Load().Fatal()
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
func (p *prog) Error() *zerolog.Event {
|
||||
return p.logger.Load().Error()
|
||||
}
|
||||
|
||||
// Notice starts a new message with notice level.
|
||||
func (p *prog) Notice() *zerolog.Event {
|
||||
return p.logger.Load().Notice()
|
||||
}
|
||||
@@ -43,10 +43,10 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
p.Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
p.Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
@@ -55,7 +55,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
p.Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
p.Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
@@ -77,7 +77,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
@@ -92,7 +92,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
p.Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
@@ -113,7 +113,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
p.Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
p.Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
@@ -139,16 +139,16 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
p.Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
p.Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
p.Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
p.Error().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -27,7 +27,7 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
p.Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
@@ -16,7 +17,8 @@ const (
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
logger atomic.Pointer[ctrld.Logger]
|
||||
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
@@ -28,7 +30,7 @@ type upstreamMonitor struct {
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
@@ -37,6 +39,7 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
um.logger.Store(logger)
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
@@ -53,7 +56,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
um.logger.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,7 +64,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
failedCount := um.failureReq[upstream]
|
||||
|
||||
// Log the updated failure count.
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
um.logger.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
@@ -74,7 +77,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
um.logger.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
@@ -84,7 +87,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
um.logger.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user