mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Skip public DNS for LAN query
So we don't blindly send requests to public DNS even though they can not handle these queries.
This commit is contained in:
committed by
Cuong Manh Le
parent
8a96b8bec4
commit
37d41bd215
@@ -448,6 +448,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
case isSrvLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
@@ -457,6 +458,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
@@ -467,6 +469,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
|
||||
17
resolver.go
17
resolver.go
@@ -47,6 +47,14 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = newResolverWithNameserver(defaultNameservers())
|
||||
|
||||
// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network.
|
||||
type LanQueryCtxKey struct{}
|
||||
|
||||
// LanQueryCtx returns a context.Context with LanQueryCtxKey set.
|
||||
func LanQueryCtx(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, LanQueryCtxKey{}, true)
|
||||
}
|
||||
|
||||
// defaultNameservers is like nameservers with each element formed "ip:53".
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
@@ -191,6 +199,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
nss = append(nss, (*p)...)
|
||||
}
|
||||
numServers := len(nss) + len(publicServers)
|
||||
// If this is a LAN query, skip public DNS.
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
if ok && lan {
|
||||
numServers -= len(publicServers)
|
||||
}
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available")
|
||||
}
|
||||
@@ -216,7 +229,9 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
}
|
||||
}
|
||||
do(nss, true)
|
||||
do(publicServers, false)
|
||||
if !lan {
|
||||
do(publicServers, false)
|
||||
}
|
||||
|
||||
logAnswer := func(server string) {
|
||||
if before, _, found := strings.Cut(server, ":"); found {
|
||||
|
||||
@@ -34,6 +34,44 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveLanHostname(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
reqId := "req-id"
|
||||
ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId)
|
||||
ctx = LanQueryCtx(ctx)
|
||||
|
||||
go func(ctx context.Context) {
|
||||
defer cancel()
|
||||
id, ok := ctx.Value(ReqIdCtxKey{}).(string)
|
||||
if !ok || id != reqId {
|
||||
t.Error("missing request id")
|
||||
return
|
||||
}
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
if !ok || !lan {
|
||||
t.Error("not a LAN query")
|
||||
return
|
||||
}
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
_, err := resolver.Resolve(ctx, m)
|
||||
if err == nil {
|
||||
t.Error("os resolver succeeded unexpectedly")
|
||||
return
|
||||
}
|
||||
}(ctx)
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("os resolver hangs")
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
ns := make([]string, 0, 2)
|
||||
servers := make([]*dns.Server, 0, 2)
|
||||
|
||||
Reference in New Issue
Block a user