mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
By adding a logger field to "prog" struct, and use this field inside its method instead of always accessing global mainLog variable. This at least ensure more consistent usage of the logger during ctrld prog runtime, and also help refactoring the code more easily in the future (like replacing the logger library).
469 lines
15 KiB
Go
469 lines
15 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
|
"github.com/Control-D-Inc/ctrld/testhelper"
|
|
)
|
|
|
|
func Test_wildcardMatches(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
wildcard string
|
|
domain string
|
|
match bool
|
|
}{
|
|
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
|
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
|
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
|
|
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
|
|
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
|
|
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
|
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
|
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
|
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
|
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
|
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
|
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
|
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
|
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
|
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
|
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_canonicalName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
domain string
|
|
canonical string
|
|
}{
|
|
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
|
{"already canonical", "windscribe.com", "windscribe.com"},
|
|
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := canonicalName(tc.domain); got != tc.canonical {
|
|
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_prog_upstreamFor(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
|
p := &prog{cfg: cfg}
|
|
p.logger.Store(mainLog.Load())
|
|
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
|
p.lanLoopGuard = newLoopGuard()
|
|
p.ptrLoopGuard = newLoopGuard()
|
|
for _, nc := range p.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
mac string
|
|
defaultUpstreamNum string
|
|
lc *ctrld.ListenerConfig
|
|
domain string
|
|
upstreams []string
|
|
matched bool
|
|
testLogMsg string
|
|
}{
|
|
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
|
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
|
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
|
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
|
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
|
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
|
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
for _, network := range []string{"udp", "tcp"} {
|
|
var (
|
|
addr net.Addr
|
|
err error
|
|
)
|
|
switch network {
|
|
case "udp":
|
|
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
|
case "tcp":
|
|
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
|
}
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addr)
|
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
|
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
|
p.proxy(ctx, &proxyRequest{
|
|
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
|
ufr: ufr,
|
|
})
|
|
assert.Equal(t, tc.matched, ufr.matched)
|
|
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
|
if tc.testLogMsg != "" {
|
|
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCache(t *testing.T) {
|
|
cfg := testhelper.SampleConfig(t)
|
|
prog := &prog{cfg: cfg}
|
|
prog.logger.Store(mainLog.Load())
|
|
for _, nc := range prog.cfg.Network {
|
|
for _, cidr := range nc.Cidrs {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nc.IPNets = append(nc.IPNets, ipNet)
|
|
}
|
|
}
|
|
cacher, err := dnscache.NewLRUCache(4096)
|
|
require.NoError(t, err)
|
|
prog.cache = cacher
|
|
|
|
msg := new(dns.Msg)
|
|
msg.SetQuestion("example.com", dns.TypeA)
|
|
msg.MsgHdr.RecursionDesired = true
|
|
answer1 := new(dns.Msg)
|
|
answer1.SetRcode(msg, dns.RcodeSuccess)
|
|
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
|
answer2 := new(dns.Msg)
|
|
answer2.SetRcode(msg, dns.RcodeRefused)
|
|
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
|
|
|
req1 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.1"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
req2 := &proxyRequest{
|
|
msg: msg,
|
|
ci: nil,
|
|
failoverRcodes: nil,
|
|
ufr: &upstreamForResult{
|
|
upstreams: []string{"upstream.0"},
|
|
matchedPolicy: "",
|
|
matchedNetwork: "",
|
|
matchedRule: "",
|
|
matched: false,
|
|
},
|
|
}
|
|
got1 := prog.proxy(context.Background(), req1)
|
|
got2 := prog.proxy(context.Background(), req2)
|
|
assert.NotSame(t, got1, got2)
|
|
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
|
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
|
}
|
|
|
|
func Test_ipAndMacFromMsg(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
wantIp bool
|
|
mac string
|
|
wantMac bool
|
|
}{
|
|
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
|
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
|
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ip := net.ParseIP(tc.ip)
|
|
if ip == nil {
|
|
t.Fatal("missing IP")
|
|
}
|
|
hw, err := net.ParseMAC(tc.mac)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
if tc.wantMac {
|
|
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
|
o.Option = append(o.Option, ec1)
|
|
}
|
|
if tc.wantIp {
|
|
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
|
o.Option = append(o.Option, ec2)
|
|
}
|
|
m.Extra = append(m.Extra, o)
|
|
gotIP, gotMac := ipAndMacFromMsg(m)
|
|
if tc.wantMac && gotMac != tc.mac {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
|
}
|
|
if !tc.wantMac && gotMac != "" {
|
|
t.Errorf("unexpected mac: %q", gotMac)
|
|
}
|
|
if tc.wantIp && gotIP != tc.ip {
|
|
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
|
}
|
|
if !tc.wantIp && gotIP != "" {
|
|
t.Errorf("unexpected ip: %q", gotIP)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_remoteAddrFromMsg(t *testing.T) {
|
|
loopbackIP := net.ParseIP("127.0.0.1")
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
ci *ctrld.ClientInfo
|
|
want string
|
|
}{
|
|
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
|
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
|
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
|
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
|
if addr.String() != tc.want {
|
|
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ipFromARPA(t *testing.T) {
|
|
tests := []struct {
|
|
IP string
|
|
ARPA string
|
|
}{
|
|
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
|
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
|
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
|
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
|
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
|
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
|
{"", "asd.in-addr.arpa."},
|
|
{"", "asd.ip6.arpa."},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.IP, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
|
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
|
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
|
m.Extra = append(m.Extra, o)
|
|
return m
|
|
}
|
|
|
|
func Test_stripClientSubnet(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
wantSubnet bool
|
|
}{
|
|
{"no edns0", new(dns.Msg), false},
|
|
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
|
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
|
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
|
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
|
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
|
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
stripClientSubnet(tc.msg)
|
|
hasSubnet := false
|
|
if opt := tc.msg.IsEdns0(); opt != nil {
|
|
for _, s := range opt.Option {
|
|
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
|
hasSubnet = true
|
|
}
|
|
}
|
|
}
|
|
if tc.wantSubnet != hasSubnet {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(hostname, typ)
|
|
return m
|
|
}
|
|
|
|
func Test_isLanHostnameQuery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isLanHostnameQuery bool
|
|
}{
|
|
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
|
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
|
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
|
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
|
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
|
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
|
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
|
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
|
t.Helper()
|
|
m := new(dns.Msg)
|
|
ptr, err := dns.ReverseAddr(ip)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
m.SetQuestion(ptr, dns.TypePTR)
|
|
return m
|
|
}
|
|
|
|
func Test_isPrivatePtrLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isPrivatePtrLookup bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
|
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
|
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
|
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
|
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
|
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
|
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isSrvLanLookup(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg *dns.Msg
|
|
isSrvLookup bool
|
|
}{
|
|
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
|
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
|
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_isWanClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
isWanClient bool
|
|
}{
|
|
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
|
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
|
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
|
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
|
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
|
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
|
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
|
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
|
}
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
|
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
|
}
|
|
})
|
|
}
|
|
}
|