From b18cd7ee83a8aa7c985db68fa3e8b598fb167415 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Jun 2025 20:09:07 +0700 Subject: [PATCH] refactor(dns): improve DNS proxy code structure and readability Break down the large DNS handling function into smaller, focused functions with clear responsibilities: - Extract handleDNSQuery from serveDNS handler function - Create dedicated startListeners function for listener management - Add standardQueryRequest struct to encapsulate query parameters - Split special domain handling into separate function - Add descriptive comments for each new function - Improve variable names for better clarity (e.g., startTime vs t) This refactoring improves code maintainability and readability without changing the core DNS proxy functionality. --- cmd/cli/dns_proxy.go | 300 ++++++++++++++++++++++++++----------------- 1 file changed, 183 insertions(+), 117 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 4491160..a5bbd0b 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -69,9 +69,10 @@ type proxyRequest struct { // proxyResponse contains data for proxying a DNS response from upstream. type proxyResponse struct { answer *dns.Msg + upstream string cached bool clientInfo bool - upstream string + refused bool } // upstreamForResult represents the result of processing rules for a request. @@ -84,151 +85,57 @@ type upstreamForResult struct { srcAddr string } +// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring. func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { - // Start network monitoring if err := p.monitorNetworkChanges(mainCtx); err != nil { p.Error().Err(err).Msg("Failed to start network monitoring") // Don't return here as we still want DNS service to run } listenerConfig := p.cfg.Listener[listenerNum] - // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { - p.sema.acquire() - defer p.sema.release() - if len(m.Question) == 0 { - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeFormatError) - _ = w.WriteMsg(answer) - return - } - listenerConfig := p.cfg.Listener[listenerNum] - reqId := requestID() - ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) - ctx = ctrld.LoggerCtx(ctx, p.logger.Load()) - if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { - ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeRefused) - _ = w.WriteMsg(answer) - return - } - go p.detectLoop(m) - q := m.Question[0] - domain := canonicalName(q.Name) - switch { - case domain == "": - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeFormatError) - _ = w.WriteMsg(answer) - return - case domain == selfCheckInternalTestDomain: - answer := resolveInternalDomainTestQuery(ctx, domain, m) - _ = w.WriteMsg(answer) - return - } - - if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { - p.cache.Purge() - ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain) - } - remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - ci := p.getClientInfo(remoteIP, m) - ci.ClientIDPref = p.cfg.Service.ClientIDPref - stripClientSubnet(m) - remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) - fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String()) - t := time.Now() - ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) - ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) - - labelValues := make([]string, 0, len(statsQueriesCountLabels)) - labelValues = append(labelValues, net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))) - labelValues = append(labelValues, ci.IP) - labelValues = append(labelValues, ci.Mac) - labelValues = append(labelValues, ci.Hostname) - - var answer *dns.Msg - if !ur.matched && listenerConfig.Restricted { - ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String()) - answer = new(dns.Msg) - answer.SetRcode(m, dns.RcodeRefused) - labelValues = append(labelValues, "") // no upstream - } else { - var failoverRcode []int - if listenerConfig.Policy != nil { - failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers - } - pr := p.proxy(ctx, &proxyRequest{ - msg: m, - ci: ci, - failoverRcodes: failoverRcode, - ufr: ur, - }) - go p.doSelfUninstall(pr.answer) - - answer = pr.answer - rtt := time.Since(t) - ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt) - upstream := pr.upstream - switch { - case pr.cached: - upstream = "cache" - case pr.clientInfo: - upstream = "client_info_table" - } - labelValues = append(labelValues, upstream) - } - labelValues = append(labelValues, dns.TypeToString[q.Qtype]) - labelValues = append(labelValues, dns.RcodeToString[answer.Rcode]) - go func() { - p.WithLabelValuesInc(statsQueriesCount, labelValues...) - p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...) - p.forceFetchingAPI(domain) - }() - if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client") - } + p.handleDNSQuery(w, m, listenerNum, listenerConfig) }) - g, ctx := errgroup.WithContext(context.Background()) + return p.startListeners(mainCtx, listenerConfig, handler) +} + +// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols. +// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors. +func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error { + g, gctx := errgroup.WithContext(ctx) + for _, proto := range []string{"udp", "tcp"} { - proto := proto if needLocalIPv6Listener() { g.Go(func() error { - s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler) + s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler) defer s.Shutdown() select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: - // Local ipv6 listener should not terminate ctrld. - // It's a workaround for a quirk on Windows. p.Warn().Err(err).Msg("local ipv6 listener failed") } return nil }) } - // When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918 - // addresses of the machine. So ctrld could receive queries from LAN clients. - if needRFC1918Listeners(listenerConfig) { + + if needRFC1918Listeners(cfg) { g.Go(func() error { for _, addr := range ctrld.Rfc1918Addresses() { func() { - listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port)) + listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port)) s, errCh := runDNSServer(listenAddr, proto, handler) defer s.Shutdown() select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: - // RFC1918 listener should not terminate ctrld. - // It's a workaround for a quirk on system with systemd-resolved. p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) } }() @@ -236,25 +143,183 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { return nil }) } + g.Go(func() error { - addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - p.started <- struct{}{} - select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: return err } return nil }) } + return g.Wait() } +// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers. +func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) { + p.sema.acquire() + defer p.sema.release() + + if len(m.Question) == 0 { + sendDNSResponse(w, m, dns.RcodeFormatError) + return + } + + reqID := requestID() + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID) + ctx = ctrld.LoggerCtx(ctx, p.logger.Load()) + + if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { + ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + sendDNSResponse(w, m, dns.RcodeRefused) + return + } + + go p.detectLoop(m) + + q := m.Question[0] + domain := canonicalName(q.Name) + + if p.handleSpecialDomains(ctx, w, m, domain) { + return + } + p.processStandardQuery(&standardQueryRequest{ + ctx: ctx, + writer: w, + msg: m, + listenerNum: listenerNum, + listenerConfig: listenerConfig, + domain: domain, + }) +} + +// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status. +func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool { + switch { + case domain == "": + sendDNSResponse(w, m, dns.RcodeFormatError) + return true + case domain == selfCheckInternalTestDomain: + answer := resolveInternalDomainTestQuery(ctx, domain, m) + _ = w.WriteMsg(answer) + return true + } + + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { + p.cache.Purge() + ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain) + } + + return false +} + +// standardQueryRequest represents a standard DNS query request with associated context and configuration. +type standardQueryRequest struct { + ctx context.Context + writer dns.ResponseWriter + msg *dns.Msg + listenerNum string + listenerConfig *ctrld.ListenerConfig + domain string +} + +// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response. +func (p *prog) processStandardQuery(req *standardQueryRequest) { + remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String()) + ci := p.getClientInfo(remoteIP, req.msg) + ci.ClientIDPref = p.cfg.Service.ClientIDPref + + stripClientSubnet(req.msg) + remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci) + fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String()) + + startTime := time.Now() + q := req.msg.Question[0] + ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain) + + ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain) + + var answer *dns.Msg + // Handle restricted listener case + if !ur.matched && req.listenerConfig.Restricted { + ctrld.Log(req.ctx, p.Debug(), "query refused, %s does not match any network policy", remoteAddr.String()) + answer = new(dns.Msg) + answer.SetRcode(req.msg, dns.RcodeRefused) + // Process the refused query + go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true}) + } else { + // Process a normal query + pr := p.proxy(req.ctx, &proxyRequest{ + msg: req.msg, + ci: ci, + failoverRcodes: p.getFailoverRcodes(req.listenerConfig), + ufr: ur, + }) + + rtt := time.Since(startTime) + ctrld.Log(req.ctx, p.Debug(), "received response of %d bytes in %s", pr.answer.Len(), rtt) + + go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr) + answer = pr.answer + } + + if err := req.writer.WriteMsg(answer); err != nil { + ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client") + } +} + +// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording, +// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures. +func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) { + p.doSelfUninstall(pr) + p.recordMetrics(ci, listenerConfig, q, pr) + p.forceFetchingAPI(canonicalName(q.Name)) +} + +// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists. +func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int { + if cfg.Policy != nil { + return cfg.Policy.FailoverRcodeNumbers + } + return nil +} + +// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics. +func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) { + upstream := pr.upstream + switch { + case pr.cached: + upstream = "cache" + case pr.clientInfo: + upstream = "client_info_table" + } + labelValues := []string{ + net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)), + ci.IP, + ci.Mac, + ci.Hostname, + upstream, + dns.TypeToString[q.Qtype], + dns.RcodeToString[pr.answer.Rcode], + } + p.WithLabelValuesInc(statsQueriesCount, labelValues...) + p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...) +} + +// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter. +func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) { + answer := new(dns.Msg) + answer.SetRcode(m, rcode) + _ = w.WriteMsg(answer) +} + // upstreamFor returns the list of upstreams for resolving the given domain, // matching by policies defined in the listener config. The second return value // reports whether the domain matches the policy. @@ -947,8 +1012,9 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { // - There is only 1 ControlD upstream in-use. // - Number of refused queries seen so far equals to selfUninstallMaxQueries. // - The cdUID is deleted. -func (p *prog) doSelfUninstall(answer *dns.Msg) { - if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { +func (p *prog) doSelfUninstall(pr *proxyResponse) { + answer := pr.answer + if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { return }