diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 2311260..33012fa 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(req.msg, answer.Rcode) + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..f2b71a5 --- /dev/null +++ b/dns.go @@ -0,0 +1,30 @@ +package ctrld + +import ( + "github.com/miekg/dns" +) + +// SetCacheReply extracts and stores the necessary data from the message for a cached answer. +func SetCacheReply(answer, msg *dns.Msg, code int) { + answer.SetRcode(msg, code) + cCookie := getEdns0Cookie(msg.IsEdns0()) + sCookie := getEdns0Cookie(answer.IsEdns0()) + if cCookie != nil && sCookie != nil { + // Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes. + // See https://datatracker.ietf.org/doc/html/rfc7873#section-4 + sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:] + } +} + +// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present. +func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE { + if opt == nil { + return nil + } + for _, o := range opt.Option { + if e, ok := o.(*dns.EDNS0_COOKIE); ok { + return e + } + } + return nil +} diff --git a/resolver.go b/resolver.go index 52515f9..c20f1f5 100644 --- a/resolver.go +++ b/resolver.go @@ -305,7 +305,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if val, ok := val.(*dns.Msg); ok { Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() - res.SetRcode(msg, val.Rcode) + SetCacheReply(res, msg, val.Rcode) return res, nil } } @@ -336,7 +336,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, fmt.Errorf("invalid answer for key: %s", key) } res := sharedMsg.Copy() - res.SetRcode(msg, sharedMsg.Rcode) + SetCacheReply(res, msg, sharedMsg.Rcode) if shared { Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) } diff --git a/resolver_test.go b/resolver_test.go index a75e748..ebcad16 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "crypto/rand" + "encoding/hex" "net" "sync" "sync/atomic" @@ -232,6 +234,54 @@ func Test_osResolver_HotCache(t *testing.T) { } } +func Test_Edns0_CacheReply(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + do := func() *dns.Msg { + msg := m.Copy() + msg.SetEdns0(4096, true) + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ClientCookie() + msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + + answer, err := or.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return answer + } + answer1 := do() + answer2 := do() + // Ensure the cache was hit, so we can check that edns0 cookie must be modified. + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) + cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { + t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + } + if cookie1.Cookie == cookie2.Cookie { + t.Fatalf("edns0 cookie is not modified: %v", cookie1) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -306,7 +356,32 @@ func countHandler(call *atomic.Int64) dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, dns.RcodeSuccess) + if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil { + if m.IsEdns0() == nil { + m.SetEdns0(4096, false) + } + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie) + m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption) + } w.WriteMsg(m) call.Add(1) } } + +func generateEdns0ClientCookie() string { + cookie := make([]byte, 8) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return hex.EncodeToString(cookie) +} + +func generateEdns0ServerCookie(clientCookie string) string { + cookie := make([]byte, 32) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return clientCookie + hex.EncodeToString(cookie) +}