diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 316414d..475ba09 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -24,19 +24,26 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // Network rules are lowercase during toml config marshaling, // lowercase the domain here too for consistency. domain = strings.ToLower(domain) + domainRuleAdded := addSplitDnsRule(cfg, domain) + wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, ".")) + return domainRuleAdded || wildcardDomainRuleRuleAdded +} + +// addSplitDnsRule adds split-rule for given domain if there's no existed rule. +// The return value indicates whether the split-rule was added or not. +func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { for n, lc := range cfg.Listener { if lc.Policy == nil { lc.Policy = &ctrld.ListenerPolicyConfig{} } - domainRule := "*." + strings.TrimPrefix(domain, ".") for _, rule := range lc.Policy.Rules { - if _, ok := rule[domainRule]; ok { - mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n) + if _, ok := rule[domain]; ok { + mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n) return false } } - mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n) - lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}}) + mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n) + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) } return true } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go index 392abbd..6fe7f41 100644 --- a/cmd/cli/ad_windows_test.go +++ b/cmd/cli/ad_windows_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" + "github.com/stretchr/testify/assert" ) func Test_getActiveDirectoryDomain(t *testing.T) { @@ -34,3 +38,35 @@ func getActiveDirectoryDomainPowershell() (string, error) { } return string(output), nil } + +func Test_addSplitDnsRule(t *testing.T) { + newCfg := func(domains ...string) *ctrld.Config { + cfg := testhelper.SampleConfig(t) + lc := cfg.Listener["0"] + for _, domain := range domains { + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) + } + return cfg + } + tests := []struct { + name string + cfg *ctrld.Config + domain string + added bool + }{ + {"added", newCfg(), "example.com", true}, + {"TLD existed", newCfg("example.com"), "*.example.com", true}, + {"wildcard existed", newCfg("*.example.com"), "example.com", true}, + {"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false}, + {"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + added := addSplitDnsRule(tc.cfg, tc.domain) + assert.Equal(t, tc.added, added) + }) + } +}