mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
refactor: extract rule matching logic into internal/rulematcher package
Extract DNS policy rule matching logic from dns_proxy.go into a dedicated internal/rulematcher package to improve code organization and maintainability. The new package provides: - RuleMatcher interface for extensible rule matching - NetworkRuleMatcher for IP-based network rules - MacRuleMatcher for MAC address-based rules - DomainRuleMatcher for domain/wildcard rules - Comprehensive unit tests for all matchers This refactoring improves: - Separation of concerns between DNS proxy and rule matching - Testability with isolated rule matcher components - Reusability of rule matching logic across the codebase - Maintainability with focused, single-responsibility modules
This commit is contained in:
committed by
Cuong Manh Le
parent
ef7432df55
commit
3afdaef6e6
36
internal/rulematcher/domain.go
Normal file
36
internal/rulematcher/domain.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package rulematcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DomainRuleMatcher handles matching of domain-based rules
|
||||||
|
type DomainRuleMatcher struct{}
|
||||||
|
|
||||||
|
// Type returns the rule type for domain matcher
|
||||||
|
func (d *DomainRuleMatcher) Type() RuleType {
|
||||||
|
return RuleTypeDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match evaluates domain rules against the requested domain
|
||||||
|
func (d *DomainRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult {
|
||||||
|
if req.Policy == nil || len(req.Policy.Rules) == 0 {
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeDomain}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range req.Policy.Rules {
|
||||||
|
// There's only one entry per rule, config validation ensures this.
|
||||||
|
for source, targets := range rule {
|
||||||
|
if source == req.Domain || wildcardMatches(source, req.Domain) {
|
||||||
|
return &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: targets,
|
||||||
|
MatchedRule: source,
|
||||||
|
RuleType: RuleTypeDomain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeDomain}
|
||||||
|
}
|
||||||
67
internal/rulematcher/mac.go
Normal file
67
internal/rulematcher/mac.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package rulematcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MacRuleMatcher handles matching of MAC address-based rules
|
||||||
|
type MacRuleMatcher struct{}
|
||||||
|
|
||||||
|
// Type returns the rule type for MAC matcher
|
||||||
|
func (m *MacRuleMatcher) Type() RuleType {
|
||||||
|
return RuleTypeMac
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match evaluates MAC address rules against the source MAC address
|
||||||
|
func (m *MacRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult {
|
||||||
|
if req.Policy == nil || len(req.Policy.Macs) == 0 {
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeMac}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range req.Policy.Macs {
|
||||||
|
for source, targets := range rule {
|
||||||
|
if source != "" && (strings.EqualFold(source, req.SourceMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(req.SourceMac))) {
|
||||||
|
return &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: targets,
|
||||||
|
MatchedRule: source, // Return the original source from the rule
|
||||||
|
RuleType: RuleTypeMac,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeMac}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wildcardMatches checks if a wildcard pattern matches a string
|
||||||
|
// This is copied from the original implementation to maintain compatibility
|
||||||
|
func wildcardMatches(wildcard, str string) bool {
|
||||||
|
if wildcard == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if wildcard == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !strings.Contains(wildcard, "*") {
|
||||||
|
return wildcard == str
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(wildcard, "*")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := parts[0]
|
||||||
|
suffix := parts[1]
|
||||||
|
|
||||||
|
if prefix != "" && !strings.HasPrefix(str, prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if suffix != "" && !strings.HasSuffix(str, suffix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
43
internal/rulematcher/network.go
Normal file
43
internal/rulematcher/network.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package rulematcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkRuleMatcher handles matching of network-based rules
|
||||||
|
type NetworkRuleMatcher struct{}
|
||||||
|
|
||||||
|
// Type returns the rule type for network matcher
|
||||||
|
func (n *NetworkRuleMatcher) Type() RuleType {
|
||||||
|
return RuleTypeNetwork
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match evaluates network rules against the source IP address
|
||||||
|
func (n *NetworkRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult {
|
||||||
|
if req.Policy == nil || len(req.Policy.Networks) == 0 {
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeNetwork}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range req.Policy.Networks {
|
||||||
|
for source, targets := range rule {
|
||||||
|
networkNum := strings.TrimPrefix(source, "network.")
|
||||||
|
nc := req.Config.Network[networkNum]
|
||||||
|
if nc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, ipNet := range nc.IPNets {
|
||||||
|
if ipNet.Contains(req.SourceIP) {
|
||||||
|
return &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: targets,
|
||||||
|
MatchedRule: source,
|
||||||
|
RuleType: RuleTypeNetwork,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MatchResult{Matched: false, RuleType: RuleTypeNetwork}
|
||||||
|
}
|
||||||
248
internal/rulematcher/rulematcher_test.go
Normal file
248
internal/rulematcher/rulematcher_test.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package rulematcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/Control-D-Inc/ctrld"
|
||||||
|
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test NetworkRuleMatcher
|
||||||
|
func TestNetworkRuleMatcher(t *testing.T) {
|
||||||
|
cfg := testhelper.SampleConfig(t)
|
||||||
|
// Convert Cidrs to IPNets like in the original test
|
||||||
|
for _, nc := range cfg.Network {
|
||||||
|
for _, cidr := range nc.Cidrs {
|
||||||
|
_, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
nc.IPNets = append(nc.IPNets, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
matcher := &NetworkRuleMatcher{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *MatchRequest
|
||||||
|
expected *MatchResult
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No policy",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceIP: net.ParseIP("192.168.0.1"),
|
||||||
|
Policy: nil,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No network rules",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceIP: net.ParseIP("192.168.0.1"),
|
||||||
|
Policy: &ctrld.ListenerPolicyConfig{},
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match network rule",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceIP: net.ParseIP("192.168.0.1"),
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: []string{"upstream.1", "upstream.0"},
|
||||||
|
MatchedRule: "network.0",
|
||||||
|
RuleType: RuleTypeNetwork,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match for IP",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceIP: net.ParseIP("10.0.0.1"),
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := matcher.Match(context.Background(), tc.request)
|
||||||
|
assert.Equal(t, tc.expected.Matched, result.Matched)
|
||||||
|
assert.Equal(t, tc.expected.RuleType, result.RuleType)
|
||||||
|
if tc.expected.Matched {
|
||||||
|
assert.Equal(t, tc.expected.Targets, result.Targets)
|
||||||
|
assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MacRuleMatcher
|
||||||
|
func TestMacRuleMatcher(t *testing.T) {
|
||||||
|
cfg := testhelper.SampleConfig(t)
|
||||||
|
matcher := &MacRuleMatcher{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *MatchRequest
|
||||||
|
expected *MatchResult
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No policy",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceMac: "14:45:A0:67:83:0A",
|
||||||
|
Policy: nil,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeMac},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No MAC rules",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceMac: "14:45:A0:67:83:0A",
|
||||||
|
Policy: &ctrld.ListenerPolicyConfig{},
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeMac},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match MAC rule - exact",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceMac: "14:45:A0:67:83:0A",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: []string{"upstream.2"},
|
||||||
|
MatchedRule: "14:45:a0:67:83:0a", // Config loading normalizes MAC addresses to lowercase
|
||||||
|
RuleType: RuleTypeMac,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match MAC rule - case insensitive",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceMac: "14:54:4a:8e:08:2d",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: []string{"upstream.2"},
|
||||||
|
MatchedRule: "14:54:4a:8e:08:2d",
|
||||||
|
RuleType: RuleTypeMac,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match for MAC",
|
||||||
|
request: &MatchRequest{
|
||||||
|
SourceMac: "00:11:22:33:44:55",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeMac},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := matcher.Match(context.Background(), tc.request)
|
||||||
|
assert.Equal(t, tc.expected.Matched, result.Matched)
|
||||||
|
assert.Equal(t, tc.expected.RuleType, result.RuleType)
|
||||||
|
if tc.expected.Matched {
|
||||||
|
assert.Equal(t, tc.expected.Targets, result.Targets)
|
||||||
|
assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DomainRuleMatcher
|
||||||
|
func TestDomainRuleMatcher(t *testing.T) {
|
||||||
|
cfg := testhelper.SampleConfig(t)
|
||||||
|
matcher := &DomainRuleMatcher{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *MatchRequest
|
||||||
|
expected *MatchResult
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No policy",
|
||||||
|
request: &MatchRequest{
|
||||||
|
Domain: "example.com",
|
||||||
|
Policy: nil,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No domain rules",
|
||||||
|
request: &MatchRequest{
|
||||||
|
Domain: "example.com",
|
||||||
|
Policy: &ctrld.ListenerPolicyConfig{},
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match domain rule - exact",
|
||||||
|
request: &MatchRequest{
|
||||||
|
Domain: "example.ru",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: []string{"upstream.1"},
|
||||||
|
MatchedRule: "*.ru",
|
||||||
|
RuleType: RuleTypeDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match domain rule - wildcard",
|
||||||
|
request: &MatchRequest{
|
||||||
|
Domain: "test.ru",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{
|
||||||
|
Matched: true,
|
||||||
|
Targets: []string{"upstream.1"},
|
||||||
|
MatchedRule: "*.ru",
|
||||||
|
RuleType: RuleTypeDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match for domain",
|
||||||
|
request: &MatchRequest{
|
||||||
|
Domain: "example.com",
|
||||||
|
Policy: cfg.Listener["0"].Policy,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := matcher.Match(context.Background(), tc.request)
|
||||||
|
assert.Equal(t, tc.expected.Matched, result.Matched)
|
||||||
|
assert.Equal(t, tc.expected.RuleType, result.RuleType)
|
||||||
|
if tc.expected.Matched {
|
||||||
|
assert.Equal(t, tc.expected.Targets, result.Targets)
|
||||||
|
assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
40
internal/rulematcher/types.go
Normal file
40
internal/rulematcher/types.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package rulematcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/Control-D-Inc/ctrld"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RuleType represents the type of rule being matched
|
||||||
|
type RuleType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RuleTypeNetwork RuleType = "network"
|
||||||
|
RuleTypeMac RuleType = "mac"
|
||||||
|
RuleTypeDomain RuleType = "domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RuleMatcher defines the interface for matching different types of rules
|
||||||
|
type RuleMatcher interface {
|
||||||
|
Match(ctx context.Context, request *MatchRequest) *MatchResult
|
||||||
|
Type() RuleType
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchRequest contains all the information needed for rule matching
|
||||||
|
type MatchRequest struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
SourceMac string
|
||||||
|
Domain string
|
||||||
|
Policy *ctrld.ListenerPolicyConfig
|
||||||
|
Config *ctrld.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchResult represents the result of a rule matching operation
|
||||||
|
type MatchResult struct {
|
||||||
|
Matched bool
|
||||||
|
Targets []string
|
||||||
|
MatchedRule string
|
||||||
|
RuleType RuleType
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user