mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-13 00:07:47 +02:00
e18c6b5bdb
* use select() before readline * support for setting static metadata for MCP sessions * nest extra mcp metadata in metadata object * unify session metadata * extra metadata tests * use empty object as parameters, if None * list_tools as tool call * offset indices in tests * test: adjust addresses * mcp: make error reporting configurable * line logging * log version * verbose logging + loud exception failure * add server and client name to policy get * append trace even if not pushing * port tools/list message support to SSE * use python -m build * adjust guardrail failure address * support for blocking tools/list in SSE * use error-based failure response format by default * tools/list test * don't list_tools in stdio connect * flaky test: handle second possible result in anthropic streaming case --------- Co-authored-by: knielsen404 <kristian@invariantlabs.ai>
113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
"""This is a simple example of how to use the MCP client with SSE transport."""
|
|
|
|
# pylint: disable=E1101
|
|
# pylint: disable=W0201
|
|
|
|
import asyncio
|
|
from datetime import timedelta
|
|
|
|
from typing import Any, Optional
|
|
from contextlib import AsyncExitStack
|
|
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
|
|
|
|
class MCPClient:
|
|
"""MCP Client for interacting with a MCP SSE server and processing queries"""
|
|
|
|
def __init__(self):
|
|
# Initialize session and client objects
|
|
self.session: Optional[ClientSession] = None
|
|
self.exit_stack = AsyncExitStack()
|
|
self._streams_context = None # Initialize these to None
|
|
self._session_context = None # so they always exist
|
|
|
|
async def connect_to_sse_server(
|
|
self, server_url: str, headers: Optional[dict] = None
|
|
):
|
|
"""
|
|
Connect to an MCP server running with SSE transport
|
|
|
|
Args:
|
|
server_url: URL of the MCP server
|
|
headers: Optional headers to include in the request
|
|
"""
|
|
# Store the context managers so they stay alive
|
|
self._streams_context = sse_client(
|
|
url=server_url,
|
|
timeout=5,
|
|
headers=headers or {},
|
|
sse_read_timeout=10,
|
|
)
|
|
streams = await self._streams_context.__aenter__()
|
|
|
|
self._session_context = ClientSession(*streams)
|
|
# pylint: disable=C2801
|
|
self.session: ClientSession = await self._session_context.__aenter__()
|
|
|
|
# Initialize
|
|
await self.session.initialize()
|
|
|
|
async def cleanup(self):
|
|
"""Clean up the session and streams"""
|
|
# Check if the session context exists before trying to exit it
|
|
if hasattr(self, "_session_context") and self._session_context is not None:
|
|
await self._session_context.__aexit__(None, None, None)
|
|
|
|
# Check if the streams context exists before trying to exit it
|
|
if hasattr(self, "_streams_context") and self._streams_context is not None:
|
|
await self._streams_context.__aexit__(None, None, None)
|
|
|
|
async def process_query(self, tool_name: str, tool_args: dict) -> str:
|
|
"""Process a query using MCP server"""
|
|
result = await self.session.call_tool(
|
|
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
|
|
)
|
|
return result
|
|
|
|
|
|
async def run(
|
|
gateway_url: str,
|
|
mcp_server_base_url: str,
|
|
project_name: str,
|
|
push_to_explorer: bool,
|
|
tool_name: str,
|
|
tool_args: dict[str, Any],
|
|
):
|
|
"""
|
|
Run the MCP client with the given parameters.
|
|
|
|
Args:
|
|
gateway_url: URL of the Invariant Gateway
|
|
mcp_server_base_url: Base URL of the MCP server
|
|
project_name: Name of the project in Invariant Explorer
|
|
push_to_explorer: Whether to push traces to the Invariant Explorer
|
|
tool_name: Name of the tool to call
|
|
tool_args: Arguments for the tool call
|
|
|
|
"""
|
|
client = MCPClient()
|
|
try:
|
|
await client.connect_to_sse_server(
|
|
server_url=gateway_url,
|
|
headers={
|
|
"MCP-SERVER-BASE-URL": mcp_server_base_url,
|
|
"INVARIANT-PROJECT-NAME": project_name,
|
|
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
|
|
},
|
|
)
|
|
# list tools
|
|
listed_tools = await client.session.list_tools()
|
|
# call tool
|
|
if tool_name == "tools/list":
|
|
return listed_tools
|
|
else:
|
|
return await client.process_query(tool_name, tool_args)
|
|
finally:
|
|
# Sleep for a while to allow the server to process the background tasks
|
|
# like pushing traces to the explorer
|
|
if push_to_explorer:
|
|
await asyncio.sleep(2)
|
|
await client.cleanup()
|