From 41282d0f512f8b2980e6af781085cbade5fd9526 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 7 Jul 2025 16:45:42 +0700 Subject: [PATCH] refactor: break down proxy method into smaller focused functions Split the long proxy method into several smaller methods to improve maintainability and testability. Each new method has a single responsibility: - initializeUpstreams: handles upstream configuration setup - tryCache: manages cache lookup logic - tryUpstreams: coordinates upstream query attempts - processUpstream: handles individual upstream query processing - handleAllUpstreamsFailure: manages failure scenarios - checkCache: performs cache checks and retrieval - serveStaleResponse: handles stale cache responses - shouldContinueWithNextUpstream: determines if failover is needed - prepareSuccessResponse: formats successful responses This refactoring: - Reduces cognitive complexity - Improves code testability - Makes the DNS proxy logic flow clearer - Isolates error handling and edge cases - Maintains existing functionality No behavioral changes were made. --- cmd/cli/dns_proxy.go | 526 ++++++++++++++++++++++++++----------------- 1 file changed, 314 insertions(+), 212 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 030cc02..8053a89 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -53,10 +53,13 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ // proxyRequest contains data for proxying a DNS query to upstream. type proxyRequest struct { - msg *dns.Msg - ci *ctrld.ClientInfo - failoverRcodes []int - ufr *upstreamForResult + msg *dns.Msg + ci *ctrld.ClientInfo + failoverRcodes []int + ufr *upstreamForResult + staleAnswer *dns.Msg + isLanOrPtrQuery bool + upstreamConfigs []*ctrld.UpstreamConfig } // proxyResponse contains data for proxying a DNS response from upstream. @@ -409,6 +412,10 @@ macRules: return } +// proxyPrivatePtrLookup performs a private PTR DNS lookup based on the client info table for the given query. +// It prevents DNS loops by locking the processing of the same domain name simultaneously. +// If a valid IP-to-hostname mapping exists, it creates a PTR DNS record as the response. +// Returns the DNS response if a hostname is found or nil otherwise. func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { cDomainName := msg.Question[0].Name locked := p.ptrLoopGuard.TryLock(cDomainName) @@ -440,6 +447,10 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg return nil } +// proxyLanHostnameQuery resolves LAN hostnames to their corresponding IP addresses based on the dns.Msg request. +// It uses a loop guard mechanism to prevent DNS query loops and ensures a hostname is processed only once at a time. +// This method queries the client info table for the hostname's IP address and logs relevant debug and client info. +// If the hostname matches known IPs in the table, it generates an appropriate dns.Msg response; otherwise, it returns nil. func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { q := msg.Question[0] hostname := strings.TrimSuffix(q.Name, ".") @@ -485,231 +496,324 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg return nil } -func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { - var staleAnswer *dns.Msg - upstreams := req.ufr.upstreams - serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale - upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - - if len(upstreamConfigs) == 0 { - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{upstreamOS} - } - - res := &proxyResponse{} - - // LAN/PTR lookup flow: - // - // 1. If there's matching rule, follow it. - // 2. Try from client info table. - // 3. Try private resolver. - // 4. Try remote upstream. - isLanOrPtrQuery := false +// handleSpecialQueryTypes processes specific types of DNS queries such as SRV, PTR, and LAN hostname lookups. +// It modifies upstreams and upstreamConfigs based on the query type and updates the query context accordingly. +// Returns a proxyResponse if the query is resolved locally; otherwise, returns nil to proceed with upstream processing. +func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest, upstreams *[]string, upstreamConfigs *[]*ctrld.UpstreamConfig) *proxyResponse { if req.ufr.matched { - ctrld.Log(ctx, p.Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } else { - switch { - case isSrvLanLookup(req.msg): - upstreams = []string{upstreamOS} - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, p.Debug(), "SRV record lookup, using upstreams: %v", upstreams) - case isPrivatePtrLookup(req.msg): - isLanOrPtrQuery = true - if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { - res.answer = answer - res.clientInfo = true - return res - } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, p.Debug(), "private PTR lookup, using upstreams: %v", upstreams) - case isLanHostnameQuery(req.msg): - isLanOrPtrQuery = true - if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { - res.answer = answer - res.clientInfo = true - return res - } - upstreams = []string{upstreamOS} - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", upstreams) - default: - ctrld.Log(ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", upstreams) - } - } - - // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 - if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { - for _, upstream := range upstreams { - cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) - if cachedValue == nil { - continue - } - answer := cachedValue.Msg.Copy() - ctrld.SetCacheReply(answer, req.msg, answer.Rcode) - now := time.Now() - if cachedValue.Expire.After(now) { - ctrld.Log(ctx, p.Debug(), "hit cached response") - setCachedAnswerTTL(answer, now, cachedValue.Expire) - res.answer = answer - res.cached = true - return res - } - staleAnswer = answer - } - } - resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) - dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) - if err != nil { - ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver") - return nil, err - } - resolveCtx, cancel := upstreamConfig.Context(ctx) - defer cancel() - return dnsResolver.Resolve(resolveCtx, msg) - } - resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { - if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { - ctrld.Log(ctx, p.Debug(), "including client info with the request") - ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) - } - answer, err := resolve1(upstream, upstreamConfig, msg) - // if we have an answer, we should reset the failure count - // we dont use reset here since we dont want to prevent failure counts from being incremented - if answer != nil { - p.um.mu.Lock() - p.um.failureReq[upstream] = 0 - p.um.down[upstream] = false - p.um.mu.Unlock() - return answer - } - - ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query") - - // increase failure count when there is no answer - // rehardless of what kind of error we get - p.um.increaseFailureCount(upstream) - - if err != nil { - // For timeout error (i.e: context deadline exceed), force re-bootstrapping. - var e net.Error - if errors.As(err, &e) && e.Timeout() { - upstreamConfig.ReBootstrap(ctx) - } - // For network error, turn ipv6 off if enabled. - if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) { - ctrld.DisableIPv6(ctx) - } - } - + ctrld.Log(*ctx, p.Debug(), "%s, %s, %s -> %v", + req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, *upstreams) return nil } - for n, upstreamConfig := range upstreamConfigs { - if upstreamConfig == nil { - continue - } - logger := p.Debug(). - Str("upstream", upstreamConfig.String()). - Str("query", req.msg.Question[0].Name). - Bool("is_lan_query", isLanOrPtrQuery) - if p.isLoop(upstreamConfig) { - ctrld.Log(ctx, logger, "DNS loop detected") - continue + switch { + case isSrvLanLookup(req.msg): + *upstreams = []string{upstreamOS} + *upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "SRV record lookup, using upstreams: %v", *upstreams) + return nil + case isPrivatePtrLookup(req.msg): + req.isLanOrPtrQuery = true + if answer := p.proxyPrivatePtrLookup(*ctx, req.msg); answer != nil { + return &proxyResponse{answer: answer, clientInfo: true} } - answer := resolve(upstreams[n], upstreamConfig, req.msg) - if answer == nil { - if serveStaleCache && staleAnswer != nil { - ctrld.Log(ctx, p.Debug(), "serving stale cached response") - now := time.Now() - setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) - res.answer = staleAnswer - res.cached = true - return res - } - continue - } - // We are doing LAN/PTR lookup using private resolver, so always process next one. - // Except for the last, we want to send response instead of saying all upstream failed. - if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { - ctrld.Log(ctx, p.Debug(), "no response from %s, process to next upstream", upstreams[n]) - continue - } - if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { - ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream") - continue + *upstreams, *upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(*upstreams, *upstreamConfigs) + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "private PTR lookup, using upstreams: %v", *upstreams) + return nil + case isLanHostnameQuery(req.msg): + req.isLanOrPtrQuery = true + if answer := p.proxyLanHostnameQuery(*ctx, req.msg); answer != nil { + return &proxyResponse{answer: answer, clientInfo: true} } + *upstreams = []string{upstreamOS} + *upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", *upstreams) + return nil + default: + ctrld.Log(*ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", *upstreams) + return nil + } +} - // set compression, as it is not set by default when unpacking - answer.Compress = true +// proxy handles DNS query proxying by selecting upstreams, attempting cache lookups, and querying configured resolvers. +func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { + upstreams, upstreamConfigs := p.initializeUpstreams(req) + if specialRes := p.handleSpecialQueryTypes(&ctx, req, &upstreams, &upstreamConfigs); specialRes != nil { + return specialRes + } - if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { - ttl := ttlFromMsg(answer) - now := time.Now() - expired := now.Add(time.Duration(ttl) * time.Second) - if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 { - expired = now.Add(time.Duration(cachedTTL) * time.Second) - } - setCachedAnswerTTL(answer, now, expired) - p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) - ctrld.Log(ctx, p.Debug(), "add cached response") - } - hostname := "" - if req.ci != nil { - hostname = req.ci.Hostname - } - ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) - res.answer = answer - res.upstream = upstreamConfig.Endpoint + if cachedRes := p.tryCache(ctx, req, upstreams); cachedRes != nil { + return cachedRes + } + + if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil { return res } - ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams) - // if we have no healthy upstreams, trigger recovery flow + return p.handleAllUpstreamsFailure(ctx, req, upstreams) +} + +// initializeUpstreams determines which upstreams and configurations to use for a given proxyRequest. +// If no upstreams are configured, it defaults to the operating system's resolver configuration. +// Returns a slice of upstream names and their corresponding configurations. +func (p *prog) initializeUpstreams(req *proxyRequest) ([]string, []*ctrld.UpstreamConfig) { + upstreams := req.ufr.upstreams + upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) + if len(upstreamConfigs) == 0 { + return []string{upstreamOS}, []*ctrld.UpstreamConfig{osUpstreamConfig} + } + return upstreams, upstreamConfigs +} + +// tryCache attempts to retrieve a cached response for the given DNS request from specified upstreams. +// Returns a proxyResponse if a cache hit occurs; otherwise, returns nil. +// Skips cache checking if caching is disabled or the request is a PTR query. +// Iterates through the provided upstreams to find a cached response using the checkCache method. +func (p *prog) tryCache(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse { + if p.cache == nil || req.msg.Question[0].Qtype == dns.TypePTR { // https://www.rfc-editor.org/rfc/rfc1035#section-7.4 + return nil + } + + for _, upstream := range upstreams { + if res := p.checkCache(ctx, req, upstream); res != nil { + return res + } + } + return nil +} + +// checkCache checks if a cached DNS response exists for the given request and upstream. +// Returns a proxyResponse with the cached response if found and valid, or nil otherwise. +func (p *prog) checkCache(ctx context.Context, req *proxyRequest, upstream string) *proxyResponse { + cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) + if cachedValue == nil { + return nil + } + + answer := cachedValue.Msg.Copy() + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) + now := time.Now() + + if cachedValue.Expire.After(now) { + ctrld.Log(ctx, p.Debug(), "hit cached response") + setCachedAnswerTTL(answer, now, cachedValue.Expire) + return &proxyResponse{answer: answer, cached: true} + } + req.staleAnswer = answer + return nil +} + +// updateCache updates the DNS response cache with the given request, response, TTL, and upstream information. +func (p *prog) updateCache(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string) { + ttl := ttlFromMsg(answer) + now := time.Now() + expired := now.Add(time.Duration(ttl) * time.Second) + if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 { + expired = now.Add(time.Duration(cachedTTL) * time.Second) + } + setCachedAnswerTTL(answer, now, expired) + p.cache.Add(dnscache.NewKey(req.msg, upstream), dnscache.NewValue(answer, expired)) + ctrld.Log(ctx, p.Debug(), "add cached response") +} + +// serveStaleResponse serves a stale cached DNS response when an upstream query fails, updating TTL for cached records. +func (p *prog) serveStaleResponse(ctx context.Context, staleAnswer *dns.Msg) *proxyResponse { + ctrld.Log(ctx, p.Debug(), "serving stale cached response") + now := time.Now() + setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) + return &proxyResponse{answer: staleAnswer, cached: true} +} + +// handleAllUpstreamsFailure handles the failure scenario when all upstream resolvers fail to respond or process the request. +func (p *prog) handleAllUpstreamsFailure(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse { + ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams) if p.leakOnUpstreamFailure() { if p.um.countHealthy(upstreams) == 0 { - p.recoveryCancelMu.Lock() - if p.recoveryCancel == nil { - var reason RecoveryReason - if upstreams[0] == upstreamOS { - reason = RecoveryReasonOSFailure - } else { - reason = RecoveryReasonRegularFailure - } - p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) - go p.handleRecovery(reason) - } else { - p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") - } - p.recoveryCancelMu.Unlock() + p.triggerRecovery(upstreams[0] == upstreamOS) } else { p.Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } - // attempt query to OS resolver while as a retry catch all - // we dont want this to happen if leakOnUpstreamFailure is false if upstreams[0] != upstreamOS { - ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all") - answer := resolve(upstreamOS, osUpstreamConfig, req.msg) - if answer != nil { - ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful") - res.answer = answer - res.upstream = osUpstreamConfig.Endpoint - return res + if answer := p.tryOSResolver(ctx, req); answer != nil { + return answer } - ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed") } } answer := new(dns.Msg) answer.SetRcode(req.msg, dns.RcodeServerFailure) - res.answer = answer - return res + return &proxyResponse{answer: answer} } +// shouldContinueWithNextUpstream determines whether processing should continue with the next upstream based on response conditions. +func (p *prog) shouldContinueWithNextUpstream(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, lastUpstream bool) bool { + if answer.Rcode == dns.RcodeSuccess { + return false + } + + // We are doing LAN/PTR lookup using private resolver, so always process the next one. + // Except for the last, we want to send a response instead of saying all upstream failed. + if req.isLanOrPtrQuery && !lastUpstream { + ctrld.Log(ctx, p.Debug(), "no response for LAN/PTR query from %s, process to next upstream", upstream) + return true + } + + if len(req.upstreamConfigs) > 1 && slices.Contains(req.failoverRcodes, answer.Rcode) { + ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream") + return true + } + + return false +} + +// prepareSuccessResponse prepares a successful DNS response for a given request, logs it, and updates the cache if applicable. +func (p *prog) prepareSuccessResponse(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, upstreamConfig *ctrld.UpstreamConfig) *proxyResponse { + answer.Compress = true + + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { + p.updateCache(ctx, req, answer, upstream) + } + + hostname := "" + if req.ci != nil { + hostname = req.ci.Hostname + } + + ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", + upstream, req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) + + return &proxyResponse{ + answer: answer, + upstream: upstreamConfig.Endpoint, + } +} + +// tryUpstreams attempts to proxy a DNS request through the provided upstreams and their configurations sequentially. +// It returns a successful proxyResponse if any upstream processes the request successfully, or nil otherwise. +// The function supports "serve stale" for cache by utilizing cached responses when upstreams fail. +func (p *prog) tryUpstreams(ctx context.Context, req *proxyRequest, upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) *proxyResponse { + serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale + req.upstreamConfigs = upstreamConfigs + for n, upstreamConfig := range upstreamConfigs { + last := n == len(upstreamConfigs)-1 + if res := p.processUpstream(ctx, req, upstreams[n], upstreamConfig, serveStaleCache, last); res != nil { + return res + } + } + return nil +} + +// processUpstream proxies a DNS query to a given upstream server and processes the response based on the provided configuration. +// It supports serving stale cache when upstream queries fail, and checks if processing should continue to another upstream. +// Returns a proxyResponse on success or nil if the upstream query fails or processing conditions are not met. +func (p *prog) processUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig, serveStaleCache, lastUpstream bool) *proxyResponse { + if upstreamConfig == nil { + return nil + } + if p.isLoop(upstreamConfig) { + logger := p.Debug(). + Str("upstream", upstreamConfig.String()). + Str("query", req.msg.Question[0].Name). + Bool("is_lan_query", req.isLanOrPtrQuery) + ctrld.Log(ctx, logger, "DNS loop detected") + return nil + } + + answer := p.queryUpstream(ctx, req, upstream, upstreamConfig) + if answer == nil { + if serveStaleCache && req.staleAnswer != nil { + return p.serveStaleResponse(ctx, req.staleAnswer) + } + return nil + } + + if p.shouldContinueWithNextUpstream(ctx, req, answer, upstream, lastUpstream) { + return nil + } + return p.prepareSuccessResponse(ctx, req, answer, upstream, upstreamConfig) +} + +// queryUpstream sends a DNS query to a specified upstream using its configuration and handles errors and retries. +func (p *prog) queryUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig) *dns.Msg { + if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { + ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) + } + + ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) + dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) + if err != nil { + ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver") + return nil + } + + resolveCtx, cancel := upstreamConfig.Context(ctx) + defer cancel() + + answer, err := dnsResolver.Resolve(resolveCtx, req.msg) + if answer != nil { + p.um.mu.Lock() + p.um.failureReq[upstream] = 0 + p.um.down[upstream] = false + p.um.mu.Unlock() + return answer + } + + ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query") + // Increasing the failure count when there is no answer regardless of what kind of error we get + p.um.increaseFailureCount(upstream) + if err != nil { + // For timeout error (i.e: context deadline exceed), force re-bootstrapping. + var e net.Error + if errors.As(err, &e) && e.Timeout() { + upstreamConfig.ReBootstrap(ctx) + } + // For network error, turn ipv6 off if enabled. + if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) { + ctrld.DisableIPv6(ctx) + } + } + return nil +} + +// triggerRecovery attempts to initiate a recovery process if no healthy upstreams are detected. +// If "isOSFailure" is true, the recovery will account for an operating system failure. +// Logs are generated to indicate whether recovery is triggered or already in progress. +func (p *prog) triggerRecovery(isOSFailure bool) { + p.recoveryCancelMu.Lock() + defer p.recoveryCancelMu.Unlock() + + if p.recoveryCancel == nil { + var reason RecoveryReason + if isOSFailure { + reason = RecoveryReasonOSFailure + } else { + reason = RecoveryReasonRegularFailure + } + p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + go p.handleRecovery(reason) + } else { + p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") + } +} + +// tryOSResolver attempts to query the OS resolver as a fallback mechanism when other upstreams fail. +// Logs success or failure of the query attempt and returns a proxyResponse or nil based on query result. +func (p *prog) tryOSResolver(ctx context.Context, req *proxyRequest) *proxyResponse { + ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all") + answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig) + if answer != nil { + ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful") + return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint} + } + ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed") + return nil +} + +// upstreamsAndUpstreamConfigForPtr returns the updated upstreams and upstreamConfigs for a private PTR lookup scenario. func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { if len(p.localUpstreams) > 0 { tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) @@ -720,6 +824,7 @@ func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConf return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...) } +// upstreamConfigsFromUpstreamNumbers converts a list of upstream names into their corresponding UpstreamConfig objects. func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { @@ -765,10 +870,12 @@ func wildcardMatches(wildcard, str string) bool { return false } +// fmtRemoteToLocal formats a remote address to indicate its mapping to a local listener using listener number and hostname. func fmtRemoteToLocal(listenerNum, hostname, remote string) string { return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum) } +// requestID generates a random 6-character hexadecimal string to uniquely identify a request. It panics on error. func requestID() string { b := make([]byte, 3) // 6 chars if _, err := rand.Read(b); err != nil { @@ -777,15 +884,7 @@ func requestID() string { return hex.EncodeToString(b) } -func containRcode(rcodes []int, rcode int) bool { - for i := range rcodes { - if rcodes[i] == rcode { - return true - } - } - return false -} - +// setCachedAnswerTTL updates the TTL of each DNS record in the provided message based on the current and expiration times. func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) { ttlSecs := expiredTime.Sub(now).Seconds() if ttlSecs < 0 { @@ -806,6 +905,8 @@ func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) { } } +// ttlFromMsg extracts and returns the TTL value from the first record in the Answer or Ns sections of a DNS message. +// If no records exist in either section, the function returns 0. func ttlFromMsg(msg *dns.Msg) uint32 { for _, rr := range msg.Answer { return rr.Header().Ttl @@ -816,6 +917,7 @@ func ttlFromMsg(msg *dns.Msg) uint32 { return 0 } +// needLocalIPv6Listener checks if a local IPv6 listener is required on Windows by verifying IPv6 support and the OS type. func needLocalIPv6Listener() bool { // On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can // listen on ::1, then spawn a listener for receiving DNS requests.