diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5005925..cc5d1fe 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -211,6 +211,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cfg: &cfg, appCallback: appCallback, } + p.logger.Store(mainLog.Load()) if homedir == "" { if dir, err := userHomeDir(); err == nil { homedir = dir @@ -228,11 +229,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.logConn = lc } else { if !errors.Is(err, os.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msg("unable to create log ipc connection") + p.Warn().Err(err).Msg("unable to create log ipc connection") } } } else { - mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) + p.Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) } notifyExitToLogServer := func() { if p.logConn != nil { @@ -241,7 +242,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if daemon && runtime.GOOS == "windows" { - mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") + p.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") } if !daemon { @@ -250,10 +251,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { go func() { s, err := newService(p, svcConfig) if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed create new service") + p.Fatal().Err(err).Msg("failed create new service") } if err := s.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start service") + p.Error().Err(err).Msg("failed to start service") } }() } @@ -261,7 +262,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { tryReadingConfig(writeDefaultConfig) if err := readBase64Config(configBase64); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") + p.Fatal().Err(err).Msg("failed to read base64 config") } processNoConfigFlags(noConfigStart) @@ -270,7 +271,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + p.Fatal().Msgf("failed to unmarshal config: %v", err) } p.mu.Unlock() @@ -280,19 +281,19 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // so it's able to log information in processCDFlags. p.initLogging(true) - mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) - mainLog.Load().Info().Msgf("os: %s", osVersion()) + p.Info().Msgf("starting ctrld %s", curVersion()) + p.Info().Msgf("os: %s", osVersion()) // Wait for network up. if !ctrldnet.Up() { notifyExitToLogServer() - mainLog.Load().Fatal().Msg("network is not up yet") + p.Fatal().Msg("network is not up yet") } p.router = router.New(&cfg, cdUID != "") cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName())) if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create control server") + p.Warn().Err(err).Msg("could not create control server") } p.cs = cs @@ -301,7 +302,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // to set the current time, so this check must happen before processCDFlags. if err := p.router.PreRun(); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") + p.Fatal().Err(err).Msg("failed to perform router pre-run check") } oldLogPath := cfg.Service.LogPath @@ -316,7 +317,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { return } - cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + cdLogger := p.logger.Load().With().Str("mode", "cd").Logger() // Performs self-uninstallation if the ControlD device does not exist. var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { @@ -340,9 +341,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if updated { if err := writeConfigFile(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to write config file") + p.Fatal().Err(err).Msg("failed to write config file") } else { - mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + p.Info().Msg("writing config file to: " + defaultConfigFile) } } @@ -354,10 +355,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not copy old log file") + p.Warn().Err(err).Msg("could not copy old log file") } } initLoggingWithBackup(false) + p.logger.Store(mainLog.Load()) } if err := validateConfig(&cfg); err != nil { @@ -369,13 +371,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if daemon { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to find the binary") + p.Error().Err(err).Msg("failed to find the binary") notifyExitToLogServer() os.Exit(1) } curDir, err := os.Getwd() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get current working directory") + p.Error().Err(err).Msg("failed to get current working directory") notifyExitToLogServer() os.Exit(1) } @@ -383,11 +385,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) cmd.Dir = curDir if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") + p.Error().Err(err).Msg("failed to start process as daemon") notifyExitToLogServer() os.Exit(1) } - mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") + p.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") os.Exit(0) } @@ -395,7 +397,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := allocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) } } } @@ -406,7 +408,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := deAllocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) } } } @@ -417,15 +419,15 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if iface != "" { p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") + p.Debug().Msg("router setup on start") if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") + p.Error().Err(err).Msg("could not configure router") } }) p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") + p.Debug().Msg("router cleanup on stop") if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") + p.Error().Err(err).Msg("could not cleanup router") } }) } @@ -438,9 +440,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { - mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + p.Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) } else { - mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + p.Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) } } return nil diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 428fe12..de3a27a 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -79,21 +79,21 @@ func (s *controlServer) register(pattern string, handler http.Handler) { func (p *prog) registerControlServerHandler() { p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - mainLog.Load().Debug().Msg("handling list clients request") + p.Debug().Msg("handling list clients request") clients := p.ciTable.ListClients() - mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list") + p.Debug().Int("client_count", len(clients)).Msg("retrieved clients list") sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) - mainLog.Load().Debug().Msg("sorted clients by IP address") + p.Debug().Msg("sorted clients by IP address") if p.metricsQueryStats.Load() { - mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts") + p.Debug().Msg("metrics query stats enabled, collecting query counts") for idx, client := range clients { - mainLog.Load().Debug(). + p.Debug(). Int("index", idx). Str("ip", client.IP.String()). Str("mac", client.Mac). @@ -104,7 +104,7 @@ func (p *prog) registerControlServerHandler() { dm := &dto.Metric{} if statsClientQueriesCount.MetricVec == nil { - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). Msg("skipping metrics collection: MetricVec is nil") continue @@ -116,7 +116,7 @@ func (p *prog) registerControlServerHandler() { client.Hostname, ) if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). Str("mac", client.Mac). @@ -127,23 +127,23 @@ func (p *prog) registerControlServerHandler() { if err := m.Write(dm); err == nil && dm.Counter != nil { client.QueryCount = int64(dm.Counter.GetValue()) - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). Int64("query_count", client.QueryCount). Msg("successfully collected query count") } else if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). Msg("failed to write metric") } } } else { - mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts") + p.Debug().Msg("metrics query stats disabled, skipping query counts") } if err := json.NewEncoder(w).Encode(&clients); err != nil { - mainLog.Load().Error(). + p.Error(). Err(err). Int("client_count", len(clients)). Msg("failed to encode clients response") @@ -151,7 +151,7 @@ func (p *prog) registerControlServerHandler() { return } - mainLog.Load().Debug(). + p.Debug(). Int("client_count", len(clients)). Msg("successfully sent clients list response") })) @@ -175,7 +175,7 @@ func (p *prog) registerControlServerHandler() { oldSvc := p.cfg.Service p.mu.Unlock() if err := p.sendReloadSignal(); err != nil { - mainLog.Load().Err(err).Msg("could not send reload signal") + p.Error().Err(err).Msg("could not send reload signal") http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -216,7 +216,7 @@ func (p *prog) registerControlServerHandler() { return } - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // Re-fetch pin code from API. if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil { if rc.DeactivationPin != nil { @@ -225,7 +225,7 @@ func (p *prog) registerControlServerHandler() { cdDeactivationPin.Store(defaultDeactivationPin) } } else { - mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code") + p.Warn().Err(err).Msg("could not re-fetch deactivation pin code") } // If pin code not set, allowing deactivation. @@ -237,7 +237,7 @@ func (p *prog) registerControlServerHandler() { var req deactivationRequest if err := json.NewDecoder(request.Body).Decode(&req); err != nil { w.WriteHeader(http.StatusPreconditionFailed) - mainLog.Load().Err(err).Msg("invalid deactivation request") + p.Error().Err(err).Msg("invalid deactivation request") return } @@ -320,15 +320,15 @@ func (p *prog) registerControlServerHandler() { UID: cdUID, Data: r.r, } - mainLog.Load().Debug().Msg("sending log file to ControlD server") + p.Debug().Msg("sending log file to ControlD server") resp := logSentResponse{Size: r.size} - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil { - mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) + p.Error().Msgf("could not send log file to ControlD server: %v", err) resp.Error = err.Error() w.WriteHeader(http.StatusInternalServerError) } else { - mainLog.Load().Debug().Msg("sending log file successfully") + p.Debug().Msg("sending log file successfully") w.WriteHeader(http.StatusOK) } if err := json.NewEncoder(w).Encode(&resp); err != nil { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a3d9970..c09e11d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -87,14 +87,14 @@ type upstreamForResult struct { func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { // Start network monitoring if err := p.monitorNetworkChanges(mainCtx); err != nil { - mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + p.Error().Err(err).Msg("Failed to start network monitoring") // Don't return here as we still want DNS service to run } listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { - mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") + p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } @@ -110,9 +110,9 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) - ctx = ctrld.LoggerCtx(ctx, mainLog.Load()) + ctx = ctrld.LoggerCtx(ctx, p.logger.Load()) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { - ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) answer := new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) _ = w.WriteMsg(answer) @@ -135,7 +135,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { p.cache.Purge() - ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) + ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain) } remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) @@ -144,7 +144,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String()) t := time.Now() - ctrld.Log(ctx, mainLog.Load().Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) + ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) labelValues := make([]string, 0, len(statsQueriesCountLabels)) @@ -155,7 +155,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { var answer *dns.Msg if !ur.matched && listenerConfig.Restricted { - ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String()) + ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String()) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) labelValues = append(labelValues, "") // no upstream @@ -174,7 +174,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { answer = pr.answer rtt := time.Since(t) - ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) + ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt) upstream := pr.upstream switch { case pr.cached: @@ -192,7 +192,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { p.forceFetchingAPI(domain) }() if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") + ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client") } }) @@ -209,7 +209,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { case err := <-errCh: // Local ipv6 listener should not terminate ctrld. // It's a workaround for a quirk on Windows. - mainLog.Load().Warn().Err(err).Msg("local ipv6 listener failed") + p.Warn().Err(err).Msg("local ipv6 listener failed") } return nil }) @@ -229,7 +229,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { case err := <-errCh: // RFC1918 listener should not terminate ctrld. // It's a workaround for a quirk on system with systemd-resolved. - mainLog.Load().Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) + p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) } }() } @@ -371,8 +371,8 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg }, Ptr: dns.Fqdn(name), }} - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: name, @@ -416,8 +416,8 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg AAAA: ip.AsSlice(), }} } - ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: hostname, @@ -441,7 +441,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // running and listening on local addresses, these local addresses must be used // as nameservers, so queries for ADDC could be resolved as expected. if p.isAdDomainQuery(req.msg) { - ctrld.Log(ctx, mainLog.Load().Debug(), + ctrld.Log(ctx, p.Debug(), "AD domain query detected for %s in domain %s", req.msg.Question[0].Name, p.adDomain) upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} @@ -459,14 +459,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // 4. Try remote upstream. isLanOrPtrQuery := false if req.ufr.matched { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) + ctrld.Log(ctx, p.Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { case isSrvLanLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { @@ -476,7 +476,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { @@ -487,9 +487,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) + ctrld.Log(ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: - ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) + ctrld.Log(ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", upstreams) } } @@ -504,7 +504,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { - ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") + ctrld.Log(ctx, p.Debug(), "hit cached response") setCachedAnswerTTL(answer, now, cachedValue.Expire) res.answer = answer res.cached = true @@ -514,10 +514,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } } resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) + ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) if err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") + ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver") return nil, err } resolveCtx, cancel := upstreamConfig.Context(ctx) @@ -526,7 +526,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") + ctrld.Log(ctx, p.Debug(), "including client info with the request") ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } answer, err := resolve1(upstream, upstreamConfig, msg) @@ -540,7 +540,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return answer } - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query") // increase failure count when there is no answer // rehardless of what kind of error we get @@ -564,7 +564,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } - logger := mainLog.Load().Debug(). + logger := p.Debug(). Str("upstream", upstreamConfig.String()). Str("query", req.msg.Question[0].Name). Bool("is_ad_query", p.isAdDomainQuery(req.msg)). @@ -577,7 +577,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { answer := resolve(upstreams[n], upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") + ctrld.Log(ctx, p.Debug(), "serving stale cached response") now := time.Now() setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) res.answer = staleAnswer @@ -589,11 +589,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // We are doing LAN/PTR lookup using private resolver, so always process next one. // Except for the last, we want to send response instead of saying all upstream failed. if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { - ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n]) + ctrld.Log(ctx, p.Debug(), "no response from %s, process to next upstream", upstreams[n]) continue } if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { - ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") + ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream") continue } @@ -609,18 +609,18 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } setCachedAnswerTTL(answer, now, expired) p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) - ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response") + ctrld.Log(ctx, p.Debug(), "add cached response") } hostname := "" if req.ci != nil { hostname = req.ci.Hostname } - ctrld.Log(ctx, mainLog.Load().Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) + ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) res.answer = answer res.upstream = upstreamConfig.Endpoint return res } - ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) + ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams) // if we have no healthy upstreams, trigger recovery flow if p.leakOnUpstreamFailure() { @@ -633,28 +633,28 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } else { reason = RecoveryReasonRegularFailure } - mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) go p.handleRecovery(reason) } else { - mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") + p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") } p.recoveryCancelMu.Unlock() } else { - mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") + p.Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } // attempt query to OS resolver while as a retry catch all // we dont want this to happen if leakOnUpstreamFailure is false if upstreams[0] != upstreamOS { - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") + ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all") answer := resolve(upstreamOS, osUpstreamConfig, req.msg) if answer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") + ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful") res.answer = answer res.upstream = osUpstreamConfig.Endpoint return res } - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") + ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed") } } @@ -958,10 +958,10 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { return } - logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() + logger := p.logger.Load().With().Str("mode", "self-uninstall").Logger() if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) _, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) @@ -1031,7 +1031,7 @@ func (p *prog) queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) regularIPs, loopbackIPs, err := netmon.LocalAddresses() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not get local addresses") + p.Warn().Err(err).Msg("could not get local addresses") return false } for _, localIP := range slices.Concat(regularIPs, loopbackIPs) { @@ -1151,7 +1151,8 @@ func isWanClient(na net.Addr) bool { // resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller. func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg { - ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query") + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "internal domain test query") q := m.Question[0] answer := new(dns.Msg) @@ -1192,7 +1193,7 @@ func FlushDNSCache() error { func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon, err := netmon.New(func(format string, args ...any) { // Always fetch the latest logger (and inject the prefix) - mainLog.Load().Printf("netmon: "+format, args...) + p.logger.Load().Printf("netmon: "+format, args...) }) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1204,7 +1205,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) - mainLog.Load().Debug(). + p.Debug(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Bool("is_major_change", isMajorChange). @@ -1232,7 +1233,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("new_ips", usableNewIPs). Msg("Interface newly appeared (was not present in old state)") @@ -1254,7 +1255,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("old_ips", oldIPs). Interface("new_ips", usableNewIPs). @@ -1267,39 +1268,39 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // if the default route changed, set changed to true if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface { changed = true - mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) + p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) } if !changed { - mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + p.Debug().Msg("Ignoring interface change - no valid interfaces affected") // check if the default IPs are still on an interface that is up ValidateDefaultLocalIPsFromDelta(delta.New) return } if !activeInterfaceExists { - mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + p.Debug().Msg("No active interfaces found, skipping reinitialization") return } // Get IPs from default route interface in new state - selfIP := defaultRouteIP() + selfIP := p.defaultRouteIP() // Ensure that selfIP is an IPv4 address. // If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil { - mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) + p.Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) selfIP = "" } var ipv6 string if delta.New.DefaultRouteInterface != "" { - mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) + p.Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("checking IP: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1310,12 +1311,12 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } else { // If no default route interface is set yet, use the changed IPs - mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) + p.Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) for _, ip := range changeIPs { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("checking IP: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1328,15 +1329,15 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Only set the IPv4 default if selfIP is a valid IPv4 address. if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { - ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } } if ip := net.ParseIP(ipv6); ip != nil { - ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) } - mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) // we only trigger recovery flow for network changes on non router devices if router.Name() == "" { @@ -1345,7 +1346,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { }) mon.Start() - mainLog.Load().Debug().Msg("Network monitor started") + p.Debug().Msg("Network monitor started") return nil } @@ -1400,11 +1401,11 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { // checkUpstreamOnce sends a test query to the specified upstream. // Returns nil if the upstream responds successfully. func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { - mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + p.Debug().Msgf("Starting check for upstream: %s", upstream) - resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc) + resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) return err } @@ -1415,22 +1416,22 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro if uc.Timeout > 0 { timeout = time.Millisecond * time.Duration(uc.Timeout) } - mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) - mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) + p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() _, err = resolver.Resolve(ctx, msg) duration := time.Since(start) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) } else { - mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) } return err } @@ -1440,13 +1441,13 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro // upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), // and then re-applying the DNS settings. func (p *prog) handleRecovery(reason RecoveryReason) { - mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + p.Debug().Msg("Starting recovery process: removing DNS settings") // For network changes, cancel any existing recovery check because the network state has changed. if reason == RecoveryReasonNetworkChange { p.recoveryCancelMu.Lock() if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.Debug().Msg("Cancelling existing recovery check (network change)") p.recoveryCancel() p.recoveryCancel = nil } @@ -1455,7 +1456,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // For upstream failures, if a recovery is already in progress, do nothing new. p.recoveryCancelMu.Lock() if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") p.recoveryCancelMu.Unlock() return } @@ -1476,15 +1477,15 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // will be appended to nameservers from the saved interface values p.resetDNS(false, false) - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { - mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") + p.Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + p.Warn().Msg("No nameservers found for OS resolver; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } @@ -1494,13 +1495,13 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // Wait indefinitely until one of the upstreams recovers. recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) if err != nil { - mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") + p.Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") p.recoveryCancelMu.Lock() p.recoveryCancel = nil p.recoveryCancelMu.Unlock() return } - mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + p.Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) // reset the upstream failure count and down state p.um.reset(recovered) @@ -1509,9 +1510,9 @@ func (p *prog) handleRecovery(reason RecoveryReason) { if reason == RecoveryReasonNetworkChange { ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") + p.Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } @@ -1534,44 +1535,44 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string recoveredCh := make(chan string, 1) var wg sync.WaitGroup - mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) + p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) for name, uc := range upstreams { wg.Add(1) go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() - mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) + p.Debug().Msgf("Starting recovery check loop for upstream: %s", name) attempts := 0 for { select { case <-ctx.Done(): - mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) + p.Debug().Msgf("Context canceled for upstream %s", name) return default: attempts++ // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { - mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) + p.Debug().Msgf("Upstream %s recovered successfully", name) select { case recoveredCh <- name: - mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) + p.Debug().Msgf("Sent recovery notification for upstream %s", name) default: - mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") + p.Debug().Msg("Recovery channel full, another upstream already recovered") } return } - mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) + p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) // if this is the upstreamOS and it's the 3rd attempt (or multiple of 3), // we should try to reinit the OS resolver to ensure we can recover if name == upstreamOS && attempts%3 == 0 { - mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) - ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true) + p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + p.Warn().Msg("No nameservers found for OS resolver; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 4a4e5b4..615ce40 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -77,7 +77,8 @@ func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false) p := &prog{cfg: cfg} - p.um = newUpstreamMonitor(p.cfg) + p.logger.Store(mainLog.Load()) + p.um = newUpstreamMonitor(p.cfg, mainLog.Load()) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() for _, nc := range p.cfg.Network { @@ -145,6 +146,7 @@ func Test_prog_upstreamFor(t *testing.T) { func TestCache(t *testing.T) { cfg := testhelper.SampleConfig(t) prog := &prog{cfg: cfg} + prog.logger.Store(mainLog.Load()) for _, nc := range prog.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 0ba2c8c..c2880c0 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -100,6 +100,7 @@ func (p *prog) initLogging(backup bool) { // Initializing internal logging after global logging. p.initInternalLogging(logWriters) + p.logger.Store(mainLog.Load()) } // initInternalLogging performs internal logging if there's no log enabled. @@ -108,7 +109,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) { return } p.initInternalLogWriterOnce.Do(func() { - mainLog.Load().Notice().Msg("internal logging enabled") + p.Notice().Msg("internal logging enabled") p.internalLogWriter = newLogWriter() p.internalLogSent = time.Now().Add(-logWriterSentInterval) p.internalWarnLogWriter = newSmallLogWriter() diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 434a4a5..fce6ce1 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) { // // See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html func (p *prog) checkDnsLoop() { - mainLog.Load().Debug().Msg("start checking DNS loop") + p.Debug().Msg("start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() for n, uc := range p.cfg.Upstream { @@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() { } // Do not send test query to external upstream. if !canBeLocalUpstream(uc.Domain) { - mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n) + p.Debug().Msgf("skipping external: upstream.%s", n) continue } uid := uc.UID() @@ -102,7 +102,7 @@ func (p *prog) checkDnsLoop() { } p.loopMu.Unlock() - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] @@ -112,14 +112,14 @@ func (p *prog) checkDnsLoop() { } resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue } if _, err := resolver.Resolve(context.Background(), msg); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) } } - mainLog.Load().Debug().Msg("end checking DNS loop") + p.Debug().Msg("end checking DNS loop") } // checkDnsLoopTicker performs p.checkDnsLoop every minute. diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index f4e9bda..2115c5b 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -14,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.LinkSubscribe(ch, done); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not subscribe link") + p.Warn().Err(err).Msg("could not subscribe link") return } for { @@ -26,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) { continue } if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + p.Debug().Msgf("link state changed, re-bootstrapping") for _, uc := range p.cfg.Upstream { - uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) } } } diff --git a/cmd/cli/network_manager_linux.go b/cmd/cli/network_manager_linux.go index 1a8c22b..bfd2775 100644 --- a/cmd/cli/network_manager_linux.go +++ b/cmd/cli/network_manager_linux.go @@ -28,61 +28,61 @@ func hasNetworkManager() bool { return exe != "" } -func setupNetworkManager() error { +func (p *prog) setupNetworkManager() error { if !hasNetworkManager() { return nil } if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent { - mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do") + p.Debug().Msg("NetworkManager already setup, nothing to do") return nil } err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644)) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file") + p.Debug().Err(err).Msg("could not write NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("setup NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("setup NetworkManager done") return nil } -func restoreNetworkManager() error { +func (p *prog) restoreNetworkManager() error { if !hasNetworkManager() { return nil } err := os.Remove(networkManagerCtrldConfFile) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file") + p.Debug().Err(err).Msg("could not remove NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("restore NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("restore NetworkManager done") return nil } -func reloadNetworkManager() { +func (p *prog) reloadNetworkManager() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() conn, err := dbus.NewSystemConnectionContext(ctx) if err != nil { - mainLog.Load().Error().Err(err).Msg("could not create new system connection") + p.Error().Err(err).Msg("could not create new system connection") return } defer conn.Close() waitCh := make(chan string) if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil { - mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager") + p.Debug().Err(err).Msg("could not reload NetworkManager") return } <-waitCh diff --git a/cmd/cli/network_manager_others.go b/cmd/cli/network_manager_others.go index 323d2f2..e6e5f68 100644 --- a/cmd/cli/network_manager_others.go +++ b/cmd/cli/network_manager_others.go @@ -2,14 +2,14 @@ package cli -func setupNetworkManager() error { - reloadNetworkManager() +func (p *prog) setupNetworkManager() error { + p.reloadNetworkManager() return nil } -func restoreNetworkManager() error { - reloadNetworkManager() +func (p *prog) restoreNetworkManager() error { + p.reloadNetworkManager() return nil } -func reloadNetworkManager() {} +func (p *prog) reloadNetworkManager() {} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d85c371..0cfd3b9 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -102,6 +102,7 @@ type prog struct { apiForceReloadGroup singleflight.Group logConn net.Conn cs *controlServer + logger atomic.Pointer[ctrld.Logger] csSetDnsDone chan struct{} csSetDnsOk bool dnsWg sync.WaitGroup @@ -150,7 +151,7 @@ type prog struct { onStopped []func() } -func (p *prog) Start(s service.Service) error { +func (p *prog) Start(_ service.Service) error { go p.runWait() return nil } @@ -164,7 +165,6 @@ func (p *prog) runWait() { notifyReloadSigCh(reloadSigCh) reload := false - logger := mainLog.Load() for { reloadCh := make(chan struct{}) done := make(chan struct{}) @@ -177,9 +177,9 @@ func (p *prog) runWait() { var newCfg *ctrld.Config select { case sig := <-reloadSigCh: - logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + p.Notice().Msgf("got signal: %s, reloading...", sig.String()) case <-p.reloadCh: - logger.Notice().Msg("reloading...") + p.Notice().Msg("reloading...") case apiCfg := <-p.apiReloadCh: newCfg = apiCfg case <-p.stopCh: @@ -202,18 +202,18 @@ func (p *prog) runWait() { } v.SetConfigFile(confFile) if err := v.ReadInConfig(); err != nil { - logger.Err(err).Msg("could not read new config") + p.Error().Err(err).Msg("could not read new config") waitOldRunDone() continue } if err := v.Unmarshal(&newCfg); err != nil { - logger.Err(err).Msg("could not unmarshal new config") + p.Error().Err(err).Msg("could not unmarshal new config") waitOldRunDone() continue } if cdUID != "" { if rc, err := processCDFlags(newCfg); err != nil { - logger.Err(err).Msg("could not fetch ControlD config") + p.Error().Err(err).Msg("could not fetch ControlD config") waitOldRunDone() continue } else { @@ -243,25 +243,25 @@ func (p *prog) runWait() { } } if err := validateConfig(newCfg); err != nil { - logger.Err(err).Msg("invalid config") + p.Error().Err(err).Msg("invalid config") continue } addExtraSplitDnsRule(newCfg) if err := writeConfigFile(newCfg); err != nil { - logger.Err(err).Msg("could not write new config") + p.Error().Err(err).Msg("could not write new config") } // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. - mainLog.Load().Debug().Msg("setup upstream with new config") + p.Debug().Msg("setup upstream with new config") p.setupUpstream(newCfg) p.mu.Lock() *p.cfg = *newCfg p.mu.Unlock() - logger.Notice().Msg("reloading config successfully") + p.Notice().Msg("reloading config successfully") select { case p.reloadDoneCh <- struct{}{}: @@ -276,6 +276,7 @@ func (p *prog) preRun() { p.requiredMultiNICsConfig = requiredMultiNICsConfig() } p.runningIface = iface + p.logger.Store(mainLog.Load()) } func (p *prog) postRun() { @@ -283,11 +284,11 @@ func (p *prog) postRun() { if runtime.GOOS == "windows" { isDC, roleInt := isRunningOnDomainController() p.runningOnDomainController = isDC - mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) + p.Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) } p.resetDNS(false, false) - ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false) - mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), false) + p.Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} close(p.csSetDnsDone) @@ -304,14 +305,14 @@ func (p *prog) apiConfigReload() { ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second) defer ticker.Stop() - logger := mainLog.Load().With().Str("mode", "api-reload").Logger() + logger := p.logger.Load().With().Str("mode", "api-reload").Logger() logger.Debug().Msg("starting custom config reload timer") lastUpdated := time.Now().Unix() curVerStr := curVersion() curVer, err := semver.NewVersion(curVerStr) isStable := curVer != nil && curVer.Prerelease() == "" if err != nil || !isStable { - l := mainLog.Load().Warn() + l := p.Warn() if err != nil { l = l.Err(err) } @@ -319,7 +320,7 @@ func (p *prog) apiConfigReload() { } doReloadApiConfig := func(forced bool, logger zerolog.Logger) { - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) selfUninstallCheck(err, p, logger) if err != nil { @@ -405,20 +406,20 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) isControlDUpstream := false - loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for n := range cfg.Upstream { uc := cfg.Upstream[n] sdns := uc.Type == ctrld.ResolverTypeSDNS uc.Init(loggerCtx) if sdns { - mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) + p.Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) } isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { - uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load())) - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), p.logger.Load())) + p.Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + p.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) } uc.SetCertPool(rootCertPool) go uc.Ping(loggerCtx) @@ -459,9 +460,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.csSetDnsDone = make(chan struct{}, 1) p.registerControlServerHandler() if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + p.Warn().Err(err).Msg("could not start control server") } - mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr) + p.Debug().Msgf("control server started: %s", p.cs.addr) } } p.onStartedDone = make(chan struct{}) @@ -473,7 +474,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled") + p.Error().Err(err).Msg("failed to create cacher, caching is disabled") } else { p.cache = cacher p.cacheFlushDomainsMap = make(map[string]struct{}, 256) @@ -483,7 +484,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() { - mainLog.Load().Debug().Msgf("active directory domain: %s", domain) + p.Debug().Msgf("active directory domain: %s", domain) p.adDomain = domain } @@ -494,14 +495,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { - mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr") + p.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr") continue } nc.IPNets = append(nc.IPNets, ipNet) } } - p.um = newUpstreamMonitor(p.cfg) + p.um = newUpstreamMonitor(p.cfg, p.logger.Load()) if !reload { p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} @@ -514,7 +515,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.setupClientInfoDiscover(defaultRouteIP()) + p.setupClientInfoDiscover() } // context for managing spawn goroutines. @@ -538,14 +539,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { listenerConfig := p.cfg.Listener[listenerNum] upstreamConfig := p.cfg.Upstream[listenerNum] if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + p.Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) + p.Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) if err := p.serveDNS(ctx, listenerNum); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + p.Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } - mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) + p.Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { @@ -602,10 +603,11 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } // setupClientInfoDiscover performs necessary works for running client info discover. -func (p *prog) setupClientInfoDiscover(selfIP string) { - p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load()) +func (p *prog) setupClientInfoDiscover() { + selfIP := p.defaultRouteIP() + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, p.logger.Load()) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + p.Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } @@ -622,18 +624,18 @@ func (p *prog) metricsEnabled() bool { return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != "" } -func (p *prog) Stop(s service.Service) error { +func (p *prog) Stop(_ service.Service) error { p.stopDnsWatchers() - mainLog.Load().Debug().Msg("dns watchers stopped") + p.Debug().Msg("dns watchers stopped") for _, f := range p.onStopped { f() } - mainLog.Load().Debug().Msg("finish running onStopped functions") + p.Debug().Msg("finish running onStopped functions") defer func() { - mainLog.Load().Info().Msg("Service stopped") + p.Info().Msg("Service stopped") }() if err := p.deAllocateIP(); err != nil { - mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") + p.Error().Err(err).Msg("de-allocate ip failed") return err } if deactivationPinSet() { @@ -645,16 +647,16 @@ func (p *prog) Stop(s service.Service) error { // No valid pin code was checked, that mean we are stopping // because of OS signal sent directly from someone else. // In this case, restarting ctrld service by ourselves. - mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code") - mainLog.Load().Debug().Msgf("self restarting ctrld service") + p.Debug().Msgf("receiving stopping signal without valid pin code") + p.Debug().Msgf("self restarting ctrld service") if exe, err := os.Executable(); err == nil { cmd := exec.Command(exe, "restart") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to run self restart command") + p.Error().Err(err).Msg("failed to run self restart command") } } else { - mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service") + p.Error().Err(err).Msg("failed to self restart ctrld service") } os.Exit(deactivationPinInvalidExitCode) } @@ -755,7 +757,7 @@ func (p *prog) setDNS() { p.dnsWg.Add(1) go func() { defer p.dnsWg.Done() - p.watchResolvConf(netIface, servers, setResolvConf) + p.watchResolvConf(netIface, servers, p.setResolvConf) }() } if p.dnsWatchdogEnabled() { @@ -772,7 +774,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface).Logger() const maxDNSRetryAttempts = 3 const retryDelay = 1 * time.Second @@ -785,10 +787,10 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In } if attempt < maxDNSRetryAttempts { // Try to find a different working interface - newIface := findWorkingInterface(p.runningIface) + newIface := p.findWorkingInterface() if newIface != p.runningIface { p.runningIface = newIface - logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger = p.logger.Load().With().Str("iface", p.runningIface).Logger() logger.Info().Msg("switched to new interface") continue } @@ -800,7 +802,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In logger.Error().Err(err).Msg("could not get interface after all attempts") return } - if err := setupNetworkManager(); err != nil { + if err := p.setupNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not patch NetworkManager") return } @@ -840,7 +842,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { return } - mainLog.Load().Debug().Msg("start DNS settings watchdog") + p.Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) @@ -851,19 +853,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msg("stop dns watchdog") + p.Debug().Msg("stop dns watchdog") return case <-ticker.C: if p.recoveryRunning.Load() { return } - if dnsChanged(iface, ns) { - mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings") + if p.dnsChanged(iface, ns) { + p.Debug().Msg("DNS settings were changed, re-applying settings") // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(iface) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { @@ -873,12 +875,12 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(iface); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) } } } if err := setDNS(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") } } if p.requiredMultiNICsConfig { @@ -887,13 +889,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { ifaceName = iface.Name } withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error { - if dnsChanged(i, ns) { + if p.dnsChanged(i, ns) { // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { @@ -903,15 +905,15 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(i); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) } } } if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { - mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") } else { - mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) + p.Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) } } return nil @@ -941,17 +943,17 @@ func (p *prog) resetDNS(isStart bool, restoreStatic bool) { // Otherwise, we restore the saved configuration (if any) or reset to DHCP. func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) { if p.runningIface == "" { - mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") + p.Debug().Msg("no running interface, skipping resetDNS") return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface).Logger() netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return } runningIface = netIface - if err := restoreNetworkManager(); err != nil { + if err := p.restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return } @@ -999,16 +1001,16 @@ func (p *prog) logInterfacesState() { withEachPhysicalInterfaces("", "", func(i *net.Interface) error { addrs, err := i.Addrs() if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") + p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") } nss, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") + p.Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") } if len(nss) == 0 { nss = currentDNS(i) } - mainLog.Load().Debug(). + p.Debug(). Any("addrs", addrs). Strs("nameservers", nss). Int("index", i.Index). @@ -1018,7 +1020,8 @@ func (p *prog) logInterfacesState() { } // findWorkingInterface looks for a network interface with a valid IP configuration -func findWorkingInterface(currentIface string) string { +func (p *prog) findWorkingInterface() string { + currentIface := p.runningIface // Helper to check if IP is valid (not link-local) isValidIP := func(ip net.IP) bool { return ip != nil && @@ -1036,7 +1039,7 @@ func findWorkingInterface(currentIface string) string { addrs, err := iface.Addrs() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Str("interface", iface.Name). Err(err). Msg("failed to get interface addresses") @@ -1057,11 +1060,11 @@ func findWorkingInterface(currentIface string) string { // Get default route interface defaultRoute, err := netmon.DefaultRoute() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Msg("failed to get default route") } else { - mainLog.Load().Debug(). + p.Debug(). Str("default_route_iface", defaultRoute.InterfaceName). Msg("found default route") } @@ -1069,7 +1072,7 @@ func findWorkingInterface(currentIface string) string { // Get all interfaces ifaces, err := net.Interfaces() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + p.Error().Err(err).Msg("failed to list network interfaces") return currentIface // Return current interface as fallback } @@ -1099,7 +1102,7 @@ func findWorkingInterface(currentIface string) string { // Found working physical interface if err == nil && defaultRoute.InterfaceName == iface.Name { // Found interface with default route - use it immediately - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", iface.Name). Msg("switching to interface with default route") @@ -1120,7 +1123,7 @@ func findWorkingInterface(currentIface string) string { // Return interfaces in order of preference: // 1. Current interface if it's still valid if currentIfaceValid { - mainLog.Load().Debug(). + p.Debug(). Str("interface", currentIface). Msg("keeping current interface") return currentIface @@ -1128,7 +1131,7 @@ func findWorkingInterface(currentIface string) string { // 2. First working interface found if firstWorkingIface != "" { - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", firstWorkingIface). Msg("switching to first working physical interface") @@ -1136,7 +1139,7 @@ func findWorkingInterface(currentIface string) string { } // 3. Fall back to current interface if nothing else works - mainLog.Load().Warn(). + p.Warn(). Str("current_iface", currentIface). Msg("no working physical interface found, keeping current") return currentIface @@ -1258,7 +1261,7 @@ func ifaceFirstPrivateIP(iface *net.Interface) string { } // defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6. -func defaultRouteIP() string { +func (p *prog) defaultRouteIP() string { dr, err := netmon.DefaultRoute() if err != nil { return "" @@ -1267,9 +1270,9 @@ func defaultRouteIP() string { if err != nil { return "" } - mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") + p.Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") if ip := ifaceFirstPrivateIP(drNetIface); ip != "" { - mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface") + p.Debug().Str("ip", ip).Msg("found ip with default route interface") return ip } @@ -1294,7 +1297,7 @@ func defaultRouteIP() string { }) if len(addrs) == 0 { - mainLog.Load().Warn().Msg("no default route IP found") + p.Warn().Msg("no default route IP found") return "" } sort.Slice(addrs, func(i, j int) bool { @@ -1302,7 +1305,7 @@ func defaultRouteIP() string { }) ip := addrs[0].String() - mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") + p.Debug().Str("ip", ip).Msg("found LAN interface IP") return ip } @@ -1413,14 +1416,14 @@ func saveCurrentStaticDNS(iface *net.Interface) error { // It returns false for a nil iface. // // The caller must sort the nameservers before calling this function. -func dnsChanged(iface *net.Interface, nameservers []string) bool { +func (p *prog) dnsChanged(iface *net.Interface, nameservers []string) bool { if iface == nil { return false } curNameservers, _ := currentStaticDNS(iface) slices.Sort(curNameservers) if !slices.Equal(curNameservers, nameservers) { - mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) + p.Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) return true } return false @@ -1465,16 +1468,16 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") + logger.Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") return } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") + logger.Error().Err(err).Msg("failed to start self-upgrade") return } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts) + logger.Debug().Msgf("self-upgrade triggered, version target: %s", vts) } // leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow diff --git a/cmd/cli/prog_log.go b/cmd/cli/prog_log.go new file mode 100644 index 0000000..dec20e9 --- /dev/null +++ b/cmd/cli/prog_log.go @@ -0,0 +1,33 @@ +package cli + +import "github.com/rs/zerolog" + +// Debug starts a new message with debug level. +func (p *prog) Debug() *zerolog.Event { + return p.logger.Load().Debug() +} + +// Warn starts a new message with warn level. +func (p *prog) Warn() *zerolog.Event { + return p.logger.Load().Warn() +} + +// Info starts a new message with info level. +func (p *prog) Info() *zerolog.Event { + return p.logger.Load().Info() +} + +// Fatal starts a new message with fatal level. +func (p *prog) Fatal() *zerolog.Event { + return p.logger.Load().Fatal() +} + +// Error starts a new message with error level. +func (p *prog) Error() *zerolog.Event { + return p.logger.Load().Error() +} + +// Notice starts a new message with notice level. +func (p *prog) Notice() *zerolog.Event { + return p.logger.Load().Notice() +} diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 0f3f731..587841d 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -43,10 +43,10 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" { resolvConfPath = rp } - mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath) + p.Debug().Msgf("start watching %s file", resolvConfPath) watcher, err := fsnotify.NewWatcher() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") + p.Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") return } defer watcher.Close() @@ -55,7 +55,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well watchDir := filepath.Dir(resolvConfPath) if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) + p.Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) return } @@ -64,7 +64,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) + p.Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: if p.recoveryRunning.Load() { @@ -77,7 +77,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") // Convert expected nameservers to strings for comparison expectedNS := make([]string, len(ns)) @@ -92,7 +92,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f for retry := 0; retry < maxRetries; retry++ { foundNS, err = p.parseResolvConfNameservers(resolvConfPath) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + p.Error().Err(err).Msg("failed to read resolv.conf content") break } @@ -103,7 +103,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // Only retry if we found no nameservers if retry < maxRetries-1 { - mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) select { case <-p.stopCh: return @@ -113,7 +113,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } } else { - mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + p.Debug().Msg("resolv.conf remained empty after all retries") } } @@ -130,7 +130,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f } } - mainLog.Load().Debug(). + p.Debug(). Strs("found", foundNS). Strs("expected", expectedNS). Bool("matches", matches). @@ -139,16 +139,16 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // Only revert if the nameservers don't match if !matches { if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + p.Error().Err(err).Msg("failed to pause watcher") continue } if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + p.Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") } if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + p.Error().Err(err).Msg("failed to continue running watcher") return } } @@ -158,7 +158,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f if !ok { return } - mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf") + p.Error().Err(err).Msg("could not get event for /etc/resolv.conf") } } } diff --git a/cmd/cli/resolvconf_darwin.go b/cmd/cli/resolvconf_darwin.go index eb70eed..05c7017 100644 --- a/cmd/cli/resolvconf_darwin.go +++ b/cmd/cli/resolvconf_darwin.go @@ -12,7 +12,7 @@ import ( const resolvConfPath = "/etc/resolv.conf" // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { servers := make([]string, len(ns)) for i := range ns { servers[i] = ns[i].String() diff --git a/cmd/cli/resolvconf_not_darwin_unix.go b/cmd/cli/resolvconf_not_darwin_unix.go index af33572..8838dc2 100644 --- a/cmd/cli/resolvconf_not_darwin_unix.go +++ b/cmd/cli/resolvconf_not_darwin_unix.go @@ -14,7 +14,7 @@ import ( ) // setResolvConf sets the content of the resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { r, err := newLoopbackOSConfigurator() if err != nil { return err @@ -27,7 +27,7 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error { if sds, err := searchDomains(); err == nil { oc.SearchDomains = sds } else { - mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") + p.Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") } return r.SetDNS(oc) } diff --git a/cmd/cli/resolvconf_windows.go b/cmd/cli/resolvconf_windows.go index 3e4ba1c..20a984f 100644 --- a/cmd/cli/resolvconf_windows.go +++ b/cmd/cli/resolvconf_windows.go @@ -6,7 +6,7 @@ import ( ) // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(_ *net.Interface, _ []netip.Addr) error { +func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error { return nil } diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 6e19e38..426886e 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -2,6 +2,7 @@ package cli import ( "sync" + "sync/atomic" "time" "github.com/Control-D-Inc/ctrld" @@ -16,7 +17,8 @@ const ( // upstreamMonitor performs monitoring upstreams health. type upstreamMonitor struct { - cfg *ctrld.Config + cfg *ctrld.Config + logger atomic.Pointer[ctrld.Logger] mu sync.RWMutex checking map[string]bool @@ -28,7 +30,7 @@ type upstreamMonitor struct { failureTimerActive map[string]bool } -func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { +func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor { um := &upstreamMonitor{ cfg: cfg, checking: make(map[string]bool), @@ -37,6 +39,7 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { recovered: make(map[string]bool), failureTimerActive: make(map[string]bool), } + um.logger.Store(logger) for n := range cfg.Upstream { upstream := upstreamPrefix + n um.reset(upstream) @@ -53,7 +56,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { defer um.mu.Unlock() if um.recovered[upstream] { - mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) + um.logger.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) return } @@ -61,7 +64,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { failedCount := um.failureReq[upstream] // Log the updated failure count. - mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + um.logger.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) // If this is the first failure and no timer is running, start a 10-second timer. if failedCount == 1 && !um.failureTimerActive[upstream] { @@ -74,7 +77,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // and the upstream is not in a recovered state, mark it as down. if um.failureReq[upstream] > 0 && !um.recovered[upstream] { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) + um.logger.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) } // Reset the timer flag so that a new timer can be spawned if needed. um.failureTimerActive[upstream] = false @@ -84,7 +87,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // If the failure count quickly reaches the threshold, mark the upstream as down immediately. if failedCount >= maxFailureRequest { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) + um.logger.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) } }