From fc527dbdfb94e6c68e9da997583ec3785a1b517b Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 3 Apr 2025 21:17:02 +0700 Subject: [PATCH] all: eliminate usage of global ProxyLogger So setting up logging for ctrld binary and ctrld packages could be done more easily, decouple the required setup for interactive vs daemon running. This is the first step toward replacing rs/zerolog libary with a different logging library. --- cmd/cli/cli.go | 32 ++--- cmd/cli/control_server.go | 6 +- cmd/cli/dns_proxy.go | 31 ++--- cmd/cli/log_writer.go | 3 +- cmd/cli/loop.go | 3 +- cmd/cli/main.go | 14 +-- cmd/cli/main_test.go | 4 +- cmd/cli/netlink_linux.go | 4 +- cmd/cli/prog.go | 16 +-- config.go | 103 +++++++++-------- config_internal_test.go | 11 +- config_quic.go | 29 ++--- doh.go | 12 +- doh_test.go | 9 +- doq.go | 2 +- dot.go | 2 +- internal/clientinfo/client_info.go | 54 +++++---- internal/clientinfo/client_info_test.go | 7 +- internal/clientinfo/dhcp.go | 20 ++-- internal/clientinfo/hostsfile.go | 7 +- internal/clientinfo/mdns.go | 23 ++-- internal/clientinfo/mdns_test.go | 4 +- internal/clientinfo/merlin.go | 3 +- internal/clientinfo/ndp.go | 13 ++- internal/clientinfo/ndp_linux.go | 10 +- internal/clientinfo/ndp_others.go | 6 +- internal/clientinfo/ptr_lookup.go | 5 +- internal/controld/config.go | 39 ++++--- log.go | 34 ++++-- nameservers.go | 8 +- nameservers_bsd.go | 3 +- nameservers_darwin.go | 8 +- nameservers_linux.go | 7 +- nameservers_test.go | 7 +- nameservers_unix.go | 3 +- nameservers_windows.go | 148 +++++++++--------------- net.go | 18 +-- resolver.go | 135 ++++++++++----------- resolver_test.go | 2 +- 39 files changed, 425 insertions(+), 420 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 3caa3bb..5005925 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -349,7 +349,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { // After processCDFlags, log config may change, so reset mainLog and re-init logging. l := zerolog.New(io.Discard) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { @@ -502,8 +502,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - nop := zerolog.Nop() - _, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true) + _, _ = tryUpdateListenerConfig(&cfg, func() {}, true) addExtraSplitDnsRule(&cfg) if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) @@ -591,7 +590,8 @@ func processNoConfigFlags(noConfigStart bool) { Type: pType, Timeout: 5000, } - puc.Init() + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + puc.Init(loggerCtx) upstream := map[string]*ctrld.UpstreamConfig{"0": puc} if secondaryUpstream != "" { sEndpoint, sType := endpointAndTyp(secondaryUpstream) @@ -601,7 +601,7 @@ func processNoConfigFlags(noConfigStart bool) { Type: sType, Timeout: 5000, } - suc.Init() + suc.Init(loggerCtx) upstream["1"] = suc rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { @@ -634,13 +634,13 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) bo.LogLongerThan = 30 * time.Second - ctx := context.Background() - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + ctx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) for { if errUrlNetworkError(err) { bo.BackOff(ctx, err) logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...") - resolverConfig, err = controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) continue } break @@ -938,9 +938,10 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } mainLog.Load().Debug().Msgf("self-check against %q failed", domain) + loggerCtx := ctrld.LoggerCtx(ctx, mainLog.Load()) // Ping all upstreams to provide better error message to users. for name, uc := range cfg.Upstream { - if err := uc.ErrorPing(); err != nil { + if err := uc.ErrorPing(loggerCtx); err != nil { mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } @@ -1181,7 +1182,7 @@ func mobileListenerIp() string { // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool { - updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true) + updated, _ := tryUpdateListenerConfig(cfg, notifyToLogServerFunc, true) if addExtraSplitDnsRule(cfg) { updated = true } @@ -1191,7 +1192,7 @@ func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. -func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) { +func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" @@ -1235,9 +1236,6 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } il := mainLog.Load() - if infoLogger != nil { - il = infoLogger - } if isMobile() { // On Mobile, only use first listener, ignore others. firstLn := cfg.FirstListener() @@ -1492,7 +1490,8 @@ func cdUIDFromProvToken() string { } req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. - resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, rootCmd.Version, cdDev) if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) } @@ -1819,7 +1818,8 @@ func runningIface(s service.Service) *ifaceResponse { // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { - rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) if err != nil { logger := mainLog.Load().Fatal() if !fatal { diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 9281b90..428fe12 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -216,8 +216,9 @@ func (p *prog) registerControlServerHandler() { return } + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) // Re-fetch pin code from API. - if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil { + if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil { if rc.DeactivationPin != nil { cdDeactivationPin.Store(*rc.DeactivationPin) } else { @@ -321,7 +322,8 @@ func (p *prog) registerControlServerHandler() { } mainLog.Load().Debug().Msg("sending log file to ControlD server") resp := logSentResponse{Size: r.size} - if err := controld.SendLogs(req, cdDev); err != nil { + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil { mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) resp.Error = err.Error() w.WriteHeader(http.StatusInternalServerError) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 33012fa..a3d9970 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -110,6 +110,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) + ctx = ctrld.LoggerCtx(ctx, mainLog.Load()) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) answer := new(dns.Msg) @@ -514,7 +515,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) - dnsResolver, err := ctrld.NewResolver(upstreamConfig) + dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err @@ -549,11 +550,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error if errors.As(err, &e) && e.Timeout() { - upstreamConfig.ReBootstrap() + upstreamConfig.ReBootstrap(ctx) } // For network error, turn ipv6 off if enabled. - if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) { - ctrld.DisableIPv6() + if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) { + ctrld.DisableIPv6(ctx) } } @@ -960,7 +961,8 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true - _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + _, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) @@ -1326,13 +1328,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Only set the IPv4 default if selfIP is a valid IPv4 address. if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { - ctrld.SetDefaultLocalIPv4(ip) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } } if ip := net.ParseIP(ipv6); ip != nil { - ctrld.SetDefaultLocalIPv6(ip) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) } mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) @@ -1400,7 +1402,7 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc) if err != nil { mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) return err @@ -1418,7 +1420,7 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - uc.ReBootstrap() + uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() @@ -1474,10 +1476,11 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // will be appended to nameservers from the saved interface values p.resetDNS(false, false) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") } else { @@ -1504,7 +1507,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // For network changes we also reinitialize the OS resolver. if reason == RecoveryReasonNetworkChange { - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") } else { @@ -1564,7 +1567,7 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string // we should try to reinit the OS resolver to ensure we can recover if name == upstreamOS && attempts%3 == 0 { mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") } else { @@ -1624,12 +1627,12 @@ func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) { // Check if the default IPv4 is still active. if currentIPv4 != nil && !activeIPs[currentIPv4.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4) - ctrld.SetDefaultLocalIPv4(nil) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } // Check if the default IPv6 is still active. if currentIPv6 != nil && !activeIPs[currentIPv6.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6) - ctrld.SetDefaultLocalIPv6(nil) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } } diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index ab6b855..0ba2c8c 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -137,8 +137,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) { }) multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - ctrld.ProxyLogger.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) } // needInternalLogging reports whether prog needs to run internal logging. diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 3504bc3..434a4a5 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -102,6 +102,7 @@ func (p *prog) checkDnsLoop() { } p.loopMu.Unlock() + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] @@ -109,7 +110,7 @@ func (p *prog) checkDnsLoop() { if uc == nil { continue } - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 6a8cb62..53b8309 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -40,7 +40,7 @@ var ( cleanup bool startOnly bool - mainLog atomic.Pointer[zerolog.Logger] + mainLog atomic.Pointer[ctrld.Logger] consoleWriter zerolog.ConsoleWriter noConfigStart bool ) @@ -54,7 +54,7 @@ const ( func init() { l := zerolog.New(io.Discard) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) } func Main() { @@ -87,16 +87,14 @@ func initConsoleLogging() { }) multi := zerolog.MultiLevelWriter(consoleWriter) l := mainLog.Load().Output(multi).With().Timestamp().Logger() - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) switch { case silent: zerolog.SetGlobalLevel(zerolog.NoLevel) case verbose == 1: - ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.InfoLevel) case verbose > 1: - ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.DebugLevel) default: zerolog.SetGlobalLevel(zerolog.NoticeLevel) @@ -113,8 +111,6 @@ func initInteractiveLogging() { zerolog.TimeFieldFormat = time.RFC3339 + ".000" initLoggingWithBackup(false) cfg.Service.LogPath = old - l := zerolog.New(io.Discard) - ctrld.ProxyLogger.Store(&l) } // initLoggingWithBackup initializes log setup base on current config. @@ -153,9 +149,7 @@ func initLoggingWithBackup(doBackup bool) []io.Writer { writers = append(writers, consoleWriter) multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - // TODO: find a better way. - ctrld.ProxyLogger.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) zerolog.SetGlobalLevel(zerolog.NoticeLevel) logLevel := cfg.Service.LogLevel diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index 6ed26c7..c7b8b17 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -6,12 +6,14 @@ import ( "testing" "github.com/rs/zerolog" + + "github.com/Control-D-Inc/ctrld" ) var logOutput strings.Builder func TestMain(m *testing.M) { l := zerolog.New(&logOutput) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) os.Exit(m.Run()) } diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index d757f8b..f4e9bda 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -5,6 +5,8 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + + "github.com/Control-D-Inc/ctrld" ) func (p *prog) watchLinkState(ctx context.Context) { @@ -26,7 +28,7 @@ func (p *prog) watchLinkState(ctx context.Context) { if lu.Change&unix.IFF_UP != 0 { mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) } } } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 3b159ee..d85c371 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -286,7 +286,7 @@ func (p *prog) postRun() { mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) } p.resetDNS(false, false) - ns := ctrld.InitializeOsResolver(false) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} @@ -319,7 +319,8 @@ func (p *prog) apiConfigReload() { } doReloadApiConfig := func(forced bool, logger zerolog.Logger) { - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) selfUninstallCheck(err, p, logger) if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") @@ -377,7 +378,7 @@ func (p *prog) apiConfigReload() { } if cfgErr != nil { logger.Warn().Err(err).Msg("skipping invalid custom config") - if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { + if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, rootCmd.Version, cdDev, true); err != nil { logger.Error().Err(err).Msg("could not mark custom last update failed") } return @@ -404,22 +405,23 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) isControlDUpstream := false + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) for n := range cfg.Upstream { uc := cfg.Upstream[n] sdns := uc.Type == ctrld.ResolverTypeSDNS - uc.Init() + uc.Init(loggerCtx) if sdns { mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) } isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() + uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load())) mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) } else { mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) } uc.SetCertPool(rootCertPool) - go uc.Ping() + go uc.Ping(loggerCtx) if canBeLocalUpstream(uc.Domain) { localUpstreams = append(localUpstreams, upstreamPrefix+n) @@ -601,7 +603,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { // setupClientInfoDiscover performs necessary works for running client info discover. func (p *prog) setupClientInfoDiscover(selfIP string) { - p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load()) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) diff --git a/config.go b/config.go index 96f6686..4aadff1 100644 --- a/config.go +++ b/config.go @@ -325,12 +325,13 @@ type ListenerPolicyConfig struct { type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. -func (uc *UpstreamConfig) Init() { +func (uc *UpstreamConfig) Init(ctx context.Context) { + logger := LoggerFromCtx(ctx) if err := uc.initDnsStamps(); err != nil { - ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps") + logger.Fatal().Err(err).Msg("invalid DNS Stamps") } uc.initDoHScheme() - uc.uid = upstreamUID() + uc.uid = upstreamUID(ctx) if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Hostname() switch uc.Type { @@ -434,12 +435,13 @@ func (uc *UpstreamConfig) UID() string { // - ControlD Bootstrap DNS 76.76.2.22 // // The setup process will block until there's usable IPs found. -func (uc *UpstreamConfig) SetupBootstrapIP() { +func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) isControlD := uc.IsControlD() - nss := initDefaultOsResolver() + logger := LoggerFromCtx(ctx) + nss := initDefaultOsResolver(ctx) for { - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, // filtering them out here to prevent weird behavior. if isControlD { @@ -454,18 +456,18 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs = uc.bootstrapIPs[:n] if len(uc.bootstrapIPs) == 0 { uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain) - ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain) + logger.Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain) } } if len(uc.bootstrapIPs) == 0 { - ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) + logger.Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) } if len(uc.bootstrapIPs) > 0 { break } - ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...") + logger.Warn().Msg("could not resolve bootstrap IPs, retrying...") b.BackOff(context.Background(), errors.New("no bootstrap IPs")) } for _, ip := range uc.bootstrapIPs { @@ -475,11 +477,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip) } } - ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs) + logger.Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs) } // ReBootstrap re-setup the bootstrap IP and the transport. -func (uc *UpstreamConfig) ReBootstrap() { +func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -487,7 +489,8 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil }) @@ -495,35 +498,35 @@ func (uc *UpstreamConfig) ReBootstrap() { // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. -func (uc *UpstreamConfig) SetupTransport() { +func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { switch uc.Type { case ResolverTypeDOH: - uc.setupDOHTransport() + uc.setupDOHTransport(ctx) case ResolverTypeDOH3: - uc.setupDOH3Transport() + uc.setupDOH3Transport(ctx) } } -func (uc *UpstreamConfig) setupDOHTransport() { +func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) - if HasIPv6() { - uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 } - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { +func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxIdleConnsPerHost = 100 transport.TLSClientConfig = &tls.Config{ @@ -543,12 +546,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { dialerTimeoutMs = uc.Timeout } dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond + logger := LoggerFromCtx(ctx) transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { _, port, _ := net.SplitHostPort(addr) if uc.BootstrapIP != "" { dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout} addr := net.JoinHostPort(uc.BootstrapIP, port) - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr) + logger.Debug().Msgf("sending doh request to: %s", addr) return dialer.DialContext(ctx, network, addr) } pd := &ctrldnet.ParallelDialer{} @@ -558,11 +562,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } - conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load()) + conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger) if err != nil { return nil, err } - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr()) + logger.Debug().Msgf("sending doh request to: %s", conn.RemoteAddr()) return conn, nil } runtime.SetFinalizer(transport, func(transport *http.Transport) { @@ -572,19 +576,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { } // Ping warms up the connection to DoH/DoH3 upstream. -func (uc *UpstreamConfig) Ping() { - if err := uc.ping(); err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) - _ = uc.FallbackToDirectIP() +func (uc *UpstreamConfig) Ping(ctx context.Context) { + if err := uc.ping(ctx); err != nil { + logger := LoggerFromCtx(ctx) + logger.Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) + _ = uc.FallbackToDirectIP(ctx) } } // ErrorPing is like Ping, but return an error if any. -func (uc *UpstreamConfig) ErrorPing() error { - return uc.ping() +func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error { + return uc.ping(ctx) } -func (uc *UpstreamConfig) ping() error { +func (uc *UpstreamConfig) ping(ctx context.Context) error { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -613,11 +618,11 @@ func (uc *UpstreamConfig) ping() error { for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { switch uc.Type { case ResolverTypeDOH: - if err := ping(uc.dohTransport(typ)); err != nil { + if err := ping(uc.dohTransport(ctx, typ)); err != nil { return err } case ResolverTypeDOH3: - if err := ping(uc.doh3Transport(typ)); err != nil { + if err := ping(uc.doh3Transport(ctx, typ)); err != nil { return err } } @@ -652,12 +657,12 @@ func (uc *UpstreamConfig) isNextDNS() bool { return domain == "dns.nextdns.io" } -func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: @@ -673,7 +678,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { return uc.transport } -func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { +func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string { switch uc.IPStack { case IpStackBoth: return pick(uc.bootstrapIPs) @@ -686,7 +691,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { case dns.TypeA: return pick(uc.bootstrapIPs4) default: - if HasIPv6() { + if HasIPv6(ctx) { return pick(uc.bootstrapIPs6) } return pick(uc.bootstrapIPs4) @@ -695,7 +700,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { return pick(uc.bootstrapIPs) } -func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { +func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) { switch uc.IPStack { case IpStackBoth: return "tcp-tls", "udp" @@ -708,7 +713,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { case dns.TypeA: return "tcp4-tls", "udp4" default: - if HasIPv6() { + if HasIPv6(ctx) { return "tcp6-tls", "udp6" } return "tcp4-tls", "udp4" @@ -789,7 +794,7 @@ func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context } // FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain. -func (uc *UpstreamConfig) FallbackToDirectIP() bool { +func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool { if !uc.IsControlD() { return false } @@ -808,7 +813,8 @@ func (uc *UpstreamConfig) FallbackToDirectIP() bool { default: return } - ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) + logger := LoggerFromCtx(ctx) + logger.Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) uc.u.Host = ip done = true }) @@ -942,11 +948,12 @@ func pick(s []string) string { } // upstreamUID generates an unique identifier for an upstream. -func upstreamUID() string { +func upstreamUID(ctx context.Context) string { + logger := LoggerFromCtx(ctx) b := make([]byte, 4) for { if _, err := crand.Read(b); err != nil { - ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...") + logger.Warn().Err(err).Msg("could not generate uid for upstream, retrying...") continue } return hex.EncodeToString(b) diff --git a/config_internal_test.go b/config_internal_test.go index b37e982..0e7f3bb 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -1,6 +1,7 @@ package ctrld import ( + "context" "net/url" "testing" @@ -36,10 +37,10 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed. // t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() + tc.uc.Init(context.Background()) + tc.uc.SetupBootstrapIP(context.Background()) if len(tc.uc.bootstrapIPs) == 0 { - t.Log(defaultNameservers()) + t.Log(defaultNameservers(context.Background())) t.Fatalf("could not bootstrap ip: %s", tc.uc.String()) } }) @@ -355,7 +356,7 @@ func TestUpstreamConfig_Init(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) tc.uc.uid = "" // we don't care about the uid. assert.Equal(t, tc.expected, tc.uc) }) @@ -497,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) if got := tc.uc.IsDiscoverable(); got != tc.discoverable { t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) } diff --git a/config_quic.go b/config_quic.go index cadcb6b..8f27bf3 100644 --- a/config_quic.go +++ b/config_quic.go @@ -14,34 +14,35 @@ import ( "github.com/quic-go/quic-go/http3" ) -func (uc *UpstreamConfig) setupDOH3Transport() { +func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - if HasIPv6() { - uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) } else { uc.http3RoundTripper6 = uc.http3RoundTripper4 } - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { +func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper { rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} + logger := LoggerFromCtx(ctx) rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { addr = net.JoinHostPort(uc.BootstrapIP, port) - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr) + logger.Debug().Msgf("sending doh3 request to: %s", addr) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err @@ -61,7 +62,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { if err != nil { return nil, err } - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) + logger.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } runtime.SetFinalizer(rt, func(rt *http3.Transport) { @@ -70,12 +71,12 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { return rt } -func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: diff --git a/doh.go b/doh.go index 3459cb8..f93dc88 100644 --- a/doh.go +++ b/doh.go @@ -105,19 +105,20 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - c := http.Client{Transport: r.uc.dohTransport(dnsTyp)} + c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)} if r.isDoH3 { - transport := r.uc.doh3Transport(dnsTyp) + transport := r.uc.doh3Transport(ctx, dnsTyp) if transport == nil { return nil, errors.New("DoH3 is not supported") } c.Transport = transport } resp, err := c.Do(req) - if err != nil && r.uc.FallbackToDirectIP() { + if err != nil && r.uc.FallbackToDirectIP(ctx) { retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx)) defer cancel() - Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip") + logger := LoggerFromCtx(ctx) + logger.Warn().Err(err).Msg("retrying request after fallback to direct ip") resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { @@ -163,7 +164,8 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { } } if printed { - Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("sending request header: %v", dohHeader) } dohHeader.Set("Content-Type", headerApplicationDNS) dohHeader.Set("Accept", headerApplicationDNS) diff --git a/doh_test.go b/doh_test.go index 92fa79f..700b299 100644 --- a/doh_test.go +++ b/doh_test.go @@ -157,20 +157,21 @@ func Test_ClientCertificateVerificationError(t *testing.T) { }, } + ctx := context.Background() for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() - r, err := NewResolver(tc.uc) + tc.uc.Init(ctx) + tc.uc.SetupBootstrapIP(ctx) + r, err := NewResolver(ctx, tc.uc) if err != nil { t.Fatal(err) } msg := new(dns.Msg) msg.SetQuestion("verify.controld.com.", dns.TypeA) msg.RecursionDesired = true - _, err = r.Resolve(context.Background(), msg) + _, err = r.Resolve(ctx, msg) // Verify the error contains the expected certificate information if err == nil { t.Fatal("expected certificate verification error, got nil") diff --git a/doq.go b/doq.go index 0903411..d341668 100644 --- a/doq.go +++ b/doq.go @@ -26,7 +26,7 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - ip = r.uc.bootstrapIPForDNSType(dnsTyp) + ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp) } tlsConfig.ServerName = r.uc.Domain _, port, _ := net.SplitHostPort(endpoint) diff --git a/dot.go b/dot.go index 295134c..03c08db 100644 --- a/dot.go +++ b/dot.go @@ -23,7 +23,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) + tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, Dialer: dialer, diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index f69b670..719e205 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -79,6 +79,7 @@ type Table struct { initOnce sync.Once stopOnce sync.Once refreshInterval int + logger *ctrld.Logger dhcp *dhcp merlin *merlinDiscover @@ -98,11 +99,14 @@ type Table struct { ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrld.Logger) *Table { refreshInterval := cfg.Service.DiscoverRefreshInterval if refreshInterval <= 0 { refreshInterval = 2 * 60 // 2 minutes } + if logger == nil { + logger = ctrld.NopLogger + } return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), @@ -111,6 +115,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { cdUID: cdUID, ptrNameservers: ns, refreshInterval: refreshInterval, + logger: logger, } } @@ -179,7 +184,7 @@ func (t *Table) SetSelfIP(ip string) { // initSelfDiscover initializes necessary client metadata for self query. func (t *Table) initSelfDiscover() { - t.dhcp = &dhcp{selfIP: t.selfIP} + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} t.dhcp.addSelf() t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -189,14 +194,14 @@ func (t *Table) initSelfDiscover() { func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id") + t.logger.Debug().Msg("start self discovery with custom client id") t.initSelfDiscover() return } // If we are running on platforms that should only do self discover, use it as the only source, too. if ctrld.SelfDiscover() { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms") + t.logger.Debug().Msg("start self discovery on desktop platforms") t.initSelfDiscover() return } @@ -208,7 +213,7 @@ func (t *Table) init() { // - Merlin // - Ubios if t.discoverDHCP() || t.discoverARP() { - t.merlin = &merlinDiscover{} + t.merlin = &merlinDiscover{logger: t.logger} t.ubios = &ubiosDiscover{} discovers := map[string]interface { refresher @@ -219,7 +224,7 @@ func (t *Table) init() { } for platform, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform) + t.logger.Warn().Err(err).Msgf("failed to init %s discover", platform) } t.hostnameResolvers = append(t.hostnameResolvers, discover) t.refreshers = append(t.refreshers, discover) @@ -227,10 +232,10 @@ func (t *Table) init() { } // Hosts file mapping. if t.discoverHosts() { - t.hf = &hostsFile{} - ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery") + t.hf = &hostsFile{logger: t.logger} + t.logger.Debug().Msg("start hosts file discovery") if err := t.hf.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover") + t.logger.Error().Err(err).Msg("could not init hosts file discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.hf) t.refreshers = append(t.refreshers, t.hf) @@ -239,10 +244,10 @@ func (t *Table) init() { } // DHCP lease files. if t.discoverDHCP() { - t.dhcp = &dhcp{selfIP: t.selfIP} - ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery") + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} + t.logger.Debug().Msg("start dhcp discovery") if err := t.dhcp.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init DHCP discover") + t.logger.Error().Err(err).Msg("could not init DHCP discover") } else { t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -253,8 +258,8 @@ func (t *Table) init() { // ARP/NDP table. if t.discoverARP() { t.arp = &arpDiscover{} - t.ndp = &ndpDiscover{} - ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery") + t.ndp = &ndpDiscover{logger: t.logger} + t.logger.Debug().Msg("start arp discovery") discovers := map[string]interface { refresher IpResolver @@ -266,7 +271,7 @@ func (t *Table) init() { for protocol, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", protocol) + t.logger.Error().Err(err).Msgf("could not init %s discover", protocol) } else { t.ipResolvers = append(t.ipResolvers, discover) t.macResolvers = append(t.macResolvers, discover) @@ -283,7 +288,10 @@ func (t *Table) init() { } // PTR lookup. if t.discoverPTR() { - t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} + t.ptr = &ptrDiscover{ + resolver: ctrld.NewPrivateResolver(context.Background()), + logger: t.logger, + } if len(t.ptrNameservers) > 0 { nss := make([]string, 0, len(t.ptrNameservers)) for _, ns := range t.ptrNameservers { @@ -295,18 +303,18 @@ func (t *Table) init() { if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { nss = append(nss, net.JoinHostPort(host, port)) } else { - ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + t.logger.Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) } } if len(nss) > 0 { t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + t.logger.Debug().Msgf("using nameservers %v for ptr discovery", nss) } } - ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") + t.logger.Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover") + t.logger.Error().Err(err).Msg("could not init PTR discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.ptr) t.refreshers = append(t.refreshers, t.ptr) @@ -314,10 +322,10 @@ func (t *Table) init() { } // mdns. if t.discoverMDNS() { - t.mdns = &mdns{} - ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery") + t.mdns = &mdns{logger: t.logger} + t.logger.Debug().Msg("start mdns discovery") if err := t.mdns.init(t.quitCh); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init mDNS discover") + t.logger.Error().Err(err).Msg("could not init mDNS discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) } diff --git a/internal/clientinfo/client_info_test.go b/internal/clientinfo/client_info_test.go index b5bdfa5..7abb907 100644 --- a/internal/clientinfo/client_info_test.go +++ b/internal/clientinfo/client_info_test.go @@ -2,6 +2,8 @@ package clientinfo import ( "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_normalizeIP(t *testing.T) { @@ -28,8 +30,9 @@ func Test_normalizeIP(t *testing.T) { func TestTable_LookupRFC1918IPv4(t *testing.T) { table := &Table{ - dhcp: &dhcp{}, - arp: &arpDiscover{}, + dhcp: &dhcp{}, + arp: &arpDiscover{}, + logger: ctrld.NopLogger, } table.ipResolvers = append(table.ipResolvers, table.dhcp) diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 5d11d5e..fbd7b08 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -13,9 +13,8 @@ import ( "strings" "sync" - "tailscale.com/net/netmon" - "github.com/fsnotify/fsnotify" + "tailscale.com/net/netmon" "tailscale.com/util/lineread" "github.com/Control-D-Inc/ctrld" @@ -30,6 +29,7 @@ type dhcp struct { watcher *fsnotify.Watcher selfIP string + logger *ctrld.Logger } func (d *dhcp) init() error { @@ -52,7 +52,7 @@ func (d *dhcp) watchChanges() { } if dir := router.LeaseFilesDir(); dir != "" { if err := d.watcher.Add(dir); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir") + d.logger.Err(err).Str("dir", dir).Msg("could not watch lease dir") } } for { @@ -64,7 +64,7 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Create) { if format, ok := clientInfoFiles[event.Name]; ok { if err := d.addLeaseFile(event.Name, format); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file") + d.logger.Err(err).Str("file", event.Name).Msg("could not add lease file") } } continue @@ -72,14 +72,14 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { format := clientInfoFiles[event.Name] if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info") + d.logger.Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info") } } case err, ok := <-d.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + d.logger.Err(err).Msg("could not watch client info file") } } @@ -222,7 +222,7 @@ func (d *dhcp) dnsmasqReadClientInfoReader(reader io.Reader) error { } ip := normalizeIP(string(fields[2])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } @@ -275,7 +275,7 @@ func (d *dhcp) iscDHCPReadClientInfoReader(reader io.Reader) error { case "lease": ip = normalizeIP(strings.ToLower(fields[1])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } case "hardware": @@ -328,7 +328,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { } ip := normalizeIP(record[0]) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } @@ -350,7 +350,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { func (d *dhcp) addSelf() { hostname, err := os.Hostname() if err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not get hostname") + d.logger.Err(err).Msg("could not get hostname") return } hostname = normalizeHostname(hostname) diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index d96229d..4dc6f35 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -27,6 +27,7 @@ type hostsFile struct { watcher *fsnotify.Watcher mu sync.Mutex m map[string][]string + logger *ctrld.Logger } // init performs initialization works, which is necessary before hostsFile can be fully operated. @@ -55,7 +56,7 @@ func (hf *hostsFile) refresh() error { // override hosts file with host_entries.conf content if present. hem, err := parseHostEntriesConf(hostEntriesConfPath) if err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not read host_entries.conf file") + hf.logger.Debug().Err(err).Msg("could not read host_entries.conf file") } for k, v := range hem { hf.m[k] = v @@ -77,14 +78,14 @@ func (hf *hostsFile) watchChanges() { } if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { if err := hf.refresh(); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info") + hf.logger.Err(err).Msg("hosts file changed but failed to update client info") } } case err, ok := <-hf.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + hf.logger.Err(err).Msg("could not watch client info file") } } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index e009e01..ebdfabc 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -34,7 +34,8 @@ var ( ) type mdns struct { - name sync.Map // ip => hostname + name sync.Map // ip => hostname + logger *ctrld.Logger } func (m *mdns) LookupHostnameByIP(ip string) string { @@ -93,9 +94,9 @@ func (m *mdns) init(quitCh chan struct{}) error { } // Check if IPv6 is available once and use the result for the rest of the function. - ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init") + m.logger.Debug().Msgf("checking for IPv6 availability in mdns init") ipv6 := ctrldnet.IPv6Available(context.Background()) - ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6) + m.logger.Debug().Msgf("IPv6 is %v in mdns init", ipv6) v4ConnList := make([]*net.UDPConn, 0, len(ifaces)) v6ConnList := make([]*net.UDPConn, 0, len(ifaces)) @@ -129,11 +130,11 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan for { err := m.probe(conns, remoteAddr) if shouldStopProbing(err) { - ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err) + m.logger.Warn().Msgf("stop probing %q: %v", remoteAddr, err) break } if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("error while probing mdns") + m.logger.Warn().Err(err).Msg("error while probing mdns") bo.BackOff(context.Background(), errors.New("mdns probe backoff")) continue } @@ -161,7 +162,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if errors.Is(err, net.ErrClosed) { return } - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error") + m.logger.Debug().Err(err).Msg("mdns readLoop error") return } @@ -184,11 +185,11 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if ip != "" && name != "" { name = normalizeHostname(name) if val, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) + m.logger.Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) } else { old := val.(string) if old != name { - ctrld.ProxyLogger.Load().Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) + m.logger.Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) m.name.Store(ip, name) } } @@ -227,7 +228,7 @@ func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr) error { // getDataFromAvahiDaemonCache reads entries from avahi-daemon cache to update mdns data. func (m *mdns) getDataFromAvahiDaemonCache() { if _, err := exec.LookPath("avahi-browse"); err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not find avahi-browse binary, skipping.") + m.logger.Debug().Err(err).Msg("could not find avahi-browse binary, skipping.") return } // Run avahi-browse to discover services from cache: @@ -237,7 +238,7 @@ func (m *mdns) getDataFromAvahiDaemonCache() { // - "-c" -> read from cache. out, err := exec.Command("avahi-browse", "-a", "-r", "-p", "-c").Output() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not browse services from avahi cache") + m.logger.Debug().Err(err).Msg("could not browse services from avahi cache") return } m.storeDataFromAvahiBrowseOutput(bytes.NewReader(out)) @@ -257,7 +258,7 @@ func (m *mdns) storeDataFromAvahiBrowseOutput(r io.Reader) { name := normalizeHostname(fields[6]) // Only using cache value if we don't have existed one. if _, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip) + m.logger.Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip) } } } diff --git a/internal/clientinfo/mdns_test.go b/internal/clientinfo/mdns_test.go index e6f8698..28c23d9 100644 --- a/internal/clientinfo/mdns_test.go +++ b/internal/clientinfo/mdns_test.go @@ -3,6 +3,8 @@ package clientinfo import ( "strings" "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { @@ -11,7 +13,7 @@ func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { =;wlp0s20f3;IPv6;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" =;wlp0s20f3;IPv4;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" ` - m := &mdns{} + m := &mdns{logger: ctrld.NopLogger} m.storeDataFromAvahiBrowseOutput(strings.NewReader(content)) ip := "192.168.1.123" val, loaded := m.name.LoadOrStore(ip, "") diff --git a/internal/clientinfo/merlin.go b/internal/clientinfo/merlin.go index 8a39398..8ba6c5c 100644 --- a/internal/clientinfo/merlin.go +++ b/internal/clientinfo/merlin.go @@ -15,6 +15,7 @@ const merlinNvramCustomClientListKey = "custom_clientlist" type merlinDiscover struct { hostname sync.Map // mac => hostname + logger *ctrld.Logger } func (m *merlinDiscover) refresh() error { @@ -25,7 +26,7 @@ func (m *merlinDiscover) refresh() error { if err != nil { return err } - ctrld.ProxyLogger.Load().Debug().Msg("reading Merlin custom client list") + m.logger.Debug().Msg("reading Merlin custom client list") m.parseMerlinCustomClientList(out) return nil } diff --git a/internal/clientinfo/ndp.go b/internal/clientinfo/ndp.go index 9d9155d..87f86fe 100644 --- a/internal/clientinfo/ndp.go +++ b/internal/clientinfo/ndp.go @@ -20,8 +20,9 @@ import ( // ndpDiscover provides client discovery functionality using NDP protocol. type ndpDiscover struct { - mac sync.Map // ip => mac - ip sync.Map // mac => ip + mac sync.Map // ip => mac + ip sync.Map // mac => ip + logger *ctrld.Logger } // refresh re-scans the NDP table. @@ -97,7 +98,7 @@ func (nd *ndpDiscover) saveInfo(ip, mac string) { func (nd *ndpDiscover) listen(ctx context.Context) { ifis, err := allInterfacesWithV6LinkLocal() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces") + nd.logger.Debug().Err(err).Msg("failed to find valid ipv6 interfaces") return } for _, ifi := range ifis { @@ -110,11 +111,11 @@ func (nd *ndpDiscover) listen(ctx context.Context) { func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) { c, ip, err := ndp.Listen(ifi, ndp.Unspecified) if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed") + nd.logger.Debug().Err(err).Msg("ndp listen failed") return } defer c.Close() - ctrld.ProxyLogger.Load().Debug().Msgf("listening ndp on: %s", ip.String()) + nd.logger.Debug().Msgf("listening ndp on: %s", ip.String()) for { select { case <-ctx.Done(): @@ -128,7 +129,7 @@ func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface if errors.As(readErr, &opErr) && (opErr.Timeout() || opErr.Temporary()) { continue } - ctrld.ProxyLogger.Load().Debug().Err(readErr).Msg("ndp read loop error") + nd.logger.Debug().Err(readErr).Msg("ndp read loop error") return } diff --git a/internal/clientinfo/ndp_linux.go b/internal/clientinfo/ndp_linux.go index ebd416f..6658c78 100644 --- a/internal/clientinfo/ndp_linux.go +++ b/internal/clientinfo/ndp_linux.go @@ -5,15 +5,13 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. func (nd *ndpDiscover) scan() { neighs, err := netlink.NeighList(0, netlink.FAMILY_V6) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not get neigh list") + nd.logger.Warn().Err(err).Msg("could not get neigh list") return } @@ -34,7 +32,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.NeighSubscribe(ch, done); err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing") + nd.logger.Err(err).Msg("could not perform neighbor subscribing") return } for { @@ -47,7 +45,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { } ip := normalizeIP(nu.IP.String()) if nu.Type == unix.RTM_DELNEIGH { - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip) + nd.logger.Debug().Msgf("removing NDP neighbor: %s", ip) nd.mac.Delete(ip) continue } @@ -56,7 +54,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { case netlink.NUD_REACHABLE: nd.saveInfo(ip, mac) case netlink.NUD_FAILED: - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip) + nd.logger.Debug().Msgf("removing NDP neighbor with failed state: %s", ip) nd.mac.Delete(ip) } } diff --git a/internal/clientinfo/ndp_others.go b/internal/clientinfo/ndp_others.go index 007407b..33e95a5 100644 --- a/internal/clientinfo/ndp_others.go +++ b/internal/clientinfo/ndp_others.go @@ -7,8 +7,6 @@ import ( "context" "os/exec" "runtime" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. @@ -17,14 +15,14 @@ func (nd *ndpDiscover) scan() { case "windows": data, err := exec.Command("netsh", "interface", "ipv6", "show", "neighbors").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("could not query ndp table") return } nd.scanWindows(bytes.NewReader(data)) default: data, err := exec.Command("ndp", "-an").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("could not query ndp table") return } nd.scanUnix(bytes.NewReader(data)) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 8e6b3f7..b4783bd 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -17,6 +17,7 @@ type ptrDiscover struct { hostname sync.Map // ip => hostname resolver ctrld.Resolver serverDown atomic.Bool + logger *ctrld.Logger } func (p *ptrDiscover) refresh() error { @@ -73,14 +74,14 @@ func (p *ptrDiscover) lookupHostname(ip string) string { msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { if p.serverDown.CompareAndSwap(false, true) { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") go p.checkServer() } return "" diff --git a/internal/controld/config.go b/internal/controld/config.go index 595e758..97ec8e2 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -88,18 +88,18 @@ type LogsRequest struct { } // FetchResolverConfig fetch Control D config for given uid. -func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverConfig(ctx context.Context, rawUID, version string, cdDev bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // FetchResolverUID fetch resolver uid from provision token. -func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverUID(ctx context.Context, req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { if req == nil { return nil, errors.New("invalid request") } @@ -108,21 +108,21 @@ func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*Reso hostname, _ = os.Hostname() } body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname}) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // UpdateCustomLastFailed calls API to mark custom config is bad. -func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { +func UpdateCustomLastFailed(ctx context.Context, rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, true, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, true, bytes.NewReader(body)) } -func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { +func postUtilityAPI(ctx context.Context, version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev @@ -139,12 +139,12 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") - transport := apiTransport(cdDev) + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: defaultTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) } @@ -166,7 +166,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } // SendLogs sends runtime log to ControlD API. -func SendLogs(lr *LogsRequest, cdDev bool) error { +func SendLogs(ctx context.Context, lr *LogsRequest, cdDev bool) error { defer lr.Data.Close() apiUrl := logURLCom if cdDev { @@ -180,12 +180,12 @@ func SendLogs(lr *LogsRequest, cdDev bool) error { q.Set("uid", lr.UID) req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - transport := apiTransport(cdDev) + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: sendLogTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { return fmt.Errorf("SendLogs client.Do: %w", err) } @@ -213,7 +213,7 @@ func ParseRawUID(rawUID string) (string, string) { } // apiTransport returns an HTTP transport for connecting to ControlD API endpoint. -func apiTransport(cdDev bool) *http.Transport { +func apiTransport(loggerCtx context.Context, cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom @@ -227,9 +227,10 @@ func apiTransport(cdDev bool) *http.Transport { apiIPs = []string{apiDomainDevIPv4} } - ips := ctrld.LookupIP(apiDomain) + ips := ctrld.LookupIP(loggerCtx, apiDomain) if len(ips) == 0 { - ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) + logger := ctrld.LoggerFromCtx(loggerCtx) + logger.Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) ips = apiIPs } @@ -245,7 +246,8 @@ func apiTransport(cdDev bool) *http.Transport { dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) { d := &ctrldnet.ParallelDialer{} - return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load()) + logger := ctrld.LoggerFromCtx(loggerCtx) + return d.DialContext(ctx, network, addrs, logger.Logger) } _, port, _ := net.SplitHostPort(addr) @@ -283,10 +285,11 @@ func addrsFromPort(ips []string, port string) []string { return addrs } -func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { +func doWithFallback(ctx context.Context, client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { resp, err := client.Do(req) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) + logger := ctrld.LoggerFromCtx(ctx) + logger.Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) ipReq := req.Clone(req.Context()) ipReq.Host = apiIp ipReq.URL.Host = apiIp diff --git a/log.go b/log.go index 14c82e8..7b7037b 100644 --- a/log.go +++ b/log.go @@ -3,19 +3,37 @@ package ctrld import ( "context" "fmt" - "io" - "sync/atomic" "github.com/rs/zerolog" ) -// ProxyLog emits the log record for proxy operations. -// The caller should set it only once. -// DEPRECATED: use ProxyLogger instead. -var ProxyLog = zerolog.New(io.Discard) +// LoggerCtxKey is the context.Context key for a logger. +type LoggerCtxKey struct{} -// ProxyLogger emits the log record for proxy operations. -var ProxyLogger atomic.Pointer[zerolog.Logger] +// LoggerCtx returns a context.Context with LoggerCtxKey set. +func LoggerCtx(ctx context.Context, l *Logger) context.Context { + return context.WithValue(ctx, LoggerCtxKey{}, l) +} + +// A Logger provides fast, leveled, structured logging. +type Logger struct { + *zerolog.Logger +} + +var noOpZeroLogger = zerolog.Nop() + +// NopLogger returns a logger which all operation are no-op. +var NopLogger = &Logger{&noOpZeroLogger} + +// LoggerFromCtx returns the logger associated with given ctx. +// +// If there's no logger, a no-op logger will be returned. +func LoggerFromCtx(ctx context.Context) *Logger { + if logger, ok := ctx.Value(LoggerCtxKey{}).(*Logger); ok && logger != nil { + return logger + } + return NopLogger +} // ReqIdCtxKey is the context.Context key for a request id. type ReqIdCtxKey struct{} diff --git a/nameservers.go b/nameservers.go index 0aebf9e..07743ac 100644 --- a/nameservers.go +++ b/nameservers.go @@ -1,9 +1,11 @@ package ctrld -type dnsFn func() []string +import "context" + +type dnsFn func(ctx context.Context) []string // nameservers returns DNS nameservers from system settings. -func nameservers() []string { +func nameservers(ctx context.Context) []string { var dns []string seen := make(map[string]bool) ch := make(chan []string) @@ -11,7 +13,7 @@ func nameservers() []string { for _, fn := range fns { go func(fn dnsFn) { - ch <- fn() + ch <- fn(ctx) }(fn) } for range fns { diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 09c9516..15c30c9 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -3,6 +3,7 @@ package ctrld import ( + "context" "net" "syscall" @@ -13,7 +14,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dnsFromRIB} } -func dnsFromRIB() []string { +func dnsFromRIB(_ context.Context) []string { var dns []string rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { diff --git a/nameservers_darwin.go b/nameservers_darwin.go index c8fa78d..822893b 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -22,8 +22,8 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} } -func getDNSFromScutil() []string { - logger := *ProxyLogger.Load() +func getDNSFromScutil(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) const ( maxRetries = 10 @@ -109,8 +109,8 @@ func getDHCPNameservers(iface string) ([]string, error) { return nameservers, nil } -func getAllDHCPNameservers() []string { - logger := *ProxyLogger.Load() +func getAllDHCPNameservers(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) interfaces, err := net.Interfaces() if err != nil { diff --git a/nameservers_linux.go b/nameservers_linux.go index 13a5507..8f877a6 100644 --- a/nameservers_linux.go +++ b/nameservers_linux.go @@ -3,6 +3,7 @@ package ctrld import ( "bufio" "bytes" + "context" "encoding/hex" "net" "os" @@ -20,7 +21,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver} } -func dns4() []string { +func dns4(_ context.Context) []string { f, err := os.Open(v4RouteFile) if err != nil { return nil @@ -60,7 +61,7 @@ func dns4() []string { return dns } -func dns6() []string { +func dns6(_ context.Context) []string { f, err := os.Open(v6RouteFile) if err != nil { return nil @@ -94,7 +95,7 @@ func dns6() []string { return dns } -func dnsFromSystemdResolver() []string { +func dnsFromSystemdResolver(_ context.Context) []string { c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf") if err != nil { return nil diff --git a/nameservers_test.go b/nameservers_test.go index 166cced..e2e2bac 100644 --- a/nameservers_test.go +++ b/nameservers_test.go @@ -1,9 +1,12 @@ package ctrld -import "testing" +import ( + "context" + "testing" +) func TestNameservers(t *testing.T) { - ns := nameservers() + ns := nameservers(context.Background()) if len(ns) == 0 { t.Fatal("failed to get nameservers") } diff --git a/nameservers_unix.go b/nameservers_unix.go index d8e6035..8082c8a 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -3,6 +3,7 @@ package ctrld import ( + "context" "net" "slices" "time" @@ -20,7 +21,7 @@ func currentNameserversFromResolvconf() []string { // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. // A nameserver is usable if it's not one of current machine's IP addresses // and loopback IP addresses. -func dnsFromResolvConf() []string { +func dnsFromResolvConf(_ context.Context) []string { const ( maxRetries = 10 retryInterval = 100 * time.Millisecond diff --git a/nameservers_windows.go b/nameservers_windows.go index 4f6ca8e..bd8f564 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -55,28 +55,25 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } -func dnsFromAdapter() []string { +func dnsFromAdapter(ctx context.Context) []string { ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) defer cancel() var ns []string var err error - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { - Log(context.Background(), logger.Debug(), - "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + logger.Debug().Msgf("dnsFromAdapter lookup cancelled or timed out, attempt %d", i) return nil } ns, err = getDNSServers(ctx) if err == nil && len(ns) >= minDNSServers { if i > 0 { - Log(context.Background(), logger.Debug(), - "Successfully got DNS servers after %d attempts, found %d servers", - i+1, len(ns)) + logger.Debug().Msgf("Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) } return ns } @@ -88,11 +85,9 @@ func dnsFromAdapter() []string { } if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get DNS servers, attempt %d: %v", i+1, err) + logger.Debug().Msgf("Failed to get DNS servers, attempt %d: %v", i+1, err) } else { - Log(context.Background(), logger.Debug(), - "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + logger.Debug().Msgf("Got insufficient DNS servers, retrying, found %d servers", len(ns)) } select { @@ -102,14 +97,12 @@ func dnsFromAdapter() []string { } } - Log(context.Background(), logger.Debug(), - "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + logger.Debug().Msgf("Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns } func getDNSServers(ctx context.Context) ([]string, error) { - logger := *ProxyLogger.Load() - // Check context before making the call if ctx.Err() != nil { return nil, ctx.Err() @@ -124,17 +117,16 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("getting adapters: %w", err) } - Log(context.Background(), logger.Debug(), - "Found network adapters, count=%d", len(aas)) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Found network adapters, count=%d", len(aas)) // Try to get domain controller info if domain-joined var dcServers []string - isDomain := checkDomainJoined() + isDomain := checkDomainJoined(ctx) if isDomain { domainName, err := getLocalADDomain() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get local AD domain: %v", err) + logger.Debug().Msgf("Failed to get local AD domain: %v", err) } else { // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") @@ -145,11 +137,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to convert domain name to UTF16: %v", err) + logger.Debug().Msgf("Failed to convert domain name to UTF16: %v", err) } else { - Log(context.Background(), logger.Debug(), - "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + logger.Debug().Msgf("Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) // Call DsGetDcNameW with domain name ret, _, err := dsDcName.Call( @@ -163,20 +153,15 @@ func getDNSServers(ctx context.Context) ([]string, error) { if ret != 0 { switch ret { case 1355: // ERROR_NO_SUCH_DOMAIN - Log(context.Background(), logger.Debug(), - "Domain not found: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain not found: %s (%d)", domainName, ret) case 1311: // ERROR_NO_LOGON_SERVERS - Log(context.Background(), logger.Debug(), - "No logon servers available for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("No logon servers available for domain: %s (%d)", domainName, ret) case 1004: // ERROR_DC_NOT_FOUND - Log(context.Background(), logger.Debug(), - "Domain controller not found for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain controller not found for domain: %s (%d)", domainName, ret) case 1722: // RPC_S_SERVER_UNAVAILABLE - Log(context.Background(), logger.Debug(), - "RPC server unavailable for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("RPC server unavailable for domain: %s (%d)", domainName, ret) default: - Log(context.Background(), logger.Debug(), - "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + logger.Debug().Msgf("Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) } } else if info != nil { defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) @@ -184,17 +169,13 @@ func getDNSServers(ctx context.Context) ([]string, error) { if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), - "Found domain controller address: %s", dcAddr) - + logger.Debug().Msgf("Found domain controller address: %s", dcAddr) if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) - Log(context.Background(), logger.Debug(), - "Added domain controller DNS servers: %v", dcServers) + logger.Debug().Msgf("Added domain controller DNS servers: %v", dcServers) } } else { - Log(context.Background(), logger.Debug(), - "No domain controller address found") + logger.Debug().Msg("No domain controller address found") } } } @@ -209,31 +190,27 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Collect all local IPs for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { - Log(context.Background(), logger.Debug(), - "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + logger.Debug().Msgf("Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) continue } // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } - Log(context.Background(), logger.Debug(), - "Processing adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Processing adapter %s", aa.FriendlyName()) for a := aa.FirstUnicastAddress; a != nil; a = a.Next { ip := a.Address.IP().String() addressMap[ip] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP %s from adapter %s", ip, aa.FriendlyName()) + logger.Debug().Msgf("Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } - validInterfacesMap := validInterfaces() + validInterfacesMap := validInterfaces(ctx) // Collect DNS servers for _, aa := range aas { @@ -244,23 +221,20 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } // if not in the validInterfacesMap, skip if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { - Log(context.Background(), logger.Debug(), - "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) continue } for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { - Log(context.Background(), logger.Debug(), - "Skipping nil IP from adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Skipping nil IP from adapter %s", aa.FriendlyName()) continue } @@ -293,28 +267,23 @@ func getDNSServers(ctx context.Context) ([]string, error) { if !seen[dcServer] { seen[dcServer] = true ns = append(ns, dcServer) - Log(context.Background(), logger.Debug(), - "Added additional domain controller DNS server: %s", dcServer) + logger.Debug().Msgf("Added additional domain controller DNS server: %s", dcServer) } } // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get default route interface: %v", err) + logger.Debug().Msgf("Failed to get default route interface: %v", err) } else { drIface, err := net.InterfaceByName(drIfaceName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get interface by name %s: %v", drIfaceName, err) + logger.Debug().Msgf("Failed to get interface by name %s: %v", drIfaceName, err) } else { staticNs, file := SavedStaticNameserversAndPath(drIface) - Log(context.Background(), logger.Debug(), - "static dns servers from %s: %v", file, staticNs) + logger.Debug().Msgf("static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { - Log(context.Background(), logger.Debug(), - "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + logger.Debug().Msgf("Adding static DNS servers from %s: %v", drIfaceName, staticNs) ns = append(ns, staticNs...) } } @@ -324,9 +293,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("no valid DNS servers found") } - Log(context.Background(), logger.Debug(), - "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", - len(ns), ns, len(dcServers)) + logger.Debug().Msgf("DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers)) return ns, nil } @@ -337,33 +304,35 @@ func currentNameserversFromResolvconf() []string { // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available -func checkDomainJoined() bool { - logger := *ProxyLogger.Load() +func checkDomainJoined(ctx context.Context) bool { + logger := LoggerFromCtx(ctx) var domain *uint16 var status uint32 err := windows.NetGetJoinInformation(nil, &domain, &status) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get domain join status: %v", err) + logger.Debug().Msgf("Failed to get domain join status: %v", err) return false } defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) domainName := windows.UTF16PtrToString(domain) - Log(context.Background(), logger.Debug(), + logger.Debug().Msgf( "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", - domainName, status) + domainName, + status, + ) // Consider domain or cloud domain as domain-joined isDomain := status == NetSetupDomain || status == NetSetupCloudDomain - Log(context.Background(), logger.Debug(), + logger.Debug().Msgf( "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", status, status == NetSetupDomain, status == NetSetupCloudDomain, - isDomain) + isDomain, + ) return isDomain } @@ -411,12 +380,12 @@ func getLocalADDomain() (string, error) { // validInterfaces returns a list of all physical interfaces. // this is a duplicate of what is in net_windows.go, we should // clean this up so there is only one version -func validInterfaces() map[string]struct{} { +func validInterfaces(ctx context.Context) map[string]struct{} { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) //load the logger - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") @@ -425,23 +394,20 @@ func validInterfaces() map[string]struct{} { defer instances.Close() } if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get wmi network adapter: %v", err) + logger.Warn().Msgf("failed to get wmi network adapter: %v", err) return nil } var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get network adapter: %v", err) + logger.Warn().Msgf("failed to get network adapter: %v", err) continue } name, err := adapter.GetPropertyName() if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get interface name: %v", err) + logger.Warn().Msgf("failed to get interface name: %v", err) continue } @@ -451,13 +417,11 @@ func validInterfaces() map[string]struct{} { // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter connector present property: %v", err) + logger.Debug().Msgf("failed to get network adapter connector present property: %v", err) continue } if !physical { - Log(context.Background(), logger.Debug(), - "skipping non-physical adapter: %s", name) + logger.Debug().Msgf("skipping non-physical adapter: %s", name) continue } @@ -465,13 +429,11 @@ func validInterfaces() map[string]struct{} { // because some interfaces are not physical but have a connector. hardware, err := adapter.GetPropertyHardwareInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter hardware interface property: %v", err) + logger.Debug().Msgf("failed to get network adapter hardware interface property: %v", err) continue } if !hardware { - Log(context.Background(), logger.Debug(), - "skipping non-hardware interface: %s", name) + logger.Debug().Msgf("skipping non-hardware interface: %s", name) continue } diff --git a/net.go b/net.go index 7bbf54b..0f556f4 100644 --- a/net.go +++ b/net.go @@ -17,26 +17,27 @@ var ( ) // HasIPv6 reports whether the current network stack has IPv6 available. -func HasIPv6() bool { +func HasIPv6(ctx context.Context) bool { hasIPv6Once.Do(func() { - ProxyLogger.Load().Debug().Msg("checking for IPv6 availability once") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("checking for IPv6 availability once") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) ipv6Available.Store(val) - ProxyLogger.Load().Debug().Msgf("ipv6 availability: %v", val) + logger.Debug().Msgf("ipv6 availability: %v", val) mon, err := netmon.New(func(format string, args ...any) {}) if err != nil { - ProxyLogger.Load().Debug().Err(err).Msg("failed to monitor IPv6 state") + logger.Debug().Err(err).Msg("failed to monitor IPv6 state") return } mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { old := ipv6Available.Load() cur := delta.Monitor.InterfaceState().HaveV6 if old != cur { - ProxyLogger.Load().Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) + logger.Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) } else { - ProxyLogger.Load().Debug().Msg("ipv6 availability does not changed") + logger.Debug().Msg("ipv6 availability does not changed") } ipv6Available.Store(cur) }) @@ -46,8 +47,9 @@ func HasIPv6() bool { } // DisableIPv6 marks IPv6 as unavailable if enabled. -func DisableIPv6() { +func DisableIPv6(ctx context.Context) { if ipv6Available.CompareAndSwap(true, false) { - ProxyLogger.Load().Debug().Msg("turned off IPv6 availability") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("turned off IPv6 availability") } } diff --git a/resolver.go b/resolver.go index 27c0108..c88df1f 100644 --- a/resolver.go +++ b/resolver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/netip" "runtime" @@ -15,7 +14,6 @@ import ( "time" "github.com/miekg/dns" - "github.com/rs/zerolog" "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" @@ -50,10 +48,6 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") var localResolver Resolver func init() { - // Initializing ProxyLogger here, so other places don't have to do nil check. - l := zerolog.New(io.Discard) - ProxyLogger.Store(&l) - localResolver = newLocalResolver() } @@ -81,8 +75,8 @@ func LanQueryCtx(ctx context.Context) context.Context { } // defaultNameservers is like nameservers with each element formed "ip:53". -func defaultNameservers() []string { - ns := nameservers() +func defaultNameservers(ctx context.Context) []string { + ns := nameservers(ctx) nss := make([]string, len(ns)) for i := range ns { nss[i] = net.JoinHostPort(ns[i], "53") @@ -91,42 +85,36 @@ func defaultNameservers() []string { } // availableNameservers returns list of current available DNS servers of the system. -func availableNameservers() []string { +func availableNameservers(ctx context.Context) []string { var nss []string // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - //load the logger - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), - "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + // Load the logger. + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) for _, v := range slices.Concat(regularIPs, loopbackIPs) { ipStr := v.String() machineIPsMap[ipStr] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP to OS resolverexclusion map: %s", ipStr) + logger.Debug().Msgf("Added local IP to OS resolverexclusion map: %s", ipStr) } - systemNameservers := nameservers() - Log(context.Background(), logger.Debug(), - "Got system nameservers: %v", systemNameservers) + systemNameservers := nameservers(ctx) + logger.Debug().Msgf("Got system nameservers: %v", systemNameservers) for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { - Log(context.Background(), logger.Debug(), - "Skipping local nameserver: %s", ns) + logger.Debug().Msgf("Skipping local nameserver: %s", ns) continue } nss = append(nss, ns) - Log(context.Background(), logger.Debug(), - "Added non-local nameserver: %s", ns) + logger.Debug().Msgf("Added non-local nameserver: %s", ns) } - Log(context.Background(), logger.Debug(), - "Final available nameservers: %v", nss) + logger.Debug().Msgf("Final available nameservers: %v", nss) + return nss } @@ -135,8 +123,8 @@ func availableNameservers() []string { // // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. -func InitializeOsResolver(guardAgainstNoNameservers bool) []string { - nameservers := availableNameservers() +func InitializeOsResolver(ctx context.Context, guardAgainstNoNameservers bool) []string { + nameservers := availableNameservers(ctx) // if no nameservers, return empty slice so we dont remove all nameservers if len(nameservers) == 0 && guardAgainstNoNameservers { return []string{} @@ -188,7 +176,7 @@ type Resolver interface { var errUnknownResolver = errors.New("unknown resolver") // NewResolver creates a Resolver based on the given upstream config. -func NewResolver(uc *UpstreamConfig) (Resolver, error) { +func NewResolver(ctx context.Context, uc *UpstreamConfig) (Resolver, error) { typ := uc.Type switch typ { case ResolverTypeDOH, ResolverTypeDOH3: @@ -200,15 +188,16 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeOS: resolverMutex.Lock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver") - or = newResolverWithNameserver(defaultNameservers()) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Initialize new OS resolver") + or = newResolverWithNameserver(defaultNameservers(ctx)) } resolverMutex.Unlock() return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil case ResolverTypePrivate: - return NewPrivateResolver(), nil + return NewPrivateResolver(ctx), nil case ResolverTypeLocal: return localResolver, nil } @@ -235,14 +224,16 @@ type publicResponse struct { } // SetDefaultLocalIPv4 updates the stored local IPv4. -func SetDefaultLocalIPv4(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) +func SetDefaultLocalIPv4(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv4: %s", ip) defaultLocalIPv4.Store(ip) } // SetDefaultLocalIPv6 updates the stored local IPv6. -func SetDefaultLocalIPv6(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) +func SetDefaultLocalIPv6(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv6: %s", ip) defaultLocalIPv6.Store(ip) } @@ -300,10 +291,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // Unique key for the singleflight group. key := fmt.Sprintf("%s:%d:", domain, qtype) + logger := LoggerFromCtx(ctx) // Checking the cache first. if val, ok := o.cache.Load(key); ok { if val, ok := val.(*dns.Msg); ok { - Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() SetCacheReply(res, msg, val.Rcode) return res, nil @@ -338,7 +330,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error res := sharedMsg.Copy() SetCacheReply(res, msg, sharedMsg.Rcode) if shared { - Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) } return res, nil @@ -368,7 +360,8 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if msg != nil && len(msg.Question) > 0 { question = msg.Question[0].Name } - Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) // New check: If no resolvers are available, return an error. if numServers == 0 { @@ -417,7 +410,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // If splitting fails, fallback to the original server string host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) + Log(ctx, logger.Debug(), "got answer from nameserver: %s", host) } // try local nameservers @@ -444,7 +437,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error switch { case res.lan: // Always prefer LAN responses immediately - Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) + Log(ctx, logger.Debug(), "using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -454,7 +447,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // if there are no LAN nameservers, we should not wait // just use the first response if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server) + Log(ctx, logger.Debug(), "using public answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -465,12 +458,12 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }) } case res.answer != nil: - Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + Log(ctx, logger.Debug(), "got non-success answer from: %s with code: %d", res.server, res.answer.Rcode) // When there are no LAN nameservers, we should not wait // for other nameservers to respond. if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer") + Log(ctx, logger.Debug(), "no lan nameservers using public non success answer") cancel() logAnswer(res.server) return res.answer, nil @@ -483,17 +476,17 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if len(publicResponses) > 0 { resp := publicResponses[0] - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server) + Log(ctx, logger.Debug(), "using public answer from: %s", resp.server) logAnswer(resp.server) return resp.answer, nil } if controldSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) + Log(ctx, logger.Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer) + Log(ctx, logger.Debug(), "using non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } @@ -515,7 +508,7 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - _, udpNet := r.uc.netForDNSType(dnsTyp) + _, udpNet := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: udpNet, Dialer: dialer, @@ -541,39 +534,43 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err // LookupIP looks up domain using current system nameservers settings. // It returns a slice of that host's IPv4 and IPv6 addresses. -func LookupIP(domain string) []string { - nss := initDefaultOsResolver() - return lookupIP(domain, -1, nss) +func LookupIP(ctx context.Context, domain string) []string { + nss := initDefaultOsResolver(ctx) + return lookupIP(ctx, domain, -1, nss) } // initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet. // It returns the combined list of LAN and public nameservers currently held by the resolver. -func initDefaultOsResolver() []string { +func initDefaultOsResolver(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) resolverMutex.Lock() defer resolverMutex.Unlock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers") - or = newResolverWithNameserver(defaultNameservers()) + logger.Debug().Msgf("Initialize new OS resolver with default nameservers") + or = newResolverWithNameserver(defaultNameservers(ctx)) } nss := *or.lanServers.Load() nss = append(nss, *or.publicServers.Load()...) return nss + } // lookupIP looks up domain with given timeout and bootstrapDNS. // If the timeout is negative, default timeout 2000 ms will be used. // It returns nil if bootstrapDNS is nil or empty. -func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) { +func lookupIP(ctx context.Context, domain string, timeout int, bootstrapDNS []string) (ips []string) { if net.ParseIP(domain) != nil { return []string{domain} } + logger := LoggerFromCtx(ctx) if bootstrapDNS == nil { - ProxyLogger.Load().Debug().Msgf("empty bootstrap DNS") + logger.Debug().Msgf("empty bootstrap DNS") return nil } resolver := newResolverWithNameserver(bootstrapDNS) - ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS) + logger.Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS) + timeoutMs := 2000 if timeout > 0 && timeout < timeoutMs { timeoutMs = timeout @@ -616,15 +613,15 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) r, err := resolver.Resolve(ctx, m) if err != nil { - ProxyLogger.Load().Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) + logger.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) return } if r.Rcode != dns.RcodeSuccess { - ProxyLogger.Load().Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) + logger.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) return } if len(r.Answer) == 0 { - ProxyLogger.Load().Error().Msg("no answer from OS resolver") + logger.Error().Msg("no answer from OS resolver") return } target := targetDomain(r.Answer) @@ -641,22 +638,6 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) return ips } -// NewBootstrapResolver returns an OS resolver, which use following nameservers: -// -// - Gateway IP address (depends on OS). -// - Input servers. -func NewBootstrapResolver(servers ...string) Resolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers) - nss := defaultNameservers() - nss = append([]string{controldPublicDnsWithPort}, nss...) - for _, ns := range servers { - nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) - } - return NewResolverWithNameserver(nss) -} - // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, // excluding: // @@ -664,8 +645,8 @@ func NewBootstrapResolver(servers ...string) Resolver { // - Nameservers which is local RFC1918 addresses. // // This is useful for doing PTR lookup in LAN network. -func NewPrivateResolver() Resolver { - nss := initDefaultOsResolver() +func NewPrivateResolver(ctx context.Context) Resolver { + nss := initDefaultOsResolver(ctx) resolveConfNss := currentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 diff --git a/resolver_test.go b/resolver_test.go index ebcad16..d5a76d6 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -132,7 +132,7 @@ func Test_osResolver_InitializationRace(t *testing.T) { for range n { go func() { defer wg.Done() - InitializeOsResolver(false) + InitializeOsResolver(LoggerCtx(context.Background(), nil), false) }() } wg.Wait()