mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-27 12:52:27 +02:00
cmd: allow import/running ctrld as library
This commit is contained in:
committed by
Cuong Manh Le
parent
4896563e3c
commit
829e93c079
@@ -0,0 +1,218 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_wildcardMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard string
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
|
||||
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
||||
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
canonical string
|
||||
}{
|
||||
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
||||
{"already canonical", "windscribe.com", "windscribe.com"},
|
||||
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canonicalName(tc.domain); got != tc.canonical {
|
||||
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
err error
|
||||
)
|
||||
switch network {
|
||||
case "udp":
|
||||
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
cacher, err := dnscache.NewLRUCache(4096)
|
||||
require.NoError(t, err)
|
||||
prog.cache = cacher
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com", dns.TypeA)
|
||||
msg.MsgHdr.RecursionDesired = true
|
||||
answer1 := new(dns.Msg)
|
||||
answer1.SetRcode(msg, dns.RcodeSuccess)
|
||||
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
||||
answer2 := new(dns.Msg)
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user