diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 3caa3bb..5005925 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -349,7 +349,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { // After processCDFlags, log config may change, so reset mainLog and re-init logging. l := zerolog.New(io.Discard) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { @@ -502,8 +502,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - nop := zerolog.Nop() - _, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true) + _, _ = tryUpdateListenerConfig(&cfg, func() {}, true) addExtraSplitDnsRule(&cfg) if err := writeConfigFile(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) @@ -591,7 +590,8 @@ func processNoConfigFlags(noConfigStart bool) { Type: pType, Timeout: 5000, } - puc.Init() + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + puc.Init(loggerCtx) upstream := map[string]*ctrld.UpstreamConfig{"0": puc} if secondaryUpstream != "" { sEndpoint, sType := endpointAndTyp(secondaryUpstream) @@ -601,7 +601,7 @@ func processNoConfigFlags(noConfigStart bool) { Type: sType, Timeout: 5000, } - suc.Init() + suc.Init(loggerCtx) upstream["1"] = suc rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { @@ -634,13 +634,13 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) bo.LogLongerThan = 30 * time.Second - ctx := context.Background() - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + ctx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) for { if errUrlNetworkError(err) { bo.BackOff(ctx, err) logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...") - resolverConfig, err = controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev) continue } break @@ -938,9 +938,10 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } mainLog.Load().Debug().Msgf("self-check against %q failed", domain) + loggerCtx := ctrld.LoggerCtx(ctx, mainLog.Load()) // Ping all upstreams to provide better error message to users. for name, uc := range cfg.Upstream { - if err := uc.ErrorPing(); err != nil { + if err := uc.ErrorPing(loggerCtx); err != nil { mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } @@ -1181,7 +1182,7 @@ func mobileListenerIp() string { // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool { - updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true) + updated, _ := tryUpdateListenerConfig(cfg, notifyToLogServerFunc, true) if addExtraSplitDnsRule(cfg) { updated = true } @@ -1191,7 +1192,7 @@ func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. -func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) { +func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" @@ -1235,9 +1236,6 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } il := mainLog.Load() - if infoLogger != nil { - il = infoLogger - } if isMobile() { // On Mobile, only use first listener, ignore others. firstLn := cfg.FirstListener() @@ -1492,7 +1490,8 @@ func cdUIDFromProvToken() string { } req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. - resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, rootCmd.Version, cdDev) if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) } @@ -1819,7 +1818,8 @@ func runningIface(s service.Service) *ifaceResponse { // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { - rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) if err != nil { logger := mainLog.Load().Fatal() if !fatal { diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 9281b90..428fe12 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -216,8 +216,9 @@ func (p *prog) registerControlServerHandler() { return } + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) // Re-fetch pin code from API. - if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil { + if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil { if rc.DeactivationPin != nil { cdDeactivationPin.Store(*rc.DeactivationPin) } else { @@ -321,7 +322,8 @@ func (p *prog) registerControlServerHandler() { } mainLog.Load().Debug().Msg("sending log file to ControlD server") resp := logSentResponse{Size: r.size} - if err := controld.SendLogs(req, cdDev); err != nil { + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil { mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) resp.Error = err.Error() w.WriteHeader(http.StatusInternalServerError) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 33012fa..a3d9970 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -110,6 +110,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] reqId := requestID() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) + ctx = ctrld.LoggerCtx(ctx, mainLog.Load()) if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) answer := new(dns.Msg) @@ -514,7 +515,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) - dnsResolver, err := ctrld.NewResolver(upstreamConfig) + dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err @@ -549,11 +550,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error if errors.As(err, &e) && e.Timeout() { - upstreamConfig.ReBootstrap() + upstreamConfig.ReBootstrap(ctx) } // For network error, turn ipv6 off if enabled. - if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) { - ctrld.DisableIPv6() + if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) { + ctrld.DisableIPv6(ctx) } } @@ -960,7 +961,8 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true - _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + _, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) logger.Debug().Msg("maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) @@ -1326,13 +1328,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Only set the IPv4 default if selfIP is a valid IPv4 address. if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { - ctrld.SetDefaultLocalIPv4(ip) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } } if ip := net.ParseIP(ipv6); ip != nil { - ctrld.SetDefaultLocalIPv6(ip) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip) } mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) @@ -1400,7 +1402,7 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc) if err != nil { mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) return err @@ -1418,7 +1420,7 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - uc.ReBootstrap() + uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() @@ -1474,10 +1476,11 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // will be appended to nameservers from the saved interface values p.resetDNS(false, false) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") } else { @@ -1504,7 +1507,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) { // For network changes we also reinitialize the OS resolver. if reason == RecoveryReasonNetworkChange { - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(loggerCtx, true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") } else { @@ -1564,7 +1567,7 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string // we should try to reinit the OS resolver to ensure we can recover if name == upstreamOS && attempts%3 == 0 { mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) - ns := ctrld.InitializeOsResolver(true) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true) if len(ns) == 0 { mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") } else { @@ -1624,12 +1627,12 @@ func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) { // Check if the default IPv4 is still active. if currentIPv4 != nil && !activeIPs[currentIPv4.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4) - ctrld.SetDefaultLocalIPv4(nil) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } // Check if the default IPv6 is still active. if currentIPv6 != nil && !activeIPs[currentIPv6.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6) - ctrld.SetDefaultLocalIPv6(nil) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } } diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index ab6b855..0ba2c8c 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -137,8 +137,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) { }) multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - ctrld.ProxyLogger.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) } // needInternalLogging reports whether prog needs to run internal logging. diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 3504bc3..434a4a5 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -102,6 +102,7 @@ func (p *prog) checkDnsLoop() { } p.loopMu.Unlock() + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] @@ -109,7 +110,7 @@ func (p *prog) checkDnsLoop() { if uc == nil { continue } - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 6a8cb62..53b8309 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -40,7 +40,7 @@ var ( cleanup bool startOnly bool - mainLog atomic.Pointer[zerolog.Logger] + mainLog atomic.Pointer[ctrld.Logger] consoleWriter zerolog.ConsoleWriter noConfigStart bool ) @@ -54,7 +54,7 @@ const ( func init() { l := zerolog.New(io.Discard) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) } func Main() { @@ -87,16 +87,14 @@ func initConsoleLogging() { }) multi := zerolog.MultiLevelWriter(consoleWriter) l := mainLog.Load().Output(multi).With().Timestamp().Logger() - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) switch { case silent: zerolog.SetGlobalLevel(zerolog.NoLevel) case verbose == 1: - ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.InfoLevel) case verbose > 1: - ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.DebugLevel) default: zerolog.SetGlobalLevel(zerolog.NoticeLevel) @@ -113,8 +111,6 @@ func initInteractiveLogging() { zerolog.TimeFieldFormat = time.RFC3339 + ".000" initLoggingWithBackup(false) cfg.Service.LogPath = old - l := zerolog.New(io.Discard) - ctrld.ProxyLogger.Store(&l) } // initLoggingWithBackup initializes log setup base on current config. @@ -153,9 +149,7 @@ func initLoggingWithBackup(doBackup bool) []io.Writer { writers = append(writers, consoleWriter) multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - // TODO: find a better way. - ctrld.ProxyLogger.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) zerolog.SetGlobalLevel(zerolog.NoticeLevel) logLevel := cfg.Service.LogLevel diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index 6ed26c7..c7b8b17 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -6,12 +6,14 @@ import ( "testing" "github.com/rs/zerolog" + + "github.com/Control-D-Inc/ctrld" ) var logOutput strings.Builder func TestMain(m *testing.M) { l := zerolog.New(&logOutput) - mainLog.Store(&l) + mainLog.Store(&ctrld.Logger{Logger: &l}) os.Exit(m.Run()) } diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index d757f8b..f4e9bda 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -5,6 +5,8 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + + "github.com/Control-D-Inc/ctrld" ) func (p *prog) watchLinkState(ctx context.Context) { @@ -26,7 +28,7 @@ func (p *prog) watchLinkState(ctx context.Context) { if lu.Change&unix.IFF_UP != 0 { mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load())) } } } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 3b159ee..d85c371 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -286,7 +286,7 @@ func (p *prog) postRun() { mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) } p.resetDNS(false, false) - ns := ctrld.InitializeOsResolver(false) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} @@ -319,7 +319,8 @@ func (p *prog) apiConfigReload() { } doReloadApiConfig := func(forced bool, logger zerolog.Logger) { - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev) selfUninstallCheck(err, p, logger) if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") @@ -377,7 +378,7 @@ func (p *prog) apiConfigReload() { } if cfgErr != nil { logger.Warn().Err(err).Msg("skipping invalid custom config") - if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { + if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, rootCmd.Version, cdDev, true); err != nil { logger.Error().Err(err).Msg("could not mark custom last update failed") } return @@ -404,22 +405,23 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) isControlDUpstream := false + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) for n := range cfg.Upstream { uc := cfg.Upstream[n] sdns := uc.Type == ctrld.ResolverTypeSDNS - uc.Init() + uc.Init(loggerCtx) if sdns { mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) } isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() + uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load())) mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) } else { mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) } uc.SetCertPool(rootCertPool) - go uc.Ping() + go uc.Ping(loggerCtx) if canBeLocalUpstream(uc.Domain) { localUpstreams = append(localUpstreams, upstreamPrefix+n) @@ -601,7 +603,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { // setupClientInfoDiscover performs necessary works for running client info discover. func (p *prog) setupClientInfoDiscover(selfIP string) { - p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load()) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) diff --git a/config.go b/config.go index 96f6686..4aadff1 100644 --- a/config.go +++ b/config.go @@ -325,12 +325,13 @@ type ListenerPolicyConfig struct { type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. -func (uc *UpstreamConfig) Init() { +func (uc *UpstreamConfig) Init(ctx context.Context) { + logger := LoggerFromCtx(ctx) if err := uc.initDnsStamps(); err != nil { - ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps") + logger.Fatal().Err(err).Msg("invalid DNS Stamps") } uc.initDoHScheme() - uc.uid = upstreamUID() + uc.uid = upstreamUID(ctx) if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Hostname() switch uc.Type { @@ -434,12 +435,13 @@ func (uc *UpstreamConfig) UID() string { // - ControlD Bootstrap DNS 76.76.2.22 // // The setup process will block until there's usable IPs found. -func (uc *UpstreamConfig) SetupBootstrapIP() { +func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) isControlD := uc.IsControlD() - nss := initDefaultOsResolver() + logger := LoggerFromCtx(ctx) + nss := initDefaultOsResolver(ctx) for { - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, // filtering them out here to prevent weird behavior. if isControlD { @@ -454,18 +456,18 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs = uc.bootstrapIPs[:n] if len(uc.bootstrapIPs) == 0 { uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain) - ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain) + logger.Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain) } } if len(uc.bootstrapIPs) == 0 { - ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) + logger.Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) } if len(uc.bootstrapIPs) > 0 { break } - ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...") + logger.Warn().Msg("could not resolve bootstrap IPs, retrying...") b.BackOff(context.Background(), errors.New("no bootstrap IPs")) } for _, ip := range uc.bootstrapIPs { @@ -475,11 +477,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip) } } - ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs) + logger.Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs) } // ReBootstrap re-setup the bootstrap IP and the transport. -func (uc *UpstreamConfig) ReBootstrap() { +func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -487,7 +489,8 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil }) @@ -495,35 +498,35 @@ func (uc *UpstreamConfig) ReBootstrap() { // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. -func (uc *UpstreamConfig) SetupTransport() { +func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { switch uc.Type { case ResolverTypeDOH: - uc.setupDOHTransport() + uc.setupDOHTransport(ctx) case ResolverTypeDOH3: - uc.setupDOH3Transport() + uc.setupDOH3Transport(ctx) } } -func (uc *UpstreamConfig) setupDOHTransport() { +func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) - if HasIPv6() { - uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 } - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { +func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxIdleConnsPerHost = 100 transport.TLSClientConfig = &tls.Config{ @@ -543,12 +546,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { dialerTimeoutMs = uc.Timeout } dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond + logger := LoggerFromCtx(ctx) transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { _, port, _ := net.SplitHostPort(addr) if uc.BootstrapIP != "" { dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout} addr := net.JoinHostPort(uc.BootstrapIP, port) - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr) + logger.Debug().Msgf("sending doh request to: %s", addr) return dialer.DialContext(ctx, network, addr) } pd := &ctrldnet.ParallelDialer{} @@ -558,11 +562,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } - conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load()) + conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger) if err != nil { return nil, err } - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr()) + logger.Debug().Msgf("sending doh request to: %s", conn.RemoteAddr()) return conn, nil } runtime.SetFinalizer(transport, func(transport *http.Transport) { @@ -572,19 +576,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { } // Ping warms up the connection to DoH/DoH3 upstream. -func (uc *UpstreamConfig) Ping() { - if err := uc.ping(); err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) - _ = uc.FallbackToDirectIP() +func (uc *UpstreamConfig) Ping(ctx context.Context) { + if err := uc.ping(ctx); err != nil { + logger := LoggerFromCtx(ctx) + logger.Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) + _ = uc.FallbackToDirectIP(ctx) } } // ErrorPing is like Ping, but return an error if any. -func (uc *UpstreamConfig) ErrorPing() error { - return uc.ping() +func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error { + return uc.ping(ctx) } -func (uc *UpstreamConfig) ping() error { +func (uc *UpstreamConfig) ping(ctx context.Context) error { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -613,11 +618,11 @@ func (uc *UpstreamConfig) ping() error { for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { switch uc.Type { case ResolverTypeDOH: - if err := ping(uc.dohTransport(typ)); err != nil { + if err := ping(uc.dohTransport(ctx, typ)); err != nil { return err } case ResolverTypeDOH3: - if err := ping(uc.doh3Transport(typ)); err != nil { + if err := ping(uc.doh3Transport(ctx, typ)); err != nil { return err } } @@ -652,12 +657,12 @@ func (uc *UpstreamConfig) isNextDNS() bool { return domain == "dns.nextdns.io" } -func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: @@ -673,7 +678,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { return uc.transport } -func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { +func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string { switch uc.IPStack { case IpStackBoth: return pick(uc.bootstrapIPs) @@ -686,7 +691,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { case dns.TypeA: return pick(uc.bootstrapIPs4) default: - if HasIPv6() { + if HasIPv6(ctx) { return pick(uc.bootstrapIPs6) } return pick(uc.bootstrapIPs4) @@ -695,7 +700,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { return pick(uc.bootstrapIPs) } -func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { +func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) { switch uc.IPStack { case IpStackBoth: return "tcp-tls", "udp" @@ -708,7 +713,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { case dns.TypeA: return "tcp4-tls", "udp4" default: - if HasIPv6() { + if HasIPv6(ctx) { return "tcp6-tls", "udp6" } return "tcp4-tls", "udp4" @@ -789,7 +794,7 @@ func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context } // FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain. -func (uc *UpstreamConfig) FallbackToDirectIP() bool { +func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool { if !uc.IsControlD() { return false } @@ -808,7 +813,8 @@ func (uc *UpstreamConfig) FallbackToDirectIP() bool { default: return } - ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) + logger := LoggerFromCtx(ctx) + logger.Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) uc.u.Host = ip done = true }) @@ -942,11 +948,12 @@ func pick(s []string) string { } // upstreamUID generates an unique identifier for an upstream. -func upstreamUID() string { +func upstreamUID(ctx context.Context) string { + logger := LoggerFromCtx(ctx) b := make([]byte, 4) for { if _, err := crand.Read(b); err != nil { - ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...") + logger.Warn().Err(err).Msg("could not generate uid for upstream, retrying...") continue } return hex.EncodeToString(b) diff --git a/config_internal_test.go b/config_internal_test.go index b37e982..0e7f3bb 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -1,6 +1,7 @@ package ctrld import ( + "context" "net/url" "testing" @@ -36,10 +37,10 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed. // t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() + tc.uc.Init(context.Background()) + tc.uc.SetupBootstrapIP(context.Background()) if len(tc.uc.bootstrapIPs) == 0 { - t.Log(defaultNameservers()) + t.Log(defaultNameservers(context.Background())) t.Fatalf("could not bootstrap ip: %s", tc.uc.String()) } }) @@ -355,7 +356,7 @@ func TestUpstreamConfig_Init(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) tc.uc.uid = "" // we don't care about the uid. assert.Equal(t, tc.expected, tc.uc) }) @@ -497,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) if got := tc.uc.IsDiscoverable(); got != tc.discoverable { t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) } diff --git a/config_quic.go b/config_quic.go index cadcb6b..8f27bf3 100644 --- a/config_quic.go +++ b/config_quic.go @@ -14,34 +14,35 @@ import ( "github.com/quic-go/quic-go/http3" ) -func (uc *UpstreamConfig) setupDOH3Transport() { +func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - if HasIPv6() { - uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) } else { uc.http3RoundTripper6 = uc.http3RoundTripper4 } - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { +func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper { rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} + logger := LoggerFromCtx(ctx) rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { addr = net.JoinHostPort(uc.BootstrapIP, port) - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr) + logger.Debug().Msgf("sending doh3 request to: %s", addr) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err @@ -61,7 +62,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { if err != nil { return nil, err } - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) + logger.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } runtime.SetFinalizer(rt, func(rt *http3.Transport) { @@ -70,12 +71,12 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { return rt } -func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: diff --git a/doh.go b/doh.go index 3459cb8..f93dc88 100644 --- a/doh.go +++ b/doh.go @@ -105,19 +105,20 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - c := http.Client{Transport: r.uc.dohTransport(dnsTyp)} + c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)} if r.isDoH3 { - transport := r.uc.doh3Transport(dnsTyp) + transport := r.uc.doh3Transport(ctx, dnsTyp) if transport == nil { return nil, errors.New("DoH3 is not supported") } c.Transport = transport } resp, err := c.Do(req) - if err != nil && r.uc.FallbackToDirectIP() { + if err != nil && r.uc.FallbackToDirectIP(ctx) { retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx)) defer cancel() - Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip") + logger := LoggerFromCtx(ctx) + logger.Warn().Err(err).Msg("retrying request after fallback to direct ip") resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { @@ -163,7 +164,8 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { } } if printed { - Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("sending request header: %v", dohHeader) } dohHeader.Set("Content-Type", headerApplicationDNS) dohHeader.Set("Accept", headerApplicationDNS) diff --git a/doh_test.go b/doh_test.go index 92fa79f..700b299 100644 --- a/doh_test.go +++ b/doh_test.go @@ -157,20 +157,21 @@ func Test_ClientCertificateVerificationError(t *testing.T) { }, } + ctx := context.Background() for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() - r, err := NewResolver(tc.uc) + tc.uc.Init(ctx) + tc.uc.SetupBootstrapIP(ctx) + r, err := NewResolver(ctx, tc.uc) if err != nil { t.Fatal(err) } msg := new(dns.Msg) msg.SetQuestion("verify.controld.com.", dns.TypeA) msg.RecursionDesired = true - _, err = r.Resolve(context.Background(), msg) + _, err = r.Resolve(ctx, msg) // Verify the error contains the expected certificate information if err == nil { t.Fatal("expected certificate verification error, got nil") diff --git a/doq.go b/doq.go index 0903411..d341668 100644 --- a/doq.go +++ b/doq.go @@ -26,7 +26,7 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - ip = r.uc.bootstrapIPForDNSType(dnsTyp) + ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp) } tlsConfig.ServerName = r.uc.Domain _, port, _ := net.SplitHostPort(endpoint) diff --git a/dot.go b/dot.go index 295134c..03c08db 100644 --- a/dot.go +++ b/dot.go @@ -23,7 +23,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) + tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, Dialer: dialer, diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index f69b670..719e205 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -79,6 +79,7 @@ type Table struct { initOnce sync.Once stopOnce sync.Once refreshInterval int + logger *ctrld.Logger dhcp *dhcp merlin *merlinDiscover @@ -98,11 +99,14 @@ type Table struct { ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrld.Logger) *Table { refreshInterval := cfg.Service.DiscoverRefreshInterval if refreshInterval <= 0 { refreshInterval = 2 * 60 // 2 minutes } + if logger == nil { + logger = ctrld.NopLogger + } return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), @@ -111,6 +115,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { cdUID: cdUID, ptrNameservers: ns, refreshInterval: refreshInterval, + logger: logger, } } @@ -179,7 +184,7 @@ func (t *Table) SetSelfIP(ip string) { // initSelfDiscover initializes necessary client metadata for self query. func (t *Table) initSelfDiscover() { - t.dhcp = &dhcp{selfIP: t.selfIP} + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} t.dhcp.addSelf() t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -189,14 +194,14 @@ func (t *Table) initSelfDiscover() { func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id") + t.logger.Debug().Msg("start self discovery with custom client id") t.initSelfDiscover() return } // If we are running on platforms that should only do self discover, use it as the only source, too. if ctrld.SelfDiscover() { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms") + t.logger.Debug().Msg("start self discovery on desktop platforms") t.initSelfDiscover() return } @@ -208,7 +213,7 @@ func (t *Table) init() { // - Merlin // - Ubios if t.discoverDHCP() || t.discoverARP() { - t.merlin = &merlinDiscover{} + t.merlin = &merlinDiscover{logger: t.logger} t.ubios = &ubiosDiscover{} discovers := map[string]interface { refresher @@ -219,7 +224,7 @@ func (t *Table) init() { } for platform, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform) + t.logger.Warn().Err(err).Msgf("failed to init %s discover", platform) } t.hostnameResolvers = append(t.hostnameResolvers, discover) t.refreshers = append(t.refreshers, discover) @@ -227,10 +232,10 @@ func (t *Table) init() { } // Hosts file mapping. if t.discoverHosts() { - t.hf = &hostsFile{} - ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery") + t.hf = &hostsFile{logger: t.logger} + t.logger.Debug().Msg("start hosts file discovery") if err := t.hf.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover") + t.logger.Error().Err(err).Msg("could not init hosts file discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.hf) t.refreshers = append(t.refreshers, t.hf) @@ -239,10 +244,10 @@ func (t *Table) init() { } // DHCP lease files. if t.discoverDHCP() { - t.dhcp = &dhcp{selfIP: t.selfIP} - ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery") + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} + t.logger.Debug().Msg("start dhcp discovery") if err := t.dhcp.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init DHCP discover") + t.logger.Error().Err(err).Msg("could not init DHCP discover") } else { t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -253,8 +258,8 @@ func (t *Table) init() { // ARP/NDP table. if t.discoverARP() { t.arp = &arpDiscover{} - t.ndp = &ndpDiscover{} - ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery") + t.ndp = &ndpDiscover{logger: t.logger} + t.logger.Debug().Msg("start arp discovery") discovers := map[string]interface { refresher IpResolver @@ -266,7 +271,7 @@ func (t *Table) init() { for protocol, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", protocol) + t.logger.Error().Err(err).Msgf("could not init %s discover", protocol) } else { t.ipResolvers = append(t.ipResolvers, discover) t.macResolvers = append(t.macResolvers, discover) @@ -283,7 +288,10 @@ func (t *Table) init() { } // PTR lookup. if t.discoverPTR() { - t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} + t.ptr = &ptrDiscover{ + resolver: ctrld.NewPrivateResolver(context.Background()), + logger: t.logger, + } if len(t.ptrNameservers) > 0 { nss := make([]string, 0, len(t.ptrNameservers)) for _, ns := range t.ptrNameservers { @@ -295,18 +303,18 @@ func (t *Table) init() { if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { nss = append(nss, net.JoinHostPort(host, port)) } else { - ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + t.logger.Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) } } if len(nss) > 0 { t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + t.logger.Debug().Msgf("using nameservers %v for ptr discovery", nss) } } - ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") + t.logger.Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover") + t.logger.Error().Err(err).Msg("could not init PTR discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.ptr) t.refreshers = append(t.refreshers, t.ptr) @@ -314,10 +322,10 @@ func (t *Table) init() { } // mdns. if t.discoverMDNS() { - t.mdns = &mdns{} - ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery") + t.mdns = &mdns{logger: t.logger} + t.logger.Debug().Msg("start mdns discovery") if err := t.mdns.init(t.quitCh); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init mDNS discover") + t.logger.Error().Err(err).Msg("could not init mDNS discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) } diff --git a/internal/clientinfo/client_info_test.go b/internal/clientinfo/client_info_test.go index b5bdfa5..7abb907 100644 --- a/internal/clientinfo/client_info_test.go +++ b/internal/clientinfo/client_info_test.go @@ -2,6 +2,8 @@ package clientinfo import ( "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_normalizeIP(t *testing.T) { @@ -28,8 +30,9 @@ func Test_normalizeIP(t *testing.T) { func TestTable_LookupRFC1918IPv4(t *testing.T) { table := &Table{ - dhcp: &dhcp{}, - arp: &arpDiscover{}, + dhcp: &dhcp{}, + arp: &arpDiscover{}, + logger: ctrld.NopLogger, } table.ipResolvers = append(table.ipResolvers, table.dhcp) diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 5d11d5e..fbd7b08 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -13,9 +13,8 @@ import ( "strings" "sync" - "tailscale.com/net/netmon" - "github.com/fsnotify/fsnotify" + "tailscale.com/net/netmon" "tailscale.com/util/lineread" "github.com/Control-D-Inc/ctrld" @@ -30,6 +29,7 @@ type dhcp struct { watcher *fsnotify.Watcher selfIP string + logger *ctrld.Logger } func (d *dhcp) init() error { @@ -52,7 +52,7 @@ func (d *dhcp) watchChanges() { } if dir := router.LeaseFilesDir(); dir != "" { if err := d.watcher.Add(dir); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir") + d.logger.Err(err).Str("dir", dir).Msg("could not watch lease dir") } } for { @@ -64,7 +64,7 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Create) { if format, ok := clientInfoFiles[event.Name]; ok { if err := d.addLeaseFile(event.Name, format); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file") + d.logger.Err(err).Str("file", event.Name).Msg("could not add lease file") } } continue @@ -72,14 +72,14 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { format := clientInfoFiles[event.Name] if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info") + d.logger.Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info") } } case err, ok := <-d.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + d.logger.Err(err).Msg("could not watch client info file") } } @@ -222,7 +222,7 @@ func (d *dhcp) dnsmasqReadClientInfoReader(reader io.Reader) error { } ip := normalizeIP(string(fields[2])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } @@ -275,7 +275,7 @@ func (d *dhcp) iscDHCPReadClientInfoReader(reader io.Reader) error { case "lease": ip = normalizeIP(strings.ToLower(fields[1])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } case "hardware": @@ -328,7 +328,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { } ip := normalizeIP(record[0]) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("invalid ip address entry: %q", ip) ip = "" } @@ -350,7 +350,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { func (d *dhcp) addSelf() { hostname, err := os.Hostname() if err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not get hostname") + d.logger.Err(err).Msg("could not get hostname") return } hostname = normalizeHostname(hostname) diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index d96229d..4dc6f35 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -27,6 +27,7 @@ type hostsFile struct { watcher *fsnotify.Watcher mu sync.Mutex m map[string][]string + logger *ctrld.Logger } // init performs initialization works, which is necessary before hostsFile can be fully operated. @@ -55,7 +56,7 @@ func (hf *hostsFile) refresh() error { // override hosts file with host_entries.conf content if present. hem, err := parseHostEntriesConf(hostEntriesConfPath) if err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not read host_entries.conf file") + hf.logger.Debug().Err(err).Msg("could not read host_entries.conf file") } for k, v := range hem { hf.m[k] = v @@ -77,14 +78,14 @@ func (hf *hostsFile) watchChanges() { } if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { if err := hf.refresh(); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info") + hf.logger.Err(err).Msg("hosts file changed but failed to update client info") } } case err, ok := <-hf.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + hf.logger.Err(err).Msg("could not watch client info file") } } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index e009e01..ebdfabc 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -34,7 +34,8 @@ var ( ) type mdns struct { - name sync.Map // ip => hostname + name sync.Map // ip => hostname + logger *ctrld.Logger } func (m *mdns) LookupHostnameByIP(ip string) string { @@ -93,9 +94,9 @@ func (m *mdns) init(quitCh chan struct{}) error { } // Check if IPv6 is available once and use the result for the rest of the function. - ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init") + m.logger.Debug().Msgf("checking for IPv6 availability in mdns init") ipv6 := ctrldnet.IPv6Available(context.Background()) - ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6) + m.logger.Debug().Msgf("IPv6 is %v in mdns init", ipv6) v4ConnList := make([]*net.UDPConn, 0, len(ifaces)) v6ConnList := make([]*net.UDPConn, 0, len(ifaces)) @@ -129,11 +130,11 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan for { err := m.probe(conns, remoteAddr) if shouldStopProbing(err) { - ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err) + m.logger.Warn().Msgf("stop probing %q: %v", remoteAddr, err) break } if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("error while probing mdns") + m.logger.Warn().Err(err).Msg("error while probing mdns") bo.BackOff(context.Background(), errors.New("mdns probe backoff")) continue } @@ -161,7 +162,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if errors.Is(err, net.ErrClosed) { return } - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error") + m.logger.Debug().Err(err).Msg("mdns readLoop error") return } @@ -184,11 +185,11 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if ip != "" && name != "" { name = normalizeHostname(name) if val, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) + m.logger.Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) } else { old := val.(string) if old != name { - ctrld.ProxyLogger.Load().Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) + m.logger.Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) m.name.Store(ip, name) } } @@ -227,7 +228,7 @@ func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr) error { // getDataFromAvahiDaemonCache reads entries from avahi-daemon cache to update mdns data. func (m *mdns) getDataFromAvahiDaemonCache() { if _, err := exec.LookPath("avahi-browse"); err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not find avahi-browse binary, skipping.") + m.logger.Debug().Err(err).Msg("could not find avahi-browse binary, skipping.") return } // Run avahi-browse to discover services from cache: @@ -237,7 +238,7 @@ func (m *mdns) getDataFromAvahiDaemonCache() { // - "-c" -> read from cache. out, err := exec.Command("avahi-browse", "-a", "-r", "-p", "-c").Output() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not browse services from avahi cache") + m.logger.Debug().Err(err).Msg("could not browse services from avahi cache") return } m.storeDataFromAvahiBrowseOutput(bytes.NewReader(out)) @@ -257,7 +258,7 @@ func (m *mdns) storeDataFromAvahiBrowseOutput(r io.Reader) { name := normalizeHostname(fields[6]) // Only using cache value if we don't have existed one. if _, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip) + m.logger.Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip) } } } diff --git a/internal/clientinfo/mdns_test.go b/internal/clientinfo/mdns_test.go index e6f8698..28c23d9 100644 --- a/internal/clientinfo/mdns_test.go +++ b/internal/clientinfo/mdns_test.go @@ -3,6 +3,8 @@ package clientinfo import ( "strings" "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { @@ -11,7 +13,7 @@ func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { =;wlp0s20f3;IPv6;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" =;wlp0s20f3;IPv4;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" ` - m := &mdns{} + m := &mdns{logger: ctrld.NopLogger} m.storeDataFromAvahiBrowseOutput(strings.NewReader(content)) ip := "192.168.1.123" val, loaded := m.name.LoadOrStore(ip, "") diff --git a/internal/clientinfo/merlin.go b/internal/clientinfo/merlin.go index 8a39398..8ba6c5c 100644 --- a/internal/clientinfo/merlin.go +++ b/internal/clientinfo/merlin.go @@ -15,6 +15,7 @@ const merlinNvramCustomClientListKey = "custom_clientlist" type merlinDiscover struct { hostname sync.Map // mac => hostname + logger *ctrld.Logger } func (m *merlinDiscover) refresh() error { @@ -25,7 +26,7 @@ func (m *merlinDiscover) refresh() error { if err != nil { return err } - ctrld.ProxyLogger.Load().Debug().Msg("reading Merlin custom client list") + m.logger.Debug().Msg("reading Merlin custom client list") m.parseMerlinCustomClientList(out) return nil } diff --git a/internal/clientinfo/ndp.go b/internal/clientinfo/ndp.go index 9d9155d..87f86fe 100644 --- a/internal/clientinfo/ndp.go +++ b/internal/clientinfo/ndp.go @@ -20,8 +20,9 @@ import ( // ndpDiscover provides client discovery functionality using NDP protocol. type ndpDiscover struct { - mac sync.Map // ip => mac - ip sync.Map // mac => ip + mac sync.Map // ip => mac + ip sync.Map // mac => ip + logger *ctrld.Logger } // refresh re-scans the NDP table. @@ -97,7 +98,7 @@ func (nd *ndpDiscover) saveInfo(ip, mac string) { func (nd *ndpDiscover) listen(ctx context.Context) { ifis, err := allInterfacesWithV6LinkLocal() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces") + nd.logger.Debug().Err(err).Msg("failed to find valid ipv6 interfaces") return } for _, ifi := range ifis { @@ -110,11 +111,11 @@ func (nd *ndpDiscover) listen(ctx context.Context) { func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) { c, ip, err := ndp.Listen(ifi, ndp.Unspecified) if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed") + nd.logger.Debug().Err(err).Msg("ndp listen failed") return } defer c.Close() - ctrld.ProxyLogger.Load().Debug().Msgf("listening ndp on: %s", ip.String()) + nd.logger.Debug().Msgf("listening ndp on: %s", ip.String()) for { select { case <-ctx.Done(): @@ -128,7 +129,7 @@ func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface if errors.As(readErr, &opErr) && (opErr.Timeout() || opErr.Temporary()) { continue } - ctrld.ProxyLogger.Load().Debug().Err(readErr).Msg("ndp read loop error") + nd.logger.Debug().Err(readErr).Msg("ndp read loop error") return } diff --git a/internal/clientinfo/ndp_linux.go b/internal/clientinfo/ndp_linux.go index ebd416f..6658c78 100644 --- a/internal/clientinfo/ndp_linux.go +++ b/internal/clientinfo/ndp_linux.go @@ -5,15 +5,13 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. func (nd *ndpDiscover) scan() { neighs, err := netlink.NeighList(0, netlink.FAMILY_V6) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not get neigh list") + nd.logger.Warn().Err(err).Msg("could not get neigh list") return } @@ -34,7 +32,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.NeighSubscribe(ch, done); err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing") + nd.logger.Err(err).Msg("could not perform neighbor subscribing") return } for { @@ -47,7 +45,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { } ip := normalizeIP(nu.IP.String()) if nu.Type == unix.RTM_DELNEIGH { - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip) + nd.logger.Debug().Msgf("removing NDP neighbor: %s", ip) nd.mac.Delete(ip) continue } @@ -56,7 +54,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { case netlink.NUD_REACHABLE: nd.saveInfo(ip, mac) case netlink.NUD_FAILED: - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip) + nd.logger.Debug().Msgf("removing NDP neighbor with failed state: %s", ip) nd.mac.Delete(ip) } } diff --git a/internal/clientinfo/ndp_others.go b/internal/clientinfo/ndp_others.go index 007407b..33e95a5 100644 --- a/internal/clientinfo/ndp_others.go +++ b/internal/clientinfo/ndp_others.go @@ -7,8 +7,6 @@ import ( "context" "os/exec" "runtime" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. @@ -17,14 +15,14 @@ func (nd *ndpDiscover) scan() { case "windows": data, err := exec.Command("netsh", "interface", "ipv6", "show", "neighbors").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("could not query ndp table") return } nd.scanWindows(bytes.NewReader(data)) default: data, err := exec.Command("ndp", "-an").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("could not query ndp table") return } nd.scanUnix(bytes.NewReader(data)) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 8e6b3f7..b4783bd 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -17,6 +17,7 @@ type ptrDiscover struct { hostname sync.Map // ip => hostname resolver ctrld.Resolver serverDown atomic.Bool + logger *ctrld.Logger } func (p *ptrDiscover) refresh() error { @@ -73,14 +74,14 @@ func (p *ptrDiscover) lookupHostname(ip string) string { msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { if p.serverDown.CompareAndSwap(false, true) { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") go p.checkServer() } return "" diff --git a/internal/controld/config.go b/internal/controld/config.go index 595e758..97ec8e2 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -88,18 +88,18 @@ type LogsRequest struct { } // FetchResolverConfig fetch Control D config for given uid. -func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverConfig(ctx context.Context, rawUID, version string, cdDev bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // FetchResolverUID fetch resolver uid from provision token. -func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverUID(ctx context.Context, req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { if req == nil { return nil, errors.New("invalid request") } @@ -108,21 +108,21 @@ func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*Reso hostname, _ = os.Hostname() } body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname}) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // UpdateCustomLastFailed calls API to mark custom config is bad. -func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { +func UpdateCustomLastFailed(ctx context.Context, rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, true, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, true, bytes.NewReader(body)) } -func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { +func postUtilityAPI(ctx context.Context, version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev @@ -139,12 +139,12 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") - transport := apiTransport(cdDev) + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: defaultTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) } @@ -166,7 +166,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } // SendLogs sends runtime log to ControlD API. -func SendLogs(lr *LogsRequest, cdDev bool) error { +func SendLogs(ctx context.Context, lr *LogsRequest, cdDev bool) error { defer lr.Data.Close() apiUrl := logURLCom if cdDev { @@ -180,12 +180,12 @@ func SendLogs(lr *LogsRequest, cdDev bool) error { q.Set("uid", lr.UID) req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - transport := apiTransport(cdDev) + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: sendLogTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { return fmt.Errorf("SendLogs client.Do: %w", err) } @@ -213,7 +213,7 @@ func ParseRawUID(rawUID string) (string, string) { } // apiTransport returns an HTTP transport for connecting to ControlD API endpoint. -func apiTransport(cdDev bool) *http.Transport { +func apiTransport(loggerCtx context.Context, cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom @@ -227,9 +227,10 @@ func apiTransport(cdDev bool) *http.Transport { apiIPs = []string{apiDomainDevIPv4} } - ips := ctrld.LookupIP(apiDomain) + ips := ctrld.LookupIP(loggerCtx, apiDomain) if len(ips) == 0 { - ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) + logger := ctrld.LoggerFromCtx(loggerCtx) + logger.Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) ips = apiIPs } @@ -245,7 +246,8 @@ func apiTransport(cdDev bool) *http.Transport { dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) { d := &ctrldnet.ParallelDialer{} - return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load()) + logger := ctrld.LoggerFromCtx(loggerCtx) + return d.DialContext(ctx, network, addrs, logger.Logger) } _, port, _ := net.SplitHostPort(addr) @@ -283,10 +285,11 @@ func addrsFromPort(ips []string, port string) []string { return addrs } -func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { +func doWithFallback(ctx context.Context, client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { resp, err := client.Do(req) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) + logger := ctrld.LoggerFromCtx(ctx) + logger.Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) ipReq := req.Clone(req.Context()) ipReq.Host = apiIp ipReq.URL.Host = apiIp diff --git a/log.go b/log.go index 14c82e8..7b7037b 100644 --- a/log.go +++ b/log.go @@ -3,19 +3,37 @@ package ctrld import ( "context" "fmt" - "io" - "sync/atomic" "github.com/rs/zerolog" ) -// ProxyLog emits the log record for proxy operations. -// The caller should set it only once. -// DEPRECATED: use ProxyLogger instead. -var ProxyLog = zerolog.New(io.Discard) +// LoggerCtxKey is the context.Context key for a logger. +type LoggerCtxKey struct{} -// ProxyLogger emits the log record for proxy operations. -var ProxyLogger atomic.Pointer[zerolog.Logger] +// LoggerCtx returns a context.Context with LoggerCtxKey set. +func LoggerCtx(ctx context.Context, l *Logger) context.Context { + return context.WithValue(ctx, LoggerCtxKey{}, l) +} + +// A Logger provides fast, leveled, structured logging. +type Logger struct { + *zerolog.Logger +} + +var noOpZeroLogger = zerolog.Nop() + +// NopLogger returns a logger which all operation are no-op. +var NopLogger = &Logger{&noOpZeroLogger} + +// LoggerFromCtx returns the logger associated with given ctx. +// +// If there's no logger, a no-op logger will be returned. +func LoggerFromCtx(ctx context.Context) *Logger { + if logger, ok := ctx.Value(LoggerCtxKey{}).(*Logger); ok && logger != nil { + return logger + } + return NopLogger +} // ReqIdCtxKey is the context.Context key for a request id. type ReqIdCtxKey struct{} diff --git a/nameservers.go b/nameservers.go index 0aebf9e..07743ac 100644 --- a/nameservers.go +++ b/nameservers.go @@ -1,9 +1,11 @@ package ctrld -type dnsFn func() []string +import "context" + +type dnsFn func(ctx context.Context) []string // nameservers returns DNS nameservers from system settings. -func nameservers() []string { +func nameservers(ctx context.Context) []string { var dns []string seen := make(map[string]bool) ch := make(chan []string) @@ -11,7 +13,7 @@ func nameservers() []string { for _, fn := range fns { go func(fn dnsFn) { - ch <- fn() + ch <- fn(ctx) }(fn) } for range fns { diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 09c9516..15c30c9 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -3,6 +3,7 @@ package ctrld import ( + "context" "net" "syscall" @@ -13,7 +14,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dnsFromRIB} } -func dnsFromRIB() []string { +func dnsFromRIB(_ context.Context) []string { var dns []string rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { diff --git a/nameservers_darwin.go b/nameservers_darwin.go index c8fa78d..822893b 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -22,8 +22,8 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} } -func getDNSFromScutil() []string { - logger := *ProxyLogger.Load() +func getDNSFromScutil(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) const ( maxRetries = 10 @@ -109,8 +109,8 @@ func getDHCPNameservers(iface string) ([]string, error) { return nameservers, nil } -func getAllDHCPNameservers() []string { - logger := *ProxyLogger.Load() +func getAllDHCPNameservers(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) interfaces, err := net.Interfaces() if err != nil { diff --git a/nameservers_linux.go b/nameservers_linux.go index 13a5507..8f877a6 100644 --- a/nameservers_linux.go +++ b/nameservers_linux.go @@ -3,6 +3,7 @@ package ctrld import ( "bufio" "bytes" + "context" "encoding/hex" "net" "os" @@ -20,7 +21,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver} } -func dns4() []string { +func dns4(_ context.Context) []string { f, err := os.Open(v4RouteFile) if err != nil { return nil @@ -60,7 +61,7 @@ func dns4() []string { return dns } -func dns6() []string { +func dns6(_ context.Context) []string { f, err := os.Open(v6RouteFile) if err != nil { return nil @@ -94,7 +95,7 @@ func dns6() []string { return dns } -func dnsFromSystemdResolver() []string { +func dnsFromSystemdResolver(_ context.Context) []string { c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf") if err != nil { return nil diff --git a/nameservers_test.go b/nameservers_test.go index 166cced..e2e2bac 100644 --- a/nameservers_test.go +++ b/nameservers_test.go @@ -1,9 +1,12 @@ package ctrld -import "testing" +import ( + "context" + "testing" +) func TestNameservers(t *testing.T) { - ns := nameservers() + ns := nameservers(context.Background()) if len(ns) == 0 { t.Fatal("failed to get nameservers") } diff --git a/nameservers_unix.go b/nameservers_unix.go index d8e6035..8082c8a 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -3,6 +3,7 @@ package ctrld import ( + "context" "net" "slices" "time" @@ -20,7 +21,7 @@ func currentNameserversFromResolvconf() []string { // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. // A nameserver is usable if it's not one of current machine's IP addresses // and loopback IP addresses. -func dnsFromResolvConf() []string { +func dnsFromResolvConf(_ context.Context) []string { const ( maxRetries = 10 retryInterval = 100 * time.Millisecond diff --git a/nameservers_windows.go b/nameservers_windows.go index 4f6ca8e..bd8f564 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -55,28 +55,25 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } -func dnsFromAdapter() []string { +func dnsFromAdapter(ctx context.Context) []string { ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) defer cancel() var ns []string var err error - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { - Log(context.Background(), logger.Debug(), - "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + logger.Debug().Msgf("dnsFromAdapter lookup cancelled or timed out, attempt %d", i) return nil } ns, err = getDNSServers(ctx) if err == nil && len(ns) >= minDNSServers { if i > 0 { - Log(context.Background(), logger.Debug(), - "Successfully got DNS servers after %d attempts, found %d servers", - i+1, len(ns)) + logger.Debug().Msgf("Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) } return ns } @@ -88,11 +85,9 @@ func dnsFromAdapter() []string { } if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get DNS servers, attempt %d: %v", i+1, err) + logger.Debug().Msgf("Failed to get DNS servers, attempt %d: %v", i+1, err) } else { - Log(context.Background(), logger.Debug(), - "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + logger.Debug().Msgf("Got insufficient DNS servers, retrying, found %d servers", len(ns)) } select { @@ -102,14 +97,12 @@ func dnsFromAdapter() []string { } } - Log(context.Background(), logger.Debug(), - "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + logger.Debug().Msgf("Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns } func getDNSServers(ctx context.Context) ([]string, error) { - logger := *ProxyLogger.Load() - // Check context before making the call if ctx.Err() != nil { return nil, ctx.Err() @@ -124,17 +117,16 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("getting adapters: %w", err) } - Log(context.Background(), logger.Debug(), - "Found network adapters, count=%d", len(aas)) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Found network adapters, count=%d", len(aas)) // Try to get domain controller info if domain-joined var dcServers []string - isDomain := checkDomainJoined() + isDomain := checkDomainJoined(ctx) if isDomain { domainName, err := getLocalADDomain() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get local AD domain: %v", err) + logger.Debug().Msgf("Failed to get local AD domain: %v", err) } else { // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") @@ -145,11 +137,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to convert domain name to UTF16: %v", err) + logger.Debug().Msgf("Failed to convert domain name to UTF16: %v", err) } else { - Log(context.Background(), logger.Debug(), - "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + logger.Debug().Msgf("Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) // Call DsGetDcNameW with domain name ret, _, err := dsDcName.Call( @@ -163,20 +153,15 @@ func getDNSServers(ctx context.Context) ([]string, error) { if ret != 0 { switch ret { case 1355: // ERROR_NO_SUCH_DOMAIN - Log(context.Background(), logger.Debug(), - "Domain not found: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain not found: %s (%d)", domainName, ret) case 1311: // ERROR_NO_LOGON_SERVERS - Log(context.Background(), logger.Debug(), - "No logon servers available for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("No logon servers available for domain: %s (%d)", domainName, ret) case 1004: // ERROR_DC_NOT_FOUND - Log(context.Background(), logger.Debug(), - "Domain controller not found for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain controller not found for domain: %s (%d)", domainName, ret) case 1722: // RPC_S_SERVER_UNAVAILABLE - Log(context.Background(), logger.Debug(), - "RPC server unavailable for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("RPC server unavailable for domain: %s (%d)", domainName, ret) default: - Log(context.Background(), logger.Debug(), - "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + logger.Debug().Msgf("Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) } } else if info != nil { defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) @@ -184,17 +169,13 @@ func getDNSServers(ctx context.Context) ([]string, error) { if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), - "Found domain controller address: %s", dcAddr) - + logger.Debug().Msgf("Found domain controller address: %s", dcAddr) if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) - Log(context.Background(), logger.Debug(), - "Added domain controller DNS servers: %v", dcServers) + logger.Debug().Msgf("Added domain controller DNS servers: %v", dcServers) } } else { - Log(context.Background(), logger.Debug(), - "No domain controller address found") + logger.Debug().Msg("No domain controller address found") } } } @@ -209,31 +190,27 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Collect all local IPs for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { - Log(context.Background(), logger.Debug(), - "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + logger.Debug().Msgf("Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) continue } // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } - Log(context.Background(), logger.Debug(), - "Processing adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Processing adapter %s", aa.FriendlyName()) for a := aa.FirstUnicastAddress; a != nil; a = a.Next { ip := a.Address.IP().String() addressMap[ip] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP %s from adapter %s", ip, aa.FriendlyName()) + logger.Debug().Msgf("Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } - validInterfacesMap := validInterfaces() + validInterfacesMap := validInterfaces(ctx) // Collect DNS servers for _, aa := range aas { @@ -244,23 +221,20 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } // if not in the validInterfacesMap, skip if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { - Log(context.Background(), logger.Debug(), - "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) continue } for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { - Log(context.Background(), logger.Debug(), - "Skipping nil IP from adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Skipping nil IP from adapter %s", aa.FriendlyName()) continue } @@ -293,28 +267,23 @@ func getDNSServers(ctx context.Context) ([]string, error) { if !seen[dcServer] { seen[dcServer] = true ns = append(ns, dcServer) - Log(context.Background(), logger.Debug(), - "Added additional domain controller DNS server: %s", dcServer) + logger.Debug().Msgf("Added additional domain controller DNS server: %s", dcServer) } } // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get default route interface: %v", err) + logger.Debug().Msgf("Failed to get default route interface: %v", err) } else { drIface, err := net.InterfaceByName(drIfaceName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get interface by name %s: %v", drIfaceName, err) + logger.Debug().Msgf("Failed to get interface by name %s: %v", drIfaceName, err) } else { staticNs, file := SavedStaticNameserversAndPath(drIface) - Log(context.Background(), logger.Debug(), - "static dns servers from %s: %v", file, staticNs) + logger.Debug().Msgf("static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { - Log(context.Background(), logger.Debug(), - "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + logger.Debug().Msgf("Adding static DNS servers from %s: %v", drIfaceName, staticNs) ns = append(ns, staticNs...) } } @@ -324,9 +293,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("no valid DNS servers found") } - Log(context.Background(), logger.Debug(), - "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", - len(ns), ns, len(dcServers)) + logger.Debug().Msgf("DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers)) return ns, nil } @@ -337,33 +304,35 @@ func currentNameserversFromResolvconf() []string { // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available -func checkDomainJoined() bool { - logger := *ProxyLogger.Load() +func checkDomainJoined(ctx context.Context) bool { + logger := LoggerFromCtx(ctx) var domain *uint16 var status uint32 err := windows.NetGetJoinInformation(nil, &domain, &status) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get domain join status: %v", err) + logger.Debug().Msgf("Failed to get domain join status: %v", err) return false } defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) domainName := windows.UTF16PtrToString(domain) - Log(context.Background(), logger.Debug(), + logger.Debug().Msgf( "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", - domainName, status) + domainName, + status, + ) // Consider domain or cloud domain as domain-joined isDomain := status == NetSetupDomain || status == NetSetupCloudDomain - Log(context.Background(), logger.Debug(), + logger.Debug().Msgf( "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", status, status == NetSetupDomain, status == NetSetupCloudDomain, - isDomain) + isDomain, + ) return isDomain } @@ -411,12 +380,12 @@ func getLocalADDomain() (string, error) { // validInterfaces returns a list of all physical interfaces. // this is a duplicate of what is in net_windows.go, we should // clean this up so there is only one version -func validInterfaces() map[string]struct{} { +func validInterfaces(ctx context.Context) map[string]struct{} { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) //load the logger - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") @@ -425,23 +394,20 @@ func validInterfaces() map[string]struct{} { defer instances.Close() } if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get wmi network adapter: %v", err) + logger.Warn().Msgf("failed to get wmi network adapter: %v", err) return nil } var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get network adapter: %v", err) + logger.Warn().Msgf("failed to get network adapter: %v", err) continue } name, err := adapter.GetPropertyName() if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get interface name: %v", err) + logger.Warn().Msgf("failed to get interface name: %v", err) continue } @@ -451,13 +417,11 @@ func validInterfaces() map[string]struct{} { // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter connector present property: %v", err) + logger.Debug().Msgf("failed to get network adapter connector present property: %v", err) continue } if !physical { - Log(context.Background(), logger.Debug(), - "skipping non-physical adapter: %s", name) + logger.Debug().Msgf("skipping non-physical adapter: %s", name) continue } @@ -465,13 +429,11 @@ func validInterfaces() map[string]struct{} { // because some interfaces are not physical but have a connector. hardware, err := adapter.GetPropertyHardwareInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter hardware interface property: %v", err) + logger.Debug().Msgf("failed to get network adapter hardware interface property: %v", err) continue } if !hardware { - Log(context.Background(), logger.Debug(), - "skipping non-hardware interface: %s", name) + logger.Debug().Msgf("skipping non-hardware interface: %s", name) continue } diff --git a/net.go b/net.go index 7bbf54b..0f556f4 100644 --- a/net.go +++ b/net.go @@ -17,26 +17,27 @@ var ( ) // HasIPv6 reports whether the current network stack has IPv6 available. -func HasIPv6() bool { +func HasIPv6(ctx context.Context) bool { hasIPv6Once.Do(func() { - ProxyLogger.Load().Debug().Msg("checking for IPv6 availability once") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("checking for IPv6 availability once") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) ipv6Available.Store(val) - ProxyLogger.Load().Debug().Msgf("ipv6 availability: %v", val) + logger.Debug().Msgf("ipv6 availability: %v", val) mon, err := netmon.New(func(format string, args ...any) {}) if err != nil { - ProxyLogger.Load().Debug().Err(err).Msg("failed to monitor IPv6 state") + logger.Debug().Err(err).Msg("failed to monitor IPv6 state") return } mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { old := ipv6Available.Load() cur := delta.Monitor.InterfaceState().HaveV6 if old != cur { - ProxyLogger.Load().Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) + logger.Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) } else { - ProxyLogger.Load().Debug().Msg("ipv6 availability does not changed") + logger.Debug().Msg("ipv6 availability does not changed") } ipv6Available.Store(cur) }) @@ -46,8 +47,9 @@ func HasIPv6() bool { } // DisableIPv6 marks IPv6 as unavailable if enabled. -func DisableIPv6() { +func DisableIPv6(ctx context.Context) { if ipv6Available.CompareAndSwap(true, false) { - ProxyLogger.Load().Debug().Msg("turned off IPv6 availability") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("turned off IPv6 availability") } } diff --git a/resolver.go b/resolver.go index 27c0108..c88df1f 100644 --- a/resolver.go +++ b/resolver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/netip" "runtime" @@ -15,7 +14,6 @@ import ( "time" "github.com/miekg/dns" - "github.com/rs/zerolog" "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" @@ -50,10 +48,6 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") var localResolver Resolver func init() { - // Initializing ProxyLogger here, so other places don't have to do nil check. - l := zerolog.New(io.Discard) - ProxyLogger.Store(&l) - localResolver = newLocalResolver() } @@ -81,8 +75,8 @@ func LanQueryCtx(ctx context.Context) context.Context { } // defaultNameservers is like nameservers with each element formed "ip:53". -func defaultNameservers() []string { - ns := nameservers() +func defaultNameservers(ctx context.Context) []string { + ns := nameservers(ctx) nss := make([]string, len(ns)) for i := range ns { nss[i] = net.JoinHostPort(ns[i], "53") @@ -91,42 +85,36 @@ func defaultNameservers() []string { } // availableNameservers returns list of current available DNS servers of the system. -func availableNameservers() []string { +func availableNameservers(ctx context.Context) []string { var nss []string // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - //load the logger - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), - "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + // Load the logger. + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) for _, v := range slices.Concat(regularIPs, loopbackIPs) { ipStr := v.String() machineIPsMap[ipStr] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP to OS resolverexclusion map: %s", ipStr) + logger.Debug().Msgf("Added local IP to OS resolverexclusion map: %s", ipStr) } - systemNameservers := nameservers() - Log(context.Background(), logger.Debug(), - "Got system nameservers: %v", systemNameservers) + systemNameservers := nameservers(ctx) + logger.Debug().Msgf("Got system nameservers: %v", systemNameservers) for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { - Log(context.Background(), logger.Debug(), - "Skipping local nameserver: %s", ns) + logger.Debug().Msgf("Skipping local nameserver: %s", ns) continue } nss = append(nss, ns) - Log(context.Background(), logger.Debug(), - "Added non-local nameserver: %s", ns) + logger.Debug().Msgf("Added non-local nameserver: %s", ns) } - Log(context.Background(), logger.Debug(), - "Final available nameservers: %v", nss) + logger.Debug().Msgf("Final available nameservers: %v", nss) + return nss } @@ -135,8 +123,8 @@ func availableNameservers() []string { // // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. -func InitializeOsResolver(guardAgainstNoNameservers bool) []string { - nameservers := availableNameservers() +func InitializeOsResolver(ctx context.Context, guardAgainstNoNameservers bool) []string { + nameservers := availableNameservers(ctx) // if no nameservers, return empty slice so we dont remove all nameservers if len(nameservers) == 0 && guardAgainstNoNameservers { return []string{} @@ -188,7 +176,7 @@ type Resolver interface { var errUnknownResolver = errors.New("unknown resolver") // NewResolver creates a Resolver based on the given upstream config. -func NewResolver(uc *UpstreamConfig) (Resolver, error) { +func NewResolver(ctx context.Context, uc *UpstreamConfig) (Resolver, error) { typ := uc.Type switch typ { case ResolverTypeDOH, ResolverTypeDOH3: @@ -200,15 +188,16 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeOS: resolverMutex.Lock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver") - or = newResolverWithNameserver(defaultNameservers()) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Initialize new OS resolver") + or = newResolverWithNameserver(defaultNameservers(ctx)) } resolverMutex.Unlock() return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil case ResolverTypePrivate: - return NewPrivateResolver(), nil + return NewPrivateResolver(ctx), nil case ResolverTypeLocal: return localResolver, nil } @@ -235,14 +224,16 @@ type publicResponse struct { } // SetDefaultLocalIPv4 updates the stored local IPv4. -func SetDefaultLocalIPv4(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) +func SetDefaultLocalIPv4(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv4: %s", ip) defaultLocalIPv4.Store(ip) } // SetDefaultLocalIPv6 updates the stored local IPv6. -func SetDefaultLocalIPv6(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) +func SetDefaultLocalIPv6(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv6: %s", ip) defaultLocalIPv6.Store(ip) } @@ -300,10 +291,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // Unique key for the singleflight group. key := fmt.Sprintf("%s:%d:", domain, qtype) + logger := LoggerFromCtx(ctx) // Checking the cache first. if val, ok := o.cache.Load(key); ok { if val, ok := val.(*dns.Msg); ok { - Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() SetCacheReply(res, msg, val.Rcode) return res, nil @@ -338,7 +330,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error res := sharedMsg.Copy() SetCacheReply(res, msg, sharedMsg.Rcode) if shared { - Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) } return res, nil @@ -368,7 +360,8 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if msg != nil && len(msg.Question) > 0 { question = msg.Question[0].Name } - Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) // New check: If no resolvers are available, return an error. if numServers == 0 { @@ -417,7 +410,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // If splitting fails, fallback to the original server string host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) + Log(ctx, logger.Debug(), "got answer from nameserver: %s", host) } // try local nameservers @@ -444,7 +437,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error switch { case res.lan: // Always prefer LAN responses immediately - Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) + Log(ctx, logger.Debug(), "using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -454,7 +447,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // if there are no LAN nameservers, we should not wait // just use the first response if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server) + Log(ctx, logger.Debug(), "using public answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -465,12 +458,12 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }) } case res.answer != nil: - Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + Log(ctx, logger.Debug(), "got non-success answer from: %s with code: %d", res.server, res.answer.Rcode) // When there are no LAN nameservers, we should not wait // for other nameservers to respond. if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer") + Log(ctx, logger.Debug(), "no lan nameservers using public non success answer") cancel() logAnswer(res.server) return res.answer, nil @@ -483,17 +476,17 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if len(publicResponses) > 0 { resp := publicResponses[0] - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server) + Log(ctx, logger.Debug(), "using public answer from: %s", resp.server) logAnswer(resp.server) return resp.answer, nil } if controldSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) + Log(ctx, logger.Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer) + Log(ctx, logger.Debug(), "using non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } @@ -515,7 +508,7 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - _, udpNet := r.uc.netForDNSType(dnsTyp) + _, udpNet := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: udpNet, Dialer: dialer, @@ -541,39 +534,43 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err // LookupIP looks up domain using current system nameservers settings. // It returns a slice of that host's IPv4 and IPv6 addresses. -func LookupIP(domain string) []string { - nss := initDefaultOsResolver() - return lookupIP(domain, -1, nss) +func LookupIP(ctx context.Context, domain string) []string { + nss := initDefaultOsResolver(ctx) + return lookupIP(ctx, domain, -1, nss) } // initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet. // It returns the combined list of LAN and public nameservers currently held by the resolver. -func initDefaultOsResolver() []string { +func initDefaultOsResolver(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) resolverMutex.Lock() defer resolverMutex.Unlock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers") - or = newResolverWithNameserver(defaultNameservers()) + logger.Debug().Msgf("Initialize new OS resolver with default nameservers") + or = newResolverWithNameserver(defaultNameservers(ctx)) } nss := *or.lanServers.Load() nss = append(nss, *or.publicServers.Load()...) return nss + } // lookupIP looks up domain with given timeout and bootstrapDNS. // If the timeout is negative, default timeout 2000 ms will be used. // It returns nil if bootstrapDNS is nil or empty. -func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) { +func lookupIP(ctx context.Context, domain string, timeout int, bootstrapDNS []string) (ips []string) { if net.ParseIP(domain) != nil { return []string{domain} } + logger := LoggerFromCtx(ctx) if bootstrapDNS == nil { - ProxyLogger.Load().Debug().Msgf("empty bootstrap DNS") + logger.Debug().Msgf("empty bootstrap DNS") return nil } resolver := newResolverWithNameserver(bootstrapDNS) - ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS) + logger.Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS) + timeoutMs := 2000 if timeout > 0 && timeout < timeoutMs { timeoutMs = timeout @@ -616,15 +613,15 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) r, err := resolver.Resolve(ctx, m) if err != nil { - ProxyLogger.Load().Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) + logger.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) return } if r.Rcode != dns.RcodeSuccess { - ProxyLogger.Load().Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) + logger.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) return } if len(r.Answer) == 0 { - ProxyLogger.Load().Error().Msg("no answer from OS resolver") + logger.Error().Msg("no answer from OS resolver") return } target := targetDomain(r.Answer) @@ -641,22 +638,6 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) return ips } -// NewBootstrapResolver returns an OS resolver, which use following nameservers: -// -// - Gateway IP address (depends on OS). -// - Input servers. -func NewBootstrapResolver(servers ...string) Resolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers) - nss := defaultNameservers() - nss = append([]string{controldPublicDnsWithPort}, nss...) - for _, ns := range servers { - nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) - } - return NewResolverWithNameserver(nss) -} - // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, // excluding: // @@ -664,8 +645,8 @@ func NewBootstrapResolver(servers ...string) Resolver { // - Nameservers which is local RFC1918 addresses. // // This is useful for doing PTR lookup in LAN network. -func NewPrivateResolver() Resolver { - nss := initDefaultOsResolver() +func NewPrivateResolver(ctx context.Context) Resolver { + nss := initDefaultOsResolver(ctx) resolveConfNss := currentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 diff --git a/resolver_test.go b/resolver_test.go index ebcad16..d5a76d6 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -132,7 +132,7 @@ func Test_osResolver_InitializationRace(t *testing.T) { for range n { go func() { defer wg.Done() - InitializeOsResolver(false) + InitializeOsResolver(LoggerCtx(context.Background(), nil), false) }() } wg.Wait()