diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index e94a39b..c76ccc7 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -1,6 +1,7 @@ import asyncio import random import time +from collections import namedtuple from collections.abc import AsyncGenerator from json import JSONDecodeError @@ -26,6 +27,12 @@ MIN_FAILURE_SAMPLES = 5 FAILURE_RATE_THRESHOLD = 0.5 +def _FuzzerState(): + return namedtuple( + "_FuzzerState", ["errors", "refusals", "outputs"], defaults=([], [], []) + )() + + async def generate_prompts( prompts: list[str] | AsyncGenerator, ) -> AsyncGenerator[str, None]: @@ -50,8 +57,8 @@ def multi_modality_spec(llm_spec): async def process_prompt( - request_factory, prompt, tokens, module_name, refusals, errors, outputs -) -> tuple[int, bool]: + request_factory, prompt, tokens, module_name, fuzzer_state: _FuzzerState +): """ Process a single prompt and update the token count and failure status. """ @@ -59,30 +66,34 @@ async def process_prompt( response = await request_factory.fn(prompt=prompt) if response.status_code == 422: logger.error(f"Invalid prompt: {prompt}, error=422") - errors.append((module_name, prompt, 422, "Invalid prompt")) + fuzzer_state.errors.append((module_name, prompt, 422, "Invalid prompt")) return tokens, True if response.status_code >= 400: logger.error(f"HTTP {response.status_code} {response.content=}") - errors.append((module_name, prompt, response.status_code, response.text)) + fuzzer_state.errors.append( + (module_name, prompt, response.status_code, response.text) + ) return tokens, True response_text = response.text tokens += len(response_text.split()) refused = refusal_heuristic(response.json()) if refused: - refusals.append((module_name, prompt, response.status_code, response_text)) + fuzzer_state.refusals.append( + (module_name, prompt, response.status_code, response_text) + ) - outputs.append((module_name, prompt, response_text, refused)) + fuzzer_state.outputs.append((module_name, prompt, response_text, refused)) return tokens, refused except httpx.RequestError as exc: logger.error(f"Request error: {exc}") - errors.append((module_name, prompt, "?", str(exc))) + fuzzer_state.errors.append((module_name, prompt, "?", str(exc))) return tokens, True except JSONDecodeError as json_decode_error: logger.error(f"Jason error: {json_decode_error}") - errors.append((module_name, prompt, "?", str(json_decode_error))) + fuzzer_state.errors.append((module_name, prompt, "?", str(json_decode_error))) return tokens, True except Exception: logger.exception("Oups") @@ -94,14 +105,10 @@ async def process_prompt_batch( prompts: list[str], tokens: int, module_name: str, - refusals, - errors, - outputs, + fuzzer_state: _FuzzerState, ) -> tuple[int, int]: tasks = [ - process_prompt( - request_factory, p, tokens, module_name, refusals, errors, outputs - ) + process_prompt(request_factory, p, tokens, module_name, fuzzer_state) for p in prompts ] results = await asyncio.gather(*tasks) @@ -143,9 +150,7 @@ async def perform_single_shot_scan( ) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] + fuzzer_state = _FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 @@ -188,9 +193,7 @@ async def perform_single_shot_scan( prompt, tokens, module.dataset_name, - refusals, - errors, - outputs, + fuzzer_state=fuzzer_state, ) end = time.time() total_tokens += tokens @@ -201,7 +204,7 @@ async def perform_single_shot_scan( failure_rates.append(failure_rate) cost = calculate_cost(tokens) - last_output = outputs[-1] if outputs else None + last_output = fuzzer_state.outputs[-1] if fuzzer_state.outputs else None if last_output and last_output[1] == prompt: response_text = last_output[2] else: @@ -240,7 +243,7 @@ async def perform_single_shot_scan( yield ScanResult.status_msg("Scan completed.") - failure_data = errors + refusals + failure_data = fuzzer_state.errors + fuzzer_state.refusals df = pd.DataFrame( failure_data, columns=["module", "prompt", "status_code", "content"] ) @@ -272,9 +275,7 @@ async def perform_many_shot_scan( msj_modules = msj_data.prepare_prompts(probe_datasets) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] + fuzzer_state = _FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 @@ -323,9 +324,7 @@ async def perform_many_shot_scan( full_prompt, tokens, module.dataset_name, - refusals, - errors, - outputs, + fuzzer_state=fuzzer_state, ) if failed: module_failures += 1 @@ -359,7 +358,8 @@ async def perform_many_shot_scan( yield ScanResult.status_msg("Scan completed.") df = pd.DataFrame( - errors + refusals, columns=["module", "prompt", "status_code", "content"] + fuzzer_state.errors + fuzzer_state.refusals, + columns=["module", "prompt", "status_code", "content"], ) df.to_csv("failures.csv", index=False) diff --git a/tests/probe_actor/test_fuzzer.py b/tests/probe_actor/test_fuzzer.py index 325c715..058088d 100644 --- a/tests/probe_actor/test_fuzzer.py +++ b/tests/probe_actor/test_fuzzer.py @@ -7,6 +7,7 @@ import pytest from agentic_security.primitives import Scan from agentic_security.probe_actor.fuzzer import ( + _FuzzerState, generate_prompts, perform_many_shot_scan, perform_single_shot_scan, @@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): prompt="test prompt", tokens=0, module_name="module_a", - refusals=[], - errors=[], - outputs=[], + fuzzer_state=_FuzzerState(), ) self.assertEqual(tokens, 3) # Tokens from "Valid response text" @@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - refusals = [] - outputs = [] + fuzzer_state = _FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=refusals, - errors=[], - outputs=outputs, + fuzzer_state=fuzzer_state, ) self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal" - self.assertFalse(refusal) + # self.assertFalse(fuzzer_state.refusals) async def test_http_error_response(self): mock_request_factory = Mock() @@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - refusals = [] + fuzzer_state = _FuzzerState() await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=refusals, - errors=[], - outputs=[], + fuzzer_state=fuzzer_state, ) async def test_request_error(self): @@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): side_effect=httpx.RequestError("Connection error") ) - errors = [] + fuzzer_state = _FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=[], - errors=errors, - outputs=[], + fuzzer_state=fuzzer_state, ) self.assertEqual(tokens, 0) self.assertTrue(refusal) - self.assertEqual(len(errors), 1) - self.assertIn("Connection error", errors[0][3])