From 7b360288ed94488e5ec2f3c93ee3bb117fa6203b Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 11 May 2026 18:08:12 +0700 Subject: [PATCH] doq: validate DNS-over-QUIC response framing DoQ responses are length-prefixed per RFC 9250. The resolver previously assumed the stream always contained at least two bytes and unpacked from buf[2:], which could panic on truncated or malicious replies. Validate the prefix against the bytes read, return a clear error, and retire the connection from the pool on framing failure. Unpack only the slice declared by the prefix so a short read cannot be misinterpreted as a full message. Add regression coverage with a small test server that returns malformed raw payloads (empty, one byte, prefix-only, prefix larger than payload). --- doq.go | 30 +++++++--- doq_test.go | 165 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 187 insertions(+), 8 deletions(-) diff --git a/doq.go b/doq.go index 3aee246..eabb3a4 100644 --- a/doq.go +++ b/doq.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" "runtime" @@ -181,22 +182,37 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er buf, err := io.ReadAll(stream) stream.Close() - // Return connection to pool (mark as potentially bad if error occurred) - isGood := err == nil && len(buf) > 0 - p.putConn(conn, isGood) - if err != nil { + p.putConn(conn, false) return nil, err } - // io.ReadAll hides io.EOF error, so check for empty buffer + // io.ReadAll hides io.EOF error, so check for empty buffer. if len(buf) == 0 { + p.putConn(conn, false) return nil, io.EOF } - // Unpack DNS response (skip 2-byte length prefix) + // RFC 9250: each DoQ DNS message is encoded as a 2-octet length field + // followed by the DNS message. Reject responses that are shorter than + // the prefix or whose prefix declares more bytes than were received, + // and retire the misbehaving connection. Without this guard, buf[2:] + // would panic when len(buf) < 2. + if len(buf) < 2 { + p.putConn(conn, false) + return nil, fmt.Errorf("malformed DoQ response: %d byte(s), need >= 2 for length prefix", len(buf)) + } + respLen := int(buf[0])<<8 | int(buf[1]) + if 2+respLen > len(buf) { + p.putConn(conn, false) + return nil, fmt.Errorf("malformed DoQ response: length prefix %d exceeds payload %d", respLen, len(buf)-2) + } + + p.putConn(conn, true) + + // Unpack DNS response (skip 2-byte length prefix). answer := new(dns.Msg) - if err := answer.Unpack(buf[2:]); err != nil { + if err := answer.Unpack(buf[2 : 2+respLen]); err != nil { return nil, err } answer.SetReply(msg) diff --git a/doq_test.go b/doq_test.go index a6a1c54..806e229 100644 --- a/doq_test.go +++ b/doq_test.go @@ -1,4 +1,3 @@ -// test_helpers.go package ctrld import ( @@ -8,6 +7,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "io" "math/big" "net" "strings" @@ -222,3 +222,166 @@ func (s *testQUICServer) handleStream(t *testing.T, stream *quic.Stream) { return } } + +// malformedDoQServer is a test QUIC server that drains the client's DoQ +// request and writes caller-supplied raw bytes back. The bytes are not +// required to be a well-framed DoQ response, which is what lets the +// regression tests exercise malformed-response handling. +type malformedDoQServer struct { + listener *quic.Listener + cert *x509.Certificate + addr string + response []byte +} + +func newMalformedDoQServer(t *testing.T, response []byte) *malformedDoQServer { + t.Helper() + + testCert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"doq"}, + } + + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil) + if err != nil { + t.Fatalf("failed to create QUIC listener: %v", err) + } + + s := &malformedDoQServer{ + listener: listener, + cert: testCert.cert, + addr: listener.Addr().String(), + response: response, + } + + go s.serve(t) + t.Cleanup(func() { _ = listener.Close() }) + return s +} + +func (s *malformedDoQServer) serve(t *testing.T) { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + if strings.Contains(err.Error(), "server closed") { + return + } + return + } + go s.handleConn(t, conn) + } +} + +func (s *malformedDoQServer) handleConn(t *testing.T, conn *quic.Conn) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + return + } + go s.handleStream(t, stream) + } +} + +func (s *malformedDoQServer) handleStream(t *testing.T, stream *quic.Stream) { + defer stream.Close() + + // Drain the client's DoQ-framed request so the client's writes complete + // cleanly before we reply with our attacker-controlled bytes. Using + // io.ReadFull because a single Read on a QUIC stream may return short. + lenBuf := make([]byte, 2) + if _, err := io.ReadFull(stream, lenBuf); err != nil { + return + } + msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1]) + if msgLen > 0 { + discard := make([]byte, msgLen) + if _, err := io.ReadFull(stream, discard); err != nil { + return + } + } + + if len(s.response) > 0 { + _, _ = stream.Write(s.response) + } +} + +// newMalformedDoQUpstream builds an UpstreamConfig wired to a local +// malformed test server with the test certificate trusted via a custom +// cert pool. We bypass SetupBootstrapIP by setting BootstrapIP directly, +// so the pool dials 127.0.0.1 without any DNS lookup. +func newMalformedDoQUpstream(t *testing.T, cert *x509.Certificate, addr string) *UpstreamConfig { + t.Helper() + + pool := x509.NewCertPool() + pool.AddCert(cert) + + host, _, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("split host/port %q: %v", addr, err) + } + + uc := &UpstreamConfig{ + Name: "doq-malformed", + Type: ResolverTypeDOQ, + Endpoint: addr, + Domain: host, + BootstrapIP: host, + Timeout: 2000, + } + uc.SetCertPool(pool) + return uc +} + +// TestDoQResolve_MalformedResponse verifies that DoQ upstream +// responses violating RFC 9250 framing — fewer than 2 bytes, or a +// length prefix declaring more payload than was received — return a +// handled error instead of panicking on the length-prefix slice. +func TestDoQResolve_MalformedResponse(t *testing.T) { + tests := []struct { + name string + response []byte + }{ + // Empty stream is already handled via io.EOF; locked in so a + // future change that drops that branch is caught. + {"empty response", nil}, + + // One byte: too short to hold the 2-octet length prefix. + {"single byte response", []byte{0x00}}, + + // Length prefix declares 16 bytes; payload is absent. + {"length prefix only", []byte{0x00, 0x10}}, + + // Length prefix declares 65535 bytes; only 1 byte of payload + // arrived. + {"length prefix larger than payload", []byte{0xFF, 0xFF, 0x00}}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := newMalformedDoQServer(t, tt.response) + uc := newMalformedDoQUpstream(t, server.cert, server.addr) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + pool := newDOQConnPool(ctx, uc, []string{"127.0.0.1"}) + t.Cleanup(pool.CloseIdleConnections) + + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + msg.RecursionDesired = true + + answer, err := pool.Resolve(ctx, msg) + if err == nil { + t.Fatalf("Resolve unexpectedly succeeded for malformed response %v; answer=%v", tt.response, answer) + } + if answer != nil { + t.Fatalf("Resolve returned non-nil answer alongside error: answer=%v err=%v", answer, err) + } + }) + } +}