mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Cleanup some code in test_mcp.py
This commit is contained in:
@@ -48,6 +48,15 @@ def _get_streamable_server_base_url(transport: str) -> str:
|
||||
return f"http://{host_info['host']}:{host_info['port']}"
|
||||
|
||||
|
||||
def _get_server_base_url(transport: str) -> str:
|
||||
if transport == "sse":
|
||||
return _get_mcp_sse_server_base_url()
|
||||
elif transport.startswith("streamable-"):
|
||||
return _get_streamable_server_base_url(transport)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport: {transport}")
|
||||
|
||||
|
||||
def _get_headers(
|
||||
server_base_url: str, project_name: str, push_to_explorer: bool = True
|
||||
) -> dict[str, str]:
|
||||
@@ -58,6 +67,35 @@ def _get_headers(
|
||||
}
|
||||
|
||||
|
||||
async def _invoke_mcp_tool(
|
||||
transport, gateway_url, project_name, tool_name, tool_args, whl=None, push=True
|
||||
):
|
||||
if transport == "stdio":
|
||||
return await mcp_stdio_client_run(
|
||||
whl,
|
||||
project_name,
|
||||
"resources/mcp/stdio/messenger_server/main.py",
|
||||
push,
|
||||
tool_name,
|
||||
tool_args,
|
||||
)
|
||||
elif transport == "sse":
|
||||
return await mcp_sse_client_run(
|
||||
f"{gateway_url}/api/v1/gateway/mcp/sse",
|
||||
push,
|
||||
tool_name,
|
||||
tool_args,
|
||||
headers=_get_headers(_get_server_base_url(transport), project_name, push),
|
||||
)
|
||||
return await mcp_streamable_client_run(
|
||||
f"{gateway_url}/api/v1/gateway/mcp/streamable",
|
||||
push,
|
||||
tool_name,
|
||||
tool_args,
|
||||
headers=_get_headers(_get_server_base_url(transport), project_name, push),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -81,34 +119,15 @@ async def test_mcp_with_gateway(
|
||||
project_name = "test-mcp-" + str(uuid.uuid4())
|
||||
|
||||
# Run the MCP client and make the tool call.
|
||||
if transport == "sse":
|
||||
result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
elif transport == "stdio":
|
||||
result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
server_script_path="resources/mcp/stdio/messenger_server/main.py",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
|
||||
)
|
||||
else:
|
||||
result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
result = await _invoke_mcp_tool(
|
||||
transport,
|
||||
gateway_url,
|
||||
project_name,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
whl=invariant_gateway_package_whl_file,
|
||||
push=True,
|
||||
)
|
||||
|
||||
assert result.isError is False
|
||||
assert (
|
||||
@@ -203,33 +222,15 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
)
|
||||
|
||||
# Run the MCP client and make the tool call.
|
||||
if transport == "sse":
|
||||
result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
elif transport == "stdio":
|
||||
result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
server_script_path="resources/mcp/stdio/messenger_server/main.py",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
else:
|
||||
result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
result = await _invoke_mcp_tool(
|
||||
transport,
|
||||
gateway_url,
|
||||
project_name,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
whl=invariant_gateway_package_whl_file,
|
||||
push=True,
|
||||
)
|
||||
|
||||
assert result.isError is False
|
||||
assert (
|
||||
@@ -658,33 +659,16 @@ async def test_mcp_tool_list_blocking(
|
||||
return
|
||||
|
||||
# Run the MCP client and make the tools/list call.
|
||||
if transport == "sse":
|
||||
tools_result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
elif transport == "stdio":
|
||||
tools_result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
server_script_path="resources/mcp/stdio/messenger_server/main.py",
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
)
|
||||
else:
|
||||
tools_result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
# Run the MCP client and make the tool call.
|
||||
tools_result = await _invoke_mcp_tool(
|
||||
transport,
|
||||
gateway_url,
|
||||
project_name,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
whl=invariant_gateway_package_whl_file,
|
||||
push=True,
|
||||
)
|
||||
assert "blocked_get_last_message_from_user" in str(tools_result), (
|
||||
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
|
||||
+ str(tools_result)
|
||||
|
||||
Reference in New Issue
Block a user