cmd/cli: improve logging for new LAN/PTR flow

This commit is contained in:
Cuong Manh Le
2023-12-05 19:52:11 +07:00
committed by Cuong Manh Le
parent 8939debbc0
commit af2c1c87e0
2 changed files with 167 additions and 98 deletions

View File

@@ -45,6 +45,23 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
Timeout: 2000,
}
// proxyRequest contains data for proxying a DNS query to upstream.
type proxyRequest struct {
msg *dns.Msg
ci *ctrld.ClientInfo
failoverRcodes []int
ufr *upstreamForResult
}
// upstreamForResult represents the result of processing rules for a request.
type upstreamForResult struct {
upstreams []string
matchedPolicy string
matchedNetwork string
matchedRule string
matched bool
}
func (p *prog) serveDNS(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
@@ -74,9 +91,10 @@ func (p *prog) serveDNS(listenerNum string) error {
t := time.Now()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
var answer *dns.Msg
if !matched && listenerConfig.Restricted {
if !res.matched && listenerConfig.Restricted {
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
answer = new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
} else {
@@ -84,7 +102,12 @@ func (p *prog) serveDNS(listenerNum string) error {
if listenerConfig.Policy != nil {
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
}
answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched)
answer = p.proxy(ctx, &proxyRequest{
msg: m,
ci: ci,
failoverRcodes: failoverRcode,
ufr: res,
})
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
@@ -162,27 +185,24 @@ func (p *prog) serveDNS(listenerNum string) error {
// Though domain policy has higher priority than network policy, it is still
// processed later, because policy logging want to know whether a network rule
// is disregarded in favor of the domain level rule.
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) ([]string, bool) {
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
matchedPolicy := "no policy"
matchedNetwork := "no network"
matchedRule := "no rule"
matched := false
res = &upstreamForResult{}
defer func() {
if !matched && lc.Restricted {
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String())
return
}
if matched {
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
} else {
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
}
res.upstreams = upstreams
res.matched = matched
res.matchedPolicy = matchedPolicy
res.matchedNetwork = matchedNetwork
res.matchedRule = matchedRule
}()
if lc.Policy == nil {
return upstreams, false
return
}
do := func(policyUpstreams []string) {
@@ -242,7 +262,7 @@ macRules:
matchedRule = source
do(targets)
matched = true
return upstreams, matched
return
}
}
}
@@ -251,11 +271,77 @@ macRules:
do(networkTargets)
}
return upstreams, matched
return
}
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg {
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
ip := ipFromARPA(msg.Question[0].Name)
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
answer.Answer = []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(name),
}}
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: name,
})
return answer
}
return nil
}
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
q := msg.Question[0]
hostname := strings.TrimSuffix(q.Name, ".")
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
switch {
case ip.Is4():
answer.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
A: ip.AsSlice(),
}}
case ip.Is6():
answer.Answer = []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
AAAA: ip.AsSlice(),
}}
}
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: hostname,
})
return answer
}
return nil
}
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg {
var staleAnswer *dns.Msg
upstreams := req.ufr.upstreams
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
if len(upstreamConfigs) == 0 {
@@ -270,85 +356,38 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
// 3. Try private resolver.
// 4. Try remote upstream.
isLanOrPtrQuery := false
if !matched {
if req.ufr.matched {
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
switch {
case isPrivatePtrLookup(msg):
case isPrivatePtrLookup(req.msg):
isLanOrPtrQuery = true
ip := ipFromARPA(msg.Question[0].Name)
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
answer.Answer = []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(name),
}}
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: name,
})
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
case isLanHostnameQuery(msg):
case isLanHostnameQuery(req.msg):
isLanOrPtrQuery = true
q := msg.Question[0]
hostname := strings.TrimSuffix(q.Name, ".")
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
switch {
case ip.Is4():
answer.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
A: ip.AsSlice(),
}}
case ip.Is6():
answer.Answer = []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
AAAA: ip.AsSlice(),
}}
}
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: hostname,
})
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
default:
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
}
}
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
for _, upstream := range upstreams {
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
if cachedValue == nil {
continue
}
answer := cachedValue.Msg.Copy()
answer.SetRcode(msg, answer.Rcode)
answer.SetRcode(req.msg, answer.Rcode)
now := time.Now()
if cachedValue.Expire.After(now) {
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
@@ -375,9 +414,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
return dnsResolver.Resolve(resolveCtx, msg)
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
if upstreamConfig.UpstreamSendClientInfo() && ci != nil {
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
}
answer, err := resolve1(n, upstreamConfig, msg)
if err != nil {
@@ -404,7 +443,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
continue
}
answer := resolve(n, upstreamConfig, msg)
answer := resolve(n, upstreamConfig, req.msg)
if answer == nil {
if serveStaleCache && staleAnswer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
@@ -417,9 +456,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
// We are doing LAN/PTR lookup using private resolver, so always process next one.
// Except for the last, we want to send response instead of saying all upstream failed.
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
continue
}
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
continue
}
@@ -427,7 +467,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
// set compression, as it is not set by default when unpacking
answer.Compress = true
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
ttl := ttlFromMsg(answer)
now := time.Now()
expired := now.Add(time.Duration(ttl) * time.Second)
@@ -435,14 +475,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
expired = now.Add(time.Duration(cachedTTL) * time.Second)
}
setCachedAnswerTTL(answer, now, expired)
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
}
return answer
}
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
answer := new(dns.Msg)
answer.SetRcode(msg, dns.RcodeServerFailure)
answer.SetRcode(req.msg, dns.RcodeServerFailure)
return answer
}

View File

@@ -67,8 +67,9 @@ func Test_canonicalName(t *testing.T) {
func Test_prog_upstreamFor(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
for _, nc := range prog.cfg.Network {
p := &prog{cfg: cfg}
p.um = newUpstreamMonitor(p.cfg)
for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
@@ -89,14 +90,14 @@ func Test_prog_upstreamFor(t *testing.T) {
matched bool
testLogMsg string
}{
{"Policy map matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
}
for _, tc := range tests {
@@ -115,9 +116,13 @@ func Test_prog_upstreamFor(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, addr)
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
assert.Equal(t, tc.matched, matched)
assert.Equal(t, tc.upstreams, upstreams)
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
p.proxy(ctx, &proxyRequest{
msg: newDnsMsgWithHostname("foo", dns.TypeA),
ufr: ufr,
})
assert.Equal(t, tc.matched, ufr.matched)
assert.Equal(t, tc.upstreams, ufr.upstreams)
if tc.testLogMsg != "" {
assert.Contains(t, logOutput.String(), tc.testLogMsg)
}
@@ -153,8 +158,32 @@ func TestCache(t *testing.T) {
answer2.SetRcode(msg, dns.RcodeRefused)
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false)
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false)
req1 := &proxyRequest{
msg: msg,
ci: nil,
failoverRcodes: nil,
ufr: &upstreamForResult{
upstreams: []string{"upstream.1"},
matchedPolicy: "",
matchedNetwork: "",
matchedRule: "",
matched: false,
},
}
req2 := &proxyRequest{
msg: msg,
ci: nil,
failoverRcodes: nil,
ufr: &upstreamForResult{
upstreams: []string{"upstream.0"},
matchedPolicy: "",
matchedNetwork: "",
matchedRule: "",
matched: false,
},
}
got1 := prog.proxy(context.Background(), req1)
got2 := prog.proxy(context.Background(), req2)
assert.NotSame(t, got1, got2)
assert.Equal(t, answer1.Rcode, got1.Rcode)
assert.Equal(t, answer2.Rcode, got2.Rcode)