diff --git a/README.md b/README.md index d035c81..8d5be39 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ This allows you to _observe and debug_ your agents in [Invariant Explorer](https - [x] **Single Line Setup**: Just change the base URL of your LLM provider to the Invariant Gateway. - [x] **Intercepts agents on an LLM-level** for better debugging and analysis. - [x] **Tool Calling and Computer Use Support** to capture all forms of agentic interactions. +- [x] **MCP Protocol Support** for both standard I/O and Server-Sent Events (SSE) transports. - [x] **Seamless forwarding and LLM streaming** to OpenAI, Anthropic, and other LLM providers. - [x] **Store and organize runtime traces** in the [Invariant Explorer](https://explorer.invariantlabs.ai/). @@ -277,6 +278,30 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀 +### **Using MCP with Invariant Gateway** +Invariant Gateway supports MCP (both stdio and SSE transports) tool calling. + +For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp). + +For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server: + +* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally). +* Use the following configuration to connect to the local Gateway instance: +```python +await client.connect_to_sse_server( + server_url="http://localhost:8005/api/v1/gateway/mcp/sse", + headers={ + "MCP-SERVER-BASE-URL": "", + "INVARIANT-PROJECT-NAME": "", + "PUSH-INVARIANT-EXPLORER": "true", + }, + ) +``` + +If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there. + +You can also specify blocking or logging guardrails for the project name by visiting the Explorer. + --- ## **Run the Gateway Locally** diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index 663802a..359f3c4 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -43,6 +43,9 @@ class McpSession(BaseModel): blocking_guardrails=[], logging_guardrails=[] ) ) + # When tool calls are blocked, the error message is stored here + # and sent to the client via the SSE stream. + pending_error_messages: List[dict] = Field(default_factory=list) # Lock to maintain in-order pushes to explorer _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) @@ -214,6 +217,29 @@ class McpSession(BaseModel): except Exception as e: # pylint: disable=broad-except print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}") + async def add_pending_error_message(self, error_message: dict) -> None: + """ + Add a pending error message to the session. + + Args: + error_message: The error message to add + """ + async with self.session_lock(): + # pylint: disable=no-member + self.pending_error_messages.append(error_message) + + async def get_pending_error_messages(self) -> List[dict]: + """ + Get all pending error messages for the session. + + Returns: + List[dict]: A list of pending error messages + """ + async with self.session_lock(): + messages = list(self.pending_error_messages) + self.pending_error_messages = [] + return messages + class SseHeaderAttributes(BaseModel): """ @@ -235,8 +261,8 @@ class SseHeaderAttributes(BaseModel): SseHeaderAttributes: An instance with values extracted from headers """ # Extract and process header values - project_name = headers.get("PROJECT-NAME") - push_explorer_header = headers.get("PUSH-EXPLORER", "false").lower() + project_name = headers.get("INVARIANT-PROJECT-NAME") + push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower() # Determine explorer_dataset if project_name: diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index b7ff75f..29f225f 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -26,7 +26,6 @@ from gateway.mcp.mcp_context import McpContext from gateway.mcp.task_utils import run_task_in_background, run_task_sync UTF_8_ENCODING = "utf-8" -MCP_INITIALIZE = "initialize" DEFAULT_API_URL = "https://explorer.invariantlabs.ai" diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 33d9415..677e342 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -92,17 +92,10 @@ async def mcp_post_gateway( session_id=session_id, request_json=request_json ) if is_blocked: - # If blocked, hook_tool_call_result contains the block message. - # Forward the block message result back to the caller. - # The original request is not passed to the MCP process. - return Response( - content=json.dumps(hook_tool_call_result), - status_code=403, - headers={ - "X-Proxied-By": "mcp-gateway", - "Content-Type": "application/json", - }, - ) + # Add the error message to the session. + # The error message is sent back to the client using the SSE stream. + await session.add_pending_error_message(hook_tool_call_result) + return Response(content="Accepted", status_code=202) async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: @@ -150,60 +143,139 @@ async def mcp_get_sse_gateway( query_params = dict(request.query_params) response_headers = {} + filtered_headers = { + k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS + } + sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers) async def event_generator(): - async with httpx.AsyncClient( - timeout=httpx.Timeout(CLIENT_TIMEOUT), - headers={ - k: v - for k, v in request.headers.items() - if k.lower() in MCP_SERVER_SSE_HEADERS - }, - ) as client: - try: - async with aconnect_sse( - client, - "GET", - mcp_server_sse_endpoint, - params=query_params, - ) as event_source: - if event_source.response.status_code != 200: - error_content = await event_source.response.aread() - raise HTTPException( - status_code=event_source.response.status_code, - detail=error_content, - ) + """ + Generate a merged stream of MCP server events and pending error messages. + The pending error messages are added in the POST messages handler. + This function runs in a loop, yielding events as they arrive. + """ + mcp_server_events_queue = asyncio.Queue() + pending_error_messages_queue = asyncio.Queue() + tasks = set() + session_id = None - session_id = None + try: + # MCP Server Events Processor + async def process_mcp_server_events(): + """Connect to MCP server and process its events.""" + nonlocal session_id - async for sse in event_source.aiter_sse(): - event_bytes = ( - f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8") - ) - match sse.event: - case "endpoint": - ( - event_bytes, - session_id, - ) = await _handle_endpoint_event( - sse, - sse_header_attributes=SseHeaderAttributes.from_request_headers( - request.headers - ), + async with httpx.AsyncClient( + timeout=httpx.Timeout(CLIENT_TIMEOUT) + ) as client: + try: + async with aconnect_sse( + client, + "GET", + mcp_server_sse_endpoint, + headers=filtered_headers, + params=query_params, + ) as event_source: + if event_source.response.status_code != 200: + error_content = await event_source.response.aread() + raise HTTPException( + status_code=event_source.response.status_code, + detail=error_content, ) - case "message": - if session_id: - event_bytes = await _handle_message_event( - session_id=session_id, sse=sse - ) - yield event_bytes - except httpx.StreamClosed as e: - print(f"[MCP SSE] Stream closed: {str(e)}", flush=True) - except httpx.RequestError as e: - print(f"[MCP SSE] Request error: {str(e)}", flush=True) - except Exception as e: # pylint: disable=broad-except - print(f"[MCP SSE] Unexpected error: {str(e)}", flush=True) + async for sse in event_source.aiter_sse(): + if sse.event == "endpoint": + ( + event_bytes, + extracted_id, + ) = await _handle_endpoint_event( + sse, sse_header_attributes + ) + session_id = extracted_id + + if ( + session_id + and "process_error_messages_task" + not in locals() + ): + process_error_messages_task = ( + asyncio.create_task( + _check_for_pending_error_messages( + session_id, + pending_error_messages_queue, + ) + ) + ) + tasks.add(process_error_messages_task) + process_error_messages_task.add_done_callback( + tasks.discard + ) + + elif sse.event == "message" and session_id: + # Process message event + event_bytes = await _handle_message_event( + session_id, sse + ) + else: + # Pass through other event types + # pylint: disable=line-too-long + event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode( + "utf-8" + ) + + # Put the processed event in the queue + await mcp_server_events_queue.put(event_bytes) + + except httpx.StreamClosed as e: + print(f"Server stream closed: {e}", flush=True) + except Exception as e: + print(f"Error processing server events: {e}", flush=True) + + # Start server events processor + mcp_server_events_task = asyncio.create_task(process_mcp_server_events()) + tasks.add(mcp_server_events_task) + mcp_server_events_task.add_done_callback(tasks.discard) + + # Main event loop: merge MCP server events and pending error messages + while True: + # Create futures for both queues + mcp_server_event_future = asyncio.create_task( + mcp_server_events_queue.get() + ) + pending_error_message_future = asyncio.create_task( + pending_error_messages_queue.get() + ) + + # Wait for either queue to have an item, with timeout + done, pending = await asyncio.wait( + [mcp_server_event_future, pending_error_message_future], + return_when=asyncio.FIRST_COMPLETED, + timeout=0.25, + ) + + for future in pending: + future.cancel() + + # Timeout occurred and no future completed. + if not done: + continue + + for future in done: + try: + event = await future + yield event + except asyncio.CancelledError: + # Future was cancelled, continue + continue + + finally: + # Clean up all tasks + for task in tasks: + task.cancel() + + # Wait for all tasks to complete + if tasks: + await asyncio.wait(tasks, timeout=2) # Return the streaming response return StreamingResponse( @@ -286,11 +358,15 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict } result = response_json session = session_store.get_session(session_id) - guardrailing_result = await session.get_guardrails_check_result( + guardrails_result = await session.get_guardrails_check_result( message, action=GuardrailAction.BLOCK ) - if guardrailing_result and guardrailing_result.get("errors", []): + if ( + guardrails_result + and guardrails_result.get("errors", []) + and _check_if_new_errors(session_id, guardrails_result) + ): # If the request is blocked, return a message indicating the block reason. result = { "jsonrpc": "2.0", @@ -298,12 +374,12 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict "error": { "code": -32600, "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrailing_result["errors"], + % guardrails_result["errors"], }, } # Push trace to the explorer - don't block on its response asyncio.create_task( - session_store.add_message_to_session(session_id, message, guardrailing_result) + session_store.add_message_to_session(session_id, message, guardrails_result) ) return result @@ -391,6 +467,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: # Update the event bytes with hook_tool_call_response. # hook_tool_call_response is same as response_json if no guardrail is violated. # If guardrail is violated, it contains the error message. + # pylint: disable=line-too-long event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( "utf-8" ) @@ -421,3 +498,30 @@ def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool: if annotation not in session.annotations: return True return False + + +async def _check_for_pending_error_messages( + session_id: str, pending_error_messages_queue: asyncio.Queue +): + """Periodically check for and enqueue pending error messages.""" + try: + while True: + try: + session = session_store.get_session(session_id) + error_messages = await session.get_pending_error_messages() + + for error_message in error_messages: + error_bytes = ( + f"event: message\ndata: {json.dumps(error_message)}\n\n".encode( + "utf-8" + ) + ) + await pending_error_messages_queue.put(error_bytes) + + await asyncio.sleep(1) + except Exception as e: # pylint: disable=broad-except + print(f"Error checking for messages: {e}", flush=True) + await asyncio.sleep(1) + except asyncio.CancelledError: + # Task was cancelled, exit gracefully + return