mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
refactor(dns): improve DNS proxy code structure and readability
Break down the large DNS handling function into smaller, focused functions with clear responsibilities: - Extract handleDNSQuery from serveDNS handler function - Create dedicated startListeners function for listener management - Add standardQueryRequest struct to encapsulate query parameters - Split special domain handling into separate function - Add descriptive comments for each new function - Improve variable names for better clarity (e.g., startTime vs t) This refactoring improves code maintainability and readability without changing the core DNS proxy functionality.
This commit is contained in:
committed by
Cuong Manh Le
parent
a16b25ad1d
commit
b18cd7ee83
@@ -69,9 +69,10 @@ type proxyRequest struct {
|
||||
// proxyResponse contains data for proxying a DNS response from upstream.
|
||||
type proxyResponse struct {
|
||||
answer *dns.Msg
|
||||
upstream string
|
||||
cached bool
|
||||
clientInfo bool
|
||||
upstream string
|
||||
refused bool
|
||||
}
|
||||
|
||||
// upstreamForResult represents the result of processing rules for a request.
|
||||
@@ -84,151 +85,57 @@ type upstreamForResult struct {
|
||||
srcAddr string
|
||||
}
|
||||
|
||||
// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring.
|
||||
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
// Start network monitoring
|
||||
if err := p.monitorNetworkChanges(mainCtx); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to start network monitoring")
|
||||
// Don't return here as we still want DNS service to run
|
||||
}
|
||||
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
if len(m.Question) == 0 {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
switch {
|
||||
case domain == "":
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
case domain == selfCheckInternalTestDomain:
|
||||
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain)
|
||||
}
|
||||
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ci := p.getClientInfo(remoteIP, m)
|
||||
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
||||
stripClientSubnet(m)
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String())
|
||||
t := time.Now()
|
||||
ctrld.Log(ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
|
||||
|
||||
labelValues := make([]string, 0, len(statsQueriesCountLabels))
|
||||
labelValues = append(labelValues, net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)))
|
||||
labelValues = append(labelValues, ci.IP)
|
||||
labelValues = append(labelValues, ci.Mac)
|
||||
labelValues = append(labelValues, ci.Hostname)
|
||||
|
||||
var answer *dns.Msg
|
||||
if !ur.matched && listenerConfig.Restricted {
|
||||
ctrld.Log(ctx, p.Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
labelValues = append(labelValues, "") // no upstream
|
||||
} else {
|
||||
var failoverRcode []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
pr := p.proxy(ctx, &proxyRequest{
|
||||
msg: m,
|
||||
ci: ci,
|
||||
failoverRcodes: failoverRcode,
|
||||
ufr: ur,
|
||||
})
|
||||
go p.doSelfUninstall(pr.answer)
|
||||
|
||||
answer = pr.answer
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, p.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
upstream := pr.upstream
|
||||
switch {
|
||||
case pr.cached:
|
||||
upstream = "cache"
|
||||
case pr.clientInfo:
|
||||
upstream = "client_info_table"
|
||||
}
|
||||
labelValues = append(labelValues, upstream)
|
||||
}
|
||||
labelValues = append(labelValues, dns.TypeToString[q.Qtype])
|
||||
labelValues = append(labelValues, dns.RcodeToString[answer.Rcode])
|
||||
go func() {
|
||||
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
|
||||
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
|
||||
p.forceFetchingAPI(domain)
|
||||
}()
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
}
|
||||
p.handleDNSQuery(w, m, listenerNum, listenerConfig)
|
||||
})
|
||||
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
return p.startListeners(mainCtx, listenerConfig, handler)
|
||||
}
|
||||
|
||||
// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols.
|
||||
// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors.
|
||||
func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error {
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
|
||||
for _, proto := range []string{"udp", "tcp"} {
|
||||
proto := proto
|
||||
if needLocalIPv6Listener() {
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
|
||||
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-gctx.Done():
|
||||
case err := <-errCh:
|
||||
// Local ipv6 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on Windows.
|
||||
p.Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918
|
||||
// addresses of the machine. So ctrld could receive queries from LAN clients.
|
||||
if needRFC1918Listeners(listenerConfig) {
|
||||
|
||||
if needRFC1918Listeners(cfg) {
|
||||
g.Go(func() error {
|
||||
for _, addr := range ctrld.Rfc1918Addresses() {
|
||||
func() {
|
||||
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port))
|
||||
listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port))
|
||||
s, errCh := runDNSServer(listenAddr, proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-gctx.Done():
|
||||
case err := <-errCh:
|
||||
// RFC1918 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on system with systemd-resolved.
|
||||
p.Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr)
|
||||
}
|
||||
}()
|
||||
@@ -236,25 +143,183 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port))
|
||||
s, errCh := runDNSServer(addr, proto, handler)
|
||||
defer s.Shutdown()
|
||||
|
||||
p.started <- struct{}{}
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-gctx.Done():
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers.
|
||||
func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) {
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
|
||||
if len(m.Question) == 0 {
|
||||
sendDNSResponse(w, m, dns.RcodeFormatError)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID)
|
||||
ctx = ctrld.LoggerCtx(ctx, p.logger.Load())
|
||||
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, p.Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
sendDNSResponse(w, m, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
go p.detectLoop(m)
|
||||
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
|
||||
if p.handleSpecialDomains(ctx, w, m, domain) {
|
||||
return
|
||||
}
|
||||
p.processStandardQuery(&standardQueryRequest{
|
||||
ctx: ctx,
|
||||
writer: w,
|
||||
msg: m,
|
||||
listenerNum: listenerNum,
|
||||
listenerConfig: listenerConfig,
|
||||
domain: domain,
|
||||
})
|
||||
}
|
||||
|
||||
// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status.
|
||||
func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool {
|
||||
switch {
|
||||
case domain == "":
|
||||
sendDNSResponse(w, m, dns.RcodeFormatError)
|
||||
return true
|
||||
case domain == selfCheckInternalTestDomain:
|
||||
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
||||
_ = w.WriteMsg(answer)
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, p.Debug(), "received query %q, local cache is purged", domain)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// standardQueryRequest represents a standard DNS query request with associated context and configuration.
|
||||
type standardQueryRequest struct {
|
||||
ctx context.Context
|
||||
writer dns.ResponseWriter
|
||||
msg *dns.Msg
|
||||
listenerNum string
|
||||
listenerConfig *ctrld.ListenerConfig
|
||||
domain string
|
||||
}
|
||||
|
||||
// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response.
|
||||
func (p *prog) processStandardQuery(req *standardQueryRequest) {
|
||||
remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String())
|
||||
ci := p.getClientInfo(remoteIP, req.msg)
|
||||
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
||||
|
||||
stripClientSubnet(req.msg)
|
||||
remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String())
|
||||
|
||||
startTime := time.Now()
|
||||
q := req.msg.Question[0]
|
||||
ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain)
|
||||
|
||||
ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain)
|
||||
|
||||
var answer *dns.Msg
|
||||
// Handle restricted listener case
|
||||
if !ur.matched && req.listenerConfig.Restricted {
|
||||
ctrld.Log(req.ctx, p.Debug(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(req.msg, dns.RcodeRefused)
|
||||
// Process the refused query
|
||||
go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true})
|
||||
} else {
|
||||
// Process a normal query
|
||||
pr := p.proxy(req.ctx, &proxyRequest{
|
||||
msg: req.msg,
|
||||
ci: ci,
|
||||
failoverRcodes: p.getFailoverRcodes(req.listenerConfig),
|
||||
ufr: ur,
|
||||
})
|
||||
|
||||
rtt := time.Since(startTime)
|
||||
ctrld.Log(req.ctx, p.Debug(), "received response of %d bytes in %s", pr.answer.Len(), rtt)
|
||||
|
||||
go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr)
|
||||
answer = pr.answer
|
||||
}
|
||||
|
||||
if err := req.writer.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
}
|
||||
}
|
||||
|
||||
// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording,
|
||||
// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures.
|
||||
func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
||||
p.doSelfUninstall(pr)
|
||||
p.recordMetrics(ci, listenerConfig, q, pr)
|
||||
p.forceFetchingAPI(canonicalName(q.Name))
|
||||
}
|
||||
|
||||
// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists.
|
||||
func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int {
|
||||
if cfg.Policy != nil {
|
||||
return cfg.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics.
|
||||
func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) {
|
||||
upstream := pr.upstream
|
||||
switch {
|
||||
case pr.cached:
|
||||
upstream = "cache"
|
||||
case pr.clientInfo:
|
||||
upstream = "client_info_table"
|
||||
}
|
||||
labelValues := []string{
|
||||
net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)),
|
||||
ci.IP,
|
||||
ci.Mac,
|
||||
ci.Hostname,
|
||||
upstream,
|
||||
dns.TypeToString[q.Qtype],
|
||||
dns.RcodeToString[pr.answer.Rcode],
|
||||
}
|
||||
p.WithLabelValuesInc(statsQueriesCount, labelValues...)
|
||||
p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...)
|
||||
}
|
||||
|
||||
// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter.
|
||||
func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, rcode)
|
||||
_ = w.WriteMsg(answer)
|
||||
}
|
||||
|
||||
// upstreamFor returns the list of upstreams for resolving the given domain,
|
||||
// matching by policies defined in the listener config. The second return value
|
||||
// reports whether the domain matches the policy.
|
||||
@@ -947,8 +1012,9 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
||||
// - There is only 1 ControlD upstream in-use.
|
||||
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
||||
// - The cdUID is deleted.
|
||||
func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
||||
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
||||
func (p *prog) doSelfUninstall(pr *proxyResponse) {
|
||||
answer := pr.answer
|
||||
if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user