From 3afdaef6e6bc7af490e20ea2e28c620f41dc8276 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 16 Sep 2025 18:33:05 +0700 Subject: [PATCH] refactor: extract rule matching logic into internal/rulematcher package Extract DNS policy rule matching logic from dns_proxy.go into a dedicated internal/rulematcher package to improve code organization and maintainability. The new package provides: - RuleMatcher interface for extensible rule matching - NetworkRuleMatcher for IP-based network rules - MacRuleMatcher for MAC address-based rules - DomainRuleMatcher for domain/wildcard rules - Comprehensive unit tests for all matchers This refactoring improves: - Separation of concerns between DNS proxy and rule matching - Testability with isolated rule matcher components - Reusability of rule matching logic across the codebase - Maintainability with focused, single-responsibility modules --- internal/rulematcher/domain.go | 36 ++++ internal/rulematcher/mac.go | 67 ++++++ internal/rulematcher/network.go | 43 ++++ internal/rulematcher/rulematcher_test.go | 248 +++++++++++++++++++++++ internal/rulematcher/types.go | 40 ++++ 5 files changed, 434 insertions(+) create mode 100644 internal/rulematcher/domain.go create mode 100644 internal/rulematcher/mac.go create mode 100644 internal/rulematcher/network.go create mode 100644 internal/rulematcher/rulematcher_test.go create mode 100644 internal/rulematcher/types.go diff --git a/internal/rulematcher/domain.go b/internal/rulematcher/domain.go new file mode 100644 index 0000000..72ee291 --- /dev/null +++ b/internal/rulematcher/domain.go @@ -0,0 +1,36 @@ +package rulematcher + +import ( + "context" +) + +// DomainRuleMatcher handles matching of domain-based rules +type DomainRuleMatcher struct{} + +// Type returns the rule type for domain matcher +func (d *DomainRuleMatcher) Type() RuleType { + return RuleTypeDomain +} + +// Match evaluates domain rules against the requested domain +func (d *DomainRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Rules) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeDomain} + } + + for _, rule := range req.Policy.Rules { + // There's only one entry per rule, config validation ensures this. + for source, targets := range rule { + if source == req.Domain || wildcardMatches(source, req.Domain) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, + RuleType: RuleTypeDomain, + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeDomain} +} diff --git a/internal/rulematcher/mac.go b/internal/rulematcher/mac.go new file mode 100644 index 0000000..d0b1412 --- /dev/null +++ b/internal/rulematcher/mac.go @@ -0,0 +1,67 @@ +package rulematcher + +import ( + "context" + "strings" +) + +// MacRuleMatcher handles matching of MAC address-based rules +type MacRuleMatcher struct{} + +// Type returns the rule type for MAC matcher +func (m *MacRuleMatcher) Type() RuleType { + return RuleTypeMac +} + +// Match evaluates MAC address rules against the source MAC address +func (m *MacRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Macs) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeMac} + } + + for _, rule := range req.Policy.Macs { + for source, targets := range rule { + if source != "" && (strings.EqualFold(source, req.SourceMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(req.SourceMac))) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, // Return the original source from the rule + RuleType: RuleTypeMac, + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeMac} +} + +// wildcardMatches checks if a wildcard pattern matches a string +// This is copied from the original implementation to maintain compatibility +func wildcardMatches(wildcard, str string) bool { + if wildcard == "" { + return false + } + if wildcard == "*" { + return true + } + if !strings.Contains(wildcard, "*") { + return wildcard == str + } + + parts := strings.Split(wildcard, "*") + if len(parts) != 2 { + return false + } + + prefix := parts[0] + suffix := parts[1] + + if prefix != "" && !strings.HasPrefix(str, prefix) { + return false + } + if suffix != "" && !strings.HasSuffix(str, suffix) { + return false + } + + return true +} diff --git a/internal/rulematcher/network.go b/internal/rulematcher/network.go new file mode 100644 index 0000000..8114fe1 --- /dev/null +++ b/internal/rulematcher/network.go @@ -0,0 +1,43 @@ +package rulematcher + +import ( + "context" + "strings" +) + +// NetworkRuleMatcher handles matching of network-based rules +type NetworkRuleMatcher struct{} + +// Type returns the rule type for network matcher +func (n *NetworkRuleMatcher) Type() RuleType { + return RuleTypeNetwork +} + +// Match evaluates network rules against the source IP address +func (n *NetworkRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Networks) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeNetwork} + } + + for _, rule := range req.Policy.Networks { + for source, targets := range rule { + networkNum := strings.TrimPrefix(source, "network.") + nc := req.Config.Network[networkNum] + if nc == nil { + continue + } + for _, ipNet := range nc.IPNets { + if ipNet.Contains(req.SourceIP) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, + RuleType: RuleTypeNetwork, + } + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeNetwork} +} diff --git a/internal/rulematcher/rulematcher_test.go b/internal/rulematcher/rulematcher_test.go new file mode 100644 index 0000000..d4eb235 --- /dev/null +++ b/internal/rulematcher/rulematcher_test.go @@ -0,0 +1,248 @@ +package rulematcher + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" +) + +// Test NetworkRuleMatcher +func TestNetworkRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + // Convert Cidrs to IPNets like in the original test + for _, nc := range cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + matcher := &NetworkRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + { + name: "No network rules", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + { + name: "Match network rule", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1", "upstream.0"}, + MatchedRule: "network.0", + RuleType: RuleTypeNetwork, + }, + }, + { + name: "No match for IP", + request: &MatchRequest{ + SourceIP: net.ParseIP("10.0.0.1"), + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} + +// Test MacRuleMatcher +func TestMacRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + matcher := &MacRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + { + name: "No MAC rules", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + { + name: "Match MAC rule - exact", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.2"}, + MatchedRule: "14:45:a0:67:83:0a", // Config loading normalizes MAC addresses to lowercase + RuleType: RuleTypeMac, + }, + }, + { + name: "Match MAC rule - case insensitive", + request: &MatchRequest{ + SourceMac: "14:54:4a:8e:08:2d", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.2"}, + MatchedRule: "14:54:4a:8e:08:2d", + RuleType: RuleTypeMac, + }, + }, + { + name: "No match for MAC", + request: &MatchRequest{ + SourceMac: "00:11:22:33:44:55", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} + +// Test DomainRuleMatcher +func TestDomainRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + matcher := &DomainRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + Domain: "example.com", + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + { + name: "No domain rules", + request: &MatchRequest{ + Domain: "example.com", + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + { + name: "Match domain rule - exact", + request: &MatchRequest{ + Domain: "example.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1"}, + MatchedRule: "*.ru", + RuleType: RuleTypeDomain, + }, + }, + { + name: "Match domain rule - wildcard", + request: &MatchRequest{ + Domain: "test.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1"}, + MatchedRule: "*.ru", + RuleType: RuleTypeDomain, + }, + }, + { + name: "No match for domain", + request: &MatchRequest{ + Domain: "example.com", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} diff --git a/internal/rulematcher/types.go b/internal/rulematcher/types.go new file mode 100644 index 0000000..c3499e4 --- /dev/null +++ b/internal/rulematcher/types.go @@ -0,0 +1,40 @@ +package rulematcher + +import ( + "context" + "net" + + "github.com/Control-D-Inc/ctrld" +) + +// RuleType represents the type of rule being matched +type RuleType string + +const ( + RuleTypeNetwork RuleType = "network" + RuleTypeMac RuleType = "mac" + RuleTypeDomain RuleType = "domain" +) + +// RuleMatcher defines the interface for matching different types of rules +type RuleMatcher interface { + Match(ctx context.Context, request *MatchRequest) *MatchResult + Type() RuleType +} + +// MatchRequest contains all the information needed for rule matching +type MatchRequest struct { + SourceIP net.IP + SourceMac string + Domain string + Policy *ctrld.ListenerPolicyConfig + Config *ctrld.Config +} + +// MatchResult represents the result of a rule matching operation +type MatchResult struct { + Matched bool + Targets []string + MatchedRule string + RuleType RuleType +}