Add tests for MCP streamable HTTP route for json/sse and stateless/stateful servers.

This commit is contained in:
Hemang
2025-05-27 17:00:13 +02:00
committed by Hemang Sarkar
parent 34979ed18d
commit 115ae5f36b
10 changed files with 364 additions and 51 deletions

4
run.sh
View File

@@ -142,9 +142,9 @@ integration_tests() {
# Generate latest whl file for the invariant-gateway package.
# This is required to run the integration tests.
pip install build
pip install build --quiet
rm -rf dist
python -m build
python -m build > /dev/null 2>&1
WHEEL_FILE=$(ls dist/*.whl | head -n 1)
echo "WHEEL_FILE: $WHEEL_FILE"

View File

@@ -106,8 +106,8 @@ services:
timeout: 5s
retries: 5
# MCP SSE server used in integration tests
mcp-messenger-sse-server:
# MCP SSE server used in integration tests
build:
context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/sse/messenger_server/Dockerfile.mcp-server
@@ -117,11 +117,74 @@ services:
ports:
- "8123:8123"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8123/sse"]
test: [ "CMD", "curl", "-f", "http://localhost:8123/sse" ]
interval: 3s
timeout: 5s
retries: 5
# MCP Streamable HTTP server with json_response=True, stateless_http=True
# to use in integration tests
mcp-messenger-streamable-json-stateless-server:
build:
context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
container_name: invariant-gateway-test-mcp-streamable-json-stateless-server
networks:
- invariant-gateway-web-test
ports:
- "8124:8124"
environment:
PORT: 8124
TRANSPORT: json
STATEFUL: false
# MCP Streamable HTTP server with json_response=True, stateless_http=False
# to use in integration tests
mcp-messenger-streamable-json-stateful-server:
build:
context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
container_name: invariant-gateway-test-mcp-streamable-json-stateful-server
networks:
- invariant-gateway-web-test
ports:
- "8125:8125"
environment:
PORT: 8125
TRANSPORT: json
STATEFUL: true
# MCP Streamable HTTP server with json_response=False, stateless_http=True
# to use in integration tests
mcp-messenger-streamable-sse-stateless-server:
build:
context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
container_name: invariant-gateway-test-mcp-streamable-sse-stateless-server
networks:
- invariant-gateway-web-test
ports:
- "8126:8126"
environment:
PORT: 8126
TRANSPORT: sse
STATEFUL: false
# MCP Streamable HTTP server with json_response=False, stateless_http=False
# to use in integration tests
mcp-messenger-streamable-sse-stateful-server:
build:
context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
container_name: invariant-gateway-test-mcp-streamable-sse-stateful-server
networks:
- invariant-gateway-web-test
ports:
- "8127:8127"
environment:
PORT: 8127
TRANSPORT: sse
STATEFUL: true
networks:
invariant-gateway-web-test:

View File

@@ -5,6 +5,7 @@ import uuid
from resources.mcp.sse.client.main import run as mcp_sse_client_run
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
from resources.mcp.streamable.client.main import run as mcp_streamable_client_run
from utils import create_dataset, add_guardrail_to_dataset
import httpx
@@ -13,13 +14,45 @@ import requests
from mcp.shared.exceptions import McpError
# Taken from docker-compose.test.yml
MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
MCP_SSE_SERVER_PORT = 8123
MCP_STREAMABLE_HOSTS = {
"streamable-json-stateless": {
"host": "mcp-messenger-streamable-json-stateless-server",
"port": 8124,
},
"streamable-json-stateful": {
"host": "mcp-messenger-streamable-json-stateful-server",
"port": 8125,
},
"streamable-sse-stateless": {
"host": "mcp-messenger-streamable-sse-stateless-server",
"port": 8126,
},
"streamable-sse-stateful": {
"host": "mcp-messenger-streamable-sse-stateful-server",
"port": 8127,
},
}
def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]:
def _get_mcp_sse_server_base_url() -> str:
return f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}"
def _get_streamable_server_base_url(transport: str) -> str:
if transport not in MCP_STREAMABLE_HOSTS:
raise ValueError(f"Unknown transport: {transport}")
host_info = MCP_STREAMABLE_HOSTS[transport]
return f"http://{host_info['host']}:{host_info['port']}"
def _get_headers(
server_base_url: str, project_name: str, push_to_explorer: bool = True
) -> dict[str, str]:
return {
"MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
"MCP-SERVER-BASE-URL": server_base_url,
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
}
@@ -28,19 +61,20 @@ def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str,
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
"push_to_explorer, transport",
"transport",
[
(False, "stdio"),
(False, "sse"),
(True, "stdio"),
(True, "sse"),
"stdio",
"sse",
"streamable-json-stateless",
"streamable-json-stateful",
"streamable-sse-stateless",
"streamable-sse-stateful",
],
)
async def test_mcp_with_gateway(
explorer_api_url,
invariant_gateway_package_whl_file,
gateway_url,
push_to_explorer,
transport,
):
"""Test MCP gateway and verify trace is pushed to explorer"""
@@ -50,21 +84,31 @@ async def test_mcp_with_gateway(
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=push_to_explorer,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, push_to_explorer),
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
else:
elif transport == "stdio":
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=push_to_explorer,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
)
else:
result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
assert result.isError is False
assert (
@@ -72,37 +116,50 @@ async def test_mcp_with_gateway(
and result.content[0].text == "What is your favorite food?\n"
)
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][3]["role"] == "tool"
assert trace["messages"][3]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
if transport == "streamable-json-stateless":
assert metadata["server_response_type"] == "json"
assert metadata["is_stateless_http_server"] is True
elif transport == "streamable-json-stateful":
assert metadata["server_response_type"] == "json"
assert metadata["is_stateless_http_server"] is False
elif transport == "streamable-sse-stateless":
assert metadata["server_response_type"] == "sse"
assert metadata["is_stateless_http_server"] is True
elif transport == "streamable-sse-stateful":
assert metadata["server_response_type"] == "sse"
assert metadata["is_stateless_http_server"] is False
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][3]["role"] == "tool"
assert trace["messages"][3]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
@pytest.mark.asyncio
@@ -142,7 +199,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
else:
result = await mcp_stdio_client_run(
@@ -251,7 +308,9 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
headers=_get_headers(
_get_mcp_sse_server_base_url(), project_name, True
),
)
else:
_ = await mcp_stdio_client_run(
@@ -353,7 +412,9 @@ async def test_mcp_with_gateway_hybrid_guardrails(
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
headers=_get_headers(
_get_mcp_sse_server_base_url(), project_name, True
),
)
else:
_ = await mcp_stdio_client_run(
@@ -470,7 +531,7 @@ async def test_mcp_tool_list_blocking(
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(project_name, True),
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
else:
tools_result = await mcp_stdio_client_run(

View File

@@ -1,4 +1,4 @@
FROM python:3.12-slim
FROM python:3.13-alpine
WORKDIR /app

View File

@@ -0,0 +1,101 @@
"""This is a simple example of how to use the MCP client with Streamable HTTP transport."""
# pylint: disable=E1101
# pylint: disable=W0201
# pylint: disable=C2801
import asyncio
from datetime import timedelta
from typing import Any, Optional
from contextlib import AsyncExitStack
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
class MCPClient:
"""MCP Client for interacting with a MCP Streamable HTTP server and processing queries"""
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self._streams_context = None # Initialize these to None
self._session_context = None # so they always exist
async def connect_to_streamable_server(
self, server_url: str, headers: Optional[dict] = None
):
"""
Connect to an MCP server running with Streamable HTTP transport
Args:
server_url: URL of the MCP server
headers: Optional headers to include in the request
"""
# Store the context managers so they stay alive
self._streams_context = streamablehttp_client(
url=server_url,
headers=headers or {},
timeout=timedelta(seconds=5),
sse_read_timeout=timedelta(seconds=10),
)
read_stream, write_stream, session_id = await self._streams_context.__aenter__()
self.session_id = session_id
self._session_context = ClientSession(read_stream, write_stream)
self.session: ClientSession = await self._session_context.__aenter__()
await self.session.initialize()
async def cleanup(self):
"""Properly clean up the session and streams"""
if self._session_context:
await self._session_context.__aexit__(None, None, None)
if self._streams_context:
await self._streams_context.__aexit__(None, None, None)
async def process_query(self, tool_name: str, tool_args: dict) -> str:
"""Process a query using MCP server"""
result = await self.session.call_tool(
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
)
return result
async def run(
gateway_url: str,
push_to_explorer: bool,
tool_name: str,
tool_args: dict[str, Any],
headers: dict[str, str] = None,
):
"""
Run the MCP client with the given parameters.
Args:
gateway_url: URL of the Invariant Gateway
push_to_explorer: Whether to push traces to the Invariant Explorer
tool_name: Name of the tool to call
tool_args: Arguments for the tool call
"""
client = MCPClient()
try:
await client.connect_to_streamable_server(
server_url=gateway_url, headers=headers or {}
)
# list tools
listed_tools = await client.session.list_tools()
# call tool
if tool_name == "tools/list":
return listed_tools
else:
return await client.process_query(tool_name, tool_args)
finally:
# Sleep for a while to allow the server to process the background tasks
# like pushing traces to the explorer
if push_to_explorer:
await asyncio.sleep(2)
await client.cleanup()

View File

@@ -0,0 +1,18 @@
FROM python:3.13-alpine
WORKDIR /app
# Copy the messenger server code
COPY tests/integration/resources/mcp/streamable/messenger_server /app/messenger_server
# Install dependencies
RUN pip install --no-cache-dir "uvicorn[standard]" "httpx" "mcp[cli]" "starlette"
# Default values (will be overridden in compose)
ENV HOST=0.0.0.0
ENV PORT=8124
ENV TRANSPORT=json
ENV STATEFUL=false
# Use environment variables in the CMD
CMD ["sh", "-c", "python messenger_server/main.py --host $HOST --port $PORT --transport $TRANSPORT --stateful $STATEFUL"]

View File

@@ -0,0 +1,70 @@
"""This is a messenger server implementation that returns a few messages based on the username."""
import argparse
import hashlib
import os
import uvicorn
from mcp.server.fastmcp import FastMCP
# Read config from environment variables
TRANSPORT = os.getenv("TRANSPORT", "json").lower()
STATEFUL = os.getenv("STATEFUL", "false").lower() == "true"
# Initialize FastMCP server
mcp = FastMCP(
"messenger_server",
json_response=(TRANSPORT == "json"),
stateless_http=(not STATEFUL),
)
MESSAGES = [
"What about you?",
"What are you doing?",
"What is your name?",
"What is your favorite color?",
"What is your favorite food?",
"What is your favorite movie?",
"What is your favorite book?",
]
def _deterministic_index_from_username(username: str, limit: int) -> int:
"""Deterministically calculate the index of messages to return based on the username."""
hash_val = int(hashlib.sha256(username.encode()).hexdigest(), 16)
return hash_val % limit + 1
@mcp.tool()
async def get_last_message_from_user(username: str) -> str:
"""Get the last message sent by the username."""
return MESSAGES[_deterministic_index_from_username(username, len(MESSAGES))] + "\n"
@mcp.tool()
async def send_message(username: str, message: str) -> str:
"""Send a message to the username."""
return f"Message '{message}' sent to {username}."
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run MCP Streamable HTTP based server")
parser.add_argument("--host", help="Host to bind to", required=True)
parser.add_argument("--port", help="Port to listen on", required=True, type=int)
parser.add_argument(
"--transport",
help="Transport type (json or sse)",
default="json",
type=str,
)
parser.add_argument(
"--stateful",
help="Whether the server is stateful or stateless",
default="false",
type=str,
)
args = parser.parse_args()
uvicorn.run(mcp.streamable_http_app, host=args.host, port=args.port)