diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a69f5b5..031b362 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -448,6 +448,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { case isSrvLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true @@ -457,6 +458,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true @@ -467,6 +469,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) diff --git a/resolver.go b/resolver.go index f3b7a10..e3d319b 100644 --- a/resolver.go +++ b/resolver.go @@ -47,6 +47,14 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. var or = newResolverWithNameserver(defaultNameservers()) +// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network. +type LanQueryCtxKey struct{} + +// LanQueryCtx returns a context.Context with LanQueryCtxKey set. +func LanQueryCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, LanQueryCtxKey{}, true) +} + // defaultNameservers is like nameservers with each element formed "ip:53". func defaultNameservers() []string { ns := nameservers() @@ -191,6 +199,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) + // If this is a LAN query, skip public DNS. + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if ok && lan { + numServers -= len(publicServers) + } if numServers == 0 { return nil, errors.New("no nameservers available") } @@ -216,7 +229,9 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } } do(nss, true) - do(publicServers, false) + if !lan { + do(publicServers, false) + } logAnswer := func(server string) { if before, _, found := strings.Cut(server, ":"); found { diff --git a/resolver_test.go b/resolver_test.go index e0b5508..5fb8434 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -34,6 +34,44 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveLanHostname(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reqId := "req-id" + ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId) + ctx = LanQueryCtx(ctx) + + go func(ctx context.Context) { + defer cancel() + id, ok := ctx.Value(ReqIdCtxKey{}).(string) + if !ok || id != reqId { + t.Error("missing request id") + return + } + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if !ok || !lan { + t.Error("not a LAN query") + return + } + resolver := &osResolver{} + resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + m := new(dns.Msg) + m.SetQuestion("controld.com.", dns.TypeA) + m.RecursionDesired = true + _, err := resolver.Resolve(ctx, m) + if err == nil { + t.Error("os resolver succeeded unexpectedly") + return + } + }(ctx) + + select { + case <-time.After(10 * time.Second): + t.Error("os resolver hangs") + case <-ctx.Done(): + } +} + func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { ns := make([]string, 0, 2) servers := make([]*dns.Server, 0, 2)