mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Implement configurable DNS policy rule matching order and refactor upstreamFor method for better maintainability. New features: - Add MatchingConfig to ListenerPolicyConfig for rule order configuration - Support custom rule evaluation order (network, mac, domain) - Add stop_on_first_match configuration option - Hidden from config files (mapstructure:"-" toml:"-") for future release Code improvements: - Create upstreamForRequest struct to reduce method parameter count - Refactor upstreamForWithConfig to use single struct parameter - Improve code readability and maintainability - Maintain full backward compatibility Technical details: - String-based configuration converted to RuleType enum internally - Default behavior preserved (network → mac → domain order) - Domain rules still override MAC/network rules regardless of order - Comprehensive test coverage for configuration integration The matching configuration is programmatically accessible but hidden from user configuration files until ready for public release.
805 lines
23 KiB
Go
805 lines
23 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
"github.com/Control-D-Inc/ctrld/testhelper"
|
|
)
|
|
|
|
func Test_wildcardMatches(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
wildcard string
|
|
domain string
|
|
match bool
|
|
}{
|
|
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
|
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
|
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
|
|
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
|
|
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
|
|
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
|
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
|
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
|
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
|
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
|
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
|
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
|
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
|
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
|
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_canonicalName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
domain string
|
|
canonical string
|
|
}{
|
|
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
|
{"already canonical", "windscribe.com", "windscribe.com"},
|
|
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := canonicalName(tc.domain); got != tc.canonical {
|
|
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_prog_upstreamFor(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
|
p := &prog{cfg: cfg}
|
|
p.logger.Store(mainLog.Load())
|
|
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
|
p.lanLoopGuard = newLoopGuard()
|
|
p.ptrLoopGuard = newLoopGuard()
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
mac string
|
|
defaultUpstreamNum string
|
|
lc *ctrld.ListenerConfig
|
|
domain string
|
|
upstreams []string
|
|
matched bool
|
|
testLogMsg string
|
|
}{
|
|
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
|
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
|
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
|
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
|
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
|
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
|
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
for _, network := range []string{"udp", "tcp"} {
|
|
var (
|
|
addr net.Addr
|
|
err error
|
|
)
|
|
switch network {
|
|
case "udp":
|
|
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
|
case "tcp":
|
|
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
|
}
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addr)
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
|
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
|
p.proxy(ctx, &proxyRequest{
|
|
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
|
ufr: ufr,
|
|
})
|
|
assert.Equal(t, tc.matched, ufr.matched)
|
|
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
|
if tc.testLogMsg != "" {
|
|
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
prog := &prog{cfg: cfg}
|
|
prog.logger.Store(mainLog.Load())
|
|
for _, nc := range prog.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
// Create a custom policy with domain-first matching order
|
|
customPolicy := &ctrld.ListenerPolicyConfig{
|
|
Name: "Custom Policy",
|
|
Networks: []ctrld.Rule{
|
|
{"network.0": []string{"upstream.1", "upstream.0"}},
|
|
},
|
|
Macs: []ctrld.Rule{
|
|
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
|
|
},
|
|
Rules: []ctrld.Rule{
|
|
{"*.ru": []string{"upstream.1"}},
|
|
},
|
|
Matching: &ctrld.MatchingConfig{
|
|
Order: []string{"domain", "mac", "network"},
|
|
StopOnFirstMatch: true,
|
|
},
|
|
}
|
|
|
|
customListener := &ctrld.ListenerConfig{
|
|
Policy: customPolicy,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
mac string
|
|
domain string
|
|
upstreams []string
|
|
matched bool
|
|
}{
|
|
{
|
|
name: "Domain rule should match first with custom order",
|
|
ip: "192.168.0.1:0",
|
|
mac: "14:45:A0:67:83:0A",
|
|
domain: "example.ru",
|
|
upstreams: []string{"upstream.1"},
|
|
matched: true,
|
|
},
|
|
{
|
|
name: "MAC rule should match when no domain rule",
|
|
ip: "192.168.0.1:0",
|
|
mac: "14:45:A0:67:83:0A",
|
|
domain: "example.com",
|
|
upstreams: []string{"upstream.2"},
|
|
matched: true,
|
|
},
|
|
{
|
|
name: "Network rule should match when no domain or MAC rule",
|
|
ip: "192.168.0.1:0",
|
|
mac: "00:11:22:33:44:55",
|
|
domain: "example.com",
|
|
upstreams: []string{"upstream.1", "upstream.0"},
|
|
matched: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
addr, err := net.ResolveUDPAddr("udp", tc.ip)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addr)
|
|
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
|
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
|
|
|
|
assert.Equal(t, tc.matched, ufr.matched)
|
|
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCache(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
prog := &prog{cfg: cfg}
|
|
prog.logger.Store(mainLog.Load())
|
|
for _, nc := range prog.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
cacher, err := dnscache.NewLRUCache(4096)
|
|
require.NoError(t, err)
|
|
prog.cache = cacher
|
|
|
|
msg := new(dns.Msg)
|
|
msg.SetQuestion("example.com", dns.TypeA)
|
|
msg.MsgHdr.RecursionDesired = true
|
|
answer1 := new(dns.Msg)
|
|
answer1.SetRcode(msg, dns.RcodeSuccess)
|
|
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
|
answer2 := new(dns.Msg)
|
|
answer2.SetRcode(msg, dns.RcodeRefused)
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
|
|
|
req1 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.1"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
req2 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.0"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
got1 := prog.proxy(context.Background(), req1)
|
|
got2 := prog.proxy(context.Background(), req2)
|
|
assert.NotSame(t, got1, got2)
|
|
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
|
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
|
}
|
|
|
|
func Test_ipAndMacFromMsg(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
wantIp bool
|
|
mac string
|
|
wantMac bool
|
|
}{
|
|
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
|
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
|
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ip := net.ParseIP(tc.ip)
|
|
if ip == nil {
|
|
t.Fatal("missing IP")
|
|
}
|
|
hw, err := net.ParseMAC(tc.mac)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
if tc.wantMac {
|
|
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
|
o.Option = append(o.Option, ec1)
|
|
}
|
|
if tc.wantIp {
|
|
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
|
o.Option = append(o.Option, ec2)
|
|
}
|
|
m.Extra = append(m.Extra, o)
|
|
gotIP, gotMac := ipAndMacFromMsg(m)
|
|
if tc.wantMac && gotMac != tc.mac {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
|
}
|
|
if !tc.wantMac && gotMac != "" {
|
|
t.Errorf("unexpected mac: %q", gotMac)
|
|
}
|
|
if tc.wantIp && gotIP != tc.ip {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
|
}
|
|
if !tc.wantIp && gotIP != "" {
|
|
t.Errorf("unexpected ip: %q", gotIP)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_remoteAddrFromMsg(t *testing.T) {
|
|
loopbackIP := net.ParseIP("127.0.0.1")
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
ci *ctrld.ClientInfo
|
|
want string
|
|
}{
|
|
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
|
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
|
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
|
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
|
if addr.String() != tc.want {
|
|
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ipFromARPA(t *testing.T) {
|
|
tests := []struct {
|
|
IP string
|
|
ARPA string
|
|
}{
|
|
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
|
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
|
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
|
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
|
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
|
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"", "asd.in-addr.arpa."},
|
|
{"", "asd.ip6.arpa."},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.IP, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
|
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
|
m.Extra = append(m.Extra, o)
|
|
return m
|
|
}
|
|
|
|
func Test_stripClientSubnet(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
wantSubnet bool
|
|
}{
|
|
{"no edns0", new(dns.Msg), false},
|
|
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
|
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
|
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
|
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
|
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
|
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
stripClientSubnet(tc.msg)
|
|
hasSubnet := false
|
|
if opt := tc.msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
|
hasSubnet = true
|
|
}
|
|
}
|
|
}
|
|
if tc.wantSubnet != hasSubnet {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(hostname, typ)
|
|
return m
|
|
}
|
|
|
|
func Test_isLanHostnameQuery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isLanHostnameQuery bool
|
|
}{
|
|
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
|
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
|
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
|
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
|
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
|
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
|
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
|
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
|
t.Helper()
|
|
m := new(dns.Msg)
|
|
ptr, err := dns.ReverseAddr(ip)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m.SetQuestion(ptr, dns.TypePTR)
|
|
return m
|
|
}
|
|
|
|
func Test_isPrivatePtrLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isPrivatePtrLookup bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
|
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
|
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
|
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
|
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
|
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
|
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isSrvLanLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isSrvLookup bool
|
|
}{
|
|
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
|
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
|
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isWanClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
isWanClient bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
|
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
|
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
|
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
|
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
|
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
|
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_shouldStartRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
hasExistingRecovery bool
|
|
expectedResult bool
|
|
description string
|
|
}{
|
|
{
|
|
name: "network change with existing recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
hasExistingRecovery: true,
|
|
expectedResult: true,
|
|
description: "should cancel existing recovery and start new one for network change",
|
|
},
|
|
{
|
|
name: "network change without existing recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for network change",
|
|
},
|
|
{
|
|
name: "regular failure with existing recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
hasExistingRecovery: true,
|
|
expectedResult: false,
|
|
description: "should skip duplicate recovery for regular failure",
|
|
},
|
|
{
|
|
name: "regular failure without existing recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for regular failure",
|
|
},
|
|
{
|
|
name: "OS failure with existing recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
hasExistingRecovery: true,
|
|
expectedResult: false,
|
|
description: "should skip duplicate recovery for OS failure",
|
|
},
|
|
{
|
|
name: "OS failure without existing recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for OS failure",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
// Setup existing recovery if needed
|
|
if tc.hasExistingRecovery {
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = func() {} // Mock cancel function
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
result := p.shouldStartRecovery(tc.reason)
|
|
assert.Equal(t, tc.expectedResult, result, tc.description)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_createRecoveryContext(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
ctx, cleanup := p.createRecoveryContext()
|
|
|
|
// Verify context is created
|
|
assert.NotNil(t, ctx)
|
|
assert.NotNil(t, cleanup)
|
|
|
|
// Verify recoveryCancel is set
|
|
p.recoveryCancelMu.Lock()
|
|
assert.NotNil(t, p.recoveryCancel)
|
|
p.recoveryCancelMu.Unlock()
|
|
|
|
// Test cleanup function
|
|
cleanup()
|
|
|
|
// Verify recoveryCancel is cleared
|
|
p.recoveryCancelMu.Lock()
|
|
assert.Nil(t, p.recoveryCancel)
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
func Test_prepareForRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "regular failure",
|
|
reason: RecoveryReasonRegularFailure,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "network change",
|
|
reason: RecoveryReasonNetworkChange,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure",
|
|
reason: RecoveryReasonOSFailure,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.prepareForRecovery(tc.reason)
|
|
|
|
if tc.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify recoveryRunning is set to true
|
|
assert.True(t, p.recoveryRunning.Load())
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_completeRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
recovered string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "regular failure recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
recovered: "upstream1",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "network change recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
recovered: "upstream2",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
recovered: "upstream3",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.completeRecovery(tc.reason, tc.recovered)
|
|
|
|
if tc.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify recoveryRunning is set to false
|
|
assert.False(t, p.recoveryRunning.Load())
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_reinitializeOSResolver(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.reinitializeOSResolver("Test message")
|
|
|
|
// This function should not return an error under normal circumstances
|
|
// The actual behavior depends on the OS resolver implementation
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func Test_handleRecovery_Integration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "network change recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "regular failure recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
// This is an integration test that exercises the full recovery flow
|
|
// In a real test environment, you would mock the dependencies
|
|
// For now, we're just testing that the method doesn't panic
|
|
// and that the recovery logic flows correctly
|
|
assert.NotPanics(t, func() {
|
|
// Test only the preparation phase to avoid actual upstream checking
|
|
if !p.shouldStartRecovery(tc.reason) {
|
|
return
|
|
}
|
|
|
|
_, cleanup := p.createRecoveryContext()
|
|
defer cleanup()
|
|
|
|
if err := p.prepareForRecovery(tc.reason); err != nil {
|
|
return
|
|
}
|
|
|
|
// Skip the actual upstream recovery check for this test
|
|
// as it requires properly configured upstreams
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
// newTestProg creates a properly initialized *prog for testing.
|
|
func newTestProg(t *testing.T) *prog {
|
|
p := &prog{cfg: testhelper.SampleConfig(t)}
|
|
p.logger.Store(mainLog.Load())
|
|
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
|
return p
|
|
}
|