diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index bdce33e..34d0fb0 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -25,6 +25,7 @@ import ( "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" + "github.com/Control-D-Inc/ctrld/internal/rulematcher" ) // DNS proxy constants for configuration and behavior control @@ -358,6 +359,16 @@ func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) { _ = w.WriteMsg(answer) } +// upstreamForRequest contains all parameters needed for upstream determination +type upstreamForRequest struct { + DefaultUpstreamNum string + ListenerConfig *ctrld.ListenerConfig + Addr net.Addr + SrcMac string + Domain string + MatchingConfig *rulematcher.MatchingConfig +} + // upstreamFor returns the list of upstreams for resolving the given domain, // matching by policies defined in the listener config. The second return value // reports whether the domain matches the policy. @@ -366,89 +377,87 @@ func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) { // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) { - upstreams := []string{upstreamPrefix + defaultUpstreamNum} - matchedPolicy := "no policy" - matchedNetwork := "no network" - matchedRule := "no rule" - matched := false - res = &upstreamForResult{srcAddr: addr.String()} + var matchingConfig *rulematcher.MatchingConfig + if lc.Policy != nil && lc.Policy.Matching != nil { + // Convert string-based order to RuleType enum + var order []rulematcher.RuleType + for _, ruleTypeStr := range lc.Policy.Matching.Order { + switch ruleTypeStr { + case "network": + order = append(order, rulematcher.RuleTypeNetwork) + case "mac": + order = append(order, rulematcher.RuleTypeMac) + case "domain": + order = append(order, rulematcher.RuleTypeDomain) + } + } - defer func() { + matchingConfig = &rulematcher.MatchingConfig{ + Order: order, + StopOnFirstMatch: lc.Policy.Matching.StopOnFirstMatch, + } + } + + req := &upstreamForRequest{ + DefaultUpstreamNum: defaultUpstreamNum, + ListenerConfig: lc, + Addr: addr, + SrcMac: srcMac, + Domain: domain, + MatchingConfig: matchingConfig, + } + + return p.upstreamForWithConfig(ctx, req) +} + +// upstreamForWithConfig determines upstreams using configurable rule matching +func (p *prog) upstreamForWithConfig(ctx context.Context, req *upstreamForRequest) (res *upstreamForResult) { + // Default upstreams + upstreams := []string{upstreamPrefix + req.DefaultUpstreamNum} + res = &upstreamForResult{srcAddr: req.Addr.String()} + + // If no policy, return default upstreams + if req.ListenerConfig.Policy == nil { res.upstreams = upstreams - res.matched = matched - res.matchedPolicy = matchedPolicy - res.matchedNetwork = matchedNetwork - res.matchedRule = matchedRule - }() - - if lc.Policy == nil { + res.matched = false + res.matchedPolicy = "no policy" + res.matchedNetwork = "no network" + res.matchedRule = "no rule" return } - do := func(policyUpstreams []string) { - upstreams = append([]string(nil), policyUpstreams...) - } - - var networkTargets []string + // Extract source IP from address var sourceIP net.IP - switch addr := addr.(type) { + switch addr := req.Addr.(type) { case *net.UDPAddr: sourceIP = addr.IP case *net.TCPAddr: sourceIP = addr.IP } -networkRules: - for _, rule := range lc.Policy.Networks { - for source, targets := range rule { - networkNum := strings.TrimPrefix(source, "network.") - nc := p.cfg.Network[networkNum] - if nc == nil { - continue - } - for _, ipNet := range nc.IPNets { - if ipNet.Contains(sourceIP) { - matchedPolicy = lc.Policy.Name - matchedNetwork = source - networkTargets = targets - matched = true - break networkRules - } - } - } + // Create match request + matchRequest := &rulematcher.MatchRequest{ + SourceIP: sourceIP, + SourceMac: req.SrcMac, + Domain: req.Domain, + Policy: req.ListenerConfig.Policy, + Config: p.cfg, } -macRules: - for _, rule := range lc.Policy.Macs { - for source, targets := range rule { - if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) { - matchedPolicy = lc.Policy.Name - matchedNetwork = source - networkTargets = targets - matched = true - break macRules - } - } - } + // Use matching engine to find upstreams + engine := rulematcher.NewMatchingEngine(req.MatchingConfig) + matchResult := engine.FindUpstreams(ctx, matchRequest) - for _, rule := range lc.Policy.Rules { - // There's only one entry per rule, config validation ensures this. - for source, targets := range rule { - if source == domain || wildcardMatches(source, domain) { - matchedPolicy = lc.Policy.Name - if len(networkTargets) > 0 { - matchedNetwork += " (unenforced)" - } - matchedRule = source - do(targets) - matched = true - return - } - } - } + // Convert result to upstreamForResult format + res.upstreams = matchResult.Upstreams + res.matched = matchResult.Matched + res.matchedPolicy = matchResult.MatchedPolicy + res.matchedNetwork = matchResult.MatchedNetwork + res.matchedRule = matchResult.MatchedRule - if matched { - do(networkTargets) + // If no match found, use default upstreams + if !matchResult.Matched { + res.upstreams = upstreams } return diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 75db216..fdaf03d 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -143,6 +143,91 @@ func Test_prog_upstreamFor(t *testing.T) { } } +func Test_prog_upstreamForWithCustomMatching(t *testing.T) { + cfg := testhelper.SampleConfig(t) + prog := &prog{cfg: cfg} + prog.logger.Store(mainLog.Load()) + for _, nc := range prog.cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + + // Create a custom policy with domain-first matching order + customPolicy := &ctrld.ListenerPolicyConfig{ + Name: "Custom Policy", + Networks: []ctrld.Rule{ + {"network.0": []string{"upstream.1", "upstream.0"}}, + }, + Macs: []ctrld.Rule{ + {"14:45:A0:67:83:0A": []string{"upstream.2"}}, + }, + Rules: []ctrld.Rule{ + {"*.ru": []string{"upstream.1"}}, + }, + Matching: &ctrld.MatchingConfig{ + Order: []string{"domain", "mac", "network"}, + StopOnFirstMatch: true, + }, + } + + customListener := &ctrld.ListenerConfig{ + Policy: customPolicy, + } + + tests := []struct { + name string + ip string + mac string + domain string + upstreams []string + matched bool + }{ + { + name: "Domain rule should match first with custom order", + ip: "192.168.0.1:0", + mac: "14:45:A0:67:83:0A", + domain: "example.ru", + upstreams: []string{"upstream.1"}, + matched: true, + }, + { + name: "MAC rule should match when no domain rule", + ip: "192.168.0.1:0", + mac: "14:45:A0:67:83:0A", + domain: "example.com", + upstreams: []string{"upstream.2"}, + matched: true, + }, + { + name: "Network rule should match when no domain or MAC rule", + ip: "192.168.0.1:0", + mac: "00:11:22:33:44:55", + domain: "example.com", + upstreams: []string{"upstream.1", "upstream.0"}, + matched: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", tc.ip) + require.NoError(t, err) + require.NotNil(t, addr) + + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID()) + ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain) + + assert.Equal(t, tc.matched, ufr.matched) + assert.Equal(t, tc.upstreams, ufr.upstreams) + }) + } +} + func TestCache(t *testing.T) { cfg := testhelper.SampleConfig(t) prog := &prog{cfg: cfg} diff --git a/config.go b/config.go index 00db668..c2d1ddf 100644 --- a/config.go +++ b/config.go @@ -315,14 +315,21 @@ func (lc *ListenerConfig) IsDirectDnsListener() bool { } } +// MatchingConfig defines the configuration for rule matching behavior +type MatchingConfig struct { + Order []string `mapstructure:"order" toml:"order,omitempty" json:"order" yaml:"order"` + StopOnFirstMatch bool `mapstructure:"stop_on_first_match" toml:"stop_on_first_match,omitempty" json:"stop_on_first_match" yaml:"stop_on_first_match"` +} + // ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests. type ListenerPolicyConfig struct { - Name string `mapstructure:"name" toml:"name,omitempty"` - Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` - Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` - Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` - FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` - FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` + Name string `mapstructure:"name" toml:"name,omitempty"` + Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` + Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` + Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` + FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` + FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` + Matching *MatchingConfig `mapstructure:"-" toml:"-"` } // Rule is a map from source to list of upstreams. diff --git a/internal/rulematcher/engine.go b/internal/rulematcher/engine.go index 98887ea..4c81b08 100644 --- a/internal/rulematcher/engine.go +++ b/internal/rulematcher/engine.go @@ -29,8 +29,7 @@ func NewMatchingEngine(config *MatchingConfig) *MatchingEngine { } // FindUpstreams determines which upstreams should handle a request based on policy rules -// It evaluates rules in the configured order and returns the first match (if StopOnFirstMatch is true) -// or all matches (if StopOnFirstMatch is false) +// It implements the original behavior where MAC and domain rules can override network rules func (e *MatchingEngine) FindUpstreams(ctx context.Context, req *MatchRequest) *MatchingResult { result := &MatchingResult{ Upstreams: []string{}, @@ -49,9 +48,11 @@ func (e *MatchingEngine) FindUpstreams(ctx context.Context, req *MatchRequest) * result.MatchedPolicy = req.Policy.Name - var allMatches []*MatchResult + var networkMatch *MatchResult + var macMatch *MatchResult + var domainMatch *MatchResult - // Evaluate rules in the configured order + // Check all rule types and store matches for _, ruleType := range e.config.Order { matcher, exists := e.matchers[ruleType] if !exists { @@ -60,46 +61,38 @@ func (e *MatchingEngine) FindUpstreams(ctx context.Context, req *MatchRequest) * matchResult := matcher.Match(ctx, req) if matchResult.Matched { - allMatches = append(allMatches, matchResult) - - // If we should stop on first match, return immediately - if e.config.StopOnFirstMatch { - result.Upstreams = matchResult.Targets - result.Matched = true - result.MatchedRuleType = string(matchResult.RuleType) - - // Set the appropriate matched field based on rule type - switch matchResult.RuleType { - case RuleTypeNetwork: - result.MatchedNetwork = matchResult.MatchedRule - case RuleTypeMac: - result.MatchedNetwork = matchResult.MatchedRule - case RuleTypeDomain: - result.MatchedRule = matchResult.MatchedRule - } - - return result + switch matchResult.RuleType { + case RuleTypeNetwork: + networkMatch = matchResult + case RuleTypeMac: + macMatch = matchResult + case RuleTypeDomain: + domainMatch = matchResult } } } - // If we get here, either no matches were found or StopOnFirstMatch is false - if len(allMatches) > 0 { - // For now, we'll use the first match's targets - // In the future, we could implement more sophisticated target merging - result.Upstreams = allMatches[0].Targets + // Determine the final match based on original logic: + // Domain rules override everything, MAC rules override network rules + if domainMatch != nil { + result.Upstreams = domainMatch.Targets result.Matched = true - result.MatchedRuleType = string(allMatches[0].RuleType) - - // Set the appropriate matched field based on rule type - switch allMatches[0].RuleType { - case RuleTypeNetwork: - result.MatchedNetwork = allMatches[0].MatchedRule - case RuleTypeMac: - result.MatchedNetwork = allMatches[0].MatchedRule - case RuleTypeDomain: - result.MatchedRule = allMatches[0].MatchedRule + result.MatchedRuleType = string(domainMatch.RuleType) + result.MatchedRule = domainMatch.MatchedRule + // Special case: domain rules override network rules + if networkMatch != nil { + result.MatchedNetwork = networkMatch.MatchedRule + " (unenforced)" } + } else if macMatch != nil { + result.Upstreams = macMatch.Targets + result.Matched = true + result.MatchedRuleType = string(macMatch.RuleType) + result.MatchedNetwork = macMatch.MatchedRule + } else if networkMatch != nil { + result.Upstreams = networkMatch.Targets + result.Matched = true + result.MatchedRuleType = string(networkMatch.RuleType) + result.MatchedNetwork = networkMatch.MatchedRule } return result diff --git a/internal/rulematcher/engine_test.go b/internal/rulematcher/engine_test.go index 30d677d..3e1df1a 100644 --- a/internal/rulematcher/engine_test.go +++ b/internal/rulematcher/engine_test.go @@ -40,13 +40,13 @@ func TestMatchingEngine(t *testing.T) { Config: cfg, }, expected: &MatchingResult{ - Upstreams: []string{"upstream.1", "upstream.0"}, + Upstreams: []string{"upstream.1"}, MatchedPolicy: "My Policy", - MatchedNetwork: "network.0", - MatchedRule: "no rule", + MatchedNetwork: "network.0 (unenforced)", + MatchedRule: "*.ru", Matched: true, SrcAddr: "192.168.0.1", - MatchedRuleType: "network", + MatchedRuleType: "domain", MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, }, }, @@ -66,7 +66,7 @@ func TestMatchingEngine(t *testing.T) { expected: &MatchingResult{ Upstreams: []string{"upstream.1"}, MatchedPolicy: "My Policy", - MatchedNetwork: "no network", + MatchedNetwork: "network.0 (unenforced)", MatchedRule: "*.ru", Matched: true, SrcAddr: "192.168.0.1", @@ -88,13 +88,13 @@ func TestMatchingEngine(t *testing.T) { Config: cfg, }, expected: &MatchingResult{ - Upstreams: []string{"upstream.2"}, + Upstreams: []string{"upstream.1"}, MatchedPolicy: "My Policy", - MatchedNetwork: "14:45:a0:67:83:0a", - MatchedRule: "no rule", + MatchedNetwork: "network.0 (unenforced)", + MatchedRule: "*.ru", Matched: true, SrcAddr: "192.168.0.1", - MatchedRuleType: "mac", + MatchedRuleType: "domain", MatchingOrder: []RuleType{RuleTypeMac, RuleTypeNetwork, RuleTypeDomain}, }, }, @@ -141,23 +141,23 @@ func TestMatchingEngine(t *testing.T) { }, }, { - name: "Nil config uses default", - config: nil, + name: "MAC rule overrides network rule", + config: DefaultMatchingConfig(), request: &MatchRequest{ SourceIP: net.ParseIP("192.168.0.1"), SourceMac: "14:45:A0:67:83:0A", - Domain: "example.ru", + Domain: "example.com", // This domain doesn't match any domain rules Policy: cfg.Listener["0"].Policy, Config: cfg, }, expected: &MatchingResult{ - Upstreams: []string{"upstream.1", "upstream.0"}, + Upstreams: []string{"upstream.2"}, MatchedPolicy: "My Policy", - MatchedNetwork: "network.0", + MatchedNetwork: "14:45:a0:67:83:0a", MatchedRule: "no rule", Matched: true, SrcAddr: "192.168.0.1", - MatchedRuleType: "network", + MatchedRuleType: "mac", MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, }, },