Skip public DNS for LAN query

So we don't blindly send requests to public DNS even though they can not
handle these queries.
This commit is contained in:
Cuong Manh Le
2024-12-12 18:36:39 +07:00
committed by Cuong Manh Le
parent 8a96b8bec4
commit 37d41bd215
3 changed files with 57 additions and 1 deletions

View File

@@ -448,6 +448,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
case isSrvLookup(req.msg):
upstreams = []string{upstreamOS}
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
case isPrivatePtrLookup(req.msg):
isLanOrPtrQuery = true
@@ -457,6 +458,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
return res
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
case isLanHostnameQuery(req.msg):
isLanOrPtrQuery = true
@@ -467,6 +469,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
}
upstreams = []string{upstreamOS}
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
default:
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)

View File

@@ -47,6 +47,14 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS.
var or = newResolverWithNameserver(defaultNameservers())
// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network.
type LanQueryCtxKey struct{}
// LanQueryCtx returns a context.Context with LanQueryCtxKey set.
func LanQueryCtx(ctx context.Context) context.Context {
return context.WithValue(ctx, LanQueryCtxKey{}, true)
}
// defaultNameservers is like nameservers with each element formed "ip:53".
func defaultNameservers() []string {
ns := nameservers()
@@ -191,6 +199,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
nss = append(nss, (*p)...)
}
numServers := len(nss) + len(publicServers)
// If this is a LAN query, skip public DNS.
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
if ok && lan {
numServers -= len(publicServers)
}
if numServers == 0 {
return nil, errors.New("no nameservers available")
}
@@ -216,7 +229,9 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
}
}
do(nss, true)
do(publicServers, false)
if !lan {
do(publicServers, false)
}
logAnswer := func(server string) {
if before, _, found := strings.Cut(server, ":"); found {

View File

@@ -34,6 +34,44 @@ func Test_osResolver_Resolve(t *testing.T) {
}
}
func Test_osResolver_ResolveLanHostname(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
reqId := "req-id"
ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId)
ctx = LanQueryCtx(ctx)
go func(ctx context.Context) {
defer cancel()
id, ok := ctx.Value(ReqIdCtxKey{}).(string)
if !ok || id != reqId {
t.Error("missing request id")
return
}
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
if !ok || !lan {
t.Error("not a LAN query")
return
}
resolver := &osResolver{}
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
_, err := resolver.Resolve(ctx, m)
if err == nil {
t.Error("os resolver succeeded unexpectedly")
return
}
}(ctx)
select {
case <-time.After(10 * time.Second):
t.Error("os resolver hangs")
case <-ctx.Done():
}
}
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
ns := make([]string, 0, 2)
servers := make([]*dns.Server, 0, 2)