From 8360bdc50ada11c9df1456d11bbf330af87c160d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 27 Nov 2024 16:00:38 +0700 Subject: [PATCH] cmd/cli: add split route AD top level domain on Windows The sub-domains are matched using wildcard domain rule, but this rule won't match top level domain, causing requests are forwarded to ControlD upstreams. To fix this, add the split route for top level domain explicitly. --- cmd/cli/ad_windows.go | 17 ++++++++++++----- cmd/cli/ad_windows_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 316414d..475ba09 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -24,19 +24,26 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // Network rules are lowercase during toml config marshaling, // lowercase the domain here too for consistency. domain = strings.ToLower(domain) + domainRuleAdded := addSplitDnsRule(cfg, domain) + wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, ".")) + return domainRuleAdded || wildcardDomainRuleRuleAdded +} + +// addSplitDnsRule adds split-rule for given domain if there's no existed rule. +// The return value indicates whether the split-rule was added or not. +func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { for n, lc := range cfg.Listener { if lc.Policy == nil { lc.Policy = &ctrld.ListenerPolicyConfig{} } - domainRule := "*." + strings.TrimPrefix(domain, ".") for _, rule := range lc.Policy.Rules { - if _, ok := rule[domainRule]; ok { - mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n) + if _, ok := rule[domain]; ok { + mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n) return false } } - mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n) - lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}}) + mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n) + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) } return true } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go index 392abbd..6fe7f41 100644 --- a/cmd/cli/ad_windows_test.go +++ b/cmd/cli/ad_windows_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" + "github.com/stretchr/testify/assert" ) func Test_getActiveDirectoryDomain(t *testing.T) { @@ -34,3 +38,35 @@ func getActiveDirectoryDomainPowershell() (string, error) { } return string(output), nil } + +func Test_addSplitDnsRule(t *testing.T) { + newCfg := func(domains ...string) *ctrld.Config { + cfg := testhelper.SampleConfig(t) + lc := cfg.Listener["0"] + for _, domain := range domains { + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) + } + return cfg + } + tests := []struct { + name string + cfg *ctrld.Config + domain string + added bool + }{ + {"added", newCfg(), "example.com", true}, + {"TLD existed", newCfg("example.com"), "*.example.com", true}, + {"wildcard existed", newCfg("*.example.com"), "example.com", true}, + {"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false}, + {"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + added := addSplitDnsRule(tc.cfg, tc.domain) + assert.Equal(t, tc.added, added) + }) + } +}