When tool_call is blocked in MCP Post method, add the error message to a pending error messages list. Create two queues in the MCP SSE Get endpoint which correspond to the MCP server events and these pending error messages. These two queues are merged to return events back to the client.

This commit is contained in:
Hemang
2025-05-09 10:47:08 +05:30
parent 794aae0326
commit edd9fd9a5c
4 changed files with 220 additions and 66 deletions
+25
View File
@@ -24,6 +24,7 @@ This allows you to _observe and debug_ your agents in [Invariant Explorer](https
- [x] **Single Line Setup**: Just change the base URL of your LLM provider to the Invariant Gateway.
- [x] **Intercepts agents on an LLM-level** for better debugging and analysis.
- [x] **Tool Calling and Computer Use Support** to capture all forms of agentic interactions.
- [x] **MCP Protocol Support** for both standard I/O and Server-Sent Events (SSE) transports.
- [x] **Seamless forwarding and LLM streaming** to OpenAI, Anthropic, and other LLM providers.
- [x] **Store and organize runtime traces** in the [Invariant Explorer](https://explorer.invariantlabs.ai/).
@@ -277,6 +278,30 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
### **Using MCP with Invariant Gateway**
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
* Use the following configuration to connect to the local Gateway instance:
```python
await client.connect_to_sse_server(
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
headers={
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
"PUSH-INVARIANT-EXPLORER": "true",
},
)
```
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
---
## **Run the Gateway Locally**
+28 -2
View File
@@ -43,6 +43,9 @@ class McpSession(BaseModel):
blocking_guardrails=[], logging_guardrails=[]
)
)
# When tool calls are blocked, the error message is stored here
# and sent to the client via the SSE stream.
pending_error_messages: List[dict] = Field(default_factory=list)
# Lock to maintain in-order pushes to explorer
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
@@ -214,6 +217,29 @@ class McpSession(BaseModel):
except Exception as e: # pylint: disable=broad-except
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
async def add_pending_error_message(self, error_message: dict) -> None:
"""
Add a pending error message to the session.
Args:
error_message: The error message to add
"""
async with self.session_lock():
# pylint: disable=no-member
self.pending_error_messages.append(error_message)
async def get_pending_error_messages(self) -> List[dict]:
"""
Get all pending error messages for the session.
Returns:
List[dict]: A list of pending error messages
"""
async with self.session_lock():
messages = list(self.pending_error_messages)
self.pending_error_messages = []
return messages
class SseHeaderAttributes(BaseModel):
"""
@@ -235,8 +261,8 @@ class SseHeaderAttributes(BaseModel):
SseHeaderAttributes: An instance with values extracted from headers
"""
# Extract and process header values
project_name = headers.get("PROJECT-NAME")
push_explorer_header = headers.get("PUSH-EXPLORER", "false").lower()
project_name = headers.get("INVARIANT-PROJECT-NAME")
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
# Determine explorer_dataset
if project_name:
-1
View File
@@ -26,7 +26,6 @@ from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
UTF_8_ENCODING = "utf-8"
MCP_INITIALIZE = "initialize"
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
+167 -63
View File
@@ -92,17 +92,10 @@ async def mcp_post_gateway(
session_id=session_id, request_json=request_json
)
if is_blocked:
# If blocked, hook_tool_call_result contains the block message.
# Forward the block message result back to the caller.
# The original request is not passed to the MCP process.
return Response(
content=json.dumps(hook_tool_call_result),
status_code=403,
headers={
"X-Proxied-By": "mcp-gateway",
"Content-Type": "application/json",
},
)
# Add the error message to the session.
# The error message is sent back to the client using the SSE stream.
await session.add_pending_error_message(hook_tool_call_result)
return Response(content="Accepted", status_code=202)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
@@ -150,60 +143,139 @@ async def mcp_get_sse_gateway(
query_params = dict(request.query_params)
response_headers = {}
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
}
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
async def event_generator():
async with httpx.AsyncClient(
timeout=httpx.Timeout(CLIENT_TIMEOUT),
headers={
k: v
for k, v in request.headers.items()
if k.lower() in MCP_SERVER_SSE_HEADERS
},
) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_sse_endpoint,
params=query_params,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
"""
Generate a merged stream of MCP server events and pending error messages.
The pending error messages are added in the POST messages handler.
This function runs in a loop, yielding events as they arrive.
"""
mcp_server_events_queue = asyncio.Queue()
pending_error_messages_queue = asyncio.Queue()
tasks = set()
session_id = None
session_id = None
try:
# MCP Server Events Processor
async def process_mcp_server_events():
"""Connect to MCP server and process its events."""
nonlocal session_id
async for sse in event_source.aiter_sse():
event_bytes = (
f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
)
match sse.event:
case "endpoint":
(
event_bytes,
session_id,
) = await _handle_endpoint_event(
sse,
sse_header_attributes=SseHeaderAttributes.from_request_headers(
request.headers
),
async with httpx.AsyncClient(
timeout=httpx.Timeout(CLIENT_TIMEOUT)
) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_sse_endpoint,
headers=filtered_headers,
params=query_params,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
case "message":
if session_id:
event_bytes = await _handle_message_event(
session_id=session_id, sse=sse
)
yield event_bytes
except httpx.StreamClosed as e:
print(f"[MCP SSE] Stream closed: {str(e)}", flush=True)
except httpx.RequestError as e:
print(f"[MCP SSE] Request error: {str(e)}", flush=True)
except Exception as e: # pylint: disable=broad-except
print(f"[MCP SSE] Unexpected error: {str(e)}", flush=True)
async for sse in event_source.aiter_sse():
if sse.event == "endpoint":
(
event_bytes,
extracted_id,
) = await _handle_endpoint_event(
sse, sse_header_attributes
)
session_id = extracted_id
if (
session_id
and "process_error_messages_task"
not in locals()
):
process_error_messages_task = (
asyncio.create_task(
_check_for_pending_error_messages(
session_id,
pending_error_messages_queue,
)
)
)
tasks.add(process_error_messages_task)
process_error_messages_task.add_done_callback(
tasks.discard
)
elif sse.event == "message" and session_id:
# Process message event
event_bytes = await _handle_message_event(
session_id, sse
)
else:
# Pass through other event types
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
"utf-8"
)
# Put the processed event in the queue
await mcp_server_events_queue.put(event_bytes)
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}", flush=True)
except Exception as e:
print(f"Error processing server events: {e}", flush=True)
# Start server events processor
mcp_server_events_task = asyncio.create_task(process_mcp_server_events())
tasks.add(mcp_server_events_task)
mcp_server_events_task.add_done_callback(tasks.discard)
# Main event loop: merge MCP server events and pending error messages
while True:
# Create futures for both queues
mcp_server_event_future = asyncio.create_task(
mcp_server_events_queue.get()
)
pending_error_message_future = asyncio.create_task(
pending_error_messages_queue.get()
)
# Wait for either queue to have an item, with timeout
done, pending = await asyncio.wait(
[mcp_server_event_future, pending_error_message_future],
return_when=asyncio.FIRST_COMPLETED,
timeout=0.25,
)
for future in pending:
future.cancel()
# Timeout occurred and no future completed.
if not done:
continue
for future in done:
try:
event = await future
yield event
except asyncio.CancelledError:
# Future was cancelled, continue
continue
finally:
# Clean up all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to complete
if tasks:
await asyncio.wait(tasks, timeout=2)
# Return the streaming response
return StreamingResponse(
@@ -286,11 +358,15 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
}
result = response_json
session = session_store.get_session(session_id)
guardrailing_result = await session.get_guardrails_check_result(
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if guardrailing_result and guardrailing_result.get("errors", []):
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# If the request is blocked, return a message indicating the block reason.
result = {
"jsonrpc": "2.0",
@@ -298,12 +374,12 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrailing_result["errors"],
% guardrails_result["errors"],
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrailing_result)
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result
@@ -391,6 +467,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
)
@@ -421,3 +498,30 @@ def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
if annotation not in session.annotations:
return True
return False
async def _check_for_pending_error_messages(
session_id: str, pending_error_messages_queue: asyncio.Queue
):
"""Periodically check for and enqueue pending error messages."""
try:
while True:
try:
session = session_store.get_session(session_id)
error_messages = await session.get_pending_error_messages()
for error_message in error_messages:
error_bytes = (
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
"utf-8"
)
)
await pending_error_messages_queue.put(error_bytes)
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
print(f"Error checking for messages: {e}", flush=True)
await asyncio.sleep(1)
except asyncio.CancelledError:
# Task was cancelled, exit gracefully
return