mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 14:19:55 +02:00
286 lines
8.3 KiB
Python
286 lines
8.3 KiB
Python
import asyncio
|
|
import unittest
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from agentic_security.primitives import Scan
|
|
from agentic_security.probe_actor.fuzzer import (
|
|
FuzzerState,
|
|
generate_prompts,
|
|
perform_many_shot_scan,
|
|
perform_single_shot_scan,
|
|
process_prompt,
|
|
scan_router,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_prompts_with_list():
|
|
prompts = ["prompt1", "prompt2", "prompt3"]
|
|
results = [p async for p in generate_prompts(prompts)]
|
|
assert results == prompts
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_prompts_with_async_generator():
|
|
async def async_gen():
|
|
for i in range(3):
|
|
yield f"prompt{i}"
|
|
|
|
results = [p async for p in generate_prompts(async_gen())]
|
|
assert results == ["prompt0", "prompt1", "prompt2"]
|
|
|
|
|
|
async def assert_scan(generator, messages):
|
|
results = [r async for r in generator]
|
|
|
|
for m in messages:
|
|
found = False
|
|
for r in results:
|
|
if m in r:
|
|
found = True
|
|
break
|
|
assert found, f"Message '{m}' not found in results. Results: {results}"
|
|
return results
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("agentic_security.probe_data.data.prepare_prompts")
|
|
async def test_perform_single_shot_scan_success(prepare_prompts_mock):
|
|
# Mock prompt modules
|
|
prepare_prompts_mock.return_value = [
|
|
MagicMock(
|
|
dataset_name="test_module",
|
|
prompts=["test_prompt1", "test_prompt2"],
|
|
lazy=False,
|
|
)
|
|
]
|
|
|
|
# Mock request_factory
|
|
mock_response = AsyncMock()
|
|
mock_response.fn.return_value = AsyncMock(
|
|
status_code=200, text="response text", json=lambda: {}
|
|
)
|
|
request_factory = mock_response
|
|
|
|
async_gen = perform_single_shot_scan(
|
|
request_factory=request_factory,
|
|
max_budget=100,
|
|
datasets=[{"dataset_name": "test", "selected": True}],
|
|
optimize=False,
|
|
)
|
|
|
|
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
|
|
@patch("agentic_security.probe_data.data.prepare_prompts")
|
|
async def test_perform_many_shot_scan_probe_injection(
|
|
prepare_prompts_mock, msj_prepare_prompts_mock
|
|
):
|
|
# Mock main and probe prompt modules
|
|
prepare_prompts_mock.side_effect = [
|
|
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
|
|
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
|
|
]
|
|
|
|
msj_prepare_prompts_mock.return_value = [
|
|
MagicMock(
|
|
dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False
|
|
)
|
|
]
|
|
|
|
# Mock request_factory
|
|
mock_response = AsyncMock()
|
|
mock_response.fn.side_effect = [
|
|
AsyncMock(status_code=200, text="main response", json=lambda: {}),
|
|
AsyncMock(status_code=200, text="probe response", json=lambda: {}),
|
|
]
|
|
request_factory = mock_response
|
|
|
|
async_gen = perform_many_shot_scan(
|
|
request_factory=request_factory,
|
|
max_budget=100,
|
|
datasets=[{"dataset_name": "main", "selected": True}],
|
|
probe_datasets=[{"dataset_name": "probe", "selected": True}],
|
|
probe_frequency=1.0, # Always inject probes
|
|
optimize=False,
|
|
)
|
|
|
|
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("agentic_security.probe_data.data.prepare_prompts")
|
|
async def test_scan_router_single_shot(prepare_prompts_mock):
|
|
prepare_prompts_mock.return_value = []
|
|
|
|
request_factory = AsyncMock()
|
|
scan_params = Scan(
|
|
maxBudget=100,
|
|
llmSpec="test",
|
|
datasets=[],
|
|
probe_datasets=[],
|
|
enableMultiStepAttack=False,
|
|
optimize=False,
|
|
)
|
|
|
|
gen = scan_router(
|
|
request_factory=request_factory,
|
|
scan_parameters=scan_params,
|
|
)
|
|
await assert_scan(gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("agentic_security.probe_data.data.prepare_prompts")
|
|
async def test_scan_router_many_shot(prepare_prompts_mock):
|
|
prepare_prompts_mock.return_value = []
|
|
|
|
request_factory = AsyncMock()
|
|
scan_params = Scan(
|
|
maxBudget=100,
|
|
datasets=[],
|
|
llmSpec="test",
|
|
probeDatasets=[],
|
|
enableMultiStepAttack=True,
|
|
optimize=False,
|
|
)
|
|
|
|
gen = scan_router(
|
|
request_factory=request_factory,
|
|
scan_parameters=scan_params,
|
|
)
|
|
assert gen is not None
|
|
|
|
await assert_scan(gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perform_single_shot_scan_stop_event():
|
|
stop_event = asyncio.Event()
|
|
stop_event.set() # Pre-set to simulate user stopping the scan
|
|
|
|
async def request_mock(*args, **kwargs):
|
|
return AsyncMock(status_code=200, text="response text", json=lambda: {})
|
|
|
|
async_gen = perform_single_shot_scan(
|
|
request_factory=MagicMock(fn=request_mock),
|
|
max_budget=100,
|
|
datasets=[],
|
|
stop_event=stop_event,
|
|
)
|
|
|
|
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perform_many_shot_scan_stop_event():
|
|
stop_event = asyncio.Event()
|
|
stop_event.set() # Pre-set to simulate user stopping the scan
|
|
|
|
async def request_mock(*args, **kwargs):
|
|
return AsyncMock(status_code=200, text="response text", json=lambda: {})
|
|
|
|
async_gen = perform_many_shot_scan(
|
|
request_factory=MagicMock(fn=request_mock),
|
|
max_budget=100,
|
|
datasets=[],
|
|
probe_datasets=[],
|
|
stop_event=stop_event,
|
|
)
|
|
|
|
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
|
|
|
|
|
def mock_refusal_heuristic(response_json):
|
|
return response_json.get("is_refusal", False)
|
|
|
|
|
|
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
|
async def test_successful_response_no_refusal(self):
|
|
mock_request_factory = Mock()
|
|
mock_request_factory.fn = AsyncMock(
|
|
return_value=Mock(
|
|
status_code=200,
|
|
text="Valid response text",
|
|
json=Mock(return_value={"is_refusal": False}),
|
|
request="mock_request",
|
|
)
|
|
)
|
|
|
|
tokens, refusal = await process_prompt(
|
|
request_factory=mock_request_factory,
|
|
prompt="test prompt",
|
|
tokens=0,
|
|
module_name="module_a",
|
|
fuzzer_state=FuzzerState(),
|
|
)
|
|
|
|
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
|
self.assertTrue(refusal)
|
|
|
|
async def test_successful_response_with_refusal(self):
|
|
mock_request_factory = Mock()
|
|
mock_request_factory.fn = AsyncMock(
|
|
return_value=Mock(
|
|
status_code=200,
|
|
text="Response indicating refusal",
|
|
json=Mock(return_value={"is_refusal": True}),
|
|
request="mock_request",
|
|
)
|
|
)
|
|
|
|
fuzzer_state = FuzzerState()
|
|
tokens, refusal = await process_prompt(
|
|
request_factory=mock_request_factory,
|
|
prompt="test prompt",
|
|
tokens=0,
|
|
module_name="module_a",
|
|
fuzzer_state=fuzzer_state,
|
|
)
|
|
|
|
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
|
# self.assertFalse(fuzzer_state.refusals)
|
|
|
|
async def test_http_error_response(self):
|
|
mock_request_factory = Mock()
|
|
mock_request_factory.fn = AsyncMock(
|
|
return_value=Mock(
|
|
status_code=500,
|
|
text="Internal Server Error",
|
|
request="mock_request",
|
|
response=Mock(),
|
|
)
|
|
)
|
|
|
|
fuzzer_state = FuzzerState()
|
|
await process_prompt(
|
|
request_factory=mock_request_factory,
|
|
prompt="test prompt",
|
|
tokens=0,
|
|
module_name="module_a",
|
|
fuzzer_state=fuzzer_state,
|
|
)
|
|
|
|
async def test_request_error(self):
|
|
mock_request_factory = Mock()
|
|
mock_request_factory.fn = AsyncMock(
|
|
side_effect=httpx.RequestError("Connection error")
|
|
)
|
|
|
|
fuzzer_state = FuzzerState()
|
|
tokens, refusal = await process_prompt(
|
|
request_factory=mock_request_factory,
|
|
prompt="test prompt",
|
|
tokens=0,
|
|
module_name="module_a",
|
|
fuzzer_state=fuzzer_state,
|
|
)
|
|
|
|
self.assertEqual(tokens, 0)
|
|
self.assertTrue(refusal)
|