feat(add unit tests):

This commit is contained in:
Alexander Myasoedov
2025-01-03 00:10:34 +02:00
parent ffd7d710f1
commit 0a536ee999
3 changed files with 87 additions and 1 deletions
@@ -0,0 +1,68 @@
import asyncio
import pytest
from agentic_security.probe_data.modules.fine_tuned import Module
@pytest.mark.asyncio
async def test_module_initialization():
tools_inbox = asyncio.Queue()
prompt_groups = ["group1", "group2"]
opts = {"max_prompts": 1000, "batch_size": 100}
module = Module(prompt_groups, tools_inbox, opts)
assert module.max_prompts == 1000
assert module.batch_size == 100
assert module.run_id is not None
@pytest.mark.asyncio
async def test_fetch_prompts(mocker):
tools_inbox = asyncio.Queue()
prompt_groups = ["group1", "group2"]
module = Module(prompt_groups, tools_inbox)
mocker.patch(
"agentic_security.probe_data.modules.fine_tuned.httpx.AsyncClient.post",
return_value=mocker.Mock(
status_code=200, json=lambda: {"prompts": ["prompt1", "prompt2"]}
),
)
prompts = await module.fetch_prompts()
assert prompts == ["prompt1", "prompt2"]
@pytest.mark.asyncio
async def test_post_prompt(mocker):
tools_inbox = asyncio.Queue()
prompt_groups = ["group1", "group2"]
module = Module(prompt_groups, tools_inbox)
mocker.patch(
"agentic_security.probe_data.modules.fine_tuned.httpx.AsyncClient.post",
return_value=mocker.Mock(status_code=200, json=lambda: {"response": "success"}),
)
response = await module.post_prompt("test prompt")
assert response == {"response": "success"}
@pytest.mark.asyncio
async def test_apply(mocker):
tools_inbox = asyncio.Queue()
prompt_groups = ["group1", "group2"]
module = Module(prompt_groups, tools_inbox, {"max_prompts": 2, "batch_size": 1})
mocker.patch(
"agentic_security.probe_data.modules.fine_tuned.Module.fetch_prompts",
return_value=["prompt1", "prompt2"],
)
mocker.patch(
"agentic_security.probe_data.modules.fine_tuned.Module.post_prompt",
return_value={"response": "success"},
)
prompts = [prompt async for prompt in module.apply()]
# Adjust the assertion to account for batched processing
expected_prompts = ["prompt1", "prompt2", "prompt1", "prompt2"]
assert prompts == expected_prompts
Generated
+18 -1
View File
@@ -2361,6 +2361,23 @@ pytest = "==8.*"
[package.extras]
testing = ["pytest-asyncio (==0.24.*)", "pytest-cov (==6.*)"]
[[package]]
name = "pytest-mock"
version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
[package.dependencies]
pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -3037,4 +3054,4 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "b9d8d47186a7ed4e1383f914ca8446de6571eebe7bafac500981bf4356a0754a"
content-hash = "4f40d67a422a8163c8938cd74f59dfb98290c0ee1ae658d44c35ba44e1d9fd12"
+1
View File
@@ -56,6 +56,7 @@ pre-commit = "^4.0.1"
langchain-groq = "^0.2.0"
huggingface-hub = "^0.25.1"
pytest-httpx = "^0.35.0"
pytest-mock = "^3.14.0"
[tool.ruff]
line-length = 120