doq: validate DNS-over-QUIC response framing

DoQ responses are length-prefixed per RFC 9250. The resolver previously
assumed the stream always contained at least two bytes and unpacked from
buf[2:], which could panic on truncated or malicious replies.

Validate the prefix against the bytes read, return a clear error, and
retire the connection from the pool on framing failure. Unpack only the
slice declared by the prefix so a short read cannot be misinterpreted as
a full message.

Add regression coverage with a small test server that returns malformed
raw payloads (empty, one byte, prefix-only, prefix larger than payload).
This commit is contained in:
Cuong Manh Le
2026-05-11 18:08:12 +07:00
committed by Cuong Manh Le
parent 65d3d468f7
commit 7b360288ed
2 changed files with 187 additions and 8 deletions
+23 -7
View File
@@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"runtime"
@@ -181,22 +182,37 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
buf, err := io.ReadAll(stream)
stream.Close()
// Return connection to pool (mark as potentially bad if error occurred)
isGood := err == nil && len(buf) > 0
p.putConn(conn, isGood)
if err != nil {
p.putConn(conn, false)
return nil, err
}
// io.ReadAll hides io.EOF error, so check for empty buffer
// io.ReadAll hides io.EOF error, so check for empty buffer.
if len(buf) == 0 {
p.putConn(conn, false)
return nil, io.EOF
}
// Unpack DNS response (skip 2-byte length prefix)
// RFC 9250: each DoQ DNS message is encoded as a 2-octet length field
// followed by the DNS message. Reject responses that are shorter than
// the prefix or whose prefix declares more bytes than were received,
// and retire the misbehaving connection. Without this guard, buf[2:]
// would panic when len(buf) < 2.
if len(buf) < 2 {
p.putConn(conn, false)
return nil, fmt.Errorf("malformed DoQ response: %d byte(s), need >= 2 for length prefix", len(buf))
}
respLen := int(buf[0])<<8 | int(buf[1])
if 2+respLen > len(buf) {
p.putConn(conn, false)
return nil, fmt.Errorf("malformed DoQ response: length prefix %d exceeds payload %d", respLen, len(buf)-2)
}
p.putConn(conn, true)
// Unpack DNS response (skip 2-byte length prefix).
answer := new(dns.Msg)
if err := answer.Unpack(buf[2:]); err != nil {
if err := answer.Unpack(buf[2 : 2+respLen]); err != nil {
return nil, err
}
answer.SetReply(msg)
+164 -1
View File
@@ -1,4 +1,3 @@
// test_helpers.go
package ctrld
import (
@@ -8,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"io"
"math/big"
"net"
"strings"
@@ -222,3 +222,166 @@ func (s *testQUICServer) handleStream(t *testing.T, stream *quic.Stream) {
return
}
}
// malformedDoQServer is a test QUIC server that drains the client's DoQ
// request and writes caller-supplied raw bytes back. The bytes are not
// required to be a well-framed DoQ response, which is what lets the
// regression tests exercise malformed-response handling.
type malformedDoQServer struct {
listener *quic.Listener
cert *x509.Certificate
addr string
response []byte
}
func newMalformedDoQServer(t *testing.T, response []byte) *malformedDoQServer {
t.Helper()
testCert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"doq"},
}
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
if err != nil {
t.Fatalf("failed to create QUIC listener: %v", err)
}
s := &malformedDoQServer{
listener: listener,
cert: testCert.cert,
addr: listener.Addr().String(),
response: response,
}
go s.serve(t)
t.Cleanup(func() { _ = listener.Close() })
return s
}
func (s *malformedDoQServer) serve(t *testing.T) {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
if strings.Contains(err.Error(), "server closed") {
return
}
return
}
go s.handleConn(t, conn)
}
}
func (s *malformedDoQServer) handleConn(t *testing.T, conn *quic.Conn) {
for {
stream, err := conn.AcceptStream(context.Background())
if err != nil {
return
}
go s.handleStream(t, stream)
}
}
func (s *malformedDoQServer) handleStream(t *testing.T, stream *quic.Stream) {
defer stream.Close()
// Drain the client's DoQ-framed request so the client's writes complete
// cleanly before we reply with our attacker-controlled bytes. Using
// io.ReadFull because a single Read on a QUIC stream may return short.
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(stream, lenBuf); err != nil {
return
}
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
if msgLen > 0 {
discard := make([]byte, msgLen)
if _, err := io.ReadFull(stream, discard); err != nil {
return
}
}
if len(s.response) > 0 {
_, _ = stream.Write(s.response)
}
}
// newMalformedDoQUpstream builds an UpstreamConfig wired to a local
// malformed test server with the test certificate trusted via a custom
// cert pool. We bypass SetupBootstrapIP by setting BootstrapIP directly,
// so the pool dials 127.0.0.1 without any DNS lookup.
func newMalformedDoQUpstream(t *testing.T, cert *x509.Certificate, addr string) *UpstreamConfig {
t.Helper()
pool := x509.NewCertPool()
pool.AddCert(cert)
host, _, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split host/port %q: %v", addr, err)
}
uc := &UpstreamConfig{
Name: "doq-malformed",
Type: ResolverTypeDOQ,
Endpoint: addr,
Domain: host,
BootstrapIP: host,
Timeout: 2000,
}
uc.SetCertPool(pool)
return uc
}
// TestDoQResolve_MalformedResponse verifies that DoQ upstream
// responses violating RFC 9250 framing — fewer than 2 bytes, or a
// length prefix declaring more payload than was received — return a
// handled error instead of panicking on the length-prefix slice.
func TestDoQResolve_MalformedResponse(t *testing.T) {
tests := []struct {
name string
response []byte
}{
// Empty stream is already handled via io.EOF; locked in so a
// future change that drops that branch is caught.
{"empty response", nil},
// One byte: too short to hold the 2-octet length prefix.
{"single byte response", []byte{0x00}},
// Length prefix declares 16 bytes; payload is absent.
{"length prefix only", []byte{0x00, 0x10}},
// Length prefix declares 65535 bytes; only 1 byte of payload
// arrived.
{"length prefix larger than payload", []byte{0xFF, 0xFF, 0x00}},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := newMalformedDoQServer(t, tt.response)
uc := newMalformedDoQUpstream(t, server.cert, server.addr)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool := newDOQConnPool(ctx, uc, []string{"127.0.0.1"})
t.Cleanup(pool.CloseIdleConnections)
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.RecursionDesired = true
answer, err := pool.Resolve(ctx, msg)
if err == nil {
t.Fatalf("Resolve unexpectedly succeeded for malformed response %v; answer=%v", tt.response, answer)
}
if answer != nil {
t.Fatalf("Resolve returned non-nil answer alongside error: answer=%v err=%v", answer, err)
}
})
}
}