diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index af7628c..666bf50 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -71,6 +71,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) + stripClientSubnet(m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() @@ -498,6 +499,23 @@ func ipAndMacFromMsg(msg *dns.Msg) (string, string) { return ip, mac } +// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address, +// passing them to upstream is pointless, these cannot be used by anything on the WAN. +func stripClientSubnet(msg *dns.Msg) { + if opt := msg.IsEdns0(); opt != nil { + opts := make([]dns.EDNS0, 0, len(opt.Option)) + for _, s := range opt.Option { + if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) { + continue + } + opts = append(opts, s) + } + if len(opts) != len(opt.Option) { + opt.Option = opts + } + } +} + func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { if ci != nil && ci.IP != "" { switch addr := addr.(type) { diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 8d18fa0..d0e5c74 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -265,3 +265,45 @@ func Test_ipFromARPA(t *testing.T) { }) } } + +func newDnsMsgWithClientIP(ip string) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)}) + m.Extra = append(m.Extra, o) + return m +} +func Test_stripClientSubnet(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantSubnet bool + }{ + {"no edns0", new(dns.Msg), false}, + {"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false}, + {"loopback IP v6", newDnsMsgWithClientIP("::1"), false}, + {"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false}, + {"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false}, + {"public IP", newDnsMsgWithClientIP("1.1.1.1"), true}, + {"invalid IP", newDnsMsgWithClientIP(""), true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + stripClientSubnet(tc.msg) + hasSubnet := false + if opt := tc.msg.IsEdns0(); opt != nil { + for _, s := range opt.Option { + if _, ok := s.(*dns.EDNS0_SUBNET); ok { + hasSubnet = true + } + } + } + if tc.wantSubnet != hasSubnet { + t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet) + } + }) + } +}