cmd/cli: add split route AD top level domain on Windows

The sub-domains are matched using wildcard domain rule, but this rule
won't match top level domain, causing requests are forwarded to ControlD
upstreams.

To fix this, add the split route for top level domain explicitly.
This commit is contained in:
Cuong Manh Le
2024-11-27 16:00:38 +07:00
committed by Cuong Manh Le
parent 6837176ec7
commit 8360bdc50a
2 changed files with 48 additions and 5 deletions

View File

@@ -24,19 +24,26 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
// Network rules are lowercase during toml config marshaling,
// lowercase the domain here too for consistency.
domain = strings.ToLower(domain)
domainRuleAdded := addSplitDnsRule(cfg, domain)
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
return domainRuleAdded || wildcardDomainRuleRuleAdded
}
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
// The return value indicates whether the split-rule was added or not.
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
for n, lc := range cfg.Listener {
if lc.Policy == nil {
lc.Policy = &ctrld.ListenerPolicyConfig{}
}
domainRule := "*." + strings.TrimPrefix(domain, ".")
for _, rule := range lc.Policy.Rules {
if _, ok := rule[domainRule]; ok {
mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n)
if _, ok := rule[domain]; ok {
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
return false
}
}
mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n)
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}})
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
}
return true
}

View File

@@ -4,6 +4,10 @@ import (
"fmt"
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/testhelper"
"github.com/stretchr/testify/assert"
)
func Test_getActiveDirectoryDomain(t *testing.T) {
@@ -34,3 +38,35 @@ func getActiveDirectoryDomainPowershell() (string, error) {
}
return string(output), nil
}
func Test_addSplitDnsRule(t *testing.T) {
newCfg := func(domains ...string) *ctrld.Config {
cfg := testhelper.SampleConfig(t)
lc := cfg.Listener["0"]
for _, domain := range domains {
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
}
return cfg
}
tests := []struct {
name string
cfg *ctrld.Config
domain string
added bool
}{
{"added", newCfg(), "example.com", true},
{"TLD existed", newCfg("example.com"), "*.example.com", true},
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
added := addSplitDnsRule(tc.cfg, tc.domain)
assert.Equal(t, tc.added, added)
})
}
}