Files
invariant-gateway/tests/integration/resources/mcp/sse/client/main.py
T
Luca Beurer-Kellner e18c6b5bdb Add an option to add extra metadata that is pushed and passed to Guardrails during an MCP session (#47)
* use select() before readline

* support for setting static metadata for MCP sessions

* nest extra mcp metadata in metadata object

* unify session metadata

* extra metadata tests

* use empty object as parameters, if None

* list_tools as tool call

* offset indices in tests

* test: adjust addresses

* mcp: make error reporting configurable

* line logging

* log version

* verbose logging + loud exception failure

* add server and client name to policy get

* append trace even if not pushing

* port tools/list message support to SSE

* use python -m build

* adjust guardrail failure address

* support for blocking tools/list in SSE

* use error-based failure response format by default

* tools/list test

* don't list_tools in stdio connect

* flaky test: handle second possible result in anthropic streaming case

---------

Co-authored-by: knielsen404 <kristian@invariantlabs.ai>
2025-05-19 13:44:37 +02:00

113 lines
3.7 KiB
Python

"""This is a simple example of how to use the MCP client with SSE transport."""
# pylint: disable=E1101
# pylint: disable=W0201
import asyncio
from datetime import timedelta
from typing import Any, Optional
from contextlib import AsyncExitStack
from mcp import ClientSession
from mcp.client.sse import sse_client
class MCPClient:
"""MCP Client for interacting with a MCP SSE server and processing queries"""
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self._streams_context = None # Initialize these to None
self._session_context = None # so they always exist
async def connect_to_sse_server(
self, server_url: str, headers: Optional[dict] = None
):
"""
Connect to an MCP server running with SSE transport
Args:
server_url: URL of the MCP server
headers: Optional headers to include in the request
"""
# Store the context managers so they stay alive
self._streams_context = sse_client(
url=server_url,
timeout=5,
headers=headers or {},
sse_read_timeout=10,
)
streams = await self._streams_context.__aenter__()
self._session_context = ClientSession(*streams)
# pylint: disable=C2801
self.session: ClientSession = await self._session_context.__aenter__()
# Initialize
await self.session.initialize()
async def cleanup(self):
"""Clean up the session and streams"""
# Check if the session context exists before trying to exit it
if hasattr(self, "_session_context") and self._session_context is not None:
await self._session_context.__aexit__(None, None, None)
# Check if the streams context exists before trying to exit it
if hasattr(self, "_streams_context") and self._streams_context is not None:
await self._streams_context.__aexit__(None, None, None)
async def process_query(self, tool_name: str, tool_args: dict) -> str:
"""Process a query using MCP server"""
result = await self.session.call_tool(
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
)
return result
async def run(
gateway_url: str,
mcp_server_base_url: str,
project_name: str,
push_to_explorer: bool,
tool_name: str,
tool_args: dict[str, Any],
):
"""
Run the MCP client with the given parameters.
Args:
gateway_url: URL of the Invariant Gateway
mcp_server_base_url: Base URL of the MCP server
project_name: Name of the project in Invariant Explorer
push_to_explorer: Whether to push traces to the Invariant Explorer
tool_name: Name of the tool to call
tool_args: Arguments for the tool call
"""
client = MCPClient()
try:
await client.connect_to_sse_server(
server_url=gateway_url,
headers={
"MCP-SERVER-BASE-URL": mcp_server_base_url,
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
},
)
# list tools
listed_tools = await client.session.list_tools()
# call tool
if tool_name == "tools/list":
return listed_tools
else:
return await client.process_query(tool_name, tool_args)
finally:
# Sleep for a while to allow the server to process the background tasks
# like pushing traces to the explorer
if push_to_explorer:
await asyncio.sleep(2)
await client.cleanup()