From b4faf82f76a577c56cc83a954cf6fe9b01560587 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 26 May 2025 20:49:03 +0700 Subject: [PATCH] all: set edns0 cookie for shared message For cached or singleflight messages, the edns0 cookie is currently shared among all of them, causing mismatch cookie warning from clients. The ctrld proxy should re-set client cookies for each request separately, even though they use the same shared answer. --- cmd/cli/dns_proxy.go | 2 +- dns.go | 30 ++++++++++++++++++ resolver.go | 4 +-- resolver_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 dns.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 2311260..33012fa 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(req.msg, answer.Rcode) + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..f2b71a5 --- /dev/null +++ b/dns.go @@ -0,0 +1,30 @@ +package ctrld + +import ( + "github.com/miekg/dns" +) + +// SetCacheReply extracts and stores the necessary data from the message for a cached answer. +func SetCacheReply(answer, msg *dns.Msg, code int) { + answer.SetRcode(msg, code) + cCookie := getEdns0Cookie(msg.IsEdns0()) + sCookie := getEdns0Cookie(answer.IsEdns0()) + if cCookie != nil && sCookie != nil { + // Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes. + // See https://datatracker.ietf.org/doc/html/rfc7873#section-4 + sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:] + } +} + +// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present. +func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE { + if opt == nil { + return nil + } + for _, o := range opt.Option { + if e, ok := o.(*dns.EDNS0_COOKIE); ok { + return e + } + } + return nil +} diff --git a/resolver.go b/resolver.go index 52515f9..c20f1f5 100644 --- a/resolver.go +++ b/resolver.go @@ -305,7 +305,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if val, ok := val.(*dns.Msg); ok { Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() - res.SetRcode(msg, val.Rcode) + SetCacheReply(res, msg, val.Rcode) return res, nil } } @@ -336,7 +336,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, fmt.Errorf("invalid answer for key: %s", key) } res := sharedMsg.Copy() - res.SetRcode(msg, sharedMsg.Rcode) + SetCacheReply(res, msg, sharedMsg.Rcode) if shared { Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) } diff --git a/resolver_test.go b/resolver_test.go index a75e748..ebcad16 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "crypto/rand" + "encoding/hex" "net" "sync" "sync/atomic" @@ -232,6 +234,54 @@ func Test_osResolver_HotCache(t *testing.T) { } } +func Test_Edns0_CacheReply(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + do := func() *dns.Msg { + msg := m.Copy() + msg.SetEdns0(4096, true) + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ClientCookie() + msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + + answer, err := or.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return answer + } + answer1 := do() + answer2 := do() + // Ensure the cache was hit, so we can check that edns0 cookie must be modified. + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) + cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { + t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + } + if cookie1.Cookie == cookie2.Cookie { + t.Fatalf("edns0 cookie is not modified: %v", cookie1) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -306,7 +356,32 @@ func countHandler(call *atomic.Int64) dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, dns.RcodeSuccess) + if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil { + if m.IsEdns0() == nil { + m.SetEdns0(4096, false) + } + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie) + m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption) + } w.WriteMsg(m) call.Add(1) } } + +func generateEdns0ClientCookie() string { + cookie := make([]byte, 8) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return hex.EncodeToString(cookie) +} + +func generateEdns0ServerCookie(clientCookie string) string { + cookie := make([]byte, 32) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return clientCookie + hex.EncodeToString(cookie) +}