From 0f3e8c7ada837fb2741126d5f4b0ea9479fb6d01 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 7 Sep 2023 11:09:53 +0000 Subject: [PATCH] all: include client IP if ctrld is dnsmasq upstream So ctrld can record the raw/original client IP instead of looking up from MAC to IP, which may not the right choice in some network setup like using wireguard/vpn on Merlin router. --- cmd/cli/dns_proxy.go | 56 +++++++++++++++++++++--------- cmd/cli/dns_proxy_test.go | 36 ++++++++++++++----- internal/clientinfo/arp.go | 3 ++ internal/clientinfo/client_info.go | 28 ++++++++++++++- internal/clientinfo/dhcp.go | 3 ++ internal/clientinfo/mdns.go | 3 ++ internal/clientinfo/ptr_lookup.go | 3 ++ internal/clientinfo/vpn.go | 43 +++++++++++++++++++++++ internal/router/dnsmasq/dnsmasq.go | 4 +++ 9 files changed, 153 insertions(+), 26 deletions(-) create mode 100644 internal/clientinfo/vpn.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 98ccafc..ca9a4d0 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -54,8 +54,7 @@ func (p *prog) serveDNS(listenerNum string) error { domain := canonicalName(q.Name) reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - mac := macFromMsg(m) - ci := p.getClientInfo(remoteIP, mac) + ci := p.getClientInfo(remoteIP, m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() @@ -419,18 +418,24 @@ func needLocalIPv6Listener() bool { return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows" } -func macFromMsg(msg *dns.Msg) string { +// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any. +func ipAndMacFromMsg(msg *dns.Msg) (string, string) { + ip, mac := "", "" if opt := msg.IsEdns0(); opt != nil { for _, s := range opt.Option { switch e := s.(type) { case *dns.EDNS0_LOCAL: if e.Code == EDNS0_OPTION_MAC { - return net.HardwareAddr(e.Data).String() + mac = net.HardwareAddr(e.Data).String() + } + case *dns.EDNS0_SUBNET: + if len(e.Address) > 0 && !e.Address.IsLoopback() { + ip = e.Address.String() } } } } - return "" + return ip, mac } func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { @@ -484,19 +489,38 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha return s, errCh } -func (p *prog) getClientInfo(ip, mac string) *ctrld.ClientInfo { +func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { ci := &ctrld.ClientInfo{} - if mac != "" { - ci.Mac = mac - ci.IP = p.ciTable.LookupIP(mac) - } else { - ci.IP = ip - ci.Mac = p.ciTable.LookupMac(ip) - if ip == "127.0.0.1" || ip == "::1" { - ci.IP = p.ciTable.LookupIP(ci.Mac) - } + ci.IP, ci.Mac = ipAndMacFromMsg(msg) + switch { + case ci.IP != "" && ci.Mac != "": + // Nothing to do. + case ci.IP == "" && ci.Mac != "": + // Have MAC, no IP. + ci.IP = p.ciTable.LookupIP(ci.Mac) + case ci.IP == "" && ci.Mac == "": + // Have nothing, use remote IP then lookup MAC. + ci.IP = remoteIP + fallthrough + case ci.IP != "" && ci.Mac == "": + // Have IP, no MAC. + ci.Mac = p.ciTable.LookupMac(ci.IP) + } + + // If MAC is still empty here, that mean the requests are made from virtual interface, + // like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP + // (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients. + if ci.Mac == "" { + ci.Mac = p.ciTable.LookupMac(remoteIP) + if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" { + ci.Hostname = hostname + } else { + ci.Hostname = ci.IP + p.ciTable.StoreVPNClient(ci) + } + } else { + ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) } - ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) return ci } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index b7b0dbd..674d486 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -156,19 +156,27 @@ func TestCache(t *testing.T) { assert.Equal(t, answer2.Rcode, got2.Rcode) } -func Test_macFromMsg(t *testing.T) { +func Test_ipAndMacFromMsg(t *testing.T) { tests := []struct { name string + ip string + wantIp bool mac string wantMac bool }{ - {"has mac", "4c:20:b8:ab:87:1b", true}, - {"no mac", "4c:20:b8:ab:87:1b", false}, + {"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true}, + {"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true}, + {"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false}, + {"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatal("missing IP") + } hw, err := net.ParseMAC(tc.mac) if err != nil { t.Fatal(err) @@ -180,13 +188,23 @@ func Test_macFromMsg(t *testing.T) { ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw} o.Option = append(o.Option, ec1) } - m.Extra = append(m.Extra, o) - got := macFromMsg(m) - if tc.wantMac && got != tc.mac { - t.Errorf("mismatch, want: %q, got: %q", tc.mac, got) + if tc.wantIp { + ec2 := &dns.EDNS0_SUBNET{Address: ip} + o.Option = append(o.Option, ec2) } - if !tc.wantMac && got != "" { - t.Errorf("unexpected mac: %q", got) + m.Extra = append(m.Extra, o) + gotIP, gotMac := ipAndMacFromMsg(m) + if tc.wantMac && gotMac != tc.mac { + t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac) + } + if !tc.wantMac && gotMac != "" { + t.Errorf("unexpected mac: %q", gotMac) + } + if tc.wantIp && gotIP != tc.ip { + t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP) + } + if !tc.wantIp && gotIP != "" { + t.Errorf("unexpected ip: %q", gotIP) } }) } diff --git a/internal/clientinfo/arp.go b/internal/clientinfo/arp.go index 8429b56..f99f783 100644 --- a/internal/clientinfo/arp.go +++ b/internal/clientinfo/arp.go @@ -33,6 +33,9 @@ func (a *arpDiscover) String() string { } func (a *arpDiscover) List() []string { + if a == nil { + return nil + } var ips []string a.ip.Range(func(key, value any) bool { ips = append(ips, value.(string)) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index a371a19..ea7ed05 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -74,6 +74,7 @@ type Table struct { ptr *ptrDiscover mdns *mdns hf *hostsFile + vpn *vpn cfg *ctrld.Config quitCh chan struct{} selfIP string @@ -117,6 +118,7 @@ func (t *Table) Init() { } func (t *Table) init() { + // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { ctrld.ProxyLogger.Load().Debug().Msg("start self discovery") t.dhcp = &dhcp{selfIP: t.selfIP} @@ -126,6 +128,11 @@ func (t *Table) init() { t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp) return } + + // Otherwise, process all possible sources in order, that means + // the first result of IP/MAC/Hostname lookup will be used. + // + // Merlin custom clients. if t.discoverDHCP() || t.discoverARP() { t.merlin = &merlinDiscover{} if err := t.merlin.refresh(); err != nil { @@ -135,6 +142,7 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.merlin) } } + // Hosts file mapping. if t.discoverHosts() { t.hf = &hostsFile{} ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery") @@ -146,6 +154,7 @@ func (t *Table) init() { } go t.hf.watchChanges() } + // DHCP lease files. if t.discoverDHCP() { t.dhcp = &dhcp{selfIP: t.selfIP} ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery") @@ -158,6 +167,7 @@ func (t *Table) init() { } go t.dhcp.watchChanges() } + // ARP table. if t.discoverARP() { t.arp = &arpDiscover{} ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery") @@ -169,6 +179,7 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.arp) } } + // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") @@ -179,6 +190,7 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.ptr) } } + // mdns. if t.discoverMDNS() { t.mdns = &mdns{} ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery") @@ -188,6 +200,11 @@ func (t *Table) init() { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) } } + // VPN clients. + if t.discoverDHCP() || t.discoverARP() { + t.vpn = &vpn{} + t.hostnameResolvers = append(t.hostnameResolvers, t.vpn) + } } func (t *Table) LookupIP(mac string) string { @@ -271,7 +288,7 @@ func (t *Table) ListClients() []*Client { _ = r.refresh() } ipMap := make(map[string]*Client) - il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns} + il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns, t.vpn} for _, ir := range il { for _, ip := range ir.List() { c, ok := ipMap[ip] @@ -312,6 +329,15 @@ func (t *Table) ListClients() []*Client { return clients } +// StoreVPNClient stores client info for VPN clients. +func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { + if ci == nil || t.vpn == nil { + return + } + t.vpn.mac.Store(ci.IP, ci.Mac) + t.vpn.ip2name.Store(ci.IP, ci.Hostname) +} + func (t *Table) discoverDHCP() bool { if t.cfg.Service.DiscoverDHCP == nil { return true diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index a5b6a57..7c1b2cf 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -119,6 +119,9 @@ func (d *dhcp) String() string { } func (d *dhcp) List() []string { + if d == nil { + return nil + } var ips []string d.ip.Range(func(key, value any) bool { ips = append(ips, value.(string)) diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index c9d97e5..5875b69 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -48,6 +48,9 @@ func (m *mdns) String() string { } func (m *mdns) List() []string { + if m == nil { + return nil + } var ips []string m.name.Range(func(key, value any) bool { ips = append(ips, key.(string)) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 6bd7bc7..0a8867b 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -41,6 +41,9 @@ func (p *ptrDiscover) String() string { } func (p *ptrDiscover) List() []string { + if p == nil { + return nil + } var ips []string p.hostname.Range(func(key, value any) bool { ips = append(ips, key.(string)) diff --git a/internal/clientinfo/vpn.go b/internal/clientinfo/vpn.go new file mode 100644 index 0000000..fe62bcb --- /dev/null +++ b/internal/clientinfo/vpn.go @@ -0,0 +1,43 @@ +package clientinfo + +import ( + "sync" +) + +// vpn is the manager for VPN clients info. +type vpn struct { + ip2name sync.Map // ip => name + mac sync.Map // ip => mac +} + +// LookupHostnameByIP returns hostname of the given VPN client ip. +func (v *vpn) LookupHostnameByIP(ip string) string { + val, ok := v.ip2name.Load(ip) + if !ok { + return "" + } + return val.(string) +} + +// LookupHostnameByMac always returns empty string. +func (v *vpn) LookupHostnameByMac(mac string) string { + return "" +} + +// String returns the string representation of vpn struct. +func (v *vpn) String() string { + return "vpn" +} + +// List lists all known VPN clients IP. +func (v *vpn) List() []string { + if v == nil { + return nil + } + var ips []string + v.mac.Range(func(key, value any) bool { + ips = append(ips, key.(string)) + return true + }) + return ips +} diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index a25f564..54ba8fd 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -17,6 +17,7 @@ server={{ .IP }}#{{ .Port }} {{- end}} {{- if .SendClientInfo}} add-mac +add-subnet=32,128 {{- end}} ` @@ -39,7 +40,10 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then pc_append "server={{ .IP }}#{{ .Port }}" "$config_file" {{- end}} {{- if .SendClientInfo}} + pc_delete "add-mac" "$config_file" + pc_delete "add-subnet" "$config_file" pc_append "add-mac" "$config_file" # add client mac + pc_append "add-subnet=32,128" "$config_file" # add client ip {{- end}} pc_delete "dnssec" "$config_file" # disable DNSSEC pc_delete "trust-anchor=" "$config_file" # disable DNSSEC