Merge branch 'main' into mcp-metadata

This commit is contained in:
knielsen404
2025-05-12 11:09:22 +02:00
9 changed files with 913 additions and 23 deletions
+25
View File
@@ -24,6 +24,7 @@ This allows you to _observe and debug_ your agents in [Invariant Explorer](https
- [x] **Single Line Setup**: Just change the base URL of your LLM provider to the Invariant Gateway.
- [x] **Intercepts agents on an LLM-level** for better debugging and analysis.
- [x] **Tool Calling and Computer Use Support** to capture all forms of agentic interactions.
- [x] **MCP Protocol Support** for both standard I/O and Server-Sent Events (SSE) transports.
- [x] **Seamless forwarding and LLM streaming** to OpenAI, Anthropic, and other LLM providers.
- [x] **Store and organize runtime traces** in the [Invariant Explorer](https://explorer.invariantlabs.ai/).
@@ -277,6 +278,30 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
### **Using MCP with Invariant Gateway**
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
* Use the following configuration to connect to the local Gateway instance:
```python
await client.connect_to_sse_server(
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
headers={
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
"PUSH-INVARIANT-EXPLORER": "true",
},
)
```
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
---
## **Run the Gateway Locally**
+20
View File
@@ -13,3 +13,23 @@ IGNORED_HEADERS = [
]
CLIENT_TIMEOUT = 60.0
# MCP related constants
MCP_METHOD = "method"
MCP_TOOL_CALL = "tools/call"
MCP_LIST_TOOLS = "tools/list"
MCP_PARAMS = "params"
MCP_RESULT = "result"
MCP_SERVER_INFO = "serverInfo"
MCP_CLIENT_INFO = "clientInfo"
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
[Invariant Guardrails] The MCP tool call was blocked for security reasons.
Do not attempt to circumvent this block, rather explain to the user based
on the following output what went wrong: %s
"""
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
[Invariant Guardrails] This server was blocked from advertising its tools due to a security guardrail failure.
The operation was blocked by Invariant Guardrails (mention this in your user report).
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
%s
"""
+324
View File
@@ -0,0 +1,324 @@
"""MCP Sessions Manager related classes"""
import asyncio
import contextlib
import os
import random
from typing import Any, Dict, List, Optional
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
from invariant_sdk.types.push_traces import PushTracesRequest
from pydantic import BaseModel, Field, PrivateAttr
from starlette.datastructures import Headers
from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
)
from gateway.integrations.guardrails import check_guardrails
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
class McpSession(BaseModel):
"""
Represents a single MCP session.
"""
session_id: str
messages: List[Dict[str, Any]] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
explorer_dataset: str
push_explorer: bool
trace_id: Optional[str] = None
last_trace_length: int = 0
annotations: List[Dict[str, Any]] = Field(default_factory=list)
guardrails: GuardrailRuleSet = Field(
default_factory=lambda: GuardrailRuleSet(
blocking_guardrails=[], logging_guardrails=[]
)
)
# When tool calls are blocked, the error message is stored here
# and sent to the client via the SSE stream.
pending_error_messages: List[dict] = Field(default_factory=list)
# Lock to maintain in-order pushes to explorer
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def load_guardrails(self) -> None:
"""
Load guardrails for the session.
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
"""
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset,
"Bearer " + os.getenv("INVARIANT_API_KEY"),
)
def _deduplicate_annotations(self, new_annotations: list) -> list:
"""Deduplicate new_annotations using the annotations in the session."""
deduped_annotations = []
for annotation in new_annotations:
# Check if an annotation with the same content and address exists in self.annotations
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
# to remove duplicates instead of using a custom logic.
# This is a temporary solution until the Invariant SDK is updated.
is_duplicate = False
for current_annotation in self.annotations:
if (
annotation.content == current_annotation.content
and annotation.address == current_annotation.address
and annotation.extra_metadata == current_annotation.extra_metadata
):
is_duplicate = True
break
if not is_duplicate:
deduped_annotations.append(annotation)
return deduped_annotations
@contextlib.asynccontextmanager
async def session_lock(self):
"""
Context manager for the session lock.
Usage:
async with session.session_lock():
# Code that requires exclusive access to the session
"""
async with self._lock:
yield
async def get_guardrails_check_result(
self,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action.
"""
# Skip if no guardrails are configured for this action
if not (
(self.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (self.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=self.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=self.guardrails,
)
guardrails_to_check = (
self.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else self.guardrails.logging_guardrails
)
result = await check_guardrails(
messages=self.messages + [message],
guardrails=guardrails_to_check,
context=context,
)
return result
async def add_message(
self, message: Dict[str, Any], guardrails_result=Dict
) -> None:
"""
Add a message to the session and optionally push to explorer.
Args:
message: The message to add
"""
async with self.session_lock():
annotations = []
if guardrails_result and guardrails_result.get("errors", []):
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors")
)
if self.guardrails.logging_guardrails:
logging_guardrails_check_result = (
await self.get_guardrails_check_result(
message, action=GuardrailAction.LOG
)
)
if (
logging_guardrails_check_result
and logging_guardrails_check_result.get("errors", [])
):
annotations.extend(
create_annotations_from_guardrails_errors(
logging_guardrails_check_result["errors"]
)
)
deduplicated_annotations = self._deduplicate_annotations(annotations)
# pylint: disable=no-member
self.messages.append(message)
# If push_explorer is enabled, push the trace
if self.push_explorer:
await self._push_trace_update(deduplicated_annotations)
async def _push_trace_update(self, deduplicated_annotations: list) -> None:
"""
Push trace updates to the explorer.
If a trace doesn't exist, create a new one. If it does, append new messages.
This is an internal method that should only be called within a lock.
"""
try:
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
)
# If no trace exists, create a new one
if not self.trace_id:
# pylint: disable=no-member
metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])}
if self.metadata.get("mcp_client_name"):
metadata["mcp_client"] = self.metadata.get("mcp_client_name")
if self.metadata.get("mcp_server_name"):
metadata["mcp_server"] = self.metadata.get("mcp_server_name")
response = await client.push_trace(
PushTracesRequest(
messages=[self.messages],
dataset=self.explorer_dataset,
metadata=[metadata],
annotations=[deduplicated_annotations],
)
)
self.trace_id = response.id[0]
else:
new_messages = self.messages[self.last_trace_length :]
if new_messages:
await client.append_messages(
AppendMessagesRequest(
trace_id=self.trace_id,
messages=new_messages,
annotations=deduplicated_annotations,
)
)
# pylint: disable=no-member
self.annotations.extend(deduplicated_annotations)
self.last_trace_length = len(self.messages)
except Exception as e: # pylint: disable=broad-except
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
async def add_pending_error_message(self, error_message: dict) -> None:
"""
Add a pending error message to the session.
Args:
error_message: The error message to add
"""
async with self.session_lock():
# pylint: disable=no-member
self.pending_error_messages.append(error_message)
async def get_pending_error_messages(self) -> List[dict]:
"""
Get all pending error messages for the session.
Returns:
List[dict]: A list of pending error messages
"""
async with self.session_lock():
messages = list(self.pending_error_messages)
self.pending_error_messages = []
return messages
class SseHeaderAttributes(BaseModel):
"""
A Pydantic model to represent header attributes.
"""
push_explorer: bool
explorer_dataset: str
@classmethod
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
"""
Create an instance from FastAPI request headers.
Args:
headers: FastAPI Request headers
Returns:
SseHeaderAttributes: An instance with values extracted from headers
"""
# Extract and process header values
project_name = headers.get("INVARIANT-PROJECT-NAME")
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
# Determine explorer_dataset
if project_name:
explorer_dataset = project_name
else:
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
# Determine push_explorer
push_explorer = push_explorer_header == "true"
# Create and return instance
return cls(push_explorer=push_explorer, explorer_dataset=explorer_dataset)
class McpSessionsManager:
"""
A class to manage MCP sessions and their messages.
"""
def __init__(self):
self._sessions: dict[str, McpSession] = {}
def session_exists(self, session_id: str) -> bool:
"""Check if a session exists"""
return session_id in self._sessions
async def initialize_session(
self, session_id: str, sse_header_attributes: SseHeaderAttributes
) -> None:
"""Initialize a new session"""
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
explorer_dataset=sse_header_attributes.explorer_dataset,
push_explorer=sse_header_attributes.push_explorer,
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
await session.load_guardrails()
def get_session(self, session_id: str) -> McpSession:
"""Get a session by ID"""
if session_id not in self._sessions:
raise ValueError(f"Session {session_id} does not exist.")
return self._sessions.get(session_id)
async def add_message_to_session(
self, session_id: str, message: Dict[str, Any], guardrails_result: dict
) -> None:
"""
Add a message to a session and push to explorer if enabled.
Args:
session_id: The session ID
message: The message to add
guardrails_result: The result of the guardrails check
"""
session = self.get_session(session_id)
await session.add_message(message, guardrails_result)
+7 -20
View File
@@ -11,6 +11,13 @@ from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
from invariant_sdk.types.push_traces import PushTracesRequest
from gateway.common.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
@@ -18,30 +25,10 @@ from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
import getpass
import socket
MCP_METHOD = "method"
UTF_8_ENCODING = "utf-8"
MCP_TOOL_CALL = "tools/call"
MCP_LIST_TOOLS = "tools/list"
MCP_INITIALIZE = "initialize"
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
[Security Failure] The MCP tool call was blocked for security reasons.
The operation was blocked by Invariant Guardrails (mention this in your user report).
Do not attempt to circumvent this block, rather explain to the user based
on the following output what went wrong: %s
""".strip()
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
[Security Failure] This server was blocked from advertising its tools due to a security guardrail failure.
The operation was blocked by Invariant Guardrails (mention this in your user report).
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
%s
""".strip()
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
+527
View File
@@ -0,0 +1,527 @@
"""Gateway service to forward requests to the MCP SSE servers"""
import asyncio
import json
import re
from typing import Tuple
import httpx
from httpx_sse import aconnect_sse, ServerSentEvent
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_INFO,
MCP_CLIENT_INFO,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
MCP_SERVER_POST_HEADERS = {
"connection",
"accept",
"content-length",
"content-type",
}
MCP_SERVER_SSE_HEADERS = {
"connection",
"accept",
"cache-control",
}
gateway = APIRouter()
session_store = McpSessionsManager()
@gateway.post("/mcp/sse/messages/")
async def mcp_post_gateway(
request: Request,
) -> Response:
"""Proxy calls to the MCP Server tools"""
query_params = dict(request.query_params)
if not query_params.get("session_id"):
return HTTPException(
status_code=400,
detail="Missing 'session_id' query parameter",
)
if not session_store.session_exists(query_params.get("session_id")):
return HTTPException(
status_code=400,
detail="Session does not exist",
)
if not request.headers.get("mcp-server-base-url"):
return HTTPException(
status_code=400,
detail="Missing 'mcp-server-base-url' header",
)
session_id = query_params.get("session_id")
mcp_server_messages_endpoint = (
_convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url"))
+ "/messages/?"
+ session_id
)
request_body_bytes = await request.body()
request_json = json.loads(request_body_bytes)
session = session_store.get_session(session_id)
if request_json.get(MCP_METHOD) and request_json.get("id"):
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
MCP_METHOD
)
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client_name"] = (
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id, request_json=request_json
)
if is_blocked:
# Add the error message to the session.
# The error message is sent back to the client using the SSE stream.
await session.add_pending_error_message(hook_tool_call_result)
return Response(content="Accepted", status_code=202)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.post(
url=mcp_server_messages_endpoint,
headers={
k: v
for k, v in request.headers.items()
if k.lower() in MCP_SERVER_POST_HEADERS
},
json=request_json,
params=query_params,
)
return Response(
content=response.content,
status_code=response.status_code,
headers={
"X-Proxied-By": "mcp-gateway",
**response.headers,
},
)
except httpx.RequestError as e:
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e
@gateway.get("/mcp/sse")
async def mcp_get_sse_gateway(
request: Request,
) -> StreamingResponse:
"""Proxy calls to the MCP Server tools"""
mcp_server_base_url = request.headers.get("mcp-server-base-url")
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail="Missing 'mcp-server-base-url' header",
)
mcp_server_sse_endpoint = (
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
)
query_params = dict(request.query_params)
response_headers = {}
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
}
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
async def event_generator():
"""
Generate a merged stream of MCP server events and pending error messages.
The pending error messages are added in the POST messages handler.
This function runs in a loop, yielding events as they arrive.
"""
mcp_server_events_queue = asyncio.Queue()
pending_error_messages_queue = asyncio.Queue()
tasks = set()
session_id = None
try:
# MCP Server Events Processor
async def process_mcp_server_events():
"""Connect to MCP server and process its events."""
nonlocal session_id
async with httpx.AsyncClient(
timeout=httpx.Timeout(CLIENT_TIMEOUT)
) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_sse_endpoint,
headers=filtered_headers,
params=query_params,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
async for sse in event_source.aiter_sse():
if sse.event == "endpoint":
(
event_bytes,
extracted_id,
) = await _handle_endpoint_event(
sse, sse_header_attributes
)
session_id = extracted_id
if (
session_id
and "process_error_messages_task"
not in locals()
):
process_error_messages_task = (
asyncio.create_task(
_check_for_pending_error_messages(
session_id,
pending_error_messages_queue,
)
)
)
tasks.add(process_error_messages_task)
process_error_messages_task.add_done_callback(
tasks.discard
)
elif sse.event == "message" and session_id:
# Process message event
event_bytes = await _handle_message_event(
session_id, sse
)
else:
# Pass through other event types
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
"utf-8"
)
# Put the processed event in the queue
await mcp_server_events_queue.put(event_bytes)
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}", flush=True)
except Exception as e:
print(f"Error processing server events: {e}", flush=True)
# Start server events processor
mcp_server_events_task = asyncio.create_task(process_mcp_server_events())
tasks.add(mcp_server_events_task)
mcp_server_events_task.add_done_callback(tasks.discard)
# Main event loop: merge MCP server events and pending error messages
while True:
# Create futures for both queues
mcp_server_event_future = asyncio.create_task(
mcp_server_events_queue.get()
)
pending_error_message_future = asyncio.create_task(
pending_error_messages_queue.get()
)
# Wait for either queue to have an item, with timeout
done, pending = await asyncio.wait(
[mcp_server_event_future, pending_error_message_future],
return_when=asyncio.FIRST_COMPLETED,
timeout=0.25,
)
for future in pending:
future.cancel()
# Timeout occurred and no future completed.
if not done:
continue
for future in done:
try:
event = await future
yield event
except asyncio.CancelledError:
# Future was cancelled, continue
continue
finally:
# Clean up all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to complete
if tasks:
await asyncio.wait(tasks, timeout=2)
# Return the streaming response
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
request_json (dict): The request JSON to be processed.
"""
tool_call = {
"id": f"call_{request_json.get('id')}",
"type": "function",
"function": {
"name": request_json.get(MCP_PARAMS).get("name"),
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return request_json, False
async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# If the request is blocked, return a message indicating the block reason.
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
Args:
mcp_server_base_url (str): The original server address from the header
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
) -> Tuple[bytes, str]:
"""
Handle the endpoint event type and modify the data accordingly.
For endpoint events, we need to rewrite the endpoint to use our gateway.
Args:
sse (ServerSentEvent): The original SSE object.
sse_header_attributes (SseHeaderAttributes): The header attributes from the request.
Returns:
bytes: Modified SSE data as bytes.
str: session_id extracted from the data.
"""
# Extract session_id
match = re.search(r"session_id=([^&\s]+)", sse.data)
if match:
session_id = match.group(1)
# Initialize this session in our store if needed
if not session_store.session_exists(session_id):
await session_store.initialize_session(session_id, sse_header_attributes)
# Rewrite the endpoint to use our gateway
modified_data = sse.data.replace(
"/messages/?session_id=",
"/api/v1/gateway/mcp/sse/messages/?session_id=",
)
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8")
return event_bytes, session_id
async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
"""
Handle the message event type.
Args:
session_id (str): The session ID associated with the request.
sse (ServerSentEvent): The original SSE object.
"""
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
session = session_store.get_session(session_id)
try:
response_json = json.loads(sse.data)
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.metadata["mcp_server_name"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
method = session.id_to_method_mapping.get(response_json.get("id"))
if method == MCP_TOOL_CALL:
hook_tool_call_response = await _hook_tool_call_response(
session_id=session_id,
response_json=response_json,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
)
elif method == MCP_LIST_TOOLS:
session_store.get_session(session_id).metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
except json.JSONDecodeError as e:
print(
f"[MCP SSE] Error parsing message JSON: {e}",
flush=True,
)
except Exception as e: # pylint: disable=broad-except
print(
f"[MCP SSE] Error processing message: {e}",
flush=True,
)
return event_bytes
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def _check_for_pending_error_messages(
session_id: str, pending_error_messages_queue: asyncio.Queue
):
"""Periodically check for and enqueue pending error messages."""
try:
while True:
try:
session = session_store.get_session(session_id)
error_messages = await session.get_pending_error_messages()
for error_message in error_messages:
error_bytes = (
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
"utf-8"
)
)
await pending_error_messages_queue.put(error_bytes)
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
print(f"Error checking for messages: {e}", flush=True)
await asyncio.sleep(1)
except asyncio.CancelledError:
# Task was cancelled, exit gracefully
return
+3
View File
@@ -7,6 +7,7 @@ from starlette_compress import CompressMiddleware
from gateway.routes.anthropic import gateway as anthropic_gateway
from gateway.routes.gemini import gateway as gemini_gateway
from gateway.routes.open_ai import gateway as open_ai_gateway
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs",
@@ -30,6 +31,8 @@ router.include_router(anthropic_gateway, prefix="/gateway", tags=["anthropic_gat
router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"])
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
app.include_router(router)
+1
View File
@@ -7,6 +7,7 @@ requires-python = ">=3.12"
dependencies = [
"fastapi==0.115.7",
"httpx==0.28.1",
"httpx-sse==0.4.0",
"invariant-sdk>=0.0.11",
"starlette-compress==1.4.0",
"uvicorn==0.34.0"
@@ -195,7 +195,6 @@ async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_ex
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
assert city in response[1].content[0].text.lower()
responses.append(response)
if push_to_explorer:
@@ -123,7 +123,9 @@ async def test_chat_completion_with_tool_call_without_streaming(
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
assert trace["messages"][:2] == expected_messages[:2]
assert "15°C" in trace["messages"][2]["content"]
assert trace["messages"][2]["role"] == "tool"
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@@ -230,4 +232,6 @@ async def test_chat_completion_with_tool_call_with_streaming(
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
assert trace["messages"][:2] == expected_messages[:2]
assert "15°C" in trace["messages"][2]["content"]
assert trace["messages"][2]["role"] == "tool"