refactor(dns): improve DNS proxy code structure and readability

Break down the large DNS handling function into smaller, focused functions
with clear responsibilities:

- Extract handleDNSQuery from serveDNS handler function
- Create dedicated startListeners function for listener management
- Add standardQueryRequest struct to encapsulate query parameters
- Split special domain handling into separate function
- Add descriptive comments for each new function
- Improve variable names for better clarity (e.g., startTime vs t)

This refactoring improves code maintainability and readability without
changing the core DNS proxy functionality.
This commit is contained in:
Cuong Manh Le
2025-06-19 20:09:07 +07:00
committed by Cuong Manh Le
parent a16b25ad1d
commit b18cd7ee83

View File

@@ -69,9 +69,10 @@ type proxyRequest struct {
// proxyResponse contains data for proxying a DNS response from upstream. // proxyResponse contains data for proxying a DNS response from upstream.
type proxyResponse struct { type proxyResponse struct {
answer *dns.Msg answer *dns.Msg
upstream string
cached bool cached bool
clientInfo bool clientInfo bool
upstream string refused bool
} }
// upstreamForResult represents the result of processing rules for a request. // upstreamForResult represents the result of processing rules for a request.
@@ -84,151 +85,57 @@ type upstreamForResult struct {
srcAddr string srcAddr string
} }
// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring.
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
// Start network monitoring
if err := p.monitorNetworkChanges(mainCtx); err != nil { if err := p.monitorNetworkChanges(mainCtx); err != nil {
p.Error().Err(err).Msg("Failed to start network monitoring") p.Error().Err(err).Msg("Failed to start network monitoring")
// Don't return here as we still want DNS service to run // Don't return here as we still want DNS service to run
} }
listenerConfig := p.cfg.Listener[listenerNum] listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
return allocErr return allocErr
} }
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
p.sema.acquire() p.handleDNSQuery(w, m, listenerNum, listenerConfig)
defer p.sema.release()
if len(m.Question) == 0 {
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeFormatError)
_ = w.WriteMsg(answer)
return
}
listenerConfig := p.cfg.Listener[listenerNum]
reqId := requestID()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
_ = w.WriteMsg(answer)
return
}
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
switch {
case domain == "":
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeFormatError)
_ = w.WriteMsg(answer)
return
case domain == selfCheckInternalTestDomain:
answer := resolveInternalDomainTestQuery(ctx, domain, m)
_ = w.WriteMsg(answer)
return
}
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
p.cache.Purge()
ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain)
}
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
ci := p.getClientInfo(remoteIP, m)
ci.ClientIDPref = p.cfg.Service.ClientIDPref
stripClientSubnet(m)
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String())
t := time.Now()
ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
labelValues := make([]string, 0, len(statsQueriesCountLabels))
labelValues = append(labelValues, net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)))
labelValues = append(labelValues, ci.IP)
labelValues = append(labelValues, ci.Mac)
labelValues = append(labelValues, ci.Hostname)
var answer *dns.Msg
if !ur.matched && listenerConfig.Restricted {
ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String())
answer = new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
labelValues = append(labelValues, "") // no upstream
} else {
var failoverRcode []int
if listenerConfig.Policy != nil {
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
}
pr := p.proxy(ctx, &proxyRequest{
msg: m,
ci: ci,
failoverRcodes: failoverRcode,
ufr: ur,
})
go p.doSelfUninstall(pr.answer)
answer = pr.answer
rtt := time.Since(t)
ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
upstream := pr.upstream
switch {
case pr.cached:
upstream = "cache"
case pr.clientInfo:
upstream = "client_info_table"
}
labelValues = append(labelValues, upstream)
}
labelValues = append(labelValues, dns.TypeToString[q.Qtype])
labelValues = append(labelValues, dns.RcodeToString[answer.Rcode])
go func() {
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
p.forceFetchingAPI(domain)
}()
if err := w.WriteMsg(answer); err != nil {
ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
}
}) })
g, ctx := errgroup.WithContext(context.Background()) return p.startListeners(mainCtx, listenerConfig, handler)
}
// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols.
// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors.
func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error {
g, gctx := errgroup.WithContext(ctx)
for _, proto := range []string{"udp", "tcp"} { for _, proto := range []string{"udp", "tcp"} {
proto := proto
if needLocalIPv6Listener() { if needLocalIPv6Listener() {
g.Go(func() error { g.Go(func() error {
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler) s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler)
defer s.Shutdown() defer s.Shutdown()
select { select {
case <-p.stopCh: case <-p.stopCh:
case <-ctx.Done(): case <-gctx.Done():
case err := <-errCh: case err := <-errCh:
// Local ipv6 listener should not terminate ctrld.
// It's a workaround for a quirk on Windows.
p.Warn().Err(err).Msg("local ipv6 listener failed") p.Warn().Err(err).Msg("local ipv6 listener failed")
} }
return nil return nil
}) })
} }
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918
// addresses of the machine. So ctrld could receive queries from LAN clients. if needRFC1918Listeners(cfg) {
if needRFC1918Listeners(listenerConfig) {
g.Go(func() error { g.Go(func() error {
for _, addr := range ctrld.Rfc1918Addresses() { for _, addr := range ctrld.Rfc1918Addresses() {
func() { func() {
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port)) listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port))
s, errCh := runDNSServer(listenAddr, proto, handler) s, errCh := runDNSServer(listenAddr, proto, handler)
defer s.Shutdown() defer s.Shutdown()
select { select {
case <-p.stopCh: case <-p.stopCh:
case <-ctx.Done(): case <-gctx.Done():
case err := <-errCh: case err := <-errCh:
// RFC1918 listener should not terminate ctrld.
// It's a workaround for a quirk on system with systemd-resolved.
p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr)
} }
}() }()
@@ -236,25 +143,183 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
return nil return nil
}) })
} }
g.Go(func() error { g.Go(func() error {
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port))
s, errCh := runDNSServer(addr, proto, handler) s, errCh := runDNSServer(addr, proto, handler)
defer s.Shutdown() defer s.Shutdown()
p.started <- struct{}{} p.started <- struct{}{}
select { select {
case <-p.stopCh: case <-p.stopCh:
case <-ctx.Done(): case <-gctx.Done():
case err := <-errCh: case err := <-errCh:
return err return err
} }
return nil return nil
}) })
} }
return g.Wait() return g.Wait()
} }
// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers.
func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) {
p.sema.acquire()
defer p.sema.release()
if len(m.Question) == 0 {
sendDNSResponse(w, m, dns.RcodeFormatError)
return
}
reqID := requestID()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID)
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
sendDNSResponse(w, m, dns.RcodeRefused)
return
}
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
if p.handleSpecialDomains(ctx, w, m, domain) {
return
}
p.processStandardQuery(&standardQueryRequest{
ctx: ctx,
writer: w,
msg: m,
listenerNum: listenerNum,
listenerConfig: listenerConfig,
domain: domain,
})
}
// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status.
func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool {
switch {
case domain == "":
sendDNSResponse(w, m, dns.RcodeFormatError)
return true
case domain == selfCheckInternalTestDomain:
answer := resolveInternalDomainTestQuery(ctx, domain, m)
_ = w.WriteMsg(answer)
return true
}
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
p.cache.Purge()
ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain)
}
return false
}
// standardQueryRequest represents a standard DNS query request with associated context and configuration.
type standardQueryRequest struct {
ctx context.Context
writer dns.ResponseWriter
msg *dns.Msg
listenerNum string
listenerConfig *ctrld.ListenerConfig
domain string
}
// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response.
func (p *prog) processStandardQuery(req *standardQueryRequest) {
remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String())
ci := p.getClientInfo(remoteIP, req.msg)
ci.ClientIDPref = p.cfg.Service.ClientIDPref
stripClientSubnet(req.msg)
remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci)
fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String())
startTime := time.Now()
q := req.msg.Question[0]
ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain)
ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain)
var answer *dns.Msg
// Handle restricted listener case
if !ur.matched && req.listenerConfig.Restricted {
ctrld.Log(req.ctx, p.Debug(), "query refused, %s does not match any network policy", remoteAddr.String())
answer = new(dns.Msg)
answer.SetRcode(req.msg, dns.RcodeRefused)
// Process the refused query
go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true})
} else {
// Process a normal query
pr := p.proxy(req.ctx, &proxyRequest{
msg: req.msg,
ci: ci,
failoverRcodes: p.getFailoverRcodes(req.listenerConfig),
ufr: ur,
})
rtt := time.Since(startTime)
ctrld.Log(req.ctx, p.Debug(), "received response of %d bytes in %s", pr.answer.Len(), rtt)
go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr)
answer = pr.answer
}
if err := req.writer.WriteMsg(answer); err != nil {
ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
}
}
// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording,
// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures.
func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
p.doSelfUninstall(pr)
p.recordMetrics(ci, listenerConfig, q, pr)
p.forceFetchingAPI(canonicalName(q.Name))
}
// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists.
func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int {
if cfg.Policy != nil {
return cfg.Policy.FailoverRcodeNumbers
}
return nil
}
// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics.
func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
upstream := pr.upstream
switch {
case pr.cached:
upstream = "cache"
case pr.clientInfo:
upstream = "client_info_table"
}
labelValues := []string{
net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)),
ci.IP,
ci.Mac,
ci.Hostname,
upstream,
dns.TypeToString[q.Qtype],
dns.RcodeToString[pr.answer.Rcode],
}
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
}
// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter.
func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) {
answer := new(dns.Msg)
answer.SetRcode(m, rcode)
_ = w.WriteMsg(answer)
}
// upstreamFor returns the list of upstreams for resolving the given domain, // upstreamFor returns the list of upstreams for resolving the given domain,
// matching by policies defined in the listener config. The second return value // matching by policies defined in the listener config. The second return value
// reports whether the domain matches the policy. // reports whether the domain matches the policy.
@@ -947,8 +1012,9 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
// - There is only 1 ControlD upstream in-use. // - There is only 1 ControlD upstream in-use.
// - Number of refused queries seen so far equals to selfUninstallMaxQueries. // - Number of refused queries seen so far equals to selfUninstallMaxQueries.
// - The cdUID is deleted. // - The cdUID is deleted.
func (p *prog) doSelfUninstall(answer *dns.Msg) { func (p *prog) doSelfUninstall(pr *proxyResponse) {
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { answer := pr.answer
if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
return return
} }