mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-13 00:07:47 +02:00
When tool_call is blocked in MCP Post method, add the error message to a pending error messages list. Create two queues in the MCP SSE Get endpoint which correspond to the MCP server events and these pending error messages. These two queues are merged to return events back to the client.
This commit is contained in:
@@ -24,6 +24,7 @@ This allows you to _observe and debug_ your agents in [Invariant Explorer](https
|
||||
- [x] **Single Line Setup**: Just change the base URL of your LLM provider to the Invariant Gateway.
|
||||
- [x] **Intercepts agents on an LLM-level** for better debugging and analysis.
|
||||
- [x] **Tool Calling and Computer Use Support** to capture all forms of agentic interactions.
|
||||
- [x] **MCP Protocol Support** for both standard I/O and Server-Sent Events (SSE) transports.
|
||||
- [x] **Seamless forwarding and LLM streaming** to OpenAI, Anthropic, and other LLM providers.
|
||||
- [x] **Store and organize runtime traces** in the [Invariant Explorer](https://explorer.invariantlabs.ai/).
|
||||
|
||||
@@ -277,6 +278,30 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
|
||||
|
||||
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
|
||||
|
||||
### **Using MCP with Invariant Gateway**
|
||||
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
|
||||
|
||||
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
|
||||
|
||||
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
|
||||
|
||||
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
|
||||
* Use the following configuration to connect to the local Gateway instance:
|
||||
```python
|
||||
await client.connect_to_sse_server(
|
||||
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
|
||||
headers={
|
||||
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
|
||||
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
|
||||
"PUSH-INVARIANT-EXPLORER": "true",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
|
||||
|
||||
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
|
||||
|
||||
---
|
||||
|
||||
## **Run the Gateway Locally**
|
||||
|
||||
@@ -43,6 +43,9 @@ class McpSession(BaseModel):
|
||||
blocking_guardrails=[], logging_guardrails=[]
|
||||
)
|
||||
)
|
||||
# When tool calls are blocked, the error message is stored here
|
||||
# and sent to the client via the SSE stream.
|
||||
pending_error_messages: List[dict] = Field(default_factory=list)
|
||||
|
||||
# Lock to maintain in-order pushes to explorer
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
@@ -214,6 +217,29 @@ class McpSession(BaseModel):
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
|
||||
|
||||
async def add_pending_error_message(self, error_message: dict) -> None:
|
||||
"""
|
||||
Add a pending error message to the session.
|
||||
|
||||
Args:
|
||||
error_message: The error message to add
|
||||
"""
|
||||
async with self.session_lock():
|
||||
# pylint: disable=no-member
|
||||
self.pending_error_messages.append(error_message)
|
||||
|
||||
async def get_pending_error_messages(self) -> List[dict]:
|
||||
"""
|
||||
Get all pending error messages for the session.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of pending error messages
|
||||
"""
|
||||
async with self.session_lock():
|
||||
messages = list(self.pending_error_messages)
|
||||
self.pending_error_messages = []
|
||||
return messages
|
||||
|
||||
|
||||
class SseHeaderAttributes(BaseModel):
|
||||
"""
|
||||
@@ -235,8 +261,8 @@ class SseHeaderAttributes(BaseModel):
|
||||
SseHeaderAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-EXPLORER", "false").lower()
|
||||
project_name = headers.get("INVARIANT-PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
|
||||
|
||||
# Determine explorer_dataset
|
||||
if project_name:
|
||||
|
||||
@@ -26,7 +26,6 @@ from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
|
||||
UTF_8_ENCODING = "utf-8"
|
||||
MCP_INITIALIZE = "initialize"
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
|
||||
+167
-63
@@ -92,17 +92,10 @@ async def mcp_post_gateway(
|
||||
session_id=session_id, request_json=request_json
|
||||
)
|
||||
if is_blocked:
|
||||
# If blocked, hook_tool_call_result contains the block message.
|
||||
# Forward the block message result back to the caller.
|
||||
# The original request is not passed to the MCP process.
|
||||
return Response(
|
||||
content=json.dumps(hook_tool_call_result),
|
||||
status_code=403,
|
||||
headers={
|
||||
"X-Proxied-By": "mcp-gateway",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
# Add the error message to the session.
|
||||
# The error message is sent back to the client using the SSE stream.
|
||||
await session.add_pending_error_message(hook_tool_call_result)
|
||||
return Response(content="Accepted", status_code=202)
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
@@ -150,60 +143,139 @@ async def mcp_get_sse_gateway(
|
||||
|
||||
query_params = dict(request.query_params)
|
||||
response_headers = {}
|
||||
filtered_headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
}
|
||||
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
|
||||
|
||||
async def event_generator():
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(CLIENT_TIMEOUT),
|
||||
headers={
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
},
|
||||
) as client:
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
mcp_server_sse_endpoint,
|
||||
params=query_params,
|
||||
) as event_source:
|
||||
if event_source.response.status_code != 200:
|
||||
error_content = await event_source.response.aread()
|
||||
raise HTTPException(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
"""
|
||||
Generate a merged stream of MCP server events and pending error messages.
|
||||
The pending error messages are added in the POST messages handler.
|
||||
This function runs in a loop, yielding events as they arrive.
|
||||
"""
|
||||
mcp_server_events_queue = asyncio.Queue()
|
||||
pending_error_messages_queue = asyncio.Queue()
|
||||
tasks = set()
|
||||
session_id = None
|
||||
|
||||
session_id = None
|
||||
try:
|
||||
# MCP Server Events Processor
|
||||
async def process_mcp_server_events():
|
||||
"""Connect to MCP server and process its events."""
|
||||
nonlocal session_id
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
|
||||
)
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
(
|
||||
event_bytes,
|
||||
session_id,
|
||||
) = await _handle_endpoint_event(
|
||||
sse,
|
||||
sse_header_attributes=SseHeaderAttributes.from_request_headers(
|
||||
request.headers
|
||||
),
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(CLIENT_TIMEOUT)
|
||||
) as client:
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
mcp_server_sse_endpoint,
|
||||
headers=filtered_headers,
|
||||
params=query_params,
|
||||
) as event_source:
|
||||
if event_source.response.status_code != 200:
|
||||
error_content = await event_source.response.aread()
|
||||
raise HTTPException(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
case "message":
|
||||
if session_id:
|
||||
event_bytes = await _handle_message_event(
|
||||
session_id=session_id, sse=sse
|
||||
)
|
||||
yield event_bytes
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"[MCP SSE] Stream closed: {str(e)}", flush=True)
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP SSE] Request error: {str(e)}", flush=True)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Unexpected error: {str(e)}", flush=True)
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "endpoint":
|
||||
(
|
||||
event_bytes,
|
||||
extracted_id,
|
||||
) = await _handle_endpoint_event(
|
||||
sse, sse_header_attributes
|
||||
)
|
||||
session_id = extracted_id
|
||||
|
||||
if (
|
||||
session_id
|
||||
and "process_error_messages_task"
|
||||
not in locals()
|
||||
):
|
||||
process_error_messages_task = (
|
||||
asyncio.create_task(
|
||||
_check_for_pending_error_messages(
|
||||
session_id,
|
||||
pending_error_messages_queue,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.add(process_error_messages_task)
|
||||
process_error_messages_task.add_done_callback(
|
||||
tasks.discard
|
||||
)
|
||||
|
||||
elif sse.event == "message" and session_id:
|
||||
# Process message event
|
||||
event_bytes = await _handle_message_event(
|
||||
session_id, sse
|
||||
)
|
||||
else:
|
||||
# Pass through other event types
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Put the processed event in the queue
|
||||
await mcp_server_events_queue.put(event_bytes)
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"Server stream closed: {e}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error processing server events: {e}", flush=True)
|
||||
|
||||
# Start server events processor
|
||||
mcp_server_events_task = asyncio.create_task(process_mcp_server_events())
|
||||
tasks.add(mcp_server_events_task)
|
||||
mcp_server_events_task.add_done_callback(tasks.discard)
|
||||
|
||||
# Main event loop: merge MCP server events and pending error messages
|
||||
while True:
|
||||
# Create futures for both queues
|
||||
mcp_server_event_future = asyncio.create_task(
|
||||
mcp_server_events_queue.get()
|
||||
)
|
||||
pending_error_message_future = asyncio.create_task(
|
||||
pending_error_messages_queue.get()
|
||||
)
|
||||
|
||||
# Wait for either queue to have an item, with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
[mcp_server_event_future, pending_error_message_future],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=0.25,
|
||||
)
|
||||
|
||||
for future in pending:
|
||||
future.cancel()
|
||||
|
||||
# Timeout occurred and no future completed.
|
||||
if not done:
|
||||
continue
|
||||
|
||||
for future in done:
|
||||
try:
|
||||
event = await future
|
||||
yield event
|
||||
except asyncio.CancelledError:
|
||||
# Future was cancelled, continue
|
||||
continue
|
||||
|
||||
finally:
|
||||
# Clean up all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=2)
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
@@ -286,11 +358,15 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrailing_result = await session.get_guardrails_check_result(
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if guardrailing_result and guardrailing_result.get("errors", []):
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -298,12 +374,12 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrailing_result["errors"],
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrailing_result)
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -391,6 +467,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
@@ -421,3 +498,30 @@ def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _check_for_pending_error_messages(
|
||||
session_id: str, pending_error_messages_queue: asyncio.Queue
|
||||
):
|
||||
"""Periodically check for and enqueue pending error messages."""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
session = session_store.get_session(session_id)
|
||||
error_messages = await session.get_pending_error_messages()
|
||||
|
||||
for error_message in error_messages:
|
||||
error_bytes = (
|
||||
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
await pending_error_messages_queue.put(error_bytes)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Error checking for messages: {e}", flush=True)
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled, exit gracefully
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user