mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Remove StopOnFirstMatch field that was defined but never used in the actual matching logic. The current implementation always evaluates all rule types and applies a fixed precedence (Domain > MAC > Network), making the StopOnFirstMatch field unnecessary. Changes: - Remove StopOnFirstMatch from MatchingConfig structs - Update DefaultMatchingConfig() function - Update all test cases and references - Simplify configuration to only include Order field This cleanup removes dead code and simplifies the configuration API without changing any functional behavior.
804 lines
23 KiB
Go
804 lines
23 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
"github.com/Control-D-Inc/ctrld/testhelper"
|
|
)
|
|
|
|
func Test_wildcardMatches(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
wildcard string
|
|
domain string
|
|
match bool
|
|
}{
|
|
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
|
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
|
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
|
|
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
|
|
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
|
|
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
|
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
|
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
|
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
|
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
|
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
|
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
|
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
|
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
|
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_canonicalName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
domain string
|
|
canonical string
|
|
}{
|
|
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
|
{"already canonical", "windscribe.com", "windscribe.com"},
|
|
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := canonicalName(tc.domain); got != tc.canonical {
|
|
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_prog_upstreamFor(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
|
p := &prog{cfg: cfg}
|
|
p.logger.Store(mainLog.Load())
|
|
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
|
p.lanLoopGuard = newLoopGuard()
|
|
p.ptrLoopGuard = newLoopGuard()
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
mac string
|
|
defaultUpstreamNum string
|
|
lc *ctrld.ListenerConfig
|
|
domain string
|
|
upstreams []string
|
|
matched bool
|
|
testLogMsg string
|
|
}{
|
|
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
|
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
|
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
|
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
|
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
|
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
|
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
for _, network := range []string{"udp", "tcp"} {
|
|
var (
|
|
addr net.Addr
|
|
err error
|
|
)
|
|
switch network {
|
|
case "udp":
|
|
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
|
case "tcp":
|
|
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
|
}
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addr)
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
|
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
|
p.proxy(ctx, &proxyRequest{
|
|
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
|
ufr: ufr,
|
|
})
|
|
assert.Equal(t, tc.matched, ufr.matched)
|
|
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
|
if tc.testLogMsg != "" {
|
|
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
prog := &prog{cfg: cfg}
|
|
prog.logger.Store(mainLog.Load())
|
|
for _, nc := range prog.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
// Create a custom policy with domain-first matching order
|
|
customPolicy := &ctrld.ListenerPolicyConfig{
|
|
Name: "Custom Policy",
|
|
Networks: []ctrld.Rule{
|
|
{"network.0": []string{"upstream.1", "upstream.0"}},
|
|
},
|
|
Macs: []ctrld.Rule{
|
|
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
|
|
},
|
|
Rules: []ctrld.Rule{
|
|
{"*.ru": []string{"upstream.1"}},
|
|
},
|
|
Matching: &ctrld.MatchingConfig{
|
|
Order: []string{"domain", "mac", "network"},
|
|
},
|
|
}
|
|
|
|
customListener := &ctrld.ListenerConfig{
|
|
Policy: customPolicy,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
mac string
|
|
domain string
|
|
upstreams []string
|
|
matched bool
|
|
}{
|
|
{
|
|
name: "Domain rule should match first with custom order",
|
|
ip: "192.168.0.1:0",
|
|
mac: "14:45:A0:67:83:0A",
|
|
domain: "example.ru",
|
|
upstreams: []string{"upstream.1"},
|
|
matched: true,
|
|
},
|
|
{
|
|
name: "MAC rule should match when no domain rule",
|
|
ip: "192.168.0.1:0",
|
|
mac: "14:45:A0:67:83:0A",
|
|
domain: "example.com",
|
|
upstreams: []string{"upstream.2"},
|
|
matched: true,
|
|
},
|
|
{
|
|
name: "Network rule should match when no domain or MAC rule",
|
|
ip: "192.168.0.1:0",
|
|
mac: "00:11:22:33:44:55",
|
|
domain: "example.com",
|
|
upstreams: []string{"upstream.1", "upstream.0"},
|
|
matched: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
addr, err := net.ResolveUDPAddr("udp", tc.ip)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addr)
|
|
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
|
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
|
|
|
|
assert.Equal(t, tc.matched, ufr.matched)
|
|
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCache(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
prog := &prog{cfg: cfg}
|
|
prog.logger.Store(mainLog.Load())
|
|
for _, nc := range prog.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
cacher, err := dnscache.NewLRUCache(4096)
|
|
require.NoError(t, err)
|
|
prog.cache = cacher
|
|
|
|
msg := new(dns.Msg)
|
|
msg.SetQuestion("example.com", dns.TypeA)
|
|
msg.MsgHdr.RecursionDesired = true
|
|
answer1 := new(dns.Msg)
|
|
answer1.SetRcode(msg, dns.RcodeSuccess)
|
|
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
|
answer2 := new(dns.Msg)
|
|
answer2.SetRcode(msg, dns.RcodeRefused)
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
|
|
|
req1 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.1"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
req2 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.0"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
got1 := prog.proxy(context.Background(), req1)
|
|
got2 := prog.proxy(context.Background(), req2)
|
|
assert.NotSame(t, got1, got2)
|
|
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
|
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
|
}
|
|
|
|
func Test_ipAndMacFromMsg(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
wantIp bool
|
|
mac string
|
|
wantMac bool
|
|
}{
|
|
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
|
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
|
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ip := net.ParseIP(tc.ip)
|
|
if ip == nil {
|
|
t.Fatal("missing IP")
|
|
}
|
|
hw, err := net.ParseMAC(tc.mac)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
if tc.wantMac {
|
|
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
|
o.Option = append(o.Option, ec1)
|
|
}
|
|
if tc.wantIp {
|
|
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
|
o.Option = append(o.Option, ec2)
|
|
}
|
|
m.Extra = append(m.Extra, o)
|
|
gotIP, gotMac := ipAndMacFromMsg(m)
|
|
if tc.wantMac && gotMac != tc.mac {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
|
}
|
|
if !tc.wantMac && gotMac != "" {
|
|
t.Errorf("unexpected mac: %q", gotMac)
|
|
}
|
|
if tc.wantIp && gotIP != tc.ip {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
|
}
|
|
if !tc.wantIp && gotIP != "" {
|
|
t.Errorf("unexpected ip: %q", gotIP)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_remoteAddrFromMsg(t *testing.T) {
|
|
loopbackIP := net.ParseIP("127.0.0.1")
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
ci *ctrld.ClientInfo
|
|
want string
|
|
}{
|
|
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
|
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
|
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
|
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
|
if addr.String() != tc.want {
|
|
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ipFromARPA(t *testing.T) {
|
|
tests := []struct {
|
|
IP string
|
|
ARPA string
|
|
}{
|
|
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
|
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
|
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
|
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
|
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
|
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"", "asd.in-addr.arpa."},
|
|
{"", "asd.ip6.arpa."},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.IP, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
|
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
|
m.Extra = append(m.Extra, o)
|
|
return m
|
|
}
|
|
|
|
func Test_stripClientSubnet(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
wantSubnet bool
|
|
}{
|
|
{"no edns0", new(dns.Msg), false},
|
|
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
|
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
|
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
|
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
|
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
|
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
stripClientSubnet(tc.msg)
|
|
hasSubnet := false
|
|
if opt := tc.msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
|
hasSubnet = true
|
|
}
|
|
}
|
|
}
|
|
if tc.wantSubnet != hasSubnet {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(hostname, typ)
|
|
return m
|
|
}
|
|
|
|
func Test_isLanHostnameQuery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isLanHostnameQuery bool
|
|
}{
|
|
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
|
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
|
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
|
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
|
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
|
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
|
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
|
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
|
t.Helper()
|
|
m := new(dns.Msg)
|
|
ptr, err := dns.ReverseAddr(ip)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m.SetQuestion(ptr, dns.TypePTR)
|
|
return m
|
|
}
|
|
|
|
func Test_isPrivatePtrLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isPrivatePtrLookup bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
|
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
|
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
|
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
|
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
|
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
|
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isSrvLanLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isSrvLookup bool
|
|
}{
|
|
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
|
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
|
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isWanClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
isWanClient bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
|
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
|
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
|
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
|
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
|
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
|
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_shouldStartRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
hasExistingRecovery bool
|
|
expectedResult bool
|
|
description string
|
|
}{
|
|
{
|
|
name: "network change with existing recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
hasExistingRecovery: true,
|
|
expectedResult: true,
|
|
description: "should cancel existing recovery and start new one for network change",
|
|
},
|
|
{
|
|
name: "network change without existing recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for network change",
|
|
},
|
|
{
|
|
name: "regular failure with existing recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
hasExistingRecovery: true,
|
|
expectedResult: false,
|
|
description: "should skip duplicate recovery for regular failure",
|
|
},
|
|
{
|
|
name: "regular failure without existing recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for regular failure",
|
|
},
|
|
{
|
|
name: "OS failure with existing recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
hasExistingRecovery: true,
|
|
expectedResult: false,
|
|
description: "should skip duplicate recovery for OS failure",
|
|
},
|
|
{
|
|
name: "OS failure without existing recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
hasExistingRecovery: false,
|
|
expectedResult: true,
|
|
description: "should start new recovery for OS failure",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
// Setup existing recovery if needed
|
|
if tc.hasExistingRecovery {
|
|
p.recoveryCancelMu.Lock()
|
|
p.recoveryCancel = func() {} // Mock cancel function
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
result := p.shouldStartRecovery(tc.reason)
|
|
assert.Equal(t, tc.expectedResult, result, tc.description)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_createRecoveryContext(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
ctx, cleanup := p.createRecoveryContext()
|
|
|
|
// Verify context is created
|
|
assert.NotNil(t, ctx)
|
|
assert.NotNil(t, cleanup)
|
|
|
|
// Verify recoveryCancel is set
|
|
p.recoveryCancelMu.Lock()
|
|
assert.NotNil(t, p.recoveryCancel)
|
|
p.recoveryCancelMu.Unlock()
|
|
|
|
// Test cleanup function
|
|
cleanup()
|
|
|
|
// Verify recoveryCancel is cleared
|
|
p.recoveryCancelMu.Lock()
|
|
assert.Nil(t, p.recoveryCancel)
|
|
p.recoveryCancelMu.Unlock()
|
|
}
|
|
|
|
func Test_prepareForRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "regular failure",
|
|
reason: RecoveryReasonRegularFailure,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "network change",
|
|
reason: RecoveryReasonNetworkChange,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure",
|
|
reason: RecoveryReasonOSFailure,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.prepareForRecovery(tc.reason)
|
|
|
|
if tc.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify recoveryRunning is set to true
|
|
assert.True(t, p.recoveryRunning.Load())
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_completeRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
recovered string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "regular failure recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
recovered: "upstream1",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "network change recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
recovered: "upstream2",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
recovered: "upstream3",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.completeRecovery(tc.reason, tc.recovered)
|
|
|
|
if tc.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify recoveryRunning is set to false
|
|
assert.False(t, p.recoveryRunning.Load())
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_reinitializeOSResolver(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
err := p.reinitializeOSResolver("Test message")
|
|
|
|
// This function should not return an error under normal circumstances
|
|
// The actual behavior depends on the OS resolver implementation
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func Test_handleRecovery_Integration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
reason RecoveryReason
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "network change recovery",
|
|
reason: RecoveryReasonNetworkChange,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "regular failure recovery",
|
|
reason: RecoveryReasonRegularFailure,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "OS failure recovery",
|
|
reason: RecoveryReasonOSFailure,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := newTestProg(t)
|
|
|
|
// This is an integration test that exercises the full recovery flow
|
|
// In a real test environment, you would mock the dependencies
|
|
// For now, we're just testing that the method doesn't panic
|
|
// and that the recovery logic flows correctly
|
|
assert.NotPanics(t, func() {
|
|
// Test only the preparation phase to avoid actual upstream checking
|
|
if !p.shouldStartRecovery(tc.reason) {
|
|
return
|
|
}
|
|
|
|
_, cleanup := p.createRecoveryContext()
|
|
defer cleanup()
|
|
|
|
if err := p.prepareForRecovery(tc.reason); err != nil {
|
|
return
|
|
}
|
|
|
|
// Skip the actual upstream recovery check for this test
|
|
// as it requires properly configured upstreams
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
// newTestProg creates a properly initialized *prog for testing.
|
|
func newTestProg(t *testing.T) *prog {
|
|
p := &prog{cfg: testhelper.SampleConfig(t)}
|
|
p.logger.Store(mainLog.Load())
|
|
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
|
return p
|
|
}
|