mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
refactor: break down proxy method into smaller focused functions
Split the long proxy method into several smaller methods to improve maintainability and testability. Each new method has a single responsibility: - initializeUpstreams: handles upstream configuration setup - tryCache: manages cache lookup logic - tryUpstreams: coordinates upstream query attempts - processUpstream: handles individual upstream query processing - handleAllUpstreamsFailure: manages failure scenarios - checkCache: performs cache checks and retrieval - serveStaleResponse: handles stale cache responses - shouldContinueWithNextUpstream: determines if failover is needed - prepareSuccessResponse: formats successful responses This refactoring: - Reduces cognitive complexity - Improves code testability - Makes the DNS proxy logic flow clearer - Isolates error handling and edge cases - Maintains existing functionality No behavioral changes were made.
This commit is contained in:
committed by
Cuong Manh Le
parent
f7fb555c89
commit
41282d0f51
@@ -53,10 +53,13 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
|
||||
// proxyRequest contains data for proxying a DNS query to upstream.
|
||||
type proxyRequest struct {
|
||||
msg *dns.Msg
|
||||
ci *ctrld.ClientInfo
|
||||
failoverRcodes []int
|
||||
ufr *upstreamForResult
|
||||
msg *dns.Msg
|
||||
ci *ctrld.ClientInfo
|
||||
failoverRcodes []int
|
||||
ufr *upstreamForResult
|
||||
staleAnswer *dns.Msg
|
||||
isLanOrPtrQuery bool
|
||||
upstreamConfigs []*ctrld.UpstreamConfig
|
||||
}
|
||||
|
||||
// proxyResponse contains data for proxying a DNS response from upstream.
|
||||
@@ -409,6 +412,10 @@ macRules:
|
||||
return
|
||||
}
|
||||
|
||||
// proxyPrivatePtrLookup performs a private PTR DNS lookup based on the client info table for the given query.
|
||||
// It prevents DNS loops by locking the processing of the same domain name simultaneously.
|
||||
// If a valid IP-to-hostname mapping exists, it creates a PTR DNS record as the response.
|
||||
// Returns the DNS response if a hostname is found or nil otherwise.
|
||||
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
cDomainName := msg.Question[0].Name
|
||||
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
||||
@@ -440,6 +447,10 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg
|
||||
return nil
|
||||
}
|
||||
|
||||
// proxyLanHostnameQuery resolves LAN hostnames to their corresponding IP addresses based on the dns.Msg request.
|
||||
// It uses a loop guard mechanism to prevent DNS query loops and ensures a hostname is processed only once at a time.
|
||||
// This method queries the client info table for the hostname's IP address and logs relevant debug and client info.
|
||||
// If the hostname matches known IPs in the table, it generates an appropriate dns.Msg response; otherwise, it returns nil.
|
||||
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
q := msg.Question[0]
|
||||
hostname := strings.TrimSuffix(q.Name, ".")
|
||||
@@ -485,231 +496,324 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
var staleAnswer *dns.Msg
|
||||
upstreams := req.ufr.upstreams
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
res := &proxyResponse{}
|
||||
|
||||
// LAN/PTR lookup flow:
|
||||
//
|
||||
// 1. If there's matching rule, follow it.
|
||||
// 2. Try from client info table.
|
||||
// 3. Try private resolver.
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
// handleSpecialQueryTypes processes specific types of DNS queries such as SRV, PTR, and LAN hostname lookups.
|
||||
// It modifies upstreams and upstreamConfigs based on the query type and updates the query context accordingly.
|
||||
// Returns a proxyResponse if the query is resolved locally; otherwise, returns nil to proceed with upstream processing.
|
||||
func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest, upstreams *[]string, upstreamConfigs *[]*ctrld.UpstreamConfig) *proxyResponse {
|
||||
if req.ufr.matched {
|
||||
ctrld.Log(ctx, p.Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isSrvLanLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, p.Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
res.answer = answer
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, p.Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
|
||||
res.answer = answer
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
||||
if cachedValue == nil {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, p.Debug(), "hit cached response")
|
||||
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
||||
res.answer = answer
|
||||
res.cached = true
|
||||
return res
|
||||
}
|
||||
staleAnswer = answer
|
||||
}
|
||||
}
|
||||
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
}
|
||||
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
||||
defer cancel()
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(upstream, upstreamConfig, msg)
|
||||
// if we have an answer, we should reset the failure count
|
||||
// we dont use reset here since we dont want to prevent failure counts from being incremented
|
||||
if answer != nil {
|
||||
p.um.mu.Lock()
|
||||
p.um.failureReq[upstream] = 0
|
||||
p.um.down[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
return answer
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query")
|
||||
|
||||
// increase failure count when there is no answer
|
||||
// rehardless of what kind of error we get
|
||||
p.um.increaseFailureCount(upstream)
|
||||
|
||||
if err != nil {
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap(ctx)
|
||||
}
|
||||
// For network error, turn ipv6 off if enabled.
|
||||
if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) {
|
||||
ctrld.DisableIPv6(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
ctrld.Log(*ctx, p.Debug(), "%s, %s, %s -> %v",
|
||||
req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, *upstreams)
|
||||
return nil
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
logger := p.Debug().
|
||||
Str("upstream", upstreamConfig.String()).
|
||||
Str("query", req.msg.Question[0].Name).
|
||||
Bool("is_lan_query", isLanOrPtrQuery)
|
||||
|
||||
if p.isLoop(upstreamConfig) {
|
||||
ctrld.Log(ctx, logger, "DNS loop detected")
|
||||
continue
|
||||
switch {
|
||||
case isSrvLanLookup(req.msg):
|
||||
*upstreams = []string{upstreamOS}
|
||||
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
*ctx = ctrld.LanQueryCtx(*ctx)
|
||||
ctrld.Log(*ctx, p.Debug(), "SRV record lookup, using upstreams: %v", *upstreams)
|
||||
return nil
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
req.isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(*ctx, req.msg); answer != nil {
|
||||
return &proxyResponse{answer: answer, clientInfo: true}
|
||||
}
|
||||
answer := resolve(upstreams[n], upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "serving stale cached response")
|
||||
now := time.Now()
|
||||
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
||||
res.answer = staleAnswer
|
||||
res.cached = true
|
||||
return res
|
||||
}
|
||||
continue
|
||||
}
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
||||
// Except for the last, we want to send response instead of saying all upstream failed.
|
||||
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
||||
ctrld.Log(ctx, p.Debug(), "no response from %s, process to next upstream", upstreams[n])
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
*upstreams, *upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(*upstreams, *upstreamConfigs)
|
||||
*ctx = ctrld.LanQueryCtx(*ctx)
|
||||
ctrld.Log(*ctx, p.Debug(), "private PTR lookup, using upstreams: %v", *upstreams)
|
||||
return nil
|
||||
case isLanHostnameQuery(req.msg):
|
||||
req.isLanOrPtrQuery = true
|
||||
if answer := p.proxyLanHostnameQuery(*ctx, req.msg); answer != nil {
|
||||
return &proxyResponse{answer: answer, clientInfo: true}
|
||||
}
|
||||
*upstreams = []string{upstreamOS}
|
||||
*upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
*ctx = ctrld.LanQueryCtx(*ctx)
|
||||
ctrld.Log(*ctx, p.Debug(), "lan hostname lookup, using upstreams: %v", *upstreams)
|
||||
return nil
|
||||
default:
|
||||
ctrld.Log(*ctx, p.Debug(), "no explicit policy matched, using default routing -> %v", *upstreams)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
// proxy handles DNS query proxying by selecting upstreams, attempting cache lookups, and querying configured resolvers.
|
||||
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
upstreams, upstreamConfigs := p.initializeUpstreams(req)
|
||||
if specialRes := p.handleSpecialQueryTypes(&ctx, req, &upstreams, &upstreamConfigs); specialRes != nil {
|
||||
return specialRes
|
||||
}
|
||||
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, p.Debug(), "add cached response")
|
||||
}
|
||||
hostname := ""
|
||||
if req.ci != nil {
|
||||
hostname = req.ci.Hostname
|
||||
}
|
||||
ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
||||
res.answer = answer
|
||||
res.upstream = upstreamConfig.Endpoint
|
||||
if cachedRes := p.tryCache(ctx, req, upstreams); cachedRes != nil {
|
||||
return cachedRes
|
||||
}
|
||||
|
||||
if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil {
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams)
|
||||
|
||||
// if we have no healthy upstreams, trigger recovery flow
|
||||
return p.handleAllUpstreamsFailure(ctx, req, upstreams)
|
||||
}
|
||||
|
||||
// initializeUpstreams determines which upstreams and configurations to use for a given proxyRequest.
|
||||
// If no upstreams are configured, it defaults to the operating system's resolver configuration.
|
||||
// Returns a slice of upstream names and their corresponding configurations.
|
||||
func (p *prog) initializeUpstreams(req *proxyRequest) ([]string, []*ctrld.UpstreamConfig) {
|
||||
upstreams := req.ufr.upstreams
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
return []string{upstreamOS}, []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
}
|
||||
return upstreams, upstreamConfigs
|
||||
}
|
||||
|
||||
// tryCache attempts to retrieve a cached response for the given DNS request from specified upstreams.
|
||||
// Returns a proxyResponse if a cache hit occurs; otherwise, returns nil.
|
||||
// Skips cache checking if caching is disabled or the request is a PTR query.
|
||||
// Iterates through the provided upstreams to find a cached response using the checkCache method.
|
||||
func (p *prog) tryCache(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
||||
if p.cache == nil || req.msg.Question[0].Qtype == dns.TypePTR { // https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
if res := p.checkCache(ctx, req, upstream); res != nil {
|
||||
return res
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkCache checks if a cached DNS response exists for the given request and upstream.
|
||||
// Returns a proxyResponse with the cached response if found and valid, or nil otherwise.
|
||||
func (p *prog) checkCache(ctx context.Context, req *proxyRequest, upstream string) *proxyResponse {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
||||
if cachedValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
answer := cachedValue.Msg.Copy()
|
||||
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, p.Debug(), "hit cached response")
|
||||
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
||||
return &proxyResponse{answer: answer, cached: true}
|
||||
}
|
||||
req.staleAnswer = answer
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateCache updates the DNS response cache with the given request, response, TTL, and upstream information.
|
||||
func (p *prog) updateCache(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string) {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(req.msg, upstream), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, p.Debug(), "add cached response")
|
||||
}
|
||||
|
||||
// serveStaleResponse serves a stale cached DNS response when an upstream query fails, updating TTL for cached records.
|
||||
func (p *prog) serveStaleResponse(ctx context.Context, staleAnswer *dns.Msg) *proxyResponse {
|
||||
ctrld.Log(ctx, p.Debug(), "serving stale cached response")
|
||||
now := time.Now()
|
||||
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
||||
return &proxyResponse{answer: staleAnswer, cached: true}
|
||||
}
|
||||
|
||||
// handleAllUpstreamsFailure handles the failure scenario when all upstream resolvers fail to respond or process the request.
|
||||
func (p *prog) handleAllUpstreamsFailure(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse {
|
||||
ctrld.Log(ctx, p.Error(), "all %v endpoints failed", upstreams)
|
||||
if p.leakOnUpstreamFailure() {
|
||||
if p.um.countHealthy(upstreams) == 0 {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel == nil {
|
||||
var reason RecoveryReason
|
||||
if upstreams[0] == upstreamOS {
|
||||
reason = RecoveryReasonOSFailure
|
||||
} else {
|
||||
reason = RecoveryReasonRegularFailure
|
||||
}
|
||||
p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
go p.handleRecovery(reason)
|
||||
} else {
|
||||
p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
p.triggerRecovery(upstreams[0] == upstreamOS)
|
||||
} else {
|
||||
p.Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
}
|
||||
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
// we dont want this to happen if leakOnUpstreamFailure is false
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
if answer := p.tryOSResolver(ctx, req); answer != nil {
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
||||
res.answer = answer
|
||||
return res
|
||||
return &proxyResponse{answer: answer}
|
||||
}
|
||||
|
||||
// shouldContinueWithNextUpstream determines whether processing should continue with the next upstream based on response conditions.
|
||||
func (p *prog) shouldContinueWithNextUpstream(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, lastUpstream bool) bool {
|
||||
if answer.Rcode == dns.RcodeSuccess {
|
||||
return false
|
||||
}
|
||||
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process the next one.
|
||||
// Except for the last, we want to send a response instead of saying all upstream failed.
|
||||
if req.isLanOrPtrQuery && !lastUpstream {
|
||||
ctrld.Log(ctx, p.Debug(), "no response for LAN/PTR query from %s, process to next upstream", upstream)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(req.upstreamConfigs) > 1 && slices.Contains(req.failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, p.Debug(), "failover rcode matched, process to next upstream")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSuccessResponse prepares a successful DNS response for a given request, logs it, and updates the cache if applicable.
|
||||
func (p *prog) prepareSuccessResponse(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, upstreamConfig *ctrld.UpstreamConfig) *proxyResponse {
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
p.updateCache(ctx, req, answer, upstream)
|
||||
}
|
||||
|
||||
hostname := ""
|
||||
if req.ci != nil {
|
||||
hostname = req.ci.Hostname
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s",
|
||||
upstream, req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode])
|
||||
|
||||
return &proxyResponse{
|
||||
answer: answer,
|
||||
upstream: upstreamConfig.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// tryUpstreams attempts to proxy a DNS request through the provided upstreams and their configurations sequentially.
|
||||
// It returns a successful proxyResponse if any upstream processes the request successfully, or nil otherwise.
|
||||
// The function supports "serve stale" for cache by utilizing cached responses when upstreams fail.
|
||||
func (p *prog) tryUpstreams(ctx context.Context, req *proxyRequest, upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) *proxyResponse {
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
req.upstreamConfigs = upstreamConfigs
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
last := n == len(upstreamConfigs)-1
|
||||
if res := p.processUpstream(ctx, req, upstreams[n], upstreamConfig, serveStaleCache, last); res != nil {
|
||||
return res
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUpstream proxies a DNS query to a given upstream server and processes the response based on the provided configuration.
|
||||
// It supports serving stale cache when upstream queries fail, and checks if processing should continue to another upstream.
|
||||
// Returns a proxyResponse on success or nil if the upstream query fails or processing conditions are not met.
|
||||
func (p *prog) processUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig, serveStaleCache, lastUpstream bool) *proxyResponse {
|
||||
if upstreamConfig == nil {
|
||||
return nil
|
||||
}
|
||||
if p.isLoop(upstreamConfig) {
|
||||
logger := p.Debug().
|
||||
Str("upstream", upstreamConfig.String()).
|
||||
Str("query", req.msg.Question[0].Name).
|
||||
Bool("is_lan_query", req.isLanOrPtrQuery)
|
||||
ctrld.Log(ctx, logger, "DNS loop detected")
|
||||
return nil
|
||||
}
|
||||
|
||||
answer := p.queryUpstream(ctx, req, upstream, upstreamConfig)
|
||||
if answer == nil {
|
||||
if serveStaleCache && req.staleAnswer != nil {
|
||||
return p.serveStaleResponse(ctx, req.staleAnswer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.shouldContinueWithNextUpstream(ctx, req, answer, upstream, lastUpstream) {
|
||||
return nil
|
||||
}
|
||||
return p.prepareSuccessResponse(ctx, req, answer, upstream, upstreamConfig)
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to a specified upstream using its configuration and handles errors and retries.
|
||||
func (p *prog) queryUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to create resolver")
|
||||
return nil
|
||||
}
|
||||
|
||||
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
||||
defer cancel()
|
||||
|
||||
answer, err := dnsResolver.Resolve(resolveCtx, req.msg)
|
||||
if answer != nil {
|
||||
p.um.mu.Lock()
|
||||
p.um.failureReq[upstream] = 0
|
||||
p.um.down[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
return answer
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, p.Error().Err(err), "failed to resolve query")
|
||||
// Increasing the failure count when there is no answer regardless of what kind of error we get
|
||||
p.um.increaseFailureCount(upstream)
|
||||
if err != nil {
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap(ctx)
|
||||
}
|
||||
// For network error, turn ipv6 off if enabled.
|
||||
if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) {
|
||||
ctrld.DisableIPv6(ctx)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// triggerRecovery attempts to initiate a recovery process if no healthy upstreams are detected.
|
||||
// If "isOSFailure" is true, the recovery will account for an operating system failure.
|
||||
// Logs are generated to indicate whether recovery is triggered or already in progress.
|
||||
func (p *prog) triggerRecovery(isOSFailure bool) {
|
||||
p.recoveryCancelMu.Lock()
|
||||
defer p.recoveryCancelMu.Unlock()
|
||||
|
||||
if p.recoveryCancel == nil {
|
||||
var reason RecoveryReason
|
||||
if isOSFailure {
|
||||
reason = RecoveryReasonOSFailure
|
||||
} else {
|
||||
reason = RecoveryReasonRegularFailure
|
||||
}
|
||||
p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
go p.handleRecovery(reason)
|
||||
} else {
|
||||
p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
}
|
||||
}
|
||||
|
||||
// tryOSResolver attempts to query the OS resolver as a fallback mechanism when other upstreams fail.
|
||||
// Logs success or failure of the query attempt and returns a proxyResponse or nil based on query result.
|
||||
func (p *prog) tryOSResolver(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.Log(ctx, p.Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful")
|
||||
return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint}
|
||||
}
|
||||
ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// upstreamsAndUpstreamConfigForPtr returns the updated upstreams and upstreamConfigs for a private PTR lookup scenario.
|
||||
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
@@ -720,6 +824,7 @@ func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConf
|
||||
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
||||
}
|
||||
|
||||
// upstreamConfigsFromUpstreamNumbers converts a list of upstream names into their corresponding UpstreamConfig objects.
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
@@ -765,10 +870,12 @@ func wildcardMatches(wildcard, str string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// fmtRemoteToLocal formats a remote address to indicate its mapping to a local listener using listener number and hostname.
|
||||
func fmtRemoteToLocal(listenerNum, hostname, remote string) string {
|
||||
return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum)
|
||||
}
|
||||
|
||||
// requestID generates a random 6-character hexadecimal string to uniquely identify a request. It panics on error.
|
||||
func requestID() string {
|
||||
b := make([]byte, 3) // 6 chars
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
@@ -777,15 +884,7 @@ func requestID() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func containRcode(rcodes []int, rcode int) bool {
|
||||
for i := range rcodes {
|
||||
if rcodes[i] == rcode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// setCachedAnswerTTL updates the TTL of each DNS record in the provided message based on the current and expiration times.
|
||||
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
||||
ttlSecs := expiredTime.Sub(now).Seconds()
|
||||
if ttlSecs < 0 {
|
||||
@@ -806,6 +905,8 @@ func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
||||
}
|
||||
}
|
||||
|
||||
// ttlFromMsg extracts and returns the TTL value from the first record in the Answer or Ns sections of a DNS message.
|
||||
// If no records exist in either section, the function returns 0.
|
||||
func ttlFromMsg(msg *dns.Msg) uint32 {
|
||||
for _, rr := range msg.Answer {
|
||||
return rr.Header().Ttl
|
||||
@@ -816,6 +917,7 @@ func ttlFromMsg(msg *dns.Msg) uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// needLocalIPv6Listener checks if a local IPv6 listener is required on Windows by verifying IPv6 support and the OS type.
|
||||
func needLocalIPv6Listener() bool {
|
||||
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
||||
// listen on ::1, then spawn a listener for receiving DNS requests.
|
||||
|
||||
Reference in New Issue
Block a user