diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index c76ccc7..088bb70 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -1,9 +1,9 @@ import asyncio import random import time -from collections import namedtuple from collections.abc import AsyncGenerator from json import JSONDecodeError +from typing import Any import httpx import pandas as pd @@ -18,8 +18,7 @@ from agentic_security.probe_actor.refusal import refusal_heuristic from agentic_security.probe_data import audio_generator, image_generator, msj_data from agentic_security.probe_data.data import prepare_prompts -# TODO: full log file - +# Constants MAX_PROMPT_LENGTH = 2048 BUDGET_MULTIPLIER = 100_000_000 INITIAL_OPTIMIZER_POINTS = 25 @@ -27,15 +26,56 @@ MIN_FAILURE_SAMPLES = 5 FAILURE_RATE_THRESHOLD = 0.5 -def _FuzzerState(): - return namedtuple( - "_FuzzerState", ["errors", "refusals", "outputs"], defaults=([], [], []) - )() +class FuzzerState: + """Container for tracking scan results""" + + def __init__(self): + self.errors = [] + self.refusals = [] + self.outputs = [] + + def add_error( + self, + module_name: str, + prompt: str, + status_code: int | str, + error_msg: str, + ): + """Add an error to the state""" + self.errors.append((module_name, prompt, status_code, error_msg)) + + def add_refusal( + self, module_name: str, prompt: str, status_code: int, response_text: str + ): + """Add a refusal to the state""" + self.refusals.append((module_name, prompt, status_code, response_text)) + + def add_output( + self, module_name: str, prompt: str, response_text: str, refused: bool + ): + """Add an output to the state""" + self.outputs.append((module_name, prompt, response_text, refused)) + + def get_last_output(self, prompt: str) -> str | None: + """Get the last output for a given prompt""" + for output in reversed(self.outputs): + if output[1] == prompt: + return output[2] + return None + + def export_failures(self, filename: str = "failures.csv"): + """Export failures to a CSV file""" + failure_data = self.errors + self.refusals + df = pd.DataFrame( + failure_data, columns=["module", "prompt", "status_code", "content"] + ) + df.to_csv(filename, index=False) async def generate_prompts( prompts: list[str] | AsyncGenerator, ) -> AsyncGenerator[str, None]: + """Convert list of prompts or async generator to a unified async generator.""" if isinstance(prompts, list): for prompt in prompts: yield prompt @@ -44,7 +84,8 @@ async def generate_prompts( yield prompt -def multi_modality_spec(llm_spec): +def get_modality_adapter(llm_spec): + """Get the appropriate modality adapter based on the LLM spec.""" match llm_spec.modality: case Modality.IMAGE: return image_generator.RequestAdapter(llm_spec) @@ -57,46 +98,65 @@ def multi_modality_spec(llm_spec): async def process_prompt( - request_factory, prompt, tokens, module_name, fuzzer_state: _FuzzerState -): + request_factory, + prompt: str, + tokens: int, + module_name: str, + fuzzer_state: FuzzerState, +) -> tuple[int, bool]: """ Process a single prompt and update the token count and failure status. + + Args: + request_factory: The factory for creating requests + prompt: The prompt to process + tokens: Current token count + module_name: Name of the module being processed + fuzzer_state: State tracking object for the fuzzer + + Returns: + Tuple of (updated token count, whether the prompt resulted in a failure) """ try: response = await request_factory.fn(prompt=prompt) + + # Handle HTTP errors if response.status_code == 422: logger.error(f"Invalid prompt: {prompt}, error=422") - fuzzer_state.errors.append((module_name, prompt, 422, "Invalid prompt")) + fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt") return tokens, True if response.status_code >= 400: logger.error(f"HTTP {response.status_code} {response.content=}") - fuzzer_state.errors.append( - (module_name, prompt, response.status_code, response.text) + fuzzer_state.add_error( + module_name, prompt, response.status_code, response.text ) return tokens, True + + # Process successful response response_text = response.text tokens += len(response_text.split()) + # Check if the response indicates a refusal refused = refusal_heuristic(response.json()) if refused: - fuzzer_state.refusals.append( - (module_name, prompt, response.status_code, response_text) + fuzzer_state.add_refusal( + module_name, prompt, response.status_code, response_text ) - fuzzer_state.outputs.append((module_name, prompt, response_text, refused)) + fuzzer_state.add_output(module_name, prompt, response_text, refused) return tokens, refused except httpx.RequestError as exc: logger.error(f"Request error: {exc}") - fuzzer_state.errors.append((module_name, prompt, "?", str(exc))) + fuzzer_state.add_error(module_name, prompt, "?", str(exc)) return tokens, True except JSONDecodeError as json_decode_error: - logger.error(f"Jason error: {json_decode_error}") - fuzzer_state.errors.append((module_name, prompt, "?", str(json_decode_error))) + logger.error(f"JSON error: {json_decode_error}") + fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error)) return tokens, True - except Exception: - logger.exception("Oups") + except Exception as e: + logger.exception(f"Unexpected error: {e}") return tokens, False @@ -105,8 +165,21 @@ async def process_prompt_batch( prompts: list[str], tokens: int, module_name: str, - fuzzer_state: _FuzzerState, + fuzzer_state: FuzzerState, ) -> tuple[int, int]: + """ + Process a batch of prompts in parallel. + + Args: + request_factory: The factory for creating requests + prompts: List of prompts to process + tokens: Current token count + module_name: Name of the module being processed + fuzzer_state: State tracking object for the fuzzer + + Returns: + Tuple of (updated token count, number of failures) + """ tasks = [ process_prompt(request_factory, p, tokens, module_name, fuzzer_state) for p in prompts @@ -117,7 +190,124 @@ async def process_prompt_batch( return total_tokens, failures +async def scan_module( + request_factory, + module, + fuzzer_state: FuzzerState, + processed_prompts: int = 0, + total_prompts: int = 0, + max_budget: int = 0, + total_tokens: int = 0, + optimize: bool = False, + optimizer=None, + stop_event: asyncio.Event | None = None, +) -> AsyncGenerator[dict[str, Any], None]: + """ + Scan a single module. + + Args: + request_factory: The factory for creating requests + module: The prompt module to scan + fuzzer_state: State tracking object for the fuzzer + processed_prompts: Number of prompts processed so far + total_prompts: Total number of prompts to process + max_budget: Maximum token budget + total_tokens: Current token count + optimize: Whether to use optimization + optimizer: The optimizer to use + stop_event: Event to stop scanning + + Yields: + ScanResult objects as the scan progresses + """ + tokens = 0 + module_failures = 0 + module_prompts = 0 + failure_rates = [] + should_stop = False + + module_size = 0 if module.lazy else len(module.prompts) + logger.info(f"Scanning {module.dataset_name} {module_size}") + + async for prompt in generate_prompts(module.prompts): + if stop_event and stop_event.is_set(): + stop_event.clear() + logger.info("Scan stopped by user.") + yield ScanResult.status_msg("Scan stopped by user.") + return + + processed_prompts += 1 + module_prompts += 1 + + # Calculate progress based on total processed prompts + progress = 100 * processed_prompts / total_prompts if total_prompts else 0 + progress = progress % 100 + + total_tokens -= tokens + start = time.time() + + tokens, failed = await process_prompt( + request_factory, + prompt, + tokens, + module.dataset_name, + fuzzer_state=fuzzer_state, + ) + + end = time.time() + total_tokens += tokens + + if failed: + module_failures += 1 + + failure_rate = module_failures / max(module_prompts, 1) + failure_rates.append(failure_rate) + cost = calculate_cost(tokens) + + response_text = fuzzer_state.get_last_output(prompt) or "" + + yield ScanResult( + module=module.dataset_name, + tokens=round(tokens / 1000, 1), + cost=cost, + progress=round(progress, 2), + failureRate=round(failure_rate * 100, 2), + prompt=prompt[:MAX_PROMPT_LENGTH], + latency=end - start, + model=response_text, + ).model_dump_json() + + # Optimization logic + if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES: + next_point = optimizer.ask() + optimizer.tell(next_point, -failure_rate) + best_failure_rate = -optimizer.get_result().fun + if best_failure_rate > FAILURE_RATE_THRESHOLD: + yield ScanResult.status_msg( + f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." + ) + should_stop = True + break + + # Budget check + if total_tokens > max_budget: + logger.info( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + yield ScanResult.status_msg( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + should_stop = True + break + + if should_stop: + break + + return + + async def with_error_handling(agen): + """Wrapper to handle errors in async generators.""" try: async for t in agen: yield t @@ -133,14 +323,29 @@ async def perform_single_shot_scan( max_budget: int, datasets: list[dict[str, str]] = [], tools_inbox=None, - optimize=False, - stop_event: asyncio.Event = None, + optimize: bool = False, + stop_event: asyncio.Event | None = None, secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: - """Perform a standard security scan.""" + """ + Perform a standard security scan across all selected datasets. + + Args: + request_factory: The factory for creating requests + max_budget: Maximum token budget + datasets: List of datasets to scan + tools_inbox: Tools inbox + optimize: Whether to use optimization + stop_event: Event to stop scanning + secrets: Secrets to use in the scan + + Yields: + ScanResult objects as the scan progresses + """ max_budget = max_budget * BUDGET_MULTIPLIER selected_datasets = [m for m in datasets if m["selected"]] - request_factory = multi_modality_spec(request_factory) + request_factory = get_modality_adapter(request_factory) + yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( dataset_names=[m["dataset_name"] for m in selected_datasets], @@ -150,104 +355,44 @@ async def perform_single_shot_scan( ) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - fuzzer_state = _FuzzerState() + fuzzer_state = FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) + Optimizer( + [Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS + ) if optimize else None ) - failure_rates = [] total_tokens = 0 - tokens = 0 - should_stop = False for module in prompt_modules: - if should_stop: - break - tokens = 0 - module_failures = 0 + module_gen = scan_module( + request_factory=request_factory, + module=module, + fuzzer_state=fuzzer_state, + processed_prompts=processed_prompts, + total_prompts=total_prompts, + max_budget=max_budget, + total_tokens=total_tokens, + optimize=optimize, + optimizer=optimizer, + stop_event=stop_event, + ) + try: + async for result in module_gen: + yield result + except Exception: + logger.error("Module exception") + continue + # Update processed_prompts count module_size = 0 if module.lazy else len(module.prompts) - logger.info(f"Scanning {module.dataset_name} {module_size}") - module_prompts = 0 # Reset for each module - - async for prompt in generate_prompts(module.prompts): - if stop_event and stop_event.is_set(): - stop_event.clear() - logger.info("Scan stopped by user.") - yield ScanResult.status_msg("Scan stopped by user.") - return - - processed_prompts += 1 - module_prompts += 1 # Fixed increment syntax - # Calculate progress based on total processed prompts - progress = 100 * processed_prompts / total_prompts if total_prompts else 0 - progress = progress % 100 - - total_tokens -= tokens - start = time.time() - tokens, failed = await process_prompt( - request_factory, - prompt, - tokens, - module.dataset_name, - fuzzer_state=fuzzer_state, - ) - end = time.time() - total_tokens += tokens - - if failed: - module_failures += 1 - failure_rate = module_failures / max(module_prompts, 1) - failure_rates.append(failure_rate) - cost = calculate_cost(tokens) - - last_output = fuzzer_state.outputs[-1] if fuzzer_state.outputs else None - if last_output and last_output[1] == prompt: - response_text = last_output[2] - else: - response_text = "" - - yield ScanResult( - module=module.dataset_name, - tokens=round(tokens / 1000, 1), - cost=cost, - progress=round(progress, 2), - failureRate=round(failure_rate * 100, 2), - prompt=prompt[:MAX_PROMPT_LENGTH], - latency=end - start, - model=response_text, - ).model_dump_json() - - if optimize and len(failure_rates) >= 5: - next_point = optimizer.ask() - optimizer.tell(next_point, -failure_rate) - best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: - yield ScanResult.status_msg( - f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." - ) - should_stop = True - break - if total_tokens > max_budget: - logger.info( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" - ) - yield ScanResult.status_msg( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" - ) - should_stop = True - break + processed_prompts += module_size yield ScanResult.status_msg("Scan completed.") - - failure_data = fuzzer_state.errors + fuzzer_state.refusals - df = pd.DataFrame( - failure_data, columns=["module", "prompt", "status_code", "content"] - ) - df.to_csv("failures.csv", index=False) + fuzzer_state.export_failures("failures.csv") async def perform_many_shot_scan( @@ -256,14 +401,32 @@ async def perform_many_shot_scan( datasets: list[dict[str, str]] = [], probe_datasets: list[dict[str, str]] = [], tools_inbox=None, - optimize=False, - stop_event: asyncio.Event = None, + optimize: bool = False, + stop_event: asyncio.Event | None = None, probe_frequency: float = 0.2, max_ctx_length: int = 10_000, secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: - """Perform a multi-step security scan with probe injection.""" - request_factory = multi_modality_spec(request_factory) + """ + Perform a multi-step security scan with probe injection. + + Args: + request_factory: The factory for creating requests + max_budget: Maximum token budget + datasets: List of datasets to scan + probe_datasets: List of probe datasets to inject + tools_inbox: Tools inbox + optimize: Whether to use optimization + stop_event: Event to stop scanning + probe_frequency: Frequency of probe injection + max_ctx_length: Maximum context length + secrets: Secrets to use in the scan + + Yields: + ScanResult objects as the scan progresses + """ + request_factory = get_modality_adapter(request_factory) + # Load main and probe datasets yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( @@ -275,12 +438,14 @@ async def perform_many_shot_scan( msj_modules = msj_data.prepare_prompts(probe_datasets) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - fuzzer_state = _FuzzerState() + fuzzer_state = FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) + Optimizer( + [Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS + ) if optimize else None ) @@ -297,6 +462,7 @@ async def perform_many_shot_scan( logger.info("Scan stopped by user.") yield ScanResult.status_msg("Scan stopped by user.") return + tokens = 0 processed_prompts += 1 progress = 100 * processed_prompts / total_prompts if total_prompts else 0 @@ -345,31 +511,38 @@ async def perform_many_shot_scan( prompt=prompt[:MAX_PROMPT_LENGTH], ).model_dump_json() - if optimize and len(failure_rates) >= 5: + if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES: next_point = optimizer.ask() optimizer.tell(next_point, -failure_rate) best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: + if best_failure_rate > FAILURE_RATE_THRESHOLD: yield ScanResult.status_msg( f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." ) break yield ScanResult.status_msg("Scan completed.") - - df = pd.DataFrame( - fuzzer_state.errors + fuzzer_state.refusals, - columns=["module", "prompt", "status_code", "content"], - ) - df.to_csv("failures.csv", index=False) + fuzzer_state.export_failures("failures.csv") def scan_router( request_factory, scan_parameters: Scan, tools_inbox=None, - stop_event: asyncio.Event = None, + stop_event: asyncio.Event | None = None, ): + """ + Route to the appropriate scan function based on scan parameters. + + Args: + request_factory: The factory for creating requests + scan_parameters: Scan parameters + tools_inbox: Tools inbox + stop_event: Event to stop scanning + + Returns: + Async generator of scan results + """ if scan_parameters.enableMultiStepAttack: return with_error_handling( perform_many_shot_scan( diff --git a/tests/probe_actor/test_fuzzer.py b/tests/probe_actor/test_fuzzer.py index 058088d..30cabba 100644 --- a/tests/probe_actor/test_fuzzer.py +++ b/tests/probe_actor/test_fuzzer.py @@ -7,7 +7,7 @@ import pytest from agentic_security.primitives import Scan from agentic_security.probe_actor.fuzzer import ( - _FuzzerState, + FuzzerState, generate_prompts, perform_many_shot_scan, perform_single_shot_scan, @@ -208,7 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): prompt="test prompt", tokens=0, module_name="module_a", - fuzzer_state=_FuzzerState(), + fuzzer_state=FuzzerState(), ) self.assertEqual(tokens, 3) # Tokens from "Valid response text" @@ -225,7 +225,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - fuzzer_state = _FuzzerState() + fuzzer_state = FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt", @@ -248,7 +248,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - fuzzer_state = _FuzzerState() + fuzzer_state = FuzzerState() await process_prompt( request_factory=mock_request_factory, prompt="test prompt", @@ -263,7 +263,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): side_effect=httpx.RequestError("Connection error") ) - fuzzer_state = _FuzzerState() + fuzzer_state = FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt",