From b9b9cfcadec1ee32e699b5ed8770dcc1960139d7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 17 Jun 2025 19:20:37 +0700 Subject: [PATCH] cmd/cli: avoid accessing mainLog when possible By adding a logger field to "prog" struct, and use this field inside its method instead of always accessing global mainLog variable. This at least ensure more consistent usage of the logger during ctrld prog runtime, and also help refactoring the code more easily in the future (like replacing the logger library). --- cmd/cli/cli.go | 58 +++++---- cmd/cli/control_server.go | 40 +++--- cmd/cli/dns_proxy.go | 181 +++++++++++++------------- cmd/cli/dns_proxy_test.go | 4 +- cmd/cli/log_writer.go | 3 +- cmd/cli/loop.go | 12 +- cmd/cli/netlink_linux.go | 6 +- cmd/cli/network_manager_linux.go | 28 ++-- cmd/cli/network_manager_others.go | 10 +- cmd/cli/prog.go | 177 ++++++++++++------------- cmd/cli/prog_log.go | 33 +++++ cmd/cli/resolvconf.go | 26 ++-- cmd/cli/resolvconf_darwin.go | 2 +- cmd/cli/resolvconf_not_darwin_unix.go | 4 +- cmd/cli/resolvconf_windows.go | 2 +- cmd/cli/upstream_monitor.go | 15 ++- 16 files changed, 323 insertions(+), 278 deletions(-) create mode 100644 cmd/cli/prog_log.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5005925..cc5d1fe 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -211,6 +211,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cfg: &cfg, appCallback: appCallback, } + p.logger.Store(mainLog.Load()) if homedir == "" { if dir, err := userHomeDir(); err == nil { homedir = dir @@ -228,11 +229,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.logConn = lc } else { if !errors.Is(err, os.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msg("unable to create log ipc connection") + p.Warn().Err(err).Msg("unable to create log ipc connection") } } } else { - mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) + p.Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) } notifyExitToLogServer := func() { if p.logConn != nil { @@ -241,7 +242,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if daemon && runtime.GOOS == "windows" { - mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") + p.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") } if !daemon { @@ -250,10 +251,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { go func() { s, err := newService(p, svcConfig) if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed create new service") + p.Fatal().Err(err).Msg("failed create new service") } if err := s.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start service") + p.Error().Err(err).Msg("failed to start service") } }() } @@ -261,7 +262,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { tryReadingConfig(writeDefaultConfig) if err := readBase64Config(configBase64); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") + p.Fatal().Err(err).Msg("failed to read base64 config") } processNoConfigFlags(noConfigStart) @@ -270,7 +271,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + p.Fatal().Msgf("failed to unmarshal config: %v", err) } p.mu.Unlock() @@ -280,19 +281,19 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // so it's able to log information in processCDFlags. p.initLogging(true) - mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) - mainLog.Load().Info().Msgf("os: %s", osVersion()) + p.Info().Msgf("starting ctrld %s", curVersion()) + p.Info().Msgf("os: %s", osVersion()) // Wait for network up. if !ctrldnet.Up() { notifyExitToLogServer() - mainLog.Load().Fatal().Msg("network is not up yet") + p.Fatal().Msg("network is not up yet") } p.router = router.New(&cfg, cdUID != "") cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName())) if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create control server") + p.Warn().Err(err).Msg("could not create control server") } p.cs = cs @@ -301,7 +302,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // to set the current time, so this check must happen before processCDFlags. if err := p.router.PreRun(); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") + p.Fatal().Err(err).Msg("failed to perform router pre-run check") } oldLogPath := cfg.Service.LogPath @@ -316,7 +317,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { return } - cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + cdLogger := p.logger.Load().With().Str("mode", "cd").Logger() // Performs self-uninstallation if the ControlD device does not exist. var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { @@ -340,9 +341,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if updated { if err := writeConfigFile(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to write config file") + p.Fatal().Err(err).Msg("failed to write config file") } else { - mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + p.Info().Msg("writing config file to: " + defaultConfigFile) } } @@ -354,10 +355,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not copy old log file") + p.Warn().Err(err).Msg("could not copy old log file") } } initLoggingWithBackup(false) + p.logger.Store(mainLog.Load()) } if err := validateConfig(&cfg); err != nil { @@ -369,13 +371,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if daemon { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to find the binary") + p.Error().Err(err).Msg("failed to find the binary") notifyExitToLogServer() os.Exit(1) } curDir, err := os.Getwd() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get current working directory") + p.Error().Err(err).Msg("failed to get current working directory") notifyExitToLogServer() os.Exit(1) } @@ -383,11 +385,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) cmd.Dir = curDir if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") + p.Error().Err(err).Msg("failed to start process as daemon") notifyExitToLogServer() os.Exit(1) } - mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") + p.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") os.Exit(0) } @@ -395,7 +397,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := allocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) } } } @@ -406,7 +408,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := deAllocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) } } } @@ -417,15 +419,15 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if iface != "" { p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") + p.Debug().Msg("router setup on start") if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") + p.Error().Err(err).Msg("could not configure router") } }) p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") + p.Debug().Msg("router cleanup on stop") if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") + p.Error().Err(err).Msg("could not cleanup router") } }) } @@ -438,9 +440,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { - mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + p.Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) } else { - mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + p.Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) } } return nil diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 428fe12..de3a27a 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -79,21 +79,21 @@ func (s *controlServer) register(pattern string, handler http.Handler) { func (p *prog) registerControlServerHandler() { p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - mainLog.Load().Debug().Msg("handling list clients request") + p.Debug().Msg("handling list clients request") clients := p.ciTable.ListClients() - mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list") + p.Debug().Int("client_count", len(clients)).Msg("retrieved clients list") sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) - mainLog.Load().Debug().Msg("sorted clients by IP address") + p.Debug().Msg("sorted clients by IP address") if p.metricsQueryStats.Load() { - mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts") + p.Debug().Msg("metrics query stats enabled, collecting query counts") for idx, client := range clients { - mainLog.Load().Debug(). + p.Debug(). Int("index", idx). Str("ip", client.IP.String()). Str("mac", client.Mac). @@ -104,7 +104,7 @@ func (p *prog) registerControlServerHandler() { dm := &dto.Metric{} if statsClientQueriesCount.MetricVec == nil { - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). Msg("skipping metrics collection: MetricVec is nil") continue @@ -116,7 +116,7 @@ func (p *prog) registerControlServerHandler() { client.Hostname, ) if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). Str("mac", client.Mac). @@ -127,23 +127,23 @@ func (p *prog) registerControlServerHandler() { if err := m.Write(dm); err == nil && dm.Counter != nil { client.QueryCount = int64(dm.Counter.GetValue()) - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). Int64("query_count", client.QueryCount). Msg("successfully collected query count") } else if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). Msg("failed to write metric") } } } else { - mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts") + p.Debug().Msg("metrics query stats disabled, skipping query counts") } if err := json.NewEncoder(w).Encode(&clients); err != nil { - mainLog.Load().Error(). + p.Error(). Err(err). Int("client_count", len(clients)). Msg("failed to encode clients response") @@ -151,7 +151,7 @@ func (p *prog) registerControlServerHandler() { return } - mainLog.Load().Debug(). + p.Debug(). Int("client_count", len(clients)). Msg("successfully sent clients list response") })) @@ -175,7 +175,7 @@ func (p *prog) registerControlServerHandler() { oldSvc := p.cfg.Service p.mu.Unlock() if err := p.sendReloadSignal(); err != nil { - mainLog.Load().Err(err).Msg("could not send reload signal") + p.Error().Err(err).Msg("could not send reload signal") http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -216,7 +216,7 @@ func (p *prog) registerControlServerHandler() { return } - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // Re-fetch pin code from API. if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil { if rc.DeactivationPin != nil { @@ -225,7 +225,7 @@ func (p *prog) registerControlServerHandler() { cdDeactivationPin.Store(defaultDeactivationPin) } } else { - mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code") + p.Warn().Err(err).Msg("could not re-fetch deactivation pin code") } // If pin code not set, allowing deactivation. @@ -237,7 +237,7 @@ func (p *prog) registerControlServerHandler() { var req deactivationRequest if err := json.NewDecoder(request.Body).Decode(&req); err != nil { w.WriteHeader(http.StatusPreconditionFailed) - mainLog.Load().Err(err).Msg("invalid deactivation request") + p.Error().Err(err).Msg("invalid deactivation request") return } @@ -320,15 +320,15 @@ func (p *prog) registerControlServerHandler() { UID: cdUID, Data: r.r, } - mainLog.Load().Debug().Msg("sending log file to ControlD server") + p.Debug().Msg("sending log file to ControlD server") resp := logSentResponse{Size: r.size} - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil { - mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) + p.Error().Msgf("could not send log file to ControlD server: %v", err) resp.Error = err.Error() w.WriteHeader(http.StatusInternalServerError) } else { - mainLog.Load().Debug().Msg("sending log file successfully") + p.Debug().Msg("sending log file successfully") w.WriteHeader(http.StatusOK) } if err := json.NewEncoder(w).Encode(&resp); err != nil { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a3d9970..c09e11d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -87,14 +87,14 @@ type upstreamForResult struct { func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { // Start network monitoring if err := p.monitorNetworkChanges(mainCtx); err != nil { - mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + p.Error().Err(err).Msg("Failed to start network monitoring") // Don't return here as we still want DNS service to run } listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { - mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") + p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } @@ -110,9 +110,9 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) - ctx = ctrld.LoggerCtx(ctx, mainLog.Load()) + ctx = ctrld.LoggerCtx(ctx, p.logger.Load()) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { - ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) answer := new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) _ = w.WriteMsg(answer) @@ -135,7 +135,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { p.cache.Purge() - ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) + ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain) } remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) @@ -144,7 +144,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String()) t := time.Now() - ctrld.Log(ctx, mainLog.Load().Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) + ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) labelValues := make([]string, 0, len(statsQueriesCountLabels)) @@ -155,7 +155,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { var answer *dns.Msg if !ur.matched && listenerConfig.Restricted { - ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String()) + ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String()) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) labelValues = append(labelValues, "") // no upstream @@ -174,7 +174,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { answer = pr.answer rtt := time.Since(t) - ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) + ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt) upstream := pr.upstream switch { case pr.cached: @@ -192,7 +192,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { p.forceFetchingAPI(domain) }() if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") + ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client") } }) @@ -209,7 +209,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { case err := <-errCh: // Local ipv6 listener should not terminate ctrld. // It's a workaround for a quirk on Windows. - mainLog.Load().Warn().Err(err).Msg("local ipv6 listener failed") + p.Warn().Err(err).Msg("local ipv6 listener failed") } return nil }) @@ -229,7 +229,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { case err := <-errCh: // RFC1918 listener should not terminate ctrld. // It's a workaround for a quirk on system with systemd-resolved. - mainLog.Load().Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) + p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) } }() } @@ -371,8 +371,8 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg }, Ptr: dns.Fqdn(name), }} - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: name, @@ -416,8 +416,8 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg AAAA: ip.AsSlice(), }} } - ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: hostname, @@ -441,7 +441,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // running and listening on local addresses, these local addresses must be used // as nameservers, so queries for ADDC could be resolved as expected. if p.isAdDomainQuery(req.msg) { - ctrld.Log(ctx, mainLog.Load().Debug(), + ctrld.Log(ctx, p.Debug(), "AD domain query detected for %s in domain %s", req.msg.Question[0].Name, p.adDomain) upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} @@ -459,14 +459,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // 4. Try remote upstream. isLanOrPtrQuery := false if req.ufr.matched { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) + ctrld.Log(ctx, p.Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { case isSrvLanLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { @@ -476,7 +476,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { @@ -487,9 +487,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: - ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) + ctrld.Log(ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", upstreams) } } @@ -504,7 +504,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { - ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") + ctrld.Log(ctx, p.Debug(), "hit cached response") setCachedAnswerTTL(answer, now, cachedValue.Expire) res.answer = answer res.cached = true @@ -514,10 +514,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } } resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) + ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) if err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") + ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver") return nil, err } resolveCtx, cancel := upstreamConfig.Context(ctx) @@ -526,7 +526,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") + ctrld.Log(ctx, p.Debug(), "including client info with the request") ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } answer, err := resolve1(upstream, upstreamConfig, msg) @@ -540,7 +540,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return answer } - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query") // increase failure count when there is no answer // rehardless of what kind of error we get @@ -564,7 +564,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } - logger := mainLog.Load().Debug(). + logger := p.Debug(). Str("upstream", upstreamConfig.String()). Str("query", req.msg.Question[0].Name). Bool("is_ad_query", p.isAdDomainQuery(req.msg)). @@ -577,7 +577,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { answer := resolve(upstreams[n], upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") + ctrld.Log(ctx, p.Debug(), "serving stale cached response") now := time.Now() setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) res.answer = staleAnswer @@ -589,11 +589,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // We are doing LAN/PTR lookup using private resolver, so always process next one. // Except for the last, we want to send response instead of saying all upstream failed. if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { - ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n]) + ctrld.Log(ctx, p.Debug(), "no response from %s, process to next upstream", upstreams[n]) continue } if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { - ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") + ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream") continue } @@ -609,18 +609,18 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } setCachedAnswerTTL(answer, now, expired) p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) - ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response") + ctrld.Log(ctx, p.Debug(), "add cached response") } hostname := "" if req.ci != nil { hostname = req.ci.Hostname } - ctrld.Log(ctx, mainLog.Load().Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) + ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) res.answer = answer res.upstream = upstreamConfig.Endpoint return res } - ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) + ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams) // if we have no healthy upstreams, trigger recovery flow if p.leakOnUpstreamFailure() { @@ -633,28 +633,28 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } else { reason = RecoveryReasonRegularFailure } - mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) go p.handleRecovery(reason) } else { - mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") + p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") } p.recoveryCancelMu.Unlock() } else { - mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") + p.Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } // attempt query to OS resolver while as a retry catch all // we dont want this to happen if leakOnUpstreamFailure is false if upstreams[0] != upstreamOS { - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") + ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all") answer := resolve(upstreamOS, osUpstreamConfig, req.msg) if answer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") + ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful") res.answer = answer res.upstream = osUpstreamConfig.Endpoint return res } - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") + ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed") } } @@ -958,10 +958,10 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { return } - logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() + logger := p.logger.Load().With().Str("mode", "self-uninstall").Logger() if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) _, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) @@ -1031,7 +1031,7 @@ func (p *prog) queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) regularIPs, loopbackIPs, err := netmon.LocalAddresses() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not get local addresses") + p.Warn().Err(err).Msg("could not get local addresses") return false } for _, localIP := range slices.Concat(regularIPs, loopbackIPs) { @@ -1151,7 +1151,8 @@ func isWanClient(na net.Addr) bool { // resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller. func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg { - ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query") + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "internal domain test query") q := m.Question[0] answer := new(dns.Msg) @@ -1192,7 +1193,7 @@ func FlushDNSCache() error { func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon, err := netmon.New(func(format string, args ...any) { // Always fetch the latest logger (and inject the prefix) - mainLog.Load().Printf("netmon: "+format, args...) + p.logger.Load().Printf("netmon: "+format, args...) }) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1204,7 +1205,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) - mainLog.Load().Debug(). + p.Debug(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Bool("is_major_change", isMajorChange). @@ -1232,7 +1233,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("new_ips", usableNewIPs). Msg("Interface newly appeared (was not present in old state)") @@ -1254,7 +1255,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("old_ips", oldIPs). Interface("new_ips", usableNewIPs). @@ -1267,39 +1268,39 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // if the default route changed, set changed to true if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface { changed = true - mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) + p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) } if !changed { - mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + p.Debug().Msg("Ignoring interface change - no valid interfaces affected") // check if the default IPs are still on an interface that is up ValidateDefaultLocalIPsFromDelta(delta.New) return } if !activeInterfaceExists { - mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + p.Debug().Msg("No active interfaces found, skipping reinitialization") return } // Get IPs from default route interface in new state - selfIP := defaultRouteIP() + selfIP := p.defaultRouteIP() // Ensure that selfIP is an IPv4 address. // If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil { - mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) + p.Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) selfIP = "" } var ipv6 string if delta.New.DefaultRouteInterface != "" { - mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) + p.Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("checking IP: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1310,12 +1311,12 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } else { // If no default route interface is set yet, use the changed IPs - mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) + p.Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) for _, ip := range changeIPs { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("checking IP: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1328,15 +1329,15 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Only set the IPv4 default if selfIP is a valid IPv4 address. if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { - ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } } if ip := net.ParseIP(ipv6); ip != nil { - ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) } - mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) // we only trigger recovery flow for network changes on non router devices if router.Name() == "" { @@ -1345,7 +1346,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { }) mon.Start() - mainLog.Load().Debug().Msg("Network monitor started") + p.Debug().Msg("Network monitor started") return nil } @@ -1400,11 +1401,11 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { // checkUpstreamOnce sends a test query to the specified upstream. // Returns nil if the upstream responds successfully. func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { - mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + p.Debug().Msgf("Starting check for upstream: %s", upstream) - resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc) + resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) return err } @@ -1415,22 +1416,22 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro if uc.Timeout > 0 { timeout = time.Millisecond * time.Duration(uc.Timeout) } - mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) - mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) + p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() _, err = resolver.Resolve(ctx, msg) duration := time.Since(start) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) } else { - mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) } return err } @@ -1440,13 +1441,13 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro // upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), // and then re-applying the DNS settings. func (p *prog) handleRecovery(reason RecoveryReason) { - mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + p.Debug().Msg("Starting recovery process: removing DNS settings") // For network changes, cancel any existing recovery check because the network state has changed. if reason == RecoveryReasonNetworkChange { p.recoveryCancelMu.Lock() if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.Debug().Msg("Cancelling existing recovery check (network change)") p.recoveryCancel() p.recoveryCancel = nil } @@ -1455,7 +1456,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // For upstream failures, if a recovery is already in progress, do nothing new. p.recoveryCancelMu.Lock() if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") p.recoveryCancelMu.Unlock() return } @@ -1476,15 +1477,15 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // will be appended to nameservers from the saved interface values p.resetDNS(false, false) - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { - mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") + p.Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + p.Warn().Msg("No nameservers found for OS resolver; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } @@ -1494,13 +1495,13 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // Wait indefinitely until one of the upstreams recovers. recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) if err != nil { - mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") + p.Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") p.recoveryCancelMu.Lock() p.recoveryCancel = nil p.recoveryCancelMu.Unlock() return } - mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + p.Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) // reset the upstream failure count and down state p.um.reset(recovered) @@ -1509,9 +1510,9 @@ func (p *prog) handleRecovery(reason RecoveryReason) { if reason == RecoveryReasonNetworkChange { ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") + p.Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } @@ -1534,44 +1535,44 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string recoveredCh := make(chan string, 1) var wg sync.WaitGroup - mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) + p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) for name, uc := range upstreams { wg.Add(1) go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() - mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) + p.Debug().Msgf("Starting recovery check loop for upstream: %s", name) attempts := 0 for { select { case <-ctx.Done(): - mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) + p.Debug().Msgf("Context canceled for upstream %s", name) return default: attempts++ // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { - mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) + p.Debug().Msgf("Upstream %s recovered successfully", name) select { case recoveredCh <- name: - mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) + p.Debug().Msgf("Sent recovery notification for upstream %s", name) default: - mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") + p.Debug().Msg("Recovery channel full, another upstream already recovered") } return } - mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) + p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) // if this is the upstreamOS and it's the 3rd attempt (or multiple of 3), // we should try to reinit the OS resolver to ensure we can recover if name == upstreamOS && attempts%3 == 0 { - mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) - ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true) + p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + p.Warn().Msg("No nameservers found for OS resolver; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 4a4e5b4..615ce40 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -77,7 +77,8 @@ func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false) p := &prog{cfg: cfg} - p.um = newUpstreamMonitor(p.cfg) + p.logger.Store(mainLog.Load()) + p.um = newUpstreamMonitor(p.cfg, mainLog.Load()) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() for _, nc := range p.cfg.Network { @@ -145,6 +146,7 @@ func Test_prog_upstreamFor(t *testing.T) { func TestCache(t *testing.T) { cfg := testhelper.SampleConfig(t) prog := &prog{cfg: cfg} + prog.logger.Store(mainLog.Load()) for _, nc := range prog.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 0ba2c8c..c2880c0 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -100,6 +100,7 @@ func (p *prog) initLogging(backup bool) { // Initializing internal logging after global logging. p.initInternalLogging(logWriters) + p.logger.Store(mainLog.Load()) } // initInternalLogging performs internal logging if there's no log enabled. @@ -108,7 +109,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) { return } p.initInternalLogWriterOnce.Do(func() { - mainLog.Load().Notice().Msg("internal logging enabled") + p.Notice().Msg("internal logging enabled") p.internalLogWriter = newLogWriter() p.internalLogSent = time.Now().Add(-logWriterSentInterval) p.internalWarnLogWriter = newSmallLogWriter() diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 434a4a5..fce6ce1 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) { // // See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html func (p *prog) checkDnsLoop() { - mainLog.Load().Debug().Msg("start checking DNS loop") + p.Debug().Msg("start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() for n, uc := range p.cfg.Upstream { @@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() { } // Do not send test query to external upstream. if !canBeLocalUpstream(uc.Domain) { - mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n) + p.Debug().Msgf("skipping external: upstream.%s", n) continue } uid := uc.UID() @@ -102,7 +102,7 @@ func (p *prog) checkDnsLoop() { } p.loopMu.Unlock() - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] @@ -112,14 +112,14 @@ func (p *prog) checkDnsLoop() { } resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue } if _, err := resolver.Resolve(context.Background(), msg); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) } } - mainLog.Load().Debug().Msg("end checking DNS loop") + p.Debug().Msg("end checking DNS loop") } // checkDnsLoopTicker performs p.checkDnsLoop every minute. diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index f4e9bda..2115c5b 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -14,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.LinkSubscribe(ch, done); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not subscribe link") + p.Warn().Err(err).Msg("could not subscribe link") return } for { @@ -26,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) { continue } if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + p.Debug().Msgf("link state changed, re-bootstrapping") for _, uc := range p.cfg.Upstream { - uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) } } } diff --git a/cmd/cli/network_manager_linux.go b/cmd/cli/network_manager_linux.go index 1a8c22b..bfd2775 100644 --- a/cmd/cli/network_manager_linux.go +++ b/cmd/cli/network_manager_linux.go @@ -28,61 +28,61 @@ func hasNetworkManager() bool { return exe != "" } -func setupNetworkManager() error { +func (p *prog) setupNetworkManager() error { if !hasNetworkManager() { return nil } if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent { - mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do") + p.Debug().Msg("NetworkManager already setup, nothing to do") return nil } err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644)) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file") + p.Debug().Err(err).Msg("could not write NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("setup NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("setup NetworkManager done") return nil } -func restoreNetworkManager() error { +func (p *prog) restoreNetworkManager() error { if !hasNetworkManager() { return nil } err := os.Remove(networkManagerCtrldConfFile) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file") + p.Debug().Err(err).Msg("could not remove NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("restore NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("restore NetworkManager done") return nil } -func reloadNetworkManager() { +func (p *prog) reloadNetworkManager() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() conn, err := dbus.NewSystemConnectionContext(ctx) if err != nil { - mainLog.Load().Error().Err(err).Msg("could not create new system connection") + p.Error().Err(err).Msg("could not create new system connection") return } defer conn.Close() waitCh := make(chan string) if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil { - mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager") + p.Debug().Err(err).Msg("could not reload NetworkManager") return } <-waitCh diff --git a/cmd/cli/network_manager_others.go b/cmd/cli/network_manager_others.go index 323d2f2..e6e5f68 100644 --- a/cmd/cli/network_manager_others.go +++ b/cmd/cli/network_manager_others.go @@ -2,14 +2,14 @@ package cli -func setupNetworkManager() error { - reloadNetworkManager() +func (p *prog) setupNetworkManager() error { + p.reloadNetworkManager() return nil } -func restoreNetworkManager() error { - reloadNetworkManager() +func (p *prog) restoreNetworkManager() error { + p.reloadNetworkManager() return nil } -func reloadNetworkManager() {} +func (p *prog) reloadNetworkManager() {} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d85c371..0cfd3b9 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -102,6 +102,7 @@ type prog struct { apiForceReloadGroup singleflight.Group logConn net.Conn cs *controlServer + logger atomic.Pointer[ctrld.Logger] csSetDnsDone chan struct{} csSetDnsOk bool dnsWg sync.WaitGroup @@ -150,7 +151,7 @@ type prog struct { onStopped []func() } -func (p *prog) Start(s service.Service) error { +func (p *prog) Start(_ service.Service) error { go p.runWait() return nil } @@ -164,7 +165,6 @@ func (p *prog) runWait() { notifyReloadSigCh(reloadSigCh) reload := false - logger := mainLog.Load() for { reloadCh := make(chan struct{}) done := make(chan struct{}) @@ -177,9 +177,9 @@ func (p *prog) runWait() { var newCfg *ctrld.Config select { case sig := <-reloadSigCh: - logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + p.Notice().Msgf("got signal: %s, reloading...", sig.String()) case <-p.reloadCh: - logger.Notice().Msg("reloading...") + p.Notice().Msg("reloading...") case apiCfg := <-p.apiReloadCh: newCfg = apiCfg case <-p.stopCh: @@ -202,18 +202,18 @@ func (p *prog) runWait() { } v.SetConfigFile(confFile) if err := v.ReadInConfig(); err != nil { - logger.Err(err).Msg("could not read new config") + p.Error().Err(err).Msg("could not read new config") waitOldRunDone() continue } if err := v.Unmarshal(&newCfg); err != nil { - logger.Err(err).Msg("could not unmarshal new config") + p.Error().Err(err).Msg("could not unmarshal new config") waitOldRunDone() continue } if cdUID != "" { if rc, err := processCDFlags(newCfg); err != nil { - logger.Err(err).Msg("could not fetch ControlD config") + p.Error().Err(err).Msg("could not fetch ControlD config") waitOldRunDone() continue } else { @@ -243,25 +243,25 @@ func (p *prog) runWait() { } } if err := validateConfig(newCfg); err != nil { - logger.Err(err).Msg("invalid config") + p.Error().Err(err).Msg("invalid config") continue } addExtraSplitDnsRule(newCfg) if err := writeConfigFile(newCfg); err != nil { - logger.Err(err).Msg("could not write new config") + p.Error().Err(err).Msg("could not write new config") } // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. - mainLog.Load().Debug().Msg("setup upstream with new config") + p.Debug().Msg("setup upstream with new config") p.setupUpstream(newCfg) p.mu.Lock() *p.cfg = *newCfg p.mu.Unlock() - logger.Notice().Msg("reloading config successfully") + p.Notice().Msg("reloading config successfully") select { case p.reloadDoneCh <- struct{}{}: @@ -276,6 +276,7 @@ func (p *prog) preRun() { p.requiredMultiNICsConfig = requiredMultiNICsConfig() } p.runningIface = iface + p.logger.Store(mainLog.Load()) } func (p *prog) postRun() { @@ -283,11 +284,11 @@ func (p *prog) postRun() { if runtime.GOOS == "windows" { isDC, roleInt := isRunningOnDomainController() p.runningOnDomainController = isDC - mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) + p.Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) } p.resetDNS(false, false) - ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false) - mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), false) + p.Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} close(p.csSetDnsDone) @@ -304,14 +305,14 @@ func (p *prog) apiConfigReload() { ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second) defer ticker.Stop() - logger := mainLog.Load().With().Str("mode", "api-reload").Logger() + logger := p.logger.Load().With().Str("mode", "api-reload").Logger() logger.Debug().Msg("starting custom config reload timer") lastUpdated := time.Now().Unix() curVerStr := curVersion() curVer, err := semver.NewVersion(curVerStr) isStable := curVer != nil && curVer.Prerelease() == "" if err != nil || !isStable { - l := mainLog.Load().Warn() + l := p.Warn() if err != nil { l = l.Err(err) } @@ -319,7 +320,7 @@ func (p *prog) apiConfigReload() { } doReloadApiConfig := func(forced bool, logger zerolog.Logger) { - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) selfUninstallCheck(err, p, logger) if err != nil { @@ -405,20 +406,20 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) isControlDUpstream := false - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for n := range cfg.Upstream { uc := cfg.Upstream[n] sdns := uc.Type == ctrld.ResolverTypeSDNS uc.Init(loggerCtx) if sdns { - mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) + p.Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) } isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { - uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load())) - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), p.logger.Load())) + p.Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + p.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) } uc.SetCertPool(rootCertPool) go uc.Ping(loggerCtx) @@ -459,9 +460,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.csSetDnsDone = make(chan struct{}, 1) p.registerControlServerHandler() if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + p.Warn().Err(err).Msg("could not start control server") } - mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr) + p.Debug().Msgf("control server started: %s", p.cs.addr) } } p.onStartedDone = make(chan struct{}) @@ -473,7 +474,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled") + p.Error().Err(err).Msg("failed to create cacher, caching is disabled") } else { p.cache = cacher p.cacheFlushDomainsMap = make(map[string]struct{}, 256) @@ -483,7 +484,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() { - mainLog.Load().Debug().Msgf("active directory domain: %s", domain) + p.Debug().Msgf("active directory domain: %s", domain) p.adDomain = domain } @@ -494,14 +495,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { - mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr") + p.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr") continue } nc.IPNets = append(nc.IPNets, ipNet) } } - p.um = newUpstreamMonitor(p.cfg) + p.um = newUpstreamMonitor(p.cfg, p.logger.Load()) if !reload { p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} @@ -514,7 +515,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.setupClientInfoDiscover(defaultRouteIP()) + p.setupClientInfoDiscover() } // context for managing spawn goroutines. @@ -538,14 +539,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { listenerConfig := p.cfg.Listener[listenerNum] upstreamConfig := p.cfg.Upstream[listenerNum] if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + p.Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) + p.Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) if err := p.serveDNS(ctx, listenerNum); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + p.Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } - mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) + p.Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { @@ -602,10 +603,11 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } // setupClientInfoDiscover performs necessary works for running client info discover. -func (p *prog) setupClientInfoDiscover(selfIP string) { - p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load()) +func (p *prog) setupClientInfoDiscover() { + selfIP := p.defaultRouteIP() + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, p.logger.Load()) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + p.Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } @@ -622,18 +624,18 @@ func (p *prog) metricsEnabled() bool { return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != "" } -func (p *prog) Stop(s service.Service) error { +func (p *prog) Stop(_ service.Service) error { p.stopDnsWatchers() - mainLog.Load().Debug().Msg("dns watchers stopped") + p.Debug().Msg("dns watchers stopped") for _, f := range p.onStopped { f() } - mainLog.Load().Debug().Msg("finish running onStopped functions") + p.Debug().Msg("finish running onStopped functions") defer func() { - mainLog.Load().Info().Msg("Service stopped") + p.Info().Msg("Service stopped") }() if err := p.deAllocateIP(); err != nil { - mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") + p.Error().Err(err).Msg("de-allocate ip failed") return err } if deactivationPinSet() { @@ -645,16 +647,16 @@ func (p *prog) Stop(s service.Service) error { // No valid pin code was checked, that mean we are stopping // because of OS signal sent directly from someone else. // In this case, restarting ctrld service by ourselves. - mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code") - mainLog.Load().Debug().Msgf("self restarting ctrld service") + p.Debug().Msgf("receiving stopping signal without valid pin code") + p.Debug().Msgf("self restarting ctrld service") if exe, err := os.Executable(); err == nil { cmd := exec.Command(exe, "restart") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to run self restart command") + p.Error().Err(err).Msg("failed to run self restart command") } } else { - mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service") + p.Error().Err(err).Msg("failed to self restart ctrld service") } os.Exit(deactivationPinInvalidExitCode) } @@ -755,7 +757,7 @@ func (p *prog) setDNS() { p.dnsWg.Add(1) go func() { defer p.dnsWg.Done() - p.watchResolvConf(netIface, servers, setResolvConf) + p.watchResolvConf(netIface, servers, p.setResolvConf) }() } if p.dnsWatchdogEnabled() { @@ -772,7 +774,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface).Logger() const maxDNSRetryAttempts = 3 const retryDelay = 1 * time.Second @@ -785,10 +787,10 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In } if attempt < maxDNSRetryAttempts { // Try to find a different working interface - newIface := findWorkingInterface(p.runningIface) + newIface := p.findWorkingInterface() if newIface != p.runningIface { p.runningIface = newIface - logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger = p.logger.Load().With().Str("iface", p.runningIface).Logger() logger.Info().Msg("switched to new interface") continue } @@ -800,7 +802,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In logger.Error().Err(err).Msg("could not get interface after all attempts") return } - if err := setupNetworkManager(); err != nil { + if err := p.setupNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not patch NetworkManager") return } @@ -840,7 +842,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { return } - mainLog.Load().Debug().Msg("start DNS settings watchdog") + p.Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) @@ -851,19 +853,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msg("stop dns watchdog") + p.Debug().Msg("stop dns watchdog") return case <-ticker.C: if p.recoveryRunning.Load() { return } - if dnsChanged(iface, ns) { - mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings") + if p.dnsChanged(iface, ns) { + p.Debug().Msg("DNS settings were changed, re-applying settings") // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(iface) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { @@ -873,12 +875,12 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(iface); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) } } } if err := setDNS(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") } } if p.requiredMultiNICsConfig { @@ -887,13 +889,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { ifaceName = iface.Name } withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error { - if dnsChanged(i, ns) { + if p.dnsChanged(i, ns) { // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { @@ -903,15 +905,15 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(i); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) } } } if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { - mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") } else { - mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) + p.Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) } } return nil @@ -941,17 +943,17 @@ func (p *prog) resetDNS(isStart bool, restoreStatic bool) { // Otherwise, we restore the saved configuration (if any) or reset to DHCP. func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) { if p.runningIface == "" { - mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") + p.Debug().Msg("no running interface, skipping resetDNS") return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface).Logger() netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return } runningIface = netIface - if err := restoreNetworkManager(); err != nil { + if err := p.restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return } @@ -999,16 +1001,16 @@ func (p *prog) logInterfacesState() { withEachPhysicalInterfaces("", "", func(i *net.Interface) error { addrs, err := i.Addrs() if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") + p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") } nss, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") + p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") } if len(nss) == 0 { nss = currentDNS(i) } - mainLog.Load().Debug(). + p.Debug(). Any("addrs", addrs). Strs("nameservers", nss). Int("index", i.Index). @@ -1018,7 +1020,8 @@ func (p *prog) logInterfacesState() { } // findWorkingInterface looks for a network interface with a valid IP configuration -func findWorkingInterface(currentIface string) string { +func (p *prog) findWorkingInterface() string { + currentIface := p.runningIface // Helper to check if IP is valid (not link-local) isValidIP := func(ip net.IP) bool { return ip != nil && @@ -1036,7 +1039,7 @@ func findWorkingInterface(currentIface string) string { addrs, err := iface.Addrs() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Str("interface", iface.Name). Err(err). Msg("failed to get interface addresses") @@ -1057,11 +1060,11 @@ func findWorkingInterface(currentIface string) string { // Get default route interface defaultRoute, err := netmon.DefaultRoute() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Msg("failed to get default route") } else { - mainLog.Load().Debug(). + p.Debug(). Str("default_route_iface", defaultRoute.InterfaceName). Msg("found default route") } @@ -1069,7 +1072,7 @@ func findWorkingInterface(currentIface string) string { // Get all interfaces ifaces, err := net.Interfaces() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + p.Error().Err(err).Msg("failed to list network interfaces") return currentIface // Return current interface as fallback } @@ -1099,7 +1102,7 @@ func findWorkingInterface(currentIface string) string { // Found working physical interface if err == nil && defaultRoute.InterfaceName == iface.Name { // Found interface with default route - use it immediately - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", iface.Name). Msg("switching to interface with default route") @@ -1120,7 +1123,7 @@ func findWorkingInterface(currentIface string) string { // Return interfaces in order of preference: // 1. Current interface if it's still valid if currentIfaceValid { - mainLog.Load().Debug(). + p.Debug(). Str("interface", currentIface). Msg("keeping current interface") return currentIface @@ -1128,7 +1131,7 @@ func findWorkingInterface(currentIface string) string { // 2. First working interface found if firstWorkingIface != "" { - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", firstWorkingIface). Msg("switching to first working physical interface") @@ -1136,7 +1139,7 @@ func findWorkingInterface(currentIface string) string { } // 3. Fall back to current interface if nothing else works - mainLog.Load().Warn(). + p.Warn(). Str("current_iface", currentIface). Msg("no working physical interface found, keeping current") return currentIface @@ -1258,7 +1261,7 @@ func ifaceFirstPrivateIP(iface *net.Interface) string { } // defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6. -func defaultRouteIP() string { +func (p *prog) defaultRouteIP() string { dr, err := netmon.DefaultRoute() if err != nil { return "" @@ -1267,9 +1270,9 @@ func defaultRouteIP() string { if err != nil { return "" } - mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") + p.Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") if ip := ifaceFirstPrivateIP(drNetIface); ip != "" { - mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface") + p.Debug().Str("ip", ip).Msg("found ip with default route interface") return ip } @@ -1294,7 +1297,7 @@ func defaultRouteIP() string { }) if len(addrs) == 0 { - mainLog.Load().Warn().Msg("no default route IP found") + p.Warn().Msg("no default route IP found") return "" } sort.Slice(addrs, func(i, j int) bool { @@ -1302,7 +1305,7 @@ func defaultRouteIP() string { }) ip := addrs[0].String() - mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") + p.Debug().Str("ip", ip).Msg("found LAN interface IP") return ip } @@ -1413,14 +1416,14 @@ func saveCurrentStaticDNS(iface *net.Interface) error { // It returns false for a nil iface. // // The caller must sort the nameservers before calling this function. -func dnsChanged(iface *net.Interface, nameservers []string) bool { +func (p *prog) dnsChanged(iface *net.Interface, nameservers []string) bool { if iface == nil { return false } curNameservers, _ := currentStaticDNS(iface) slices.Sort(curNameservers) if !slices.Equal(curNameservers, nameservers) { - mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) + p.Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) return true } return false @@ -1465,16 +1468,16 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") + logger.Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") return } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") + logger.Error().Err(err).Msg("failed to start self-upgrade") return } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts) + logger.Debug().Msgf("self-upgrade triggered, version target: %s", vts) } // leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow diff --git a/cmd/cli/prog_log.go b/cmd/cli/prog_log.go new file mode 100644 index 0000000..dec20e9 --- /dev/null +++ b/cmd/cli/prog_log.go @@ -0,0 +1,33 @@ +package cli + +import "github.com/rs/zerolog" + +// Debug starts a new message with debug level. +func (p *prog) Debug() *zerolog.Event { + return p.logger.Load().Debug() +} + +// Warn starts a new message with warn level. +func (p *prog) Warn() *zerolog.Event { + return p.logger.Load().Warn() +} + +// Info starts a new message with info level. +func (p *prog) Info() *zerolog.Event { + return p.logger.Load().Info() +} + +// Fatal starts a new message with fatal level. +func (p *prog) Fatal() *zerolog.Event { + return p.logger.Load().Fatal() +} + +// Error starts a new message with error level. +func (p *prog) Error() *zerolog.Event { + return p.logger.Load().Error() +} + +// Notice starts a new message with notice level. +func (p *prog) Notice() *zerolog.Event { + return p.logger.Load().Notice() +} diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 0f3f731..587841d 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -43,10 +43,10 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" { resolvConfPath = rp } - mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath) + p.Debug().Msgf("start watching %s file", resolvConfPath) watcher, err := fsnotify.NewWatcher() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") + p.Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") return } defer watcher.Close() @@ -55,7 +55,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well watchDir := filepath.Dir(resolvConfPath) if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) + p.Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) return } @@ -64,7 +64,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) + p.Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: if p.recoveryRunning.Load() { @@ -77,7 +77,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") // Convert expected nameservers to strings for comparison expectedNS := make([]string, len(ns)) @@ -92,7 +92,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f for retry := 0; retry < maxRetries; retry++ { foundNS, err = p.parseResolvConfNameservers(resolvConfPath) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + p.Error().Err(err).Msg("failed to read resolv.conf content") break } @@ -103,7 +103,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // Only retry if we found no nameservers if retry < maxRetries-1 { - mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) select { case <-p.stopCh: return @@ -113,7 +113,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } } else { - mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + p.Debug().Msg("resolv.conf remained empty after all retries") } } @@ -130,7 +130,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f } } - mainLog.Load().Debug(). + p.Debug(). Strs("found", foundNS). Strs("expected", expectedNS). Bool("matches", matches). @@ -139,16 +139,16 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // Only revert if the nameservers don't match if !matches { if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + p.Error().Err(err).Msg("failed to pause watcher") continue } if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + p.Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") } if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + p.Error().Err(err).Msg("failed to continue running watcher") return } } @@ -158,7 +158,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f if !ok { return } - mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf") + p.Error().Err(err).Msg("could not get event for /etc/resolv.conf") } } } diff --git a/cmd/cli/resolvconf_darwin.go b/cmd/cli/resolvconf_darwin.go index eb70eed..05c7017 100644 --- a/cmd/cli/resolvconf_darwin.go +++ b/cmd/cli/resolvconf_darwin.go @@ -12,7 +12,7 @@ import ( const resolvConfPath = "/etc/resolv.conf" // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { servers := make([]string, len(ns)) for i := range ns { servers[i] = ns[i].String() diff --git a/cmd/cli/resolvconf_not_darwin_unix.go b/cmd/cli/resolvconf_not_darwin_unix.go index af33572..8838dc2 100644 --- a/cmd/cli/resolvconf_not_darwin_unix.go +++ b/cmd/cli/resolvconf_not_darwin_unix.go @@ -14,7 +14,7 @@ import ( ) // setResolvConf sets the content of the resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { r, err := newLoopbackOSConfigurator() if err != nil { return err @@ -27,7 +27,7 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error { if sds, err := searchDomains(); err == nil { oc.SearchDomains = sds } else { - mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") + p.Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") } return r.SetDNS(oc) } diff --git a/cmd/cli/resolvconf_windows.go b/cmd/cli/resolvconf_windows.go index 3e4ba1c..20a984f 100644 --- a/cmd/cli/resolvconf_windows.go +++ b/cmd/cli/resolvconf_windows.go @@ -6,7 +6,7 @@ import ( ) // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(_ *net.Interface, _ []netip.Addr) error { +func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error { return nil } diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 6e19e38..426886e 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -2,6 +2,7 @@ package cli import ( "sync" + "sync/atomic" "time" "github.com/Control-D-Inc/ctrld" @@ -16,7 +17,8 @@ const ( // upstreamMonitor performs monitoring upstreams health. type upstreamMonitor struct { - cfg *ctrld.Config + cfg *ctrld.Config + logger atomic.Pointer[ctrld.Logger] mu sync.RWMutex checking map[string]bool @@ -28,7 +30,7 @@ type upstreamMonitor struct { failureTimerActive map[string]bool } -func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { +func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor { um := &upstreamMonitor{ cfg: cfg, checking: make(map[string]bool), @@ -37,6 +39,7 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { recovered: make(map[string]bool), failureTimerActive: make(map[string]bool), } + um.logger.Store(logger) for n := range cfg.Upstream { upstream := upstreamPrefix + n um.reset(upstream) @@ -53,7 +56,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { defer um.mu.Unlock() if um.recovered[upstream] { - mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) + um.logger.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) return } @@ -61,7 +64,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { failedCount := um.failureReq[upstream] // Log the updated failure count. - mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + um.logger.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) // If this is the first failure and no timer is running, start a 10-second timer. if failedCount == 1 && !um.failureTimerActive[upstream] { @@ -74,7 +77,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // and the upstream is not in a recovered state, mark it as down. if um.failureReq[upstream] > 0 && !um.recovered[upstream] { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) + um.logger.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) } // Reset the timer flag so that a new timer can be spawned if needed. um.failureTimerActive[upstream] = false @@ -84,7 +87,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // If the failure count quickly reaches the threshold, mark the upstream as down immediately. if failedCount >= maxFailureRequest { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) + um.logger.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) } }