diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 1be818f..3c63782 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -806,6 +806,7 @@ func isPrivatePtrLookup(m *dns.Msg) bool { return false } +// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname. func isLanHostnameQuery(m *dns.Msg) bool { if m == nil || len(m.Question) == 0 { return false @@ -816,7 +817,8 @@ func isLanHostnameQuery(m *dns.Msg) bool { default: return false } - return !strings.Contains(q.Name, ".") || - strings.HasSuffix(q.Name, ".domain") || - strings.HasSuffix(q.Name, ".lan") + name := strings.TrimSuffix(q.Name, ".") + return !strings.Contains(name, ".") || + strings.HasSuffix(name, ".domain") || + strings.HasSuffix(name, ".lan") } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 70197ad..118914a 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -274,6 +274,7 @@ func newDnsMsgWithClientIP(ip string) *dns.Msg { m.Extra = append(m.Extra, o) return m } + func Test_stripClientSubnet(t *testing.T) { tests := []struct { name string @@ -307,3 +308,69 @@ func Test_stripClientSubnet(t *testing.T) { }) } } + +func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(hostname, typ) + return m +} + +func Test_isLanHostnameQuery(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isLanHostnameQuery bool + }{ + {"A", newDnsMsgWithHostname("foo", dns.TypeA), true}, + {"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true}, + {"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false}, + {"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false}, + {"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got) + } + }) + } +} + +func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg { + t.Helper() + m := new(dns.Msg) + ptr, err := dns.ReverseAddr(ip) + if err != nil { + t.Fatal(err) + } + m.SetQuestion(ptr, dns.TypePTR) + return m +} + +func Test_isPrivatePtrLookup(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isPrivatePtrLookup bool + }{ + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + {"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true}, + {"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true}, + {"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true}, + {"CGNAT", newDnsMsgPtr("100.66.27.28", t), true}, + {"Loopback", newDnsMsgPtr("127.0.0.1", t), true}, + {"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true}, + {"Public IP", newDnsMsgPtr("8.8.8.8", t), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got) + } + }) + } +} diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index e036638..a103263 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -143,6 +143,9 @@ func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 59e6e9c..f89e13f 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -69,6 +69,9 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } } diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index b6204d5..1439752 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -104,6 +104,9 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } }