diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 694131d..e3dbc26 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -435,14 +435,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} - } - - if p.isAdDomainQuery(req.msg) { - ctrld.Log(ctx, mainLog.Load().Debug(), - "AD domain query detected for %s in domain %s", - req.msg.Question[0].Name, p.adDomain) - upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} - upstreams = []string{upstreamOS} + // For OS resolver, local addresses are ignored to prevent possible looping. + // However, on Active Directory Domain Controller, where it has local DNS server + // running and listening on local addresses, these local addresses must be used + // as nameservers, so queries for ADDC could be resolved as expected. + if p.isAdDomainQuery(req.msg) { + ctrld.Log(ctx, mainLog.Load().Debug(), + "AD domain query detected for %s in domain %s", + req.msg.Question[0].Name, p.adDomain) + upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} + upstreams = []string{upstreamOSLocal} + } } res := &proxyResponse{} @@ -458,7 +461,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { - case isSrvLookup(req.msg): + case isSrvLanLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctx = ctrld.LanQueryCtx(ctx) @@ -1109,21 +1112,27 @@ func isLanHostnameQuery(m *dns.Msg) bool { default: return false } - name := strings.TrimSuffix(q.Name, ".") + return isLanHostname(q.Name) +} + +// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname. +func isSrvLanLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + return q.Qtype == dns.TypeSRV && isLanHostname(q.Name) +} + +// isLanHostname reports whether name is a LAN hostname. +func isLanHostname(name string) bool { + name = strings.TrimSuffix(name, ".") return !strings.Contains(name, ".") || strings.HasSuffix(name, ".domain") || strings.HasSuffix(name, ".lan") || strings.HasSuffix(name, ".local") } -// isSrvLookup reports whether DNS message is a SRV query. -func isSrvLookup(m *dns.Msg) bool { - if m == nil || len(m.Question) == 0 { - return false - } - return m.Question[0].Qtype == dns.TypeSRV -} - // isWanClient reports whether the input is a WAN address. func isWanClient(na net.Addr) bool { var ip netip.Addr diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index eae3dfa..4a4e5b4 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -418,20 +418,21 @@ func Test_isPrivatePtrLookup(t *testing.T) { } } -func Test_isSrvLookup(t *testing.T) { +func Test_isSrvLanLookup(t *testing.T) { tests := []struct { name string msg *dns.Msg isSrvLookup bool }{ - {"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, + {"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, {"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false}, + {"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - if got := isSrvLookup(tc.msg); tc.isSrvLookup != got { + if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got { t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got) } }) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index f6387f2..b119423 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -43,7 +43,7 @@ const ( ctrldControlUnixSockMobile = "cd.sock" upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" - upstreamPrivate = upstreamPrefix + "private" + upstreamOSLocal = upstreamOS + ".local" dnsWatchdogDefaultInterval = 20 * time.Second ctrldServiceName = "ctrld" )