mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-27 12:52:27 +02:00
doq: validate DNS-over-QUIC response framing
DoQ responses are length-prefixed per RFC 9250. The resolver previously assumed the stream always contained at least two bytes and unpacked from buf[2:], which could panic on truncated or malicious replies. Validate the prefix against the bytes read, return a clear error, and retire the connection from the pool on framing failure. Unpack only the slice declared by the prefix so a short read cannot be misinterpreted as a full message. Add regression coverage with a small test server that returns malformed raw payloads (empty, one byte, prefix-only, prefix larger than payload).
This commit is contained in:
committed by
Cuong Manh Le
parent
65d3d468f7
commit
7b360288ed
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
@@ -181,22 +182,37 @@ func (p *doqConnPool) doResolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, er
|
||||
buf, err := io.ReadAll(stream)
|
||||
stream.Close()
|
||||
|
||||
// Return connection to pool (mark as potentially bad if error occurred)
|
||||
isGood := err == nil && len(buf) > 0
|
||||
p.putConn(conn, isGood)
|
||||
|
||||
if err != nil {
|
||||
p.putConn(conn, false)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// io.ReadAll hides io.EOF error, so check for empty buffer
|
||||
// io.ReadAll hides io.EOF error, so check for empty buffer.
|
||||
if len(buf) == 0 {
|
||||
p.putConn(conn, false)
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// Unpack DNS response (skip 2-byte length prefix)
|
||||
// RFC 9250: each DoQ DNS message is encoded as a 2-octet length field
|
||||
// followed by the DNS message. Reject responses that are shorter than
|
||||
// the prefix or whose prefix declares more bytes than were received,
|
||||
// and retire the misbehaving connection. Without this guard, buf[2:]
|
||||
// would panic when len(buf) < 2.
|
||||
if len(buf) < 2 {
|
||||
p.putConn(conn, false)
|
||||
return nil, fmt.Errorf("malformed DoQ response: %d byte(s), need >= 2 for length prefix", len(buf))
|
||||
}
|
||||
respLen := int(buf[0])<<8 | int(buf[1])
|
||||
if 2+respLen > len(buf) {
|
||||
p.putConn(conn, false)
|
||||
return nil, fmt.Errorf("malformed DoQ response: length prefix %d exceeds payload %d", respLen, len(buf)-2)
|
||||
}
|
||||
|
||||
p.putConn(conn, true)
|
||||
|
||||
// Unpack DNS response (skip 2-byte length prefix).
|
||||
answer := new(dns.Msg)
|
||||
if err := answer.Unpack(buf[2:]); err != nil {
|
||||
if err := answer.Unpack(buf[2 : 2+respLen]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
answer.SetReply(msg)
|
||||
|
||||
+164
-1
@@ -1,4 +1,3 @@
|
||||
// test_helpers.go
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
@@ -8,6 +7,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -222,3 +222,166 @@ func (s *testQUICServer) handleStream(t *testing.T, stream *quic.Stream) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// malformedDoQServer is a test QUIC server that drains the client's DoQ
|
||||
// request and writes caller-supplied raw bytes back. The bytes are not
|
||||
// required to be a well-framed DoQ response, which is what lets the
|
||||
// regression tests exercise malformed-response handling.
|
||||
type malformedDoQServer struct {
|
||||
listener *quic.Listener
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
response []byte
|
||||
}
|
||||
|
||||
func newMalformedDoQServer(t *testing.T, response []byte) *malformedDoQServer {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"doq"},
|
||||
}
|
||||
|
||||
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create QUIC listener: %v", err)
|
||||
}
|
||||
|
||||
s := &malformedDoQServer{
|
||||
listener: listener,
|
||||
cert: testCert.cert,
|
||||
addr: listener.Addr().String(),
|
||||
response: response,
|
||||
}
|
||||
|
||||
go s.serve(t)
|
||||
t.Cleanup(func() { _ = listener.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *malformedDoQServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "server closed") {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
go s.handleConn(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *malformedDoQServer) handleConn(t *testing.T, conn *quic.Conn) {
|
||||
for {
|
||||
stream, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go s.handleStream(t, stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *malformedDoQServer) handleStream(t *testing.T, stream *quic.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
// Drain the client's DoQ-framed request so the client's writes complete
|
||||
// cleanly before we reply with our attacker-controlled bytes. Using
|
||||
// io.ReadFull because a single Read on a QUIC stream may return short.
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(stream, lenBuf); err != nil {
|
||||
return
|
||||
}
|
||||
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
|
||||
if msgLen > 0 {
|
||||
discard := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(stream, discard); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.response) > 0 {
|
||||
_, _ = stream.Write(s.response)
|
||||
}
|
||||
}
|
||||
|
||||
// newMalformedDoQUpstream builds an UpstreamConfig wired to a local
|
||||
// malformed test server with the test certificate trusted via a custom
|
||||
// cert pool. We bypass SetupBootstrapIP by setting BootstrapIP directly,
|
||||
// so the pool dials 127.0.0.1 without any DNS lookup.
|
||||
func newMalformedDoQUpstream(t *testing.T, cert *x509.Certificate, addr string) *UpstreamConfig {
|
||||
t.Helper()
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("split host/port %q: %v", addr, err)
|
||||
}
|
||||
|
||||
uc := &UpstreamConfig{
|
||||
Name: "doq-malformed",
|
||||
Type: ResolverTypeDOQ,
|
||||
Endpoint: addr,
|
||||
Domain: host,
|
||||
BootstrapIP: host,
|
||||
Timeout: 2000,
|
||||
}
|
||||
uc.SetCertPool(pool)
|
||||
return uc
|
||||
}
|
||||
|
||||
// TestDoQResolve_MalformedResponse verifies that DoQ upstream
|
||||
// responses violating RFC 9250 framing — fewer than 2 bytes, or a
|
||||
// length prefix declaring more payload than was received — return a
|
||||
// handled error instead of panicking on the length-prefix slice.
|
||||
func TestDoQResolve_MalformedResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response []byte
|
||||
}{
|
||||
// Empty stream is already handled via io.EOF; locked in so a
|
||||
// future change that drops that branch is caught.
|
||||
{"empty response", nil},
|
||||
|
||||
// One byte: too short to hold the 2-octet length prefix.
|
||||
{"single byte response", []byte{0x00}},
|
||||
|
||||
// Length prefix declares 16 bytes; payload is absent.
|
||||
{"length prefix only", []byte{0x00, 0x10}},
|
||||
|
||||
// Length prefix declares 65535 bytes; only 1 byte of payload
|
||||
// arrived.
|
||||
{"length prefix larger than payload", []byte{0xFF, 0xFF, 0x00}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := newMalformedDoQServer(t, tt.response)
|
||||
uc := newMalformedDoQUpstream(t, server.cert, server.addr)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pool := newDOQConnPool(ctx, uc, []string{"127.0.0.1"})
|
||||
t.Cleanup(pool.CloseIdleConnections)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
answer, err := pool.Resolve(ctx, msg)
|
||||
if err == nil {
|
||||
t.Fatalf("Resolve unexpectedly succeeded for malformed response %v; answer=%v", tt.response, answer)
|
||||
}
|
||||
if answer != nil {
|
||||
t.Fatalf("Resolve returned non-nil answer alongside error: answer=%v err=%v", answer, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user