mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-27 12:52:27 +02:00
all: eliminate usage of global ProxyLogger
So setting up logging for ctrld binary and ctrld packages could be done more easily, decouple the required setup for interactive vs daemon running. This is the first step toward replacing rs/zerolog libary with a different logging library.
This commit is contained in:
committed by
Cuong Manh Le
parent
5641aab5bd
commit
fc527dbdfb
+16
-16
@@ -349,7 +349,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath {
|
||||
// After processCDFlags, log config may change, so reset mainLog and re-init logging.
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
|
||||
// Copy logs written so far to new log file if possible.
|
||||
if buf, err := os.ReadFile(oldLogPath); err == nil {
|
||||
@@ -502,8 +502,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
nop := zerolog.Nop()
|
||||
_, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true)
|
||||
_, _ = tryUpdateListenerConfig(&cfg, func() {}, true)
|
||||
addExtraSplitDnsRule(&cfg)
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
@@ -591,7 +590,8 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
Type: pType,
|
||||
Timeout: 5000,
|
||||
}
|
||||
puc.Init()
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
puc.Init(loggerCtx)
|
||||
upstream := map[string]*ctrld.UpstreamConfig{"0": puc}
|
||||
if secondaryUpstream != "" {
|
||||
sEndpoint, sType := endpointAndTyp(secondaryUpstream)
|
||||
@@ -601,7 +601,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
Type: sType,
|
||||
Timeout: 5000,
|
||||
}
|
||||
suc.Init()
|
||||
suc.Init(loggerCtx)
|
||||
upstream["1"] = suc
|
||||
rules := make([]ctrld.Rule, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
@@ -634,13 +634,13 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
|
||||
bo.LogLongerThan = 30 * time.Second
|
||||
ctx := context.Background()
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
ctx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev)
|
||||
for {
|
||||
if errUrlNetworkError(err) {
|
||||
bo.BackOff(ctx, err)
|
||||
logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...")
|
||||
resolverConfig, err = controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, rootCmd.Version, cdDev)
|
||||
continue
|
||||
}
|
||||
break
|
||||
@@ -938,9 +938,10 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri
|
||||
bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr))
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("self-check against %q failed", domain)
|
||||
loggerCtx := ctrld.LoggerCtx(ctx, mainLog.Load())
|
||||
// Ping all upstreams to provide better error message to users.
|
||||
for name, uc := range cfg.Upstream {
|
||||
if err := uc.ErrorPing(); err != nil {
|
||||
if err := uc.ErrorPing(loggerCtx); err != nil {
|
||||
mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
@@ -1181,7 +1182,7 @@ func mobileListenerIp() string {
|
||||
// or defined but invalid to be used, e.g: using loopback address other
|
||||
// than 127.0.0.1 with systemd-resolved.
|
||||
func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool {
|
||||
updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true)
|
||||
updated, _ := tryUpdateListenerConfig(cfg, notifyToLogServerFunc, true)
|
||||
if addExtraSplitDnsRule(cfg) {
|
||||
updated = true
|
||||
}
|
||||
@@ -1191,7 +1192,7 @@ func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool
|
||||
// tryUpdateListenerConfig tries updating listener config with a working one.
|
||||
// If fatal is true, and there's listen address conflicted, the function do
|
||||
// fatal error.
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) {
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (updated, ok bool) {
|
||||
ok = true
|
||||
lcc := make(map[string]*listenerConfigCheck)
|
||||
cdMode := cdUID != ""
|
||||
@@ -1235,9 +1236,6 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti
|
||||
}
|
||||
|
||||
il := mainLog.Load()
|
||||
if infoLogger != nil {
|
||||
il = infoLogger
|
||||
}
|
||||
if isMobile() {
|
||||
// On Mobile, only use first listener, ignore others.
|
||||
firstLn := cfg.FirstListener()
|
||||
@@ -1492,7 +1490,8 @@ func cdUIDFromProvToken() string {
|
||||
}
|
||||
req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname}
|
||||
// Process provision token if provided.
|
||||
resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev)
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, rootCmd.Version, cdDev)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg)
|
||||
}
|
||||
@@ -1819,7 +1818,8 @@ func runningIface(s service.Service) *ifaceResponse {
|
||||
|
||||
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
|
||||
func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
|
||||
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev)
|
||||
if err != nil {
|
||||
logger := mainLog.Load().Fatal()
|
||||
if !fatal {
|
||||
|
||||
@@ -216,8 +216,9 @@ func (p *prog) registerControlServerHandler() {
|
||||
return
|
||||
}
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil {
|
||||
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
@@ -321,7 +322,8 @@ func (p *prog) registerControlServerHandler() {
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
if err := controld.SendLogs(req, cdDev); err != nil {
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
+17
-14
@@ -110,6 +110,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctx = ctrld.LoggerCtx(ctx, mainLog.Load())
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
answer := new(dns.Msg)
|
||||
@@ -514,7 +515,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
@@ -549,11 +550,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap()
|
||||
upstreamConfig.ReBootstrap(ctx)
|
||||
}
|
||||
// For network error, turn ipv6 off if enabled.
|
||||
if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) {
|
||||
ctrld.DisableIPv6()
|
||||
if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) {
|
||||
ctrld.DisableIPv6(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -960,7 +961,8 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
||||
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
|
||||
if p.refusedQueryCount > selfUninstallMaxQueries {
|
||||
p.checkingSelfUninstall = true
|
||||
_, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
_, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev)
|
||||
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
|
||||
selfUninstallCheck(err, p, logger)
|
||||
|
||||
@@ -1326,13 +1328,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, mainLog.Load()), ip)
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ip)
|
||||
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, mainLog.Load()), ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
@@ -1400,7 +1402,7 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
||||
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
||||
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
return err
|
||||
@@ -1418,7 +1420,7 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
uc.ReBootstrap()
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load()))
|
||||
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
|
||||
start := time.Now()
|
||||
@@ -1474,10 +1476,11 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
// will be appended to nameservers from the saved interface values
|
||||
p.resetDNS(false, false)
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
||||
if reason == RecoveryReasonOSFailure {
|
||||
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
@@ -1504,7 +1507,7 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
|
||||
// For network changes we also reinitialize the OS resolver.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
ns := ctrld.InitializeOsResolver(loggerCtx, true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
||||
} else {
|
||||
@@ -1564,7 +1567,7 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
|
||||
// we should try to reinit the OS resolver to ensure we can recover
|
||||
if name == upstreamOS && attempts%3 == 0 {
|
||||
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, mainLog.Load()), true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
@@ -1624,12 +1627,12 @@ func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
||||
// Check if the default IPv4 is still active.
|
||||
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
||||
ctrld.SetDefaultLocalIPv4(nil)
|
||||
ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
||||
}
|
||||
|
||||
// Check if the default IPv6 is still active.
|
||||
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
||||
ctrld.SetDefaultLocalIPv6(nil)
|
||||
ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,8 +137,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
|
||||
+2
-1
@@ -102,6 +102,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
@@ -109,7 +110,7 @@ func (p *prog) checkDnsLoop() {
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
resolver, err := ctrld.NewResolver(loggerCtx, uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
|
||||
+4
-10
@@ -40,7 +40,7 @@ var (
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
mainLog atomic.Pointer[ctrld.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
)
|
||||
@@ -54,7 +54,7 @@ const (
|
||||
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
}
|
||||
|
||||
func Main() {
|
||||
@@ -87,16 +87,14 @@ func initConsoleLogging() {
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
@@ -113,8 +111,6 @@ func initInteractiveLogging() {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
l := zerolog.New(io.Discard)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -153,9 +149,7 @@ func initLoggingWithBackup(doBackup bool) []io.Writer {
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
logLevel := cfg.Service.LogLevel
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
l := zerolog.New(&logOutput)
|
||||
mainLog.Store(&l)
|
||||
mainLog.Store(&ctrld.Logger{Logger: &l})
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
@@ -26,7 +28,7 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, mainLog.Load()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+9
-7
@@ -286,7 +286,7 @@ func (p *prog) postRun() {
|
||||
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
|
||||
}
|
||||
p.resetDNS(false, false)
|
||||
ns := ctrld.InitializeOsResolver(false)
|
||||
ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), mainLog.Load()), false)
|
||||
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
p.setDNS()
|
||||
p.csSetDnsDone <- struct{}{}
|
||||
@@ -319,7 +319,8 @@ func (p *prog) apiConfigReload() {
|
||||
}
|
||||
|
||||
doReloadApiConfig := func(forced bool, logger zerolog.Logger) {
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, rootCmd.Version, cdDev)
|
||||
selfUninstallCheck(err, p, logger)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
@@ -377,7 +378,7 @@ func (p *prog) apiConfigReload() {
|
||||
}
|
||||
if cfgErr != nil {
|
||||
logger.Warn().Err(err).Msg("skipping invalid custom config")
|
||||
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
|
||||
if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, rootCmd.Version, cdDev, true); err != nil {
|
||||
logger.Error().Err(err).Msg("could not mark custom last update failed")
|
||||
}
|
||||
return
|
||||
@@ -404,22 +405,23 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
isControlDUpstream := false
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load())
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
sdns := uc.Type == ctrld.ResolverTypeSDNS
|
||||
uc.Init()
|
||||
uc.Init(loggerCtx)
|
||||
if sdns {
|
||||
mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type)
|
||||
}
|
||||
isControlDUpstream = isControlDUpstream || uc.IsControlD()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
go uc.Ping(loggerCtx)
|
||||
|
||||
if canBeLocalUpstream(uc.Domain) {
|
||||
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
||||
@@ -601,7 +603,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
|
||||
// setupClientInfoDiscover performs necessary works for running client info discover.
|
||||
func (p *prog) setupClientInfoDiscover(selfIP string) {
|
||||
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, mainLog.Load())
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
|
||||
@@ -325,12 +325,13 @@ type ListenerPolicyConfig struct {
|
||||
type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
func (uc *UpstreamConfig) Init(ctx context.Context) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps")
|
||||
logger.Fatal().Err(err).Msg("invalid DNS Stamps")
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
uc.uid = upstreamUID()
|
||||
uc.uid = upstreamUID(ctx)
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Hostname()
|
||||
switch uc.Type {
|
||||
@@ -434,12 +435,13 @@ func (uc *UpstreamConfig) UID() string {
|
||||
// - ControlD Bootstrap DNS 76.76.2.22
|
||||
//
|
||||
// The setup process will block until there's usable IPs found.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
nss := initDefaultOsResolver(ctx)
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
@@ -454,18 +456,18 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.bootstrapIPs = uc.bootstrapIPs[:n]
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain)
|
||||
logger.Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain)
|
||||
}
|
||||
}
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
logger.Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
|
||||
}
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...")
|
||||
logger.Warn().Msg("could not resolve bootstrap IPs, retrying...")
|
||||
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
|
||||
}
|
||||
for _, ip := range uc.bootstrapIPs {
|
||||
@@ -475,11 +477,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
|
||||
}
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
logger.Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
// ReBootstrap re-setup the bootstrap IP and the transport.
|
||||
func (uc *UpstreamConfig) ReBootstrap() {
|
||||
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -487,7 +489,8 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
@@ -495,35 +498,35 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
|
||||
// SetupTransport initializes the network transport used to connect to upstream server.
|
||||
// For now, only DoH upstream is supported.
|
||||
func (uc *UpstreamConfig) SetupTransport() {
|
||||
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
uc.setupDOHTransport()
|
||||
uc.setupDOHTransport(ctx)
|
||||
case ResolverTypeDOH3:
|
||||
uc.setupDOH3Transport()
|
||||
uc.setupDOH3Transport(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if HasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
}
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConnsPerHost = 100
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
@@ -543,12 +546,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
}
|
||||
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
|
||||
logger := LoggerFromCtx(ctx)
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
if uc.BootstrapIP != "" {
|
||||
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
|
||||
addr := net.JoinHostPort(uc.BootstrapIP, port)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr)
|
||||
logger.Debug().Msgf("sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
pd := &ctrldnet.ParallelDialer{}
|
||||
@@ -558,11 +562,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr())
|
||||
logger.Debug().Msgf("sending doh request to: %s", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
}
|
||||
runtime.SetFinalizer(transport, func(transport *http.Transport) {
|
||||
@@ -572,19 +576,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
}
|
||||
|
||||
// Ping warms up the connection to DoH/DoH3 upstream.
|
||||
func (uc *UpstreamConfig) Ping() {
|
||||
if err := uc.ping(); err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP()
|
||||
func (uc *UpstreamConfig) Ping(ctx context.Context) {
|
||||
if err := uc.ping(ctx); err != nil {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPing is like Ping, but return an error if any.
|
||||
func (uc *UpstreamConfig) ErrorPing() error {
|
||||
return uc.ping()
|
||||
func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error {
|
||||
return uc.ping(ctx)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) ping() error {
|
||||
func (uc *UpstreamConfig) ping(ctx context.Context) error {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -613,11 +618,11 @@ func (uc *UpstreamConfig) ping() error {
|
||||
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
if err := ping(uc.dohTransport(typ)); err != nil {
|
||||
if err := ping(uc.dohTransport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
case ResolverTypeDOH3:
|
||||
if err := ping(uc.doh3Transport(typ)); err != nil {
|
||||
if err := ping(uc.doh3Transport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -652,12 +657,12 @@ func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
@@ -673,7 +678,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
return uc.transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return pick(uc.bootstrapIPs)
|
||||
@@ -686,7 +691,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if HasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
@@ -695,7 +700,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
return pick(uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return "tcp-tls", "udp"
|
||||
@@ -708,7 +713,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if HasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
@@ -789,7 +794,7 @@ func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context
|
||||
}
|
||||
|
||||
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP() bool {
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
|
||||
if !uc.IsControlD() {
|
||||
return false
|
||||
}
|
||||
@@ -808,7 +813,8 @@ func (uc *UpstreamConfig) FallbackToDirectIP() bool {
|
||||
default:
|
||||
return
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
uc.u.Host = ip
|
||||
done = true
|
||||
})
|
||||
@@ -942,11 +948,12 @@ func pick(s []string) string {
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
func upstreamUID(ctx context.Context) string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
logger.Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -36,10 +37,10 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
|
||||
// t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.SetupBootstrapIP(context.Background())
|
||||
if len(tc.uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers())
|
||||
t.Log(defaultNameservers(context.Background()))
|
||||
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
|
||||
}
|
||||
})
|
||||
@@ -355,7 +356,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
@@ -497,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
|
||||
+15
-14
@@ -14,34 +14,35 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
if HasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
logger := LoggerFromCtx(ctx)
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
logger.Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,7 +62,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
logger.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
@@ -70,12 +71,12 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
return rt
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
|
||||
@@ -105,19 +105,20 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
|
||||
c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)}
|
||||
if r.isDoH3 {
|
||||
transport := r.uc.doh3Transport(dnsTyp)
|
||||
transport := r.uc.doh3Transport(ctx, dnsTyp)
|
||||
if transport == nil {
|
||||
return nil, errors.New("DoH3 is not supported")
|
||||
}
|
||||
c.Transport = transport
|
||||
}
|
||||
resp, err := c.Do(req)
|
||||
if err != nil && r.uc.FallbackToDirectIP() {
|
||||
if err != nil && r.uc.FallbackToDirectIP(ctx) {
|
||||
retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx))
|
||||
defer cancel()
|
||||
Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip")
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Warn().Err(err).Msg("retrying request after fallback to direct ip")
|
||||
resp, err = c.Do(req.Clone(retryCtx))
|
||||
}
|
||||
if err != nil {
|
||||
@@ -163,7 +164,8 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
}
|
||||
}
|
||||
if printed {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader)
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("sending request header: %v", dohHeader)
|
||||
}
|
||||
dohHeader.Set("Content-Type", headerApplicationDNS)
|
||||
dohHeader.Set("Accept", headerApplicationDNS)
|
||||
|
||||
+5
-4
@@ -157,20 +157,21 @@ func Test_ClientCertificateVerificationError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
r, err := NewResolver(tc.uc)
|
||||
tc.uc.Init(ctx)
|
||||
tc.uc.SetupBootstrapIP(ctx)
|
||||
r, err := NewResolver(ctx, tc.uc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("verify.controld.com.", dns.TypeA)
|
||||
msg.RecursionDesired = true
|
||||
_, err = r.Resolve(context.Background(), msg)
|
||||
_, err = r.Resolve(ctx, msg)
|
||||
// Verify the error contains the expected certificate information
|
||||
if err == nil {
|
||||
t.Fatal("expected certificate verification error, got nil")
|
||||
|
||||
@@ -26,7 +26,7 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
|
||||
ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp)
|
||||
}
|
||||
tlsConfig.ServerName = r.uc.Domain
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
|
||||
@@ -23,7 +23,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
|
||||
tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: tcpNet,
|
||||
Dialer: dialer,
|
||||
|
||||
@@ -79,6 +79,7 @@ type Table struct {
|
||||
initOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
refreshInterval int
|
||||
logger *ctrld.Logger
|
||||
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
@@ -98,11 +99,14 @@ type Table struct {
|
||||
ptrNameservers []string
|
||||
}
|
||||
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrld.Logger) *Table {
|
||||
refreshInterval := cfg.Service.DiscoverRefreshInterval
|
||||
if refreshInterval <= 0 {
|
||||
refreshInterval = 2 * 60 // 2 minutes
|
||||
}
|
||||
if logger == nil {
|
||||
logger = ctrld.NopLogger
|
||||
}
|
||||
return &Table{
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
@@ -111,6 +115,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
refreshInterval: refreshInterval,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +184,7 @@ func (t *Table) SetSelfIP(ip string) {
|
||||
|
||||
// initSelfDiscover initializes necessary client metadata for self query.
|
||||
func (t *Table) initSelfDiscover() {
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger}
|
||||
t.dhcp.addSelf()
|
||||
t.ipResolvers = append(t.ipResolvers, t.dhcp)
|
||||
t.macResolvers = append(t.macResolvers, t.dhcp)
|
||||
@@ -189,14 +194,14 @@ func (t *Table) initSelfDiscover() {
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id")
|
||||
t.logger.Debug().Msg("start self discovery with custom client id")
|
||||
t.initSelfDiscover()
|
||||
return
|
||||
}
|
||||
|
||||
// If we are running on platforms that should only do self discover, use it as the only source, too.
|
||||
if ctrld.SelfDiscover() {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms")
|
||||
t.logger.Debug().Msg("start self discovery on desktop platforms")
|
||||
t.initSelfDiscover()
|
||||
return
|
||||
}
|
||||
@@ -208,7 +213,7 @@ func (t *Table) init() {
|
||||
// - Merlin
|
||||
// - Ubios
|
||||
if t.discoverDHCP() || t.discoverARP() {
|
||||
t.merlin = &merlinDiscover{}
|
||||
t.merlin = &merlinDiscover{logger: t.logger}
|
||||
t.ubios = &ubiosDiscover{}
|
||||
discovers := map[string]interface {
|
||||
refresher
|
||||
@@ -219,7 +224,7 @@ func (t *Table) init() {
|
||||
}
|
||||
for platform, discover := range discovers {
|
||||
if err := discover.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform)
|
||||
t.logger.Warn().Err(err).Msgf("failed to init %s discover", platform)
|
||||
}
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, discover)
|
||||
t.refreshers = append(t.refreshers, discover)
|
||||
@@ -227,10 +232,10 @@ func (t *Table) init() {
|
||||
}
|
||||
// Hosts file mapping.
|
||||
if t.discoverHosts() {
|
||||
t.hf = &hostsFile{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery")
|
||||
t.hf = &hostsFile{logger: t.logger}
|
||||
t.logger.Debug().Msg("start hosts file discovery")
|
||||
if err := t.hf.init(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover")
|
||||
t.logger.Error().Err(err).Msg("could not init hosts file discover")
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.hf)
|
||||
t.refreshers = append(t.refreshers, t.hf)
|
||||
@@ -239,10 +244,10 @@ func (t *Table) init() {
|
||||
}
|
||||
// DHCP lease files.
|
||||
if t.discoverDHCP() {
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery")
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger}
|
||||
t.logger.Debug().Msg("start dhcp discovery")
|
||||
if err := t.dhcp.init(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init DHCP discover")
|
||||
t.logger.Error().Err(err).Msg("could not init DHCP discover")
|
||||
} else {
|
||||
t.ipResolvers = append(t.ipResolvers, t.dhcp)
|
||||
t.macResolvers = append(t.macResolvers, t.dhcp)
|
||||
@@ -253,8 +258,8 @@ func (t *Table) init() {
|
||||
// ARP/NDP table.
|
||||
if t.discoverARP() {
|
||||
t.arp = &arpDiscover{}
|
||||
t.ndp = &ndpDiscover{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery")
|
||||
t.ndp = &ndpDiscover{logger: t.logger}
|
||||
t.logger.Debug().Msg("start arp discovery")
|
||||
discovers := map[string]interface {
|
||||
refresher
|
||||
IpResolver
|
||||
@@ -266,7 +271,7 @@ func (t *Table) init() {
|
||||
|
||||
for protocol, discover := range discovers {
|
||||
if err := discover.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", protocol)
|
||||
t.logger.Error().Err(err).Msgf("could not init %s discover", protocol)
|
||||
} else {
|
||||
t.ipResolvers = append(t.ipResolvers, discover)
|
||||
t.macResolvers = append(t.macResolvers, discover)
|
||||
@@ -283,7 +288,10 @@ func (t *Table) init() {
|
||||
}
|
||||
// PTR lookup.
|
||||
if t.discoverPTR() {
|
||||
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
|
||||
t.ptr = &ptrDiscover{
|
||||
resolver: ctrld.NewPrivateResolver(context.Background()),
|
||||
logger: t.logger,
|
||||
}
|
||||
if len(t.ptrNameservers) > 0 {
|
||||
nss := make([]string, 0, len(t.ptrNameservers))
|
||||
for _, ns := range t.ptrNameservers {
|
||||
@@ -295,18 +303,18 @@ func (t *Table) init() {
|
||||
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
|
||||
nss = append(nss, net.JoinHostPort(host, port))
|
||||
} else {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
|
||||
t.logger.Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
|
||||
}
|
||||
}
|
||||
if len(nss) > 0 {
|
||||
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
|
||||
t.logger.Debug().Msgf("using nameservers %v for ptr discovery", nss)
|
||||
}
|
||||
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
|
||||
t.logger.Debug().Msg("start ptr discovery")
|
||||
if err := t.ptr.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover")
|
||||
t.logger.Error().Err(err).Msg("could not init PTR discover")
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.ptr)
|
||||
t.refreshers = append(t.refreshers, t.ptr)
|
||||
@@ -314,10 +322,10 @@ func (t *Table) init() {
|
||||
}
|
||||
// mdns.
|
||||
if t.discoverMDNS() {
|
||||
t.mdns = &mdns{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery")
|
||||
t.mdns = &mdns{logger: t.logger}
|
||||
t.logger.Debug().Msg("start mdns discovery")
|
||||
if err := t.mdns.init(t.quitCh); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init mDNS discover")
|
||||
t.logger.Error().Err(err).Msg("could not init mDNS discover")
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.mdns)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package clientinfo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_normalizeIP(t *testing.T) {
|
||||
@@ -28,8 +30,9 @@ func Test_normalizeIP(t *testing.T) {
|
||||
|
||||
func TestTable_LookupRFC1918IPv4(t *testing.T) {
|
||||
table := &Table{
|
||||
dhcp: &dhcp{},
|
||||
arp: &arpDiscover{},
|
||||
dhcp: &dhcp{},
|
||||
arp: &arpDiscover{},
|
||||
logger: ctrld.NopLogger,
|
||||
}
|
||||
|
||||
table.ipResolvers = append(table.ipResolvers, table.dhcp)
|
||||
|
||||
+10
-10
@@ -13,9 +13,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/util/lineread"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
@@ -30,6 +29,7 @@ type dhcp struct {
|
||||
|
||||
watcher *fsnotify.Watcher
|
||||
selfIP string
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
func (d *dhcp) init() error {
|
||||
@@ -52,7 +52,7 @@ func (d *dhcp) watchChanges() {
|
||||
}
|
||||
if dir := router.LeaseFilesDir(); dir != "" {
|
||||
if err := d.watcher.Add(dir); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir")
|
||||
d.logger.Err(err).Str("dir", dir).Msg("could not watch lease dir")
|
||||
}
|
||||
}
|
||||
for {
|
||||
@@ -64,7 +64,7 @@ func (d *dhcp) watchChanges() {
|
||||
if event.Has(fsnotify.Create) {
|
||||
if format, ok := clientInfoFiles[event.Name]; ok {
|
||||
if err := d.addLeaseFile(event.Name, format); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file")
|
||||
d.logger.Err(err).Str("file", event.Name).Msg("could not add lease file")
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -72,14 +72,14 @@ func (d *dhcp) watchChanges() {
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
format := clientInfoFiles[event.Name]
|
||||
if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info")
|
||||
d.logger.Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info")
|
||||
}
|
||||
}
|
||||
case err, ok := <-d.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file")
|
||||
d.logger.Err(err).Msg("could not watch client info file")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ func (d *dhcp) dnsmasqReadClientInfoReader(reader io.Reader) error {
|
||||
}
|
||||
ip := normalizeIP(string(fields[2]))
|
||||
if net.ParseIP(ip) == nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
d.logger.Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
ip = ""
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ func (d *dhcp) iscDHCPReadClientInfoReader(reader io.Reader) error {
|
||||
case "lease":
|
||||
ip = normalizeIP(strings.ToLower(fields[1]))
|
||||
if net.ParseIP(ip) == nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
d.logger.Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
ip = ""
|
||||
}
|
||||
case "hardware":
|
||||
@@ -328,7 +328,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error {
|
||||
}
|
||||
ip := normalizeIP(record[0])
|
||||
if net.ParseIP(ip) == nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
d.logger.Warn().Msgf("invalid ip address entry: %q", ip)
|
||||
ip = ""
|
||||
}
|
||||
|
||||
@@ -350,7 +350,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error {
|
||||
func (d *dhcp) addSelf() {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not get hostname")
|
||||
d.logger.Err(err).Msg("could not get hostname")
|
||||
return
|
||||
}
|
||||
hostname = normalizeHostname(hostname)
|
||||
|
||||
@@ -27,6 +27,7 @@ type hostsFile struct {
|
||||
watcher *fsnotify.Watcher
|
||||
mu sync.Mutex
|
||||
m map[string][]string
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
// init performs initialization works, which is necessary before hostsFile can be fully operated.
|
||||
@@ -55,7 +56,7 @@ func (hf *hostsFile) refresh() error {
|
||||
// override hosts file with host_entries.conf content if present.
|
||||
hem, err := parseHostEntriesConf(hostEntriesConfPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not read host_entries.conf file")
|
||||
hf.logger.Debug().Err(err).Msg("could not read host_entries.conf file")
|
||||
}
|
||||
for k, v := range hem {
|
||||
hf.m[k] = v
|
||||
@@ -77,14 +78,14 @@ func (hf *hostsFile) watchChanges() {
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
if err := hf.refresh(); err != nil && !os.IsNotExist(err) {
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info")
|
||||
hf.logger.Err(err).Msg("hosts file changed but failed to update client info")
|
||||
}
|
||||
}
|
||||
case err, ok := <-hf.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file")
|
||||
hf.logger.Err(err).Msg("could not watch client info file")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+12
-11
@@ -34,7 +34,8 @@ var (
|
||||
)
|
||||
|
||||
type mdns struct {
|
||||
name sync.Map // ip => hostname
|
||||
name sync.Map // ip => hostname
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
func (m *mdns) LookupHostnameByIP(ip string) string {
|
||||
@@ -93,9 +94,9 @@ func (m *mdns) init(quitCh chan struct{}) error {
|
||||
}
|
||||
|
||||
// Check if IPv6 is available once and use the result for the rest of the function.
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init")
|
||||
m.logger.Debug().Msgf("checking for IPv6 availability in mdns init")
|
||||
ipv6 := ctrldnet.IPv6Available(context.Background())
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6)
|
||||
m.logger.Debug().Msgf("IPv6 is %v in mdns init", ipv6)
|
||||
|
||||
v4ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
v6ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
@@ -129,11 +130,11 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan
|
||||
for {
|
||||
err := m.probe(conns, remoteAddr)
|
||||
if shouldStopProbing(err) {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err)
|
||||
m.logger.Warn().Msgf("stop probing %q: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msg("error while probing mdns")
|
||||
m.logger.Warn().Err(err).Msg("error while probing mdns")
|
||||
bo.BackOff(context.Background(), errors.New("mdns probe backoff"))
|
||||
continue
|
||||
}
|
||||
@@ -161,7 +162,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error")
|
||||
m.logger.Debug().Err(err).Msg("mdns readLoop error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -184,11 +185,11 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
||||
if ip != "" && name != "" {
|
||||
name = normalizeHostname(name)
|
||||
if val, loaded := m.name.LoadOrStore(ip, name); !loaded {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip)
|
||||
m.logger.Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip)
|
||||
} else {
|
||||
old := val.(string)
|
||||
if old != name {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old)
|
||||
m.logger.Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old)
|
||||
m.name.Store(ip, name)
|
||||
}
|
||||
}
|
||||
@@ -227,7 +228,7 @@ func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr) error {
|
||||
// getDataFromAvahiDaemonCache reads entries from avahi-daemon cache to update mdns data.
|
||||
func (m *mdns) getDataFromAvahiDaemonCache() {
|
||||
if _, err := exec.LookPath("avahi-browse"); err != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not find avahi-browse binary, skipping.")
|
||||
m.logger.Debug().Err(err).Msg("could not find avahi-browse binary, skipping.")
|
||||
return
|
||||
}
|
||||
// Run avahi-browse to discover services from cache:
|
||||
@@ -237,7 +238,7 @@ func (m *mdns) getDataFromAvahiDaemonCache() {
|
||||
// - "-c" -> read from cache.
|
||||
out, err := exec.Command("avahi-browse", "-a", "-r", "-p", "-c").Output()
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not browse services from avahi cache")
|
||||
m.logger.Debug().Err(err).Msg("could not browse services from avahi cache")
|
||||
return
|
||||
}
|
||||
m.storeDataFromAvahiBrowseOutput(bytes.NewReader(out))
|
||||
@@ -257,7 +258,7 @@ func (m *mdns) storeDataFromAvahiBrowseOutput(r io.Reader) {
|
||||
name := normalizeHostname(fields[6])
|
||||
// Only using cache value if we don't have existed one.
|
||||
if _, loaded := m.name.LoadOrStore(ip, name); !loaded {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip)
|
||||
m.logger.Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package clientinfo
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) {
|
||||
@@ -11,7 +13,7 @@ func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) {
|
||||
=;wlp0s20f3;IPv6;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0"
|
||||
=;wlp0s20f3;IPv4;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0"
|
||||
`
|
||||
m := &mdns{}
|
||||
m := &mdns{logger: ctrld.NopLogger}
|
||||
m.storeDataFromAvahiBrowseOutput(strings.NewReader(content))
|
||||
ip := "192.168.1.123"
|
||||
val, loaded := m.name.LoadOrStore(ip, "")
|
||||
|
||||
@@ -15,6 +15,7 @@ const merlinNvramCustomClientListKey = "custom_clientlist"
|
||||
|
||||
type merlinDiscover struct {
|
||||
hostname sync.Map // mac => hostname
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
func (m *merlinDiscover) refresh() error {
|
||||
@@ -25,7 +26,7 @@ func (m *merlinDiscover) refresh() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("reading Merlin custom client list")
|
||||
m.logger.Debug().Msg("reading Merlin custom client list")
|
||||
m.parseMerlinCustomClientList(out)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,8 +20,9 @@ import (
|
||||
|
||||
// ndpDiscover provides client discovery functionality using NDP protocol.
|
||||
type ndpDiscover struct {
|
||||
mac sync.Map // ip => mac
|
||||
ip sync.Map // mac => ip
|
||||
mac sync.Map // ip => mac
|
||||
ip sync.Map // mac => ip
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
// refresh re-scans the NDP table.
|
||||
@@ -97,7 +98,7 @@ func (nd *ndpDiscover) saveInfo(ip, mac string) {
|
||||
func (nd *ndpDiscover) listen(ctx context.Context) {
|
||||
ifis, err := allInterfacesWithV6LinkLocal()
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces")
|
||||
nd.logger.Debug().Err(err).Msg("failed to find valid ipv6 interfaces")
|
||||
return
|
||||
}
|
||||
for _, ifi := range ifis {
|
||||
@@ -110,11 +111,11 @@ func (nd *ndpDiscover) listen(ctx context.Context) {
|
||||
func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) {
|
||||
c, ip, err := ndp.Listen(ifi, ndp.Unspecified)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed")
|
||||
nd.logger.Debug().Err(err).Msg("ndp listen failed")
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("listening ndp on: %s", ip.String())
|
||||
nd.logger.Debug().Msgf("listening ndp on: %s", ip.String())
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -128,7 +129,7 @@ func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface
|
||||
if errors.As(readErr, &opErr) && (opErr.Timeout() || opErr.Temporary()) {
|
||||
continue
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Err(readErr).Msg("ndp read loop error")
|
||||
nd.logger.Debug().Err(readErr).Msg("ndp read loop error")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// scan populates NDP table using information from system mappings.
|
||||
func (nd *ndpDiscover) scan() {
|
||||
neighs, err := netlink.NeighList(0, netlink.FAMILY_V6)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not get neigh list")
|
||||
nd.logger.Warn().Err(err).Msg("could not get neigh list")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -34,7 +32,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.NeighSubscribe(ch, done); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing")
|
||||
nd.logger.Err(err).Msg("could not perform neighbor subscribing")
|
||||
return
|
||||
}
|
||||
for {
|
||||
@@ -47,7 +45,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) {
|
||||
}
|
||||
ip := normalizeIP(nu.IP.String())
|
||||
if nu.Type == unix.RTM_DELNEIGH {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip)
|
||||
nd.logger.Debug().Msgf("removing NDP neighbor: %s", ip)
|
||||
nd.mac.Delete(ip)
|
||||
continue
|
||||
}
|
||||
@@ -56,7 +54,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) {
|
||||
case netlink.NUD_REACHABLE:
|
||||
nd.saveInfo(ip, mac)
|
||||
case netlink.NUD_FAILED:
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip)
|
||||
nd.logger.Debug().Msgf("removing NDP neighbor with failed state: %s", ip)
|
||||
nd.mac.Delete(ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// scan populates NDP table using information from system mappings.
|
||||
@@ -17,14 +15,14 @@ func (nd *ndpDiscover) scan() {
|
||||
case "windows":
|
||||
data, err := exec.Command("netsh", "interface", "ipv6", "show", "neighbors").Output()
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table")
|
||||
nd.logger.Warn().Err(err).Msg("could not query ndp table")
|
||||
return
|
||||
}
|
||||
nd.scanWindows(bytes.NewReader(data))
|
||||
default:
|
||||
data, err := exec.Command("ndp", "-an").Output()
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table")
|
||||
nd.logger.Warn().Err(err).Msg("could not query ndp table")
|
||||
return
|
||||
}
|
||||
nd.scanUnix(bytes.NewReader(data))
|
||||
|
||||
@@ -17,6 +17,7 @@ type ptrDiscover struct {
|
||||
hostname sync.Map // ip => hostname
|
||||
resolver ctrld.Resolver
|
||||
serverDown atomic.Bool
|
||||
logger *ctrld.Logger
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) refresh() error {
|
||||
@@ -73,14 +74,14 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
msg := new(dns.Msg)
|
||||
addr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
p.logger.Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
return ""
|
||||
}
|
||||
msg.SetQuestion(addr, dns.TypePTR)
|
||||
ans, err := p.resolver.Resolve(ctx, msg)
|
||||
if err != nil {
|
||||
if p.serverDown.CompareAndSwap(false, true) {
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
p.logger.Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
go p.checkServer()
|
||||
}
|
||||
return ""
|
||||
|
||||
+21
-18
@@ -88,18 +88,18 @@ type LogsRequest struct {
|
||||
}
|
||||
|
||||
// FetchResolverConfig fetch Control D config for given uid.
|
||||
func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
func FetchResolverConfig(ctx context.Context, rawUID, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
uid, clientID := ParseRawUID(rawUID)
|
||||
req := utilityRequest{UID: uid}
|
||||
if clientID != "" {
|
||||
req.ClientID = clientID
|
||||
}
|
||||
body, _ := json.Marshal(req)
|
||||
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||
return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
// FetchResolverUID fetch resolver uid from provision token.
|
||||
func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
func FetchResolverUID(ctx context.Context, req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("invalid request")
|
||||
}
|
||||
@@ -108,21 +108,21 @@ func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*Reso
|
||||
hostname, _ = os.Hostname()
|
||||
}
|
||||
body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname})
|
||||
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||
return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
// UpdateCustomLastFailed calls API to mark custom config is bad.
|
||||
func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) {
|
||||
func UpdateCustomLastFailed(ctx context.Context, rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) {
|
||||
uid, clientID := ParseRawUID(rawUID)
|
||||
req := utilityRequest{UID: uid}
|
||||
if clientID != "" {
|
||||
req.ClientID = clientID
|
||||
}
|
||||
body, _ := json.Marshal(req)
|
||||
return postUtilityAPI(version, cdDev, true, bytes.NewReader(body))
|
||||
return postUtilityAPI(ctx, version, cdDev, true, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) {
|
||||
func postUtilityAPI(ctx context.Context, version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) {
|
||||
apiUrl := resolverDataURLCom
|
||||
if cdDev {
|
||||
apiUrl = resolverDataURLDev
|
||||
@@ -139,12 +139,12 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := apiTransport(cdDev)
|
||||
transport := apiTransport(ctx, cdDev)
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := doWithFallback(client, req, apiServerIP(cdDev))
|
||||
resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err)
|
||||
}
|
||||
@@ -166,7 +166,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
}
|
||||
|
||||
// SendLogs sends runtime log to ControlD API.
|
||||
func SendLogs(lr *LogsRequest, cdDev bool) error {
|
||||
func SendLogs(ctx context.Context, lr *LogsRequest, cdDev bool) error {
|
||||
defer lr.Data.Close()
|
||||
apiUrl := logURLCom
|
||||
if cdDev {
|
||||
@@ -180,12 +180,12 @@ func SendLogs(lr *LogsRequest, cdDev bool) error {
|
||||
q.Set("uid", lr.UID)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
transport := apiTransport(cdDev)
|
||||
transport := apiTransport(ctx, cdDev)
|
||||
client := &http.Client{
|
||||
Timeout: sendLogTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := doWithFallback(client, req, apiServerIP(cdDev))
|
||||
resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev))
|
||||
if err != nil {
|
||||
return fmt.Errorf("SendLogs client.Do: %w", err)
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func ParseRawUID(rawUID string) (string, string) {
|
||||
}
|
||||
|
||||
// apiTransport returns an HTTP transport for connecting to ControlD API endpoint.
|
||||
func apiTransport(cdDev bool) *http.Transport {
|
||||
func apiTransport(loggerCtx context.Context, cdDev bool) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
apiDomain := apiDomainCom
|
||||
@@ -227,9 +227,10 @@ func apiTransport(cdDev bool) *http.Transport {
|
||||
apiIPs = []string{apiDomainDevIPv4}
|
||||
}
|
||||
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
ips := ctrld.LookupIP(loggerCtx, apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs)
|
||||
logger := ctrld.LoggerFromCtx(loggerCtx)
|
||||
logger.Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs)
|
||||
ips = apiIPs
|
||||
}
|
||||
|
||||
@@ -245,7 +246,8 @@ func apiTransport(cdDev bool) *http.Transport {
|
||||
|
||||
dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load())
|
||||
logger := ctrld.LoggerFromCtx(loggerCtx)
|
||||
return d.DialContext(ctx, network, addrs, logger.Logger)
|
||||
}
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
|
||||
@@ -283,10 +285,11 @@ func addrsFromPort(ips []string, port string) []string {
|
||||
return addrs
|
||||
}
|
||||
|
||||
func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) {
|
||||
func doWithFallback(ctx context.Context, client *http.Client, req *http.Request, apiIp string) (*http.Response, error) {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp)
|
||||
logger := ctrld.LoggerFromCtx(ctx)
|
||||
logger.Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp)
|
||||
ipReq := req.Clone(req.Context())
|
||||
ipReq.Host = apiIp
|
||||
ipReq.URL.Host = apiIp
|
||||
|
||||
@@ -3,19 +3,37 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// ProxyLog emits the log record for proxy operations.
|
||||
// The caller should set it only once.
|
||||
// DEPRECATED: use ProxyLogger instead.
|
||||
var ProxyLog = zerolog.New(io.Discard)
|
||||
// LoggerCtxKey is the context.Context key for a logger.
|
||||
type LoggerCtxKey struct{}
|
||||
|
||||
// ProxyLogger emits the log record for proxy operations.
|
||||
var ProxyLogger atomic.Pointer[zerolog.Logger]
|
||||
// LoggerCtx returns a context.Context with LoggerCtxKey set.
|
||||
func LoggerCtx(ctx context.Context, l *Logger) context.Context {
|
||||
return context.WithValue(ctx, LoggerCtxKey{}, l)
|
||||
}
|
||||
|
||||
// A Logger provides fast, leveled, structured logging.
|
||||
type Logger struct {
|
||||
*zerolog.Logger
|
||||
}
|
||||
|
||||
var noOpZeroLogger = zerolog.Nop()
|
||||
|
||||
// NopLogger returns a logger which all operation are no-op.
|
||||
var NopLogger = &Logger{&noOpZeroLogger}
|
||||
|
||||
// LoggerFromCtx returns the logger associated with given ctx.
|
||||
//
|
||||
// If there's no logger, a no-op logger will be returned.
|
||||
func LoggerFromCtx(ctx context.Context) *Logger {
|
||||
if logger, ok := ctx.Value(LoggerCtxKey{}).(*Logger); ok && logger != nil {
|
||||
return logger
|
||||
}
|
||||
return NopLogger
|
||||
}
|
||||
|
||||
// ReqIdCtxKey is the context.Context key for a request id.
|
||||
type ReqIdCtxKey struct{}
|
||||
|
||||
+5
-3
@@ -1,9 +1,11 @@
|
||||
package ctrld
|
||||
|
||||
type dnsFn func() []string
|
||||
import "context"
|
||||
|
||||
type dnsFn func(ctx context.Context) []string
|
||||
|
||||
// nameservers returns DNS nameservers from system settings.
|
||||
func nameservers() []string {
|
||||
func nameservers(ctx context.Context) []string {
|
||||
var dns []string
|
||||
seen := make(map[string]bool)
|
||||
ch := make(chan []string)
|
||||
@@ -11,7 +13,7 @@ func nameservers() []string {
|
||||
|
||||
for _, fn := range fns {
|
||||
go func(fn dnsFn) {
|
||||
ch <- fn()
|
||||
ch <- fn(ctx)
|
||||
}(fn)
|
||||
}
|
||||
for range fns {
|
||||
|
||||
+2
-1
@@ -3,6 +3,7 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
@@ -13,7 +14,7 @@ func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromResolvConf, dnsFromRIB}
|
||||
}
|
||||
|
||||
func dnsFromRIB() []string {
|
||||
func dnsFromRIB(_ context.Context) []string {
|
||||
var dns []string
|
||||
rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,8 +22,8 @@ func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers}
|
||||
}
|
||||
|
||||
func getDNSFromScutil() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
func getDNSFromScutil(ctx context.Context) []string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
|
||||
const (
|
||||
maxRetries = 10
|
||||
@@ -109,8 +109,8 @@ func getDHCPNameservers(iface string) ([]string, error) {
|
||||
return nameservers, nil
|
||||
}
|
||||
|
||||
func getAllDHCPNameservers() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
func getAllDHCPNameservers(ctx context.Context) []string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package ctrld
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"os"
|
||||
@@ -20,7 +21,7 @@ func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver}
|
||||
}
|
||||
|
||||
func dns4() []string {
|
||||
func dns4(_ context.Context) []string {
|
||||
f, err := os.Open(v4RouteFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -60,7 +61,7 @@ func dns4() []string {
|
||||
return dns
|
||||
}
|
||||
|
||||
func dns6() []string {
|
||||
func dns6(_ context.Context) []string {
|
||||
f, err := os.Open(v6RouteFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -94,7 +95,7 @@ func dns6() []string {
|
||||
return dns
|
||||
}
|
||||
|
||||
func dnsFromSystemdResolver() []string {
|
||||
func dnsFromSystemdResolver(_ context.Context) []string {
|
||||
c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf")
|
||||
if err != nil {
|
||||
return nil
|
||||
|
||||
+5
-2
@@ -1,9 +1,12 @@
|
||||
package ctrld
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNameservers(t *testing.T) {
|
||||
ns := nameservers()
|
||||
ns := nameservers(context.Background())
|
||||
if len(ns) == 0 {
|
||||
t.Fatal("failed to get nameservers")
|
||||
}
|
||||
|
||||
+2
-1
@@ -3,6 +3,7 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"slices"
|
||||
"time"
|
||||
@@ -20,7 +21,7 @@ func currentNameserversFromResolvconf() []string {
|
||||
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
|
||||
// A nameserver is usable if it's not one of current machine's IP addresses
|
||||
// and loopback IP addresses.
|
||||
func dnsFromResolvConf() []string {
|
||||
func dnsFromResolvConf(_ context.Context) []string {
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryInterval = 100 * time.Millisecond
|
||||
|
||||
+55
-93
@@ -55,28 +55,25 @@ func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromAdapter}
|
||||
}
|
||||
|
||||
func dnsFromAdapter() []string {
|
||||
func dnsFromAdapter(ctx context.Context) []string {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
||||
defer cancel()
|
||||
|
||||
var ns []string
|
||||
var err error
|
||||
|
||||
logger := *ProxyLogger.Load()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
|
||||
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||
if ctx.Err() != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||
logger.Debug().Msgf("dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||
return nil
|
||||
}
|
||||
|
||||
ns, err = getDNSServers(ctx)
|
||||
if err == nil && len(ns) >= minDNSServers {
|
||||
if i > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Successfully got DNS servers after %d attempts, found %d servers",
|
||||
i+1, len(ns))
|
||||
logger.Debug().Msgf("Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns))
|
||||
}
|
||||
return ns
|
||||
}
|
||||
@@ -88,11 +85,9 @@ func dnsFromAdapter() []string {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||
logger.Debug().Msgf("Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
||||
logger.Debug().Msgf("Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -102,14 +97,12 @@ func dnsFromAdapter() []string {
|
||||
}
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||
logger.Debug().Msgf("Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||
|
||||
return ns
|
||||
}
|
||||
|
||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
// Check context before making the call
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
@@ -124,17 +117,16 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
return nil, fmt.Errorf("getting adapters: %w", err)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found network adapters, count=%d", len(aas))
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("Found network adapters, count=%d", len(aas))
|
||||
|
||||
// Try to get domain controller info if domain-joined
|
||||
var dcServers []string
|
||||
isDomain := checkDomainJoined()
|
||||
isDomain := checkDomainJoined(ctx)
|
||||
if isDomain {
|
||||
domainName, err := getLocalADDomain()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get local AD domain: %v", err)
|
||||
logger.Debug().Msgf("Failed to get local AD domain: %v", err)
|
||||
} else {
|
||||
// Load netapi32.dll
|
||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||
@@ -145,11 +137,9 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
|
||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to convert domain name to UTF16: %v", err)
|
||||
logger.Debug().Msgf("Failed to convert domain name to UTF16: %v", err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
||||
logger.Debug().Msgf("Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
||||
|
||||
// Call DsGetDcNameW with domain name
|
||||
ret, _, err := dsDcName.Call(
|
||||
@@ -163,20 +153,15 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if ret != 0 {
|
||||
switch ret {
|
||||
case 1355: // ERROR_NO_SUCH_DOMAIN
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain not found: %s (%d)", domainName, ret)
|
||||
logger.Debug().Msgf("Domain not found: %s (%d)", domainName, ret)
|
||||
case 1311: // ERROR_NO_LOGON_SERVERS
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No logon servers available for domain: %s (%d)", domainName, ret)
|
||||
logger.Debug().Msgf("No logon servers available for domain: %s (%d)", domainName, ret)
|
||||
case 1004: // ERROR_DC_NOT_FOUND
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
||||
logger.Debug().Msgf("Domain controller not found for domain: %s (%d)", domainName, ret)
|
||||
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
||||
logger.Debug().Msgf("RPC server unavailable for domain: %s (%d)", domainName, ret)
|
||||
default:
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
||||
logger.Debug().Msgf("Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
||||
}
|
||||
} else if info != nil {
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
||||
@@ -184,17 +169,13 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if info.DomainControllerAddress != nil {
|
||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found domain controller address: %s", dcAddr)
|
||||
|
||||
logger.Debug().Msgf("Found domain controller address: %s", dcAddr)
|
||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||
dcServers = append(dcServers, ip.String())
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added domain controller DNS servers: %v", dcServers)
|
||||
logger.Debug().Msgf("Added domain controller DNS servers: %v", dcServers)
|
||||
}
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No domain controller address found")
|
||||
logger.Debug().Msg("No domain controller address found")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -209,31 +190,27 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
// Collect all local IPs
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
||||
logger.Debug().Msgf("Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Processing adapter %s", aa.FriendlyName())
|
||||
logger.Debug().Msgf("Processing adapter %s", aa.FriendlyName())
|
||||
|
||||
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||
ip := a.Address.IP().String()
|
||||
addressMap[ip] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
logger.Debug().Msgf("Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
}
|
||||
}
|
||||
|
||||
validInterfacesMap := validInterfaces()
|
||||
validInterfacesMap := validInterfaces(ctx)
|
||||
|
||||
// Collect DNS servers
|
||||
for _, aa := range aas {
|
||||
@@ -244,23 +221,20 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
// if not in the validInterfacesMap, skip
|
||||
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
logger.Debug().Msgf("Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
ip := dns.Address.IP()
|
||||
if ip == nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
logger.Debug().Msgf("Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -293,28 +267,23 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
if !seen[dcServer] {
|
||||
seen[dcServer] = true
|
||||
ns = append(ns, dcServer)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added additional domain controller DNS server: %s", dcServer)
|
||||
logger.Debug().Msgf("Added additional domain controller DNS server: %s", dcServer)
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
logger.Debug().Msgf("Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
logger.Debug().Msgf("Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else {
|
||||
staticNs, file := SavedStaticNameserversAndPath(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
logger.Debug().Msgf("static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
logger.Debug().Msgf("Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
ns = append(ns, staticNs...)
|
||||
}
|
||||
}
|
||||
@@ -324,9 +293,7 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
return nil, fmt.Errorf("no valid DNS servers found")
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
||||
len(ns), ns, len(dcServers))
|
||||
logger.Debug().Msgf("DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers))
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
@@ -337,33 +304,35 @@ func currentNameserversFromResolvconf() []string {
|
||||
|
||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||
// Returns whether it's domain joined and the domain name if available
|
||||
func checkDomainJoined() bool {
|
||||
logger := *ProxyLogger.Load()
|
||||
func checkDomainJoined(ctx context.Context) bool {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
|
||||
var domain *uint16
|
||||
var status uint32
|
||||
|
||||
err := windows.NetGetJoinInformation(nil, &domain, &status)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get domain join status: %v", err)
|
||||
logger.Debug().Msgf("Failed to get domain join status: %v", err)
|
||||
return false
|
||||
}
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
|
||||
|
||||
domainName := windows.UTF16PtrToString(domain)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
logger.Debug().Msgf(
|
||||
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
|
||||
domainName, status)
|
||||
domainName,
|
||||
status,
|
||||
)
|
||||
|
||||
// Consider domain or cloud domain as domain-joined
|
||||
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
||||
Log(context.Background(), logger.Debug(),
|
||||
logger.Debug().Msgf(
|
||||
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
||||
status,
|
||||
status == NetSetupDomain,
|
||||
status == NetSetupCloudDomain,
|
||||
isDomain)
|
||||
isDomain,
|
||||
)
|
||||
|
||||
return isDomain
|
||||
}
|
||||
@@ -411,12 +380,12 @@ func getLocalADDomain() (string, error) {
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
// this is a duplicate of what is in net_windows.go, we should
|
||||
// clean this up so there is only one version
|
||||
func validInterfaces() map[string]struct{} {
|
||||
func validInterfaces(ctx context.Context) map[string]struct{} {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
//load the logger
|
||||
logger := *ProxyLogger.Load()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
@@ -425,23 +394,20 @@ func validInterfaces() map[string]struct{} {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get wmi network adapter: %v", err)
|
||||
logger.Warn().Msgf("failed to get wmi network adapter: %v", err)
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get network adapter: %v", err)
|
||||
logger.Warn().Msgf("failed to get network adapter: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get interface name: %v", err)
|
||||
logger.Warn().Msgf("failed to get interface name: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -451,13 +417,11 @@ func validInterfaces() map[string]struct{} {
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter connector present property: %v", err)
|
||||
logger.Debug().Msgf("failed to get network adapter connector present property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-physical adapter: %s", name)
|
||||
logger.Debug().Msgf("skipping non-physical adapter: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -465,13 +429,11 @@ func validInterfaces() map[string]struct{} {
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter hardware interface property: %v", err)
|
||||
logger.Debug().Msgf("failed to get network adapter hardware interface property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-hardware interface: %s", name)
|
||||
logger.Debug().Msgf("skipping non-hardware interface: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -17,26 +17,27 @@ var (
|
||||
)
|
||||
|
||||
// HasIPv6 reports whether the current network stack has IPv6 available.
|
||||
func HasIPv6() bool {
|
||||
func HasIPv6(ctx context.Context) bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
ProxyLogger.Load().Debug().Msg("checking for IPv6 availability once")
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msg("checking for IPv6 availability once")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
ipv6Available.Store(val)
|
||||
ProxyLogger.Load().Debug().Msgf("ipv6 availability: %v", val)
|
||||
logger.Debug().Msgf("ipv6 availability: %v", val)
|
||||
mon, err := netmon.New(func(format string, args ...any) {})
|
||||
if err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msg("failed to monitor IPv6 state")
|
||||
logger.Debug().Err(err).Msg("failed to monitor IPv6 state")
|
||||
return
|
||||
}
|
||||
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
||||
old := ipv6Available.Load()
|
||||
cur := delta.Monitor.InterfaceState().HaveV6
|
||||
if old != cur {
|
||||
ProxyLogger.Load().Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur)
|
||||
logger.Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur)
|
||||
} else {
|
||||
ProxyLogger.Load().Debug().Msg("ipv6 availability does not changed")
|
||||
logger.Debug().Msg("ipv6 availability does not changed")
|
||||
}
|
||||
ipv6Available.Store(cur)
|
||||
})
|
||||
@@ -46,8 +47,9 @@ func HasIPv6() bool {
|
||||
}
|
||||
|
||||
// DisableIPv6 marks IPv6 as unavailable if enabled.
|
||||
func DisableIPv6() {
|
||||
func DisableIPv6(ctx context.Context) {
|
||||
if ipv6Available.CompareAndSwap(true, false) {
|
||||
ProxyLogger.Load().Debug().Msg("turned off IPv6 availability")
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msg("turned off IPv6 availability")
|
||||
}
|
||||
}
|
||||
|
||||
+58
-77
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@@ -50,10 +48,6 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
var localResolver Resolver
|
||||
|
||||
func init() {
|
||||
// Initializing ProxyLogger here, so other places don't have to do nil check.
|
||||
l := zerolog.New(io.Discard)
|
||||
ProxyLogger.Store(&l)
|
||||
|
||||
localResolver = newLocalResolver()
|
||||
}
|
||||
|
||||
@@ -81,8 +75,8 @@ func LanQueryCtx(ctx context.Context) context.Context {
|
||||
}
|
||||
|
||||
// defaultNameservers is like nameservers with each element formed "ip:53".
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
func defaultNameservers(ctx context.Context) []string {
|
||||
ns := nameservers(ctx)
|
||||
nss := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
nss[i] = net.JoinHostPort(ns[i], "53")
|
||||
@@ -91,42 +85,36 @@ func defaultNameservers() []string {
|
||||
}
|
||||
|
||||
// availableNameservers returns list of current available DNS servers of the system.
|
||||
func availableNameservers() []string {
|
||||
func availableNameservers(ctx context.Context) []string {
|
||||
var nss []string
|
||||
// Ignore local addresses to prevent loop.
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
|
||||
//load the logger
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
|
||||
// Load the logger.
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
|
||||
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
machineIPsMap[ipStr] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP to OS resolverexclusion map: %s", ipStr)
|
||||
logger.Debug().Msgf("Added local IP to OS resolverexclusion map: %s", ipStr)
|
||||
}
|
||||
|
||||
systemNameservers := nameservers()
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got system nameservers: %v", systemNameservers)
|
||||
systemNameservers := nameservers(ctx)
|
||||
logger.Debug().Msgf("Got system nameservers: %v", systemNameservers)
|
||||
|
||||
for _, ns := range systemNameservers {
|
||||
if _, ok := machineIPsMap[ns]; ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping local nameserver: %s", ns)
|
||||
logger.Debug().Msgf("Skipping local nameserver: %s", ns)
|
||||
continue
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added non-local nameserver: %s", ns)
|
||||
logger.Debug().Msgf("Added non-local nameserver: %s", ns)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Final available nameservers: %v", nss)
|
||||
logger.Debug().Msgf("Final available nameservers: %v", nss)
|
||||
|
||||
return nss
|
||||
}
|
||||
|
||||
@@ -135,8 +123,8 @@ func availableNameservers() []string {
|
||||
//
|
||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||
// calling this function.
|
||||
func InitializeOsResolver(guardAgainstNoNameservers bool) []string {
|
||||
nameservers := availableNameservers()
|
||||
func InitializeOsResolver(ctx context.Context, guardAgainstNoNameservers bool) []string {
|
||||
nameservers := availableNameservers(ctx)
|
||||
// if no nameservers, return empty slice so we dont remove all nameservers
|
||||
if len(nameservers) == 0 && guardAgainstNoNameservers {
|
||||
return []string{}
|
||||
@@ -188,7 +176,7 @@ type Resolver interface {
|
||||
var errUnknownResolver = errors.New("unknown resolver")
|
||||
|
||||
// NewResolver creates a Resolver based on the given upstream config.
|
||||
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
func NewResolver(ctx context.Context, uc *UpstreamConfig) (Resolver, error) {
|
||||
typ := uc.Type
|
||||
switch typ {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
@@ -200,15 +188,16 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeOS:
|
||||
resolverMutex.Lock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("Initialize new OS resolver")
|
||||
or = newResolverWithNameserver(defaultNameservers(ctx))
|
||||
}
|
||||
resolverMutex.Unlock()
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
case ResolverTypePrivate:
|
||||
return NewPrivateResolver(), nil
|
||||
return NewPrivateResolver(ctx), nil
|
||||
case ResolverTypeLocal:
|
||||
return localResolver, nil
|
||||
}
|
||||
@@ -235,14 +224,16 @@ type publicResponse struct {
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv4 updates the stored local IPv4.
|
||||
func SetDefaultLocalIPv4(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
|
||||
func SetDefaultLocalIPv4(ctx context.Context, ip net.IP) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("SetDefaultLocalIPv4: %s", ip)
|
||||
defaultLocalIPv4.Store(ip)
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv6 updates the stored local IPv6.
|
||||
func SetDefaultLocalIPv6(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
|
||||
func SetDefaultLocalIPv6(ctx context.Context, ip net.IP) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Msgf("SetDefaultLocalIPv6: %s", ip)
|
||||
defaultLocalIPv6.Store(ip)
|
||||
}
|
||||
|
||||
@@ -300,10 +291,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
// Unique key for the singleflight group.
|
||||
key := fmt.Sprintf("%s:%d:", domain, qtype)
|
||||
|
||||
logger := LoggerFromCtx(ctx)
|
||||
// Checking the cache first.
|
||||
if val, ok := o.cache.Load(key); ok {
|
||||
if val, ok := val.(*dns.Msg); ok {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
Log(ctx, logger.Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
res := val.Copy()
|
||||
SetCacheReply(res, msg, val.Rcode)
|
||||
return res, nil
|
||||
@@ -338,7 +330,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
res := sharedMsg.Copy()
|
||||
SetCacheReply(res, msg, sharedMsg.Rcode)
|
||||
if shared {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
Log(ctx, logger.Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
return res, nil
|
||||
@@ -368,7 +360,8 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
question = msg.Question[0].Name
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers)
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers)
|
||||
|
||||
// New check: If no resolvers are available, return an error.
|
||||
if numServers == 0 {
|
||||
@@ -417,7 +410,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
// If splitting fails, fallback to the original server string
|
||||
host = server
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
|
||||
Log(ctx, logger.Debug(), "got answer from nameserver: %s", host)
|
||||
}
|
||||
|
||||
// try local nameservers
|
||||
@@ -444,7 +437,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
switch {
|
||||
case res.lan:
|
||||
// Always prefer LAN responses immediately
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server)
|
||||
Log(ctx, logger.Debug(), "using LAN answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
@@ -454,7 +447,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
// if there are no LAN nameservers, we should not wait
|
||||
// just use the first response
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server)
|
||||
Log(ctx, logger.Debug(), "using public answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
@@ -465,12 +458,12 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
})
|
||||
}
|
||||
case res.answer != nil:
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
|
||||
Log(ctx, logger.Debug(), "got non-success answer from: %s with code: %d",
|
||||
res.server, res.answer.Rcode)
|
||||
// When there are no LAN nameservers, we should not wait
|
||||
// for other nameservers to respond.
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer")
|
||||
Log(ctx, logger.Debug(), "no lan nameservers using public non success answer")
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
@@ -483,17 +476,17 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
|
||||
if len(publicResponses) > 0 {
|
||||
resp := publicResponses[0]
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server)
|
||||
Log(ctx, logger.Debug(), "using public answer from: %s", resp.server)
|
||||
logAnswer(resp.server)
|
||||
return resp.answer, nil
|
||||
}
|
||||
if controldSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
Log(ctx, logger.Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
logAnswer(controldPublicDnsWithPort)
|
||||
return controldSuccessAnswer, nil
|
||||
}
|
||||
if nonSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer)
|
||||
Log(ctx, logger.Debug(), "using non-success answer from: %s", nonSuccessServer)
|
||||
logAnswer(nonSuccessServer)
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
@@ -515,7 +508,7 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
_, udpNet := r.uc.netForDNSType(dnsTyp)
|
||||
_, udpNet := r.uc.netForDNSType(ctx, dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: udpNet,
|
||||
Dialer: dialer,
|
||||
@@ -541,39 +534,43 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
|
||||
|
||||
// LookupIP looks up domain using current system nameservers settings.
|
||||
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
func LookupIP(domain string) []string {
|
||||
nss := initDefaultOsResolver()
|
||||
return lookupIP(domain, -1, nss)
|
||||
func LookupIP(ctx context.Context, domain string) []string {
|
||||
nss := initDefaultOsResolver(ctx)
|
||||
return lookupIP(ctx, domain, -1, nss)
|
||||
}
|
||||
|
||||
// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet.
|
||||
// It returns the combined list of LAN and public nameservers currently held by the resolver.
|
||||
func initDefaultOsResolver() []string {
|
||||
func initDefaultOsResolver(ctx context.Context) []string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
logger.Debug().Msgf("Initialize new OS resolver with default nameservers")
|
||||
or = newResolverWithNameserver(defaultNameservers(ctx))
|
||||
}
|
||||
nss := *or.lanServers.Load()
|
||||
nss = append(nss, *or.publicServers.Load()...)
|
||||
return nss
|
||||
|
||||
}
|
||||
|
||||
// lookupIP looks up domain with given timeout and bootstrapDNS.
|
||||
// If the timeout is negative, default timeout 2000 ms will be used.
|
||||
// It returns nil if bootstrapDNS is nil or empty.
|
||||
func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) {
|
||||
func lookupIP(ctx context.Context, domain string, timeout int, bootstrapDNS []string) (ips []string) {
|
||||
if net.ParseIP(domain) != nil {
|
||||
return []string{domain}
|
||||
}
|
||||
logger := LoggerFromCtx(ctx)
|
||||
if bootstrapDNS == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("empty bootstrap DNS")
|
||||
logger.Debug().Msgf("empty bootstrap DNS")
|
||||
return nil
|
||||
}
|
||||
|
||||
resolver := newResolverWithNameserver(bootstrapDNS)
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS)
|
||||
logger.Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS)
|
||||
|
||||
timeoutMs := 2000
|
||||
if timeout > 0 && timeout < timeoutMs {
|
||||
timeoutMs = timeout
|
||||
@@ -616,15 +613,15 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string)
|
||||
|
||||
r, err := resolver.Resolve(ctx, m)
|
||||
if err != nil {
|
||||
ProxyLogger.Load().Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
|
||||
logger.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
|
||||
return
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
ProxyLogger.Load().Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
|
||||
logger.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
|
||||
return
|
||||
}
|
||||
if len(r.Answer) == 0 {
|
||||
ProxyLogger.Load().Error().Msg("no answer from OS resolver")
|
||||
logger.Error().Msg("no answer from OS resolver")
|
||||
return
|
||||
}
|
||||
target := targetDomain(r.Answer)
|
||||
@@ -641,22 +638,6 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string)
|
||||
return ips
|
||||
}
|
||||
|
||||
// NewBootstrapResolver returns an OS resolver, which use following nameservers:
|
||||
//
|
||||
// - Gateway IP address (depends on OS).
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers)
|
||||
nss := defaultNameservers()
|
||||
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
||||
for _, ns := range servers {
|
||||
nss = append([]string{net.JoinHostPort(ns, "53")}, nss...)
|
||||
}
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
|
||||
// excluding:
|
||||
//
|
||||
@@ -664,8 +645,8 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
// - Nameservers which is local RFC1918 addresses.
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
nss := initDefaultOsResolver()
|
||||
func NewPrivateResolver(ctx context.Context) Resolver {
|
||||
nss := initDefaultOsResolver(ctx)
|
||||
resolveConfNss := currentNameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
n := 0
|
||||
|
||||
+1
-1
@@ -132,7 +132,7 @@ func Test_osResolver_InitializationRace(t *testing.T) {
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
InitializeOsResolver(false)
|
||||
InitializeOsResolver(LoggerCtx(context.Background(), nil), false)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
Reference in New Issue
Block a user