Tweak log message for policy logging

This commit is contained in:
Cuong Manh Le
2023-04-17 23:48:49 +07:00
committed by Cuong Manh Le
parent f5ef9b917e
commit ccdb2a3f70
3 changed files with 66 additions and 31 deletions

View File

@@ -105,6 +105,13 @@ func (p *prog) serveDNS(listenerNum string) error {
return g.Wait()
}
// upstreamFor returns the list of upstreams for resolving the given domain,
// matching by policies defined in the listener config. The second return value
// reports whether the domain matches the policy.
//
// Though domain policy has higher priority than network policy, it is still
// processed later, because policy logging want to know whether a network rule
// is disregarded in favor of the domain level rule.
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
upstreams := []string{"upstream." + defaultUpstreamNum}
matchedPolicy := "no policy"
@@ -128,11 +135,43 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
upstreams = append([]string(nil), policyUpstreams...)
}
var networkTargets []string
var sourceIP net.IP
switch addr := addr.(type) {
case *net.UDPAddr:
sourceIP = addr.IP
case *net.TCPAddr:
sourceIP = addr.IP
}
networkRules:
for _, rule := range lc.Policy.Networks {
for source, targets := range rule {
networkNum := strings.TrimPrefix(source, "network.")
nc := p.cfg.Network[networkNum]
if nc == nil {
continue
}
for _, ipNet := range nc.IPNets {
if ipNet.Contains(sourceIP) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
networkTargets = targets
matched = true
break networkRules
}
}
}
}
for _, rule := range lc.Policy.Rules {
// There's only one entry per rule, config validation ensures this.
for source, targets := range rule {
if source == domain || wildcardMatches(source, domain) {
matchedPolicy = lc.Policy.Name
if len(networkTargets) > 0 {
matchedNetwork += " (unenforced)"
}
matchedRule = source
do(targets)
matched = true
@@ -141,31 +180,8 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
}
}
var sourceIP net.IP
switch addr := addr.(type) {
case *net.UDPAddr:
sourceIP = addr.IP
case *net.TCPAddr:
sourceIP = addr.IP
}
for _, rule := range lc.Policy.Networks {
for source, targets := range rule {
networkNum := strings.TrimPrefix(source, "network.")
nc := p.cfg.Network[networkNum]
if nc == nil {
continue
}
for _, ipNet := range nc.IPNets {
if ipNet.Contains(sourceIP) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
do(targets)
matched = true
return upstreams, matched
}
}
}
if matched {
do(networkTargets)
}
return upstreams, matched

View File

@@ -86,17 +86,17 @@ func Test_prog_upstreamFor(t *testing.T) {
domain string
upstreams []string
matched bool
testLogMsg string
}{
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false},
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
for _, network := range []string{"udp", "tcp"} {
var (
addr net.Addr
@@ -114,6 +114,9 @@ func Test_prog_upstreamFor(t *testing.T) {
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
assert.Equal(t, tc.matched, matched)
assert.Equal(t, tc.upstreams, upstreams)
if tc.testLogMsg != "" {
assert.Contains(t, logOutput.String(), tc.testLogMsg)
}
}
})
}

16
cmd/ctrld/main_test.go Normal file
View File

@@ -0,0 +1,16 @@
package main
import (
"os"
"strings"
"testing"
"github.com/rs/zerolog"
)
var logOutput strings.Builder
func TestMain(m *testing.M) {
mainLog = zerolog.New(&logOutput)
os.Exit(m.Run())
}