diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index a7b7180..889d7e6 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -3,9 +3,9 @@ import random import time from collections.abc import AsyncGenerator from json import JSONDecodeError +from typing import Any import httpx -import pandas as pd from skopt import Optimizer from skopt.space import Real @@ -14,11 +14,11 @@ from agentic_security.logutils import logger from agentic_security.primitives import Scan, ScanResult from agentic_security.probe_actor.cost_module import calculate_cost from agentic_security.probe_actor.refusal import refusal_heuristic +from agentic_security.probe_actor.state import FuzzerState from agentic_security.probe_data import audio_generator, image_generator, msj_data from agentic_security.probe_data.data import prepare_prompts -# TODO: full log file - +# Constants MAX_PROMPT_LENGTH = 2048 BUDGET_MULTIPLIER = 100_000_000 INITIAL_OPTIMIZER_POINTS = 25 @@ -76,7 +76,11 @@ def multi_modality_spec(llm_spec): async def process_prompt( - request_factory, prompt, tokens, module_name, refusals, errors, outputs + request_factory, + prompt: str, + tokens: int, + module_name: str, + fuzzer_state: FuzzerState, ) -> tuple[int, bool]: """ Processes a single prompt using the provided request factory and updates tracking lists. @@ -91,42 +95,52 @@ async def process_prompt( prompt (str): The input prompt to be processed. tokens (int): The current token count, which will be updated. module_name (str): The name of the module handling the request. - refusals (list): A list to store prompts that were refused. - errors (list): A list to store prompts that encountered errors. - outputs (list): A list to store processed prompt outputs. + fuzzer_state: State tracking object for the fuzzer Returns: tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused. """ try: response = await request_factory.fn(prompt=prompt) + + # Handle HTTP errors if response.status_code == 422: logger.error(f"Invalid prompt: {prompt}, error=422") - errors.append((module_name, prompt, 422, "Invalid prompt")) + fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt") return tokens, True if response.status_code >= 400: logger.error(f"HTTP {response.status_code} {response.content=}") - errors.append((module_name, prompt, response.status_code, response.text)) + fuzzer_state.add_error( + module_name, prompt, response.status_code, response.text + ) return tokens, True + + # Process successful response response_text = response.text tokens += len(response_text.split()) + # Check if the response indicates a refusal refused = refusal_heuristic(response.json()) if refused: - refusals.append((module_name, prompt, response.status_code, response_text)) + fuzzer_state.add_refusal( + module_name, prompt, response.status_code, response_text + ) - outputs.append((module_name, prompt, response_text, refused)) + fuzzer_state.add_output(module_name, prompt, response_text, refused) return tokens, refused except httpx.RequestError as exc: logger.error(f"Request error: {exc}") - errors.append((module_name, prompt, "?", str(exc))) + fuzzer_state.add_error(module_name, prompt, "?", str(exc)) return tokens, True except JSONDecodeError as json_decode_error: - logger.error(f"Jason error: {json_decode_error}") - errors.append((module_name, prompt, "?", str(json_decode_error))) + logger.error(f"JSON error: {json_decode_error}") + fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error)) return tokens, True + except Exception as e: + logger.exception(f"Unexpected error: {e}") + return tokens, False async def process_prompt_batch( @@ -134,9 +148,7 @@ async def process_prompt_batch( prompts: list[str], tokens: int, module_name: str, - refusals, - errors, - outputs, + fuzzer_state: FuzzerState, ) -> tuple[int, int]: """ Processes a batch of prompts asynchronously and aggregates the results. @@ -150,9 +162,7 @@ async def process_prompt_batch( prompts (list[str]): A list of input prompts to be processed. tokens (int): The initial token count, which will be updated. module_name (str): The name of the module handling the request. - refusals (list): A list to store prompts that were refused. - errors (list): A list to store prompts that encountered errors. - outputs (list): A list to store processed prompt outputs. + fuzzer_state: State tracking object for the fuzzer Returns: tuple[int, int]: @@ -160,9 +170,7 @@ async def process_prompt_batch( - Number of failed prompts. """ tasks = [ - process_prompt( - request_factory, p, tokens, module_name, refusals, errors, outputs - ) + process_prompt(request_factory, p, tokens, module_name, fuzzer_state) for p in prompts ] results = await asyncio.gather(*tasks) @@ -171,6 +179,129 @@ async def process_prompt_batch( return total_tokens, failures +async def scan_module( + request_factory, + module, + fuzzer_state: FuzzerState, + processed_prompts: int = 0, + total_prompts: int = 0, + max_budget: int = 0, + total_tokens: int = 0, + optimize: bool = False, + stop_event: asyncio.Event | None = None, +) -> AsyncGenerator[dict[str, Any], None]: + """ + Scan a single module. + + Args: + request_factory: The factory for creating requests + module: The prompt module to scan + fuzzer_state: State tracking object for the fuzzer + processed_prompts: Number of prompts processed so far + total_prompts: Total number of prompts to process + max_budget: Maximum token budget + total_tokens: Current token count + optimize: Whether to use optimization + stop_event: Event to stop scanning + + Yields: + ScanResult objects as the scan progresses + """ + tokens = 0 + module_failures = 0 + module_prompts = 0 + failure_rates = [] + should_stop = False + + # Initialize optimizer if optimization is enabled + optimizer = ( + Optimizer( + [Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS + ) + if optimize + else None + ) + + module_size = 0 if module.lazy else len(module.prompts) + logger.info(f"Scanning {module.dataset_name} {module_size}") + + async for prompt in generate_prompts(module.prompts): + if stop_event and stop_event.is_set(): + stop_event.clear() + logger.info("Scan stopped by user.") + yield ScanResult.status_msg("Scan stopped by user.") + return + + processed_prompts += 1 + module_prompts += 1 + + # Calculate progress based on total processed prompts + progress = 100 * processed_prompts / total_prompts if total_prompts else 0 + progress = progress % 100 + + total_tokens -= tokens + start = time.time() + + tokens, failed = await process_prompt( + request_factory, + prompt, + tokens, + module.dataset_name, + fuzzer_state=fuzzer_state, + ) + + end = time.time() + total_tokens += tokens + + if failed: + module_failures += 1 + + failure_rate = module_failures / max(module_prompts, 1) + failure_rates.append(failure_rate) + cost = calculate_cost(tokens) + + response_text = fuzzer_state.get_last_output(prompt) or "" + + yield ScanResult( + module=module.dataset_name, + tokens=round(tokens / 1000, 1), + cost=cost, + progress=round(progress, 2), + failureRate=round(failure_rate * 100, 2), + prompt=prompt[:MAX_PROMPT_LENGTH], + latency=end - start, + model=response_text, + ).model_dump_json() + + # Optimization logic + if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES: + next_point = optimizer.ask() + optimizer.tell(next_point, -failure_rate) + best_failure_rate = -optimizer.get_result().fun + if best_failure_rate > FAILURE_RATE_THRESHOLD: + yield ScanResult.status_msg( + f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." + ) + should_stop = True + break + + # Budget check + if total_tokens > max_budget: + logger.info( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + yield ScanResult.status_msg( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + should_stop = True + break + + if should_stop: + break + + return + + async def with_error_handling(agen): """ Wraps an asynchronous generator with error handling. @@ -201,8 +332,8 @@ async def perform_single_shot_scan( max_budget: int, datasets: list[dict[str, str]] = [], tools_inbox=None, - optimize=False, - stop_event: asyncio.Event = None, + optimize: bool = False, + stop_event: asyncio.Event | None = None, secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: """ @@ -230,7 +361,8 @@ async def perform_single_shot_scan( """ max_budget = max_budget * BUDGET_MULTIPLIER selected_datasets = [m for m in datasets if m["selected"]] - request_factory = multi_modality_spec(request_factory) + request_factory = get_modality_adapter(request_factory) + yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( dataset_names=[m["dataset_name"] for m in selected_datasets], @@ -240,108 +372,35 @@ async def perform_single_shot_scan( ) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] + fuzzer_state = FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 - optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) - if optimize - else None - ) - failure_rates = [] - total_tokens = 0 - tokens = 0 - should_stop = False for module in prompt_modules: - if should_stop: - break - tokens = 0 - module_failures = 0 + module_gen = scan_module( + request_factory=request_factory, + module=module, + fuzzer_state=fuzzer_state, + processed_prompts=processed_prompts, + total_prompts=total_prompts, + max_budget=max_budget, + total_tokens=total_tokens, + optimize=optimize, + stop_event=stop_event, + ) + try: + async for result in module_gen: + yield result + except Exception: + logger.error("Module exception") + continue + # Update processed_prompts count module_size = 0 if module.lazy else len(module.prompts) - logger.info(f"Scanning {module.dataset_name} {module_size}") - module_prompts = 0 # Reset for each module - - async for prompt in generate_prompts(module.prompts): - if stop_event and stop_event.is_set(): - stop_event.clear() - logger.info("Scan stopped by user.") - yield ScanResult.status_msg("Scan stopped by user.") - return - - processed_prompts += 1 - module_prompts += 1 # Fixed increment syntax - # Calculate progress based on total processed prompts - progress = 100 * processed_prompts / total_prompts if total_prompts else 0 - progress = progress % 100 - - total_tokens -= tokens - start = time.time() - tokens, failed = await process_prompt( - request_factory, - prompt, - tokens, - module.dataset_name, - refusals, - errors, - outputs, - ) - end = time.time() - total_tokens += tokens - - if failed: - module_failures += 1 - failure_rate = module_failures / max(module_prompts, 1) - failure_rates.append(failure_rate) - cost = calculate_cost(tokens) - - last_output = outputs[-1] if outputs else None - if last_output and last_output[1] == prompt: - response_text = last_output[2] - else: - response_text = "" - - yield ScanResult( - module=module.dataset_name, - tokens=round(tokens / 1000, 1), - cost=cost, - progress=round(progress, 2), - failureRate=round(failure_rate * 100, 2), - prompt=prompt[:MAX_PROMPT_LENGTH], - latency=end - start, - model=response_text, - ).model_dump_json() - - if optimize and len(failure_rates) >= 5: - next_point = optimizer.ask() - optimizer.tell(next_point, -failure_rate) - best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: - yield ScanResult.status_msg( - f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." - ) - should_stop = True - break - if total_tokens > max_budget: - logger.info( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" - ) - yield ScanResult.status_msg( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" - ) - should_stop = True - break + processed_prompts += module_size yield ScanResult.status_msg("Scan completed.") - - failure_data = errors + refusals - df = pd.DataFrame( - failure_data, columns=["module", "prompt", "status_code", "content"] - ) - df.to_csv("failures.csv", index=False) + fuzzer_state.export_failures("failures.csv") async def perform_many_shot_scan( @@ -350,8 +409,8 @@ async def perform_many_shot_scan( datasets: list[dict[str, str]] = [], probe_datasets: list[dict[str, str]] = [], tools_inbox=None, - optimize=False, - stop_event: asyncio.Event = None, + optimize: bool = False, + stop_event: asyncio.Event | None = None, probe_frequency: float = 0.2, max_ctx_length: int = 10_000, secrets: dict[str, str] = {}, @@ -382,7 +441,7 @@ async def perform_many_shot_scan( processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion. """ - request_factory = multi_modality_spec(request_factory) + request_factory = get_modality_adapter(request_factory) # Load main and probe datasets yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( @@ -394,17 +453,10 @@ async def perform_many_shot_scan( msj_modules = msj_data.prepare_prompts(probe_datasets) yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] + fuzzer_state = FuzzerState() total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 - optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) - if optimize - else None - ) failure_rates = [] for module in prompt_modules: @@ -418,6 +470,7 @@ async def perform_many_shot_scan( logger.info("Scan stopped by user.") yield ScanResult.status_msg("Scan stopped by user.") return + tokens = 0 processed_prompts += 1 progress = 100 * processed_prompts / total_prompts if total_prompts else 0 @@ -445,9 +498,7 @@ async def perform_many_shot_scan( full_prompt, tokens, module.dataset_name, - refusals, - errors, - outputs, + fuzzer_state=fuzzer_state, ) if failed: module_failures += 1 @@ -468,29 +519,21 @@ async def perform_many_shot_scan( prompt=prompt[:MAX_PROMPT_LENGTH], ).model_dump_json() - if optimize and len(failure_rates) >= 5: - next_point = optimizer.ask() - optimizer.tell(next_point, -failure_rate) - best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: - yield ScanResult.status_msg( - f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." - ) - break + if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES: + yield ScanResult.status_msg( + f"High failure rate detected ({failure_rate:.2%}). Stopping this module..." + ) + break yield ScanResult.status_msg("Scan completed.") - - df = pd.DataFrame( - errors + refusals, columns=["module", "prompt", "status_code", "content"] - ) - df.to_csv("failures.csv", index=False) + fuzzer_state.export_failures("failures.csv") def scan_router( request_factory, scan_parameters: Scan, tools_inbox=None, - stop_event: asyncio.Event = None, + stop_event: asyncio.Event | None = None, ): """ Route scan requests to the appropriate scanning function. diff --git a/agentic_security/probe_actor/state.py b/agentic_security/probe_actor/state.py new file mode 100644 index 0000000..8fba317 --- /dev/null +++ b/agentic_security/probe_actor/state.py @@ -0,0 +1,47 @@ +import pandas as pd + + +class FuzzerState: + """Container for tracking scan results""" + + def __init__(self): + self.errors = [] + self.refusals = [] + self.outputs = [] + + def add_error( + self, + module_name: str, + prompt: str, + status_code: int | str, + error_msg: str, + ): + """Add an error to the state""" + self.errors.append((module_name, prompt, status_code, error_msg)) + + def add_refusal( + self, module_name: str, prompt: str, status_code: int, response_text: str + ): + """Add a refusal to the state""" + self.refusals.append((module_name, prompt, status_code, response_text)) + + def add_output( + self, module_name: str, prompt: str, response_text: str, refused: bool + ): + """Add an output to the state""" + self.outputs.append((module_name, prompt, response_text, refused)) + + def get_last_output(self, prompt: str) -> str | None: + """Get the last output for a given prompt""" + for output in reversed(self.outputs): + if output[1] == prompt: + return output[2] + return None + + def export_failures(self, filename: str = "failures.csv"): + """Export failures to a CSV file""" + failure_data = self.errors + self.refusals + df = pd.DataFrame( + failure_data, columns=["module", "prompt", "status_code", "content"] + ) + df.to_csv(filename, index=False) diff --git a/docs/images/final.gif b/docs/images/final.gif deleted file mode 100644 index f3519a9..0000000 Binary files a/docs/images/final.gif and /dev/null differ diff --git a/poetry.lock b/poetry.lock index 01c82da..5529d49 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1283,14 +1283,14 @@ files = [ [[package]] name = "inline-snapshot" -version = "0.20.3" +version = "0.20.5" description = "golden master/snapshot/approval testing library which puts the values right into your source code" optional = false python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "inline_snapshot-0.20.3-py3-none-any.whl", hash = "sha256:1ea999fbf38dd11cc72d0e1a0b9303c63d496b77bdc406a394fe2424ae842f70"}, - {file = "inline_snapshot-0.20.3.tar.gz", hash = "sha256:7a353170b7e42aa89086c7ba790a973c9645523acf985532648dabd7ee2d71f2"}, + {file = "inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a"}, + {file = "inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index d164511..dbcb141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agentic_security" -version = "0.6.0" +version = "0.7.0" description = "Agentic LLM vulnerability scanner" authors = ["Alexander Miasoiedov "] maintainers = ["Alexander Miasoiedov "] diff --git a/tests/probe_actor/test_fuzzer.py b/tests/probe_actor/test_fuzzer.py index 325c715..30cabba 100644 --- a/tests/probe_actor/test_fuzzer.py +++ b/tests/probe_actor/test_fuzzer.py @@ -7,6 +7,7 @@ import pytest from agentic_security.primitives import Scan from agentic_security.probe_actor.fuzzer import ( + FuzzerState, generate_prompts, perform_many_shot_scan, perform_single_shot_scan, @@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): prompt="test prompt", tokens=0, module_name="module_a", - refusals=[], - errors=[], - outputs=[], + fuzzer_state=FuzzerState(), ) self.assertEqual(tokens, 3) # Tokens from "Valid response text" @@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - refusals = [] - outputs = [] + fuzzer_state = FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=refusals, - errors=[], - outputs=outputs, + fuzzer_state=fuzzer_state, ) self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal" - self.assertFalse(refusal) + # self.assertFalse(fuzzer_state.refusals) async def test_http_error_response(self): mock_request_factory = Mock() @@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): ) ) - refusals = [] + fuzzer_state = FuzzerState() await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=refusals, - errors=[], - outputs=[], + fuzzer_state=fuzzer_state, ) async def test_request_error(self): @@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): side_effect=httpx.RequestError("Connection error") ) - errors = [] + fuzzer_state = FuzzerState() tokens, refusal = await process_prompt( request_factory=mock_request_factory, prompt="test prompt", tokens=0, module_name="module_a", - refusals=[], - errors=errors, - outputs=[], + fuzzer_state=fuzzer_state, ) self.assertEqual(tokens, 0) self.assertTrue(refusal) - self.assertEqual(len(errors), 1) - self.assertIn("Connection error", errors[0][3])