mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Add tests for MCP streamable HTTP route for json/sse and stateless/stateful servers.
This commit is contained in:
4
run.sh
4
run.sh
@@ -142,9 +142,9 @@ integration_tests() {
|
||||
|
||||
# Generate latest whl file for the invariant-gateway package.
|
||||
# This is required to run the integration tests.
|
||||
pip install build
|
||||
pip install build --quiet
|
||||
rm -rf dist
|
||||
python -m build
|
||||
python -m build > /dev/null 2>&1
|
||||
WHEEL_FILE=$(ls dist/*.whl | head -n 1)
|
||||
echo "WHEEL_FILE: $WHEEL_FILE"
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ services:
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
# MCP SSE server used in integration tests
|
||||
mcp-messenger-sse-server:
|
||||
# MCP SSE server used in integration tests
|
||||
build:
|
||||
context: ${GATEWAY_ROOT_PATH}
|
||||
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/sse/messenger_server/Dockerfile.mcp-server
|
||||
@@ -117,11 +117,74 @@ services:
|
||||
ports:
|
||||
- "8123:8123"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8123/sse"]
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:8123/sse" ]
|
||||
interval: 3s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
# MCP Streamable HTTP server with json_response=True, stateless_http=True
|
||||
# to use in integration tests
|
||||
mcp-messenger-streamable-json-stateless-server:
|
||||
build:
|
||||
context: ${GATEWAY_ROOT_PATH}
|
||||
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
|
||||
container_name: invariant-gateway-test-mcp-streamable-json-stateless-server
|
||||
networks:
|
||||
- invariant-gateway-web-test
|
||||
ports:
|
||||
- "8124:8124"
|
||||
environment:
|
||||
PORT: 8124
|
||||
TRANSPORT: json
|
||||
STATEFUL: false
|
||||
|
||||
# MCP Streamable HTTP server with json_response=True, stateless_http=False
|
||||
# to use in integration tests
|
||||
mcp-messenger-streamable-json-stateful-server:
|
||||
build:
|
||||
context: ${GATEWAY_ROOT_PATH}
|
||||
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
|
||||
container_name: invariant-gateway-test-mcp-streamable-json-stateful-server
|
||||
networks:
|
||||
- invariant-gateway-web-test
|
||||
ports:
|
||||
- "8125:8125"
|
||||
environment:
|
||||
PORT: 8125
|
||||
TRANSPORT: json
|
||||
STATEFUL: true
|
||||
|
||||
# MCP Streamable HTTP server with json_response=False, stateless_http=True
|
||||
# to use in integration tests
|
||||
mcp-messenger-streamable-sse-stateless-server:
|
||||
build:
|
||||
context: ${GATEWAY_ROOT_PATH}
|
||||
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
|
||||
container_name: invariant-gateway-test-mcp-streamable-sse-stateless-server
|
||||
networks:
|
||||
- invariant-gateway-web-test
|
||||
ports:
|
||||
- "8126:8126"
|
||||
environment:
|
||||
PORT: 8126
|
||||
TRANSPORT: sse
|
||||
STATEFUL: false
|
||||
|
||||
# MCP Streamable HTTP server with json_response=False, stateless_http=False
|
||||
# to use in integration tests
|
||||
mcp-messenger-streamable-sse-stateful-server:
|
||||
build:
|
||||
context: ${GATEWAY_ROOT_PATH}
|
||||
dockerfile: ${GATEWAY_ROOT_PATH}/tests/integration/resources/mcp/streamable/messenger_server/Dockerfile.mcp-server
|
||||
container_name: invariant-gateway-test-mcp-streamable-sse-stateful-server
|
||||
networks:
|
||||
- invariant-gateway-web-test
|
||||
ports:
|
||||
- "8127:8127"
|
||||
environment:
|
||||
PORT: 8127
|
||||
TRANSPORT: sse
|
||||
STATEFUL: true
|
||||
|
||||
networks:
|
||||
invariant-gateway-web-test:
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
|
||||
from resources.mcp.sse.client.main import run as mcp_sse_client_run
|
||||
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
|
||||
from resources.mcp.streamable.client.main import run as mcp_streamable_client_run
|
||||
from utils import create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import httpx
|
||||
@@ -13,13 +14,45 @@ import requests
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
# Taken from docker-compose.test.yml
|
||||
MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
|
||||
MCP_SSE_SERVER_PORT = 8123
|
||||
MCP_STREAMABLE_HOSTS = {
|
||||
"streamable-json-stateless": {
|
||||
"host": "mcp-messenger-streamable-json-stateless-server",
|
||||
"port": 8124,
|
||||
},
|
||||
"streamable-json-stateful": {
|
||||
"host": "mcp-messenger-streamable-json-stateful-server",
|
||||
"port": 8125,
|
||||
},
|
||||
"streamable-sse-stateless": {
|
||||
"host": "mcp-messenger-streamable-sse-stateless-server",
|
||||
"port": 8126,
|
||||
},
|
||||
"streamable-sse-stateful": {
|
||||
"host": "mcp-messenger-streamable-sse-stateful-server",
|
||||
"port": 8127,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]:
|
||||
def _get_mcp_sse_server_base_url() -> str:
|
||||
return f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}"
|
||||
|
||||
|
||||
def _get_streamable_server_base_url(transport: str) -> str:
|
||||
if transport not in MCP_STREAMABLE_HOSTS:
|
||||
raise ValueError(f"Unknown transport: {transport}")
|
||||
host_info = MCP_STREAMABLE_HOSTS[transport]
|
||||
return f"http://{host_info['host']}:{host_info['port']}"
|
||||
|
||||
|
||||
def _get_headers(
|
||||
server_base_url: str, project_name: str, push_to_explorer: bool = True
|
||||
) -> dict[str, str]:
|
||||
return {
|
||||
"MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
"MCP-SERVER-BASE-URL": server_base_url,
|
||||
"INVARIANT-PROJECT-NAME": project_name,
|
||||
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
|
||||
}
|
||||
@@ -28,19 +61,20 @@ def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize(
|
||||
"push_to_explorer, transport",
|
||||
"transport",
|
||||
[
|
||||
(False, "stdio"),
|
||||
(False, "sse"),
|
||||
(True, "stdio"),
|
||||
(True, "sse"),
|
||||
"stdio",
|
||||
"sse",
|
||||
"streamable-json-stateless",
|
||||
"streamable-json-stateful",
|
||||
"streamable-sse-stateless",
|
||||
"streamable-sse-stateful",
|
||||
],
|
||||
)
|
||||
async def test_mcp_with_gateway(
|
||||
explorer_api_url,
|
||||
invariant_gateway_package_whl_file,
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
transport,
|
||||
):
|
||||
"""Test MCP gateway and verify trace is pushed to explorer"""
|
||||
@@ -50,21 +84,31 @@ async def test_mcp_with_gateway(
|
||||
if transport == "sse":
|
||||
result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
push_to_explorer=push_to_explorer,
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, push_to_explorer),
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
else:
|
||||
elif transport == "stdio":
|
||||
result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
server_script_path="resources/mcp/stdio/messenger_server/main.py",
|
||||
push_to_explorer=push_to_explorer,
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
|
||||
)
|
||||
else:
|
||||
result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
|
||||
assert result.isError is False
|
||||
assert (
|
||||
@@ -72,37 +116,50 @@ async def test_mcp_with_gateway(
|
||||
and result.content[0].text == "What is your favorite food?\n"
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
metadata = trace["extra_metadata"]
|
||||
assert (
|
||||
metadata["source"] == "mcp"
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
"arguments": {"username": "Alice"},
|
||||
}
|
||||
assert trace["messages"][3]["role"] == "tool"
|
||||
assert trace["messages"][3]["content"] == [
|
||||
{"type": "text", "text": "What is your favorite food?\n"}
|
||||
]
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
metadata = trace["extra_metadata"]
|
||||
assert (
|
||||
metadata["source"] == "mcp"
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
if transport == "streamable-json-stateless":
|
||||
assert metadata["server_response_type"] == "json"
|
||||
assert metadata["is_stateless_http_server"] is True
|
||||
elif transport == "streamable-json-stateful":
|
||||
assert metadata["server_response_type"] == "json"
|
||||
assert metadata["is_stateless_http_server"] is False
|
||||
elif transport == "streamable-sse-stateless":
|
||||
assert metadata["server_response_type"] == "sse"
|
||||
assert metadata["is_stateless_http_server"] is True
|
||||
elif transport == "streamable-sse-stateful":
|
||||
assert metadata["server_response_type"] == "sse"
|
||||
assert metadata["is_stateless_http_server"] is False
|
||||
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
"arguments": {"username": "Alice"},
|
||||
}
|
||||
assert trace["messages"][3]["role"] == "tool"
|
||||
assert trace["messages"][3]["content"] == [
|
||||
{"type": "text", "text": "What is your favorite food?\n"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -142,7 +199,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
else:
|
||||
result = await mcp_stdio_client_run(
|
||||
@@ -251,7 +308,9 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
headers=_get_headers(
|
||||
_get_mcp_sse_server_base_url(), project_name, True
|
||||
),
|
||||
)
|
||||
else:
|
||||
_ = await mcp_stdio_client_run(
|
||||
@@ -353,7 +412,9 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
headers=_get_headers(
|
||||
_get_mcp_sse_server_base_url(), project_name, True
|
||||
),
|
||||
)
|
||||
else:
|
||||
_ = await mcp_stdio_client_run(
|
||||
@@ -470,7 +531,7 @@ async def test_mcp_tool_list_blocking(
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(project_name, True),
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
else:
|
||||
tools_result = await mcp_stdio_client_run(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.12-slim
|
||||
FROM python:3.13-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
101
tests/integration/resources/mcp/streamable/client/main.py
Normal file
101
tests/integration/resources/mcp/streamable/client/main.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""This is a simple example of how to use the MCP client with Streamable HTTP transport."""
|
||||
|
||||
# pylint: disable=E1101
|
||||
# pylint: disable=W0201
|
||||
# pylint: disable=C2801
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Any, Optional
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP Client for interacting with a MCP Streamable HTTP server and processing queries"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._streams_context = None # Initialize these to None
|
||||
self._session_context = None # so they always exist
|
||||
|
||||
async def connect_to_streamable_server(
|
||||
self, server_url: str, headers: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Connect to an MCP server running with Streamable HTTP transport
|
||||
|
||||
Args:
|
||||
server_url: URL of the MCP server
|
||||
headers: Optional headers to include in the request
|
||||
"""
|
||||
# Store the context managers so they stay alive
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=server_url,
|
||||
headers=headers or {},
|
||||
timeout=timedelta(seconds=5),
|
||||
sse_read_timeout=timedelta(seconds=10),
|
||||
)
|
||||
read_stream, write_stream, session_id = await self._streams_context.__aenter__()
|
||||
self.session_id = session_id
|
||||
|
||||
self._session_context = ClientSession(read_stream, write_stream)
|
||||
self.session: ClientSession = await self._session_context.__aenter__()
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Properly clean up the session and streams"""
|
||||
if self._session_context:
|
||||
await self._session_context.__aexit__(None, None, None)
|
||||
if self._streams_context:
|
||||
await self._streams_context.__aexit__(None, None, None)
|
||||
|
||||
async def process_query(self, tool_name: str, tool_args: dict) -> str:
|
||||
"""Process a query using MCP server"""
|
||||
result = await self.session.call_tool(
|
||||
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def run(
|
||||
gateway_url: str,
|
||||
push_to_explorer: bool,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
headers: dict[str, str] = None,
|
||||
):
|
||||
"""
|
||||
Run the MCP client with the given parameters.
|
||||
|
||||
Args:
|
||||
gateway_url: URL of the Invariant Gateway
|
||||
push_to_explorer: Whether to push traces to the Invariant Explorer
|
||||
tool_name: Name of the tool to call
|
||||
tool_args: Arguments for the tool call
|
||||
|
||||
"""
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_to_streamable_server(
|
||||
server_url=gateway_url, headers=headers or {}
|
||||
)
|
||||
# list tools
|
||||
listed_tools = await client.session.list_tools()
|
||||
# call tool
|
||||
if tool_name == "tools/list":
|
||||
return listed_tools
|
||||
else:
|
||||
return await client.process_query(tool_name, tool_args)
|
||||
finally:
|
||||
# Sleep for a while to allow the server to process the background tasks
|
||||
# like pushing traces to the explorer
|
||||
if push_to_explorer:
|
||||
await asyncio.sleep(2)
|
||||
await client.cleanup()
|
||||
@@ -0,0 +1,18 @@
|
||||
FROM python:3.13-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the messenger server code
|
||||
COPY tests/integration/resources/mcp/streamable/messenger_server /app/messenger_server
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir "uvicorn[standard]" "httpx" "mcp[cli]" "starlette"
|
||||
|
||||
# Default values (will be overridden in compose)
|
||||
ENV HOST=0.0.0.0
|
||||
ENV PORT=8124
|
||||
ENV TRANSPORT=json
|
||||
ENV STATEFUL=false
|
||||
|
||||
# Use environment variables in the CMD
|
||||
CMD ["sh", "-c", "python messenger_server/main.py --host $HOST --port $PORT --transport $TRANSPORT --stateful $STATEFUL"]
|
||||
@@ -0,0 +1,70 @@
|
||||
"""This is a messenger server implementation that returns a few messages based on the username."""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Read config from environment variables
|
||||
TRANSPORT = os.getenv("TRANSPORT", "json").lower()
|
||||
STATEFUL = os.getenv("STATEFUL", "false").lower() == "true"
|
||||
|
||||
# Initialize FastMCP server
|
||||
mcp = FastMCP(
|
||||
"messenger_server",
|
||||
json_response=(TRANSPORT == "json"),
|
||||
stateless_http=(not STATEFUL),
|
||||
)
|
||||
|
||||
|
||||
MESSAGES = [
|
||||
"What about you?",
|
||||
"What are you doing?",
|
||||
"What is your name?",
|
||||
"What is your favorite color?",
|
||||
"What is your favorite food?",
|
||||
"What is your favorite movie?",
|
||||
"What is your favorite book?",
|
||||
]
|
||||
|
||||
|
||||
def _deterministic_index_from_username(username: str, limit: int) -> int:
|
||||
"""Deterministically calculate the index of messages to return based on the username."""
|
||||
hash_val = int(hashlib.sha256(username.encode()).hexdigest(), 16)
|
||||
return hash_val % limit + 1
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_last_message_from_user(username: str) -> str:
|
||||
"""Get the last message sent by the username."""
|
||||
return MESSAGES[_deterministic_index_from_username(username, len(MESSAGES))] + "\n"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_message(username: str, message: str) -> str:
|
||||
"""Send a message to the username."""
|
||||
return f"Message '{message}' sent to {username}."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run MCP Streamable HTTP based server")
|
||||
parser.add_argument("--host", help="Host to bind to", required=True)
|
||||
parser.add_argument("--port", help="Port to listen on", required=True, type=int)
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
help="Transport type (json or sse)",
|
||||
default="json",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stateful",
|
||||
help="Whether the server is stateful or stateless",
|
||||
default="false",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
uvicorn.run(mcp.streamable_http_app, host=args.host, port=args.port)
|
||||
Reference in New Issue
Block a user