diff --git a/agentic_security/probe_data/modules/test_fine_tuned.py b/agentic_security/probe_data/modules/test_fine_tuned.py new file mode 100644 index 0000000..cf3f4cc --- /dev/null +++ b/agentic_security/probe_data/modules/test_fine_tuned.py @@ -0,0 +1,68 @@ +import asyncio +import pytest +from agentic_security.probe_data.modules.fine_tuned import Module + + +@pytest.mark.asyncio +async def test_module_initialization(): + tools_inbox = asyncio.Queue() + prompt_groups = ["group1", "group2"] + opts = {"max_prompts": 1000, "batch_size": 100} + module = Module(prompt_groups, tools_inbox, opts) + + assert module.max_prompts == 1000 + assert module.batch_size == 100 + assert module.run_id is not None + + +@pytest.mark.asyncio +async def test_fetch_prompts(mocker): + tools_inbox = asyncio.Queue() + prompt_groups = ["group1", "group2"] + module = Module(prompt_groups, tools_inbox) + + mocker.patch( + "agentic_security.probe_data.modules.fine_tuned.httpx.AsyncClient.post", + return_value=mocker.Mock( + status_code=200, json=lambda: {"prompts": ["prompt1", "prompt2"]} + ), + ) + + prompts = await module.fetch_prompts() + assert prompts == ["prompt1", "prompt2"] + + +@pytest.mark.asyncio +async def test_post_prompt(mocker): + tools_inbox = asyncio.Queue() + prompt_groups = ["group1", "group2"] + module = Module(prompt_groups, tools_inbox) + + mocker.patch( + "agentic_security.probe_data.modules.fine_tuned.httpx.AsyncClient.post", + return_value=mocker.Mock(status_code=200, json=lambda: {"response": "success"}), + ) + + response = await module.post_prompt("test prompt") + assert response == {"response": "success"} + + +@pytest.mark.asyncio +async def test_apply(mocker): + tools_inbox = asyncio.Queue() + prompt_groups = ["group1", "group2"] + module = Module(prompt_groups, tools_inbox, {"max_prompts": 2, "batch_size": 1}) + + mocker.patch( + "agentic_security.probe_data.modules.fine_tuned.Module.fetch_prompts", + return_value=["prompt1", "prompt2"], + ) + mocker.patch( + "agentic_security.probe_data.modules.fine_tuned.Module.post_prompt", + return_value={"response": "success"}, + ) + + prompts = [prompt async for prompt in module.apply()] + # Adjust the assertion to account for batched processing + expected_prompts = ["prompt1", "prompt2", "prompt1", "prompt2"] + assert prompts == expected_prompts diff --git a/poetry.lock b/poetry.lock index a1bfc00..f83828d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2361,6 +2361,23 @@ pytest = "==8.*" [package.extras] testing = ["pytest-asyncio (==0.24.*)", "pytest-cov (==6.*)"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3037,4 +3054,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b9d8d47186a7ed4e1383f914ca8446de6571eebe7bafac500981bf4356a0754a" +content-hash = "4f40d67a422a8163c8938cd74f59dfb98290c0ee1ae658d44c35ba44e1d9fd12" diff --git a/pyproject.toml b/pyproject.toml index ba1e991..63995b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ pre-commit = "^4.0.1" langchain-groq = "^0.2.0" huggingface-hub = "^0.25.1" pytest-httpx = "^0.35.0" +pytest-mock = "^3.14.0" [tool.ruff] line-length = 120