diff --git a/cmd/cli/ad_others.go b/cmd/cli/ad_others.go new file mode 100644 index 0000000..1249033 --- /dev/null +++ b/cmd/cli/ad_others.go @@ -0,0 +1,10 @@ +//go:build !windows + +package cli + +import ( + "github.com/Control-D-Inc/ctrld" +) + +// addExtraSplitDnsRule adds split DNS rule if present. +func addExtraSplitDnsRule(_ *ctrld.ListenerConfig) {} diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go new file mode 100644 index 0000000..ef9e2fc --- /dev/null +++ b/cmd/cli/ad_windows.go @@ -0,0 +1,43 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/Control-D-Inc/ctrld" +) + +// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory. +func addExtraSplitDnsRule(lc *ctrld.ListenerConfig) { + if lc.Policy == nil { + lc.Policy = &ctrld.ListenerPolicyConfig{} + } + domain, err := getActiveDirectoryDomain() + if err != nil { + mainLog.Load().Debug().Msgf("unable to get active directory domain: %v", err) + return + } + if domain == "" { + mainLog.Load().Debug().Msg("no active directory domain found") + return + } + domainRule := "*." + strings.TrimPrefix(domain, ".") + for _, rule := range lc.Policy.Rules { + if _, ok := rule[domainRule]; ok { + mainLog.Load().Debug().Msg("domain rule already exist") + return + } + } + mainLog.Load().Debug().Msg("adding active directory domain") + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}}) +} + +// getActiveDirectoryDomain returns AD domain name of this computer. +func getActiveDirectoryDomain() (string, error) { + cmd := "$obj = GetWmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" + output, err := powershell(cmd) + if err != nil { + return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + } + return string(output), nil +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index bfe32e0..34f050d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -435,6 +435,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() + addExtraSplitDnsRule(p.cfg.Listener[listenerNum]) if !reload { go func(listenerNum string) { listenerConfig := p.cfg.Listener[listenerNum]