Cleanup some code in test_mcp.py

This commit is contained in:
Hemang
2025-05-27 23:01:59 +02:00
committed by Hemang Sarkar
parent bfb57029e6
commit 96826fa06d

View File

@@ -48,6 +48,15 @@ def _get_streamable_server_base_url(transport: str) -> str:
return f"http://{host_info['host']}:{host_info['port']}"
def _get_server_base_url(transport: str) -> str:
if transport == "sse":
return _get_mcp_sse_server_base_url()
elif transport.startswith("streamable-"):
return _get_streamable_server_base_url(transport)
else:
raise ValueError(f"Unknown transport: {transport}")
def _get_headers(
server_base_url: str, project_name: str, push_to_explorer: bool = True
) -> dict[str, str]:
@@ -58,6 +67,35 @@ def _get_headers(
}
async def _invoke_mcp_tool(
transport, gateway_url, project_name, tool_name, tool_args, whl=None, push=True
):
if transport == "stdio":
return await mcp_stdio_client_run(
whl,
project_name,
"resources/mcp/stdio/messenger_server/main.py",
push,
tool_name,
tool_args,
)
elif transport == "sse":
return await mcp_sse_client_run(
f"{gateway_url}/api/v1/gateway/mcp/sse",
push,
tool_name,
tool_args,
headers=_get_headers(_get_server_base_url(transport), project_name, push),
)
return await mcp_streamable_client_run(
f"{gateway_url}/api/v1/gateway/mcp/streamable",
push,
tool_name,
tool_args,
headers=_get_headers(_get_server_base_url(transport), project_name, push),
)
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
@@ -81,34 +119,15 @@ async def test_mcp_with_gateway(
project_name = "test-mcp-" + str(uuid.uuid4())
# Run the MCP client and make the tool call.
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
elif transport == "stdio":
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
)
else:
result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
result = await _invoke_mcp_tool(
transport,
gateway_url,
project_name,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
whl=invariant_gateway_package_whl_file,
push=True,
)
assert result.isError is False
assert (
@@ -203,33 +222,15 @@ async def test_mcp_with_gateway_and_logging_guardrails(
)
# Run the MCP client and make the tool call.
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
elif transport == "stdio":
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
else:
result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
result = await _invoke_mcp_tool(
transport,
gateway_url,
project_name,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
whl=invariant_gateway_package_whl_file,
push=True,
)
assert result.isError is False
assert (
@@ -658,33 +659,16 @@ async def test_mcp_tool_list_blocking(
return
# Run the MCP client and make the tools/list call.
if transport == "sse":
tools_result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
elif transport == "stdio":
tools_result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
)
else:
tools_result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
# Run the MCP client and make the tool call.
tools_result = await _invoke_mcp_tool(
transport,
gateway_url,
project_name,
tool_name="tools/list",
tool_args={},
whl=invariant_gateway_package_whl_file,
push=True,
)
assert "blocked_get_last_message_from_user" in str(tools_result), (
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result)