diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 6e8b39a..f902a0c 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -29,7 +29,18 @@ FAILURE_RATE_THRESHOLD = 0.5 async def generate_prompts( prompts: list[str] | AsyncGenerator, ) -> AsyncGenerator[str, None]: - """Convert list of prompts or async generator to a unified async generator.""" + """ + Asynchronously generates and yields individual prompts. + + If the input is a list of strings, the function sequentially yields each string. + If the input is an asynchronous generator, it forwards each generated prompt. + + Args: + prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts. + + Yields: + str: An individual prompt from the list or the asynchronous generator. + """ if isinstance(prompts, list): for prompt in prompts: yield prompt @@ -39,7 +50,20 @@ async def generate_prompts( def get_modality_adapter(llm_spec): - """Get the appropriate modality adapter based on the LLM spec.""" + """ + Returns the appropriate request adapter based on the modality of the LLM specification. + + Depending on the modality of `llm_spec`, the function selects the corresponding request adapter. + If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type. + If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is. + + Args: + llm_spec: An object containing modality information for the LLM. + + Returns: + RequestAdapter | llm_spec: An instance of the appropriate request adapter + or the original `llm_spec` if no adaptation is needed. + """ match llm_spec.modality: case Modality.IMAGE: return image_generator.RequestAdapter(llm_spec) @@ -59,17 +83,22 @@ async def process_prompt( fuzzer_state: FuzzerState, ) -> tuple[int, bool]: """ - Process a single prompt and update the token count and failure status. + Processes a single prompt using the provided request factory and updates tracking lists. + + This function sends the given `prompt` to the `request_factory`, checks for errors, and updates + the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or + the response indicates a refusal, the function records the issue and returns the updated token count + along with a boolean indicating whether the prompt was refused. Args: - request_factory: The factory for creating requests - prompt: The prompt to process - tokens: Current token count - module_name: Name of the module being processed + request_factory: An object with a `fn` method used to send the prompt. + prompt (str): The input prompt to be processed. + tokens (int): The current token count, which will be updated. + module_name (str): The name of the module handling the request. fuzzer_state: State tracking object for the fuzzer Returns: - Tuple of (updated token count, whether the prompt resulted in a failure) + tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused. """ try: response = await request_factory.fn(prompt=prompt) @@ -122,17 +151,23 @@ async def process_prompt_batch( fuzzer_state: FuzzerState, ) -> tuple[int, int]: """ - Process a batch of prompts in parallel. + Processes a batch of prompts asynchronously and aggregates the results. + + This function sends multiple prompts concurrently using `process_prompt`, + collects the token count and failure status for each prompt, and returns + the total number of tokens processed and the number of failed prompts. Args: - request_factory: The factory for creating requests - prompts: List of prompts to process - tokens: Current token count - module_name: Name of the module being processed + request_factory: An object with a `fn` method used to send the prompts. + prompts (list[str]): A list of input prompts to be processed. + tokens (int): The initial token count, which will be updated. + module_name (str): The name of the module handling the request. fuzzer_state: State tracking object for the fuzzer Returns: - Tuple of (updated token count, number of failures) + tuple[int, int]: + - Total number of tokens processed. + - Number of failed prompts. """ tasks = [ process_prompt(request_factory, p, tokens, module_name, fuzzer_state) @@ -268,7 +303,20 @@ async def scan_module( async def with_error_handling(agen): - """Wrapper to handle errors in async generators.""" + """ + Wraps an asynchronous generator with error handling. + + This function iterates over an asynchronous generator, yielding its values. + If an exception occurs, it logs the error and yields a failure message. + Finally, it ensures that a completion message is always yielded. + + Args: + agen: An asynchronous generator that produces scan results. + + Yields: + ScanResult: Either a successful result, an error message if an + exception occurs, or a completion message at the end. + """ try: async for t in agen: yield t @@ -289,19 +337,27 @@ async def perform_single_shot_scan( secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: """ - Perform a standard security scan across all selected datasets. + Perform a standard security scan using a given request factory. + + This function processes security scan prompts from selected datasets while + respecting a predefined token budget. It supports optimization, failure tracking, + and early stopping based on budget constraints or user intervention. Args: - request_factory: The factory for creating requests - max_budget: Maximum token budget - datasets: List of datasets to scan - tools_inbox: Tools inbox - optimize: Whether to use optimization - stop_event: Event to stop scanning - secrets: Secrets to use in the scan + request_factory: A factory function that generates requests for processing prompts. + max_budget (int): The maximum token budget for the scan. + datasets (list[dict[str, str]], optional): A list of datasets containing security prompts. + tools_inbox: Optional additional tools for processing (default: None). + optimize (bool, optional): Whether to enable failure rate optimization (default: False). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). + secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}). Yields: - ScanResult objects as the scan progresses + str: JSON-encoded scan results or status messages. + + The function iterates over prompts, processes them asynchronously, and updates + failure statistics and token usage. If the scan exceeds the budget or failure rate is too high, + it stops execution. Results are saved to a CSV file upon completion. """ max_budget = max_budget * BUDGET_MULTIPLIER selected_datasets = [m for m in datasets if m["selected"]] @@ -362,23 +418,30 @@ async def perform_many_shot_scan( """ Perform a multi-step security scan with probe injection. + This function executes a security scan while periodically injecting probe datasets + to test system robustness. It tracks failures, optimizes scan efficiency, + and ensures adherence to a predefined token budget. + Args: - request_factory: The factory for creating requests - max_budget: Maximum token budget - datasets: List of datasets to scan - probe_datasets: List of probe datasets to inject - tools_inbox: Tools inbox - optimize: Whether to use optimization - stop_event: Event to stop scanning - probe_frequency: Frequency of probe injection - max_ctx_length: Maximum context length - secrets: Secrets to use in the scan + request_factory: A factory function that generates requests for processing prompts. + max_budget (int): The maximum token budget for the scan. + datasets (list[dict[str, str]], optional): The main datasets for scanning. + probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection. + tools_inbox: Optional tools for additional processing (default: None). + optimize (bool, optional): Whether to enable failure rate optimization (default: False). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). + probe_frequency (float, optional): The probability of probe injection (default: 0.2). + max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens). + secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}). Yields: - ScanResult objects as the scan progresses + str: JSON-encoded scan results or status messages. + + This function iterates over prompts, injects probe prompts at random intervals, + processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold + or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion. """ request_factory = get_modality_adapter(request_factory) - # Load main and probe datasets yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( @@ -473,16 +536,30 @@ def scan_router( stop_event: asyncio.Event | None = None, ): """ - Route to the appropriate scan function based on scan parameters. + Route scan requests to the appropriate scanning function. + + This function determines whether to perform a multi-step or single-shot + security scan based on the provided scan parameters. Args: - request_factory: The factory for creating requests - scan_parameters: Scan parameters - tools_inbox: Tools inbox - stop_event: Event to stop scanning + request_factory: A factory function to generate requests for processing prompts. + scan_parameters (Scan): An object containing the parameters for the scan, including: + - enableMultiStepAttack (bool): Whether to perform a multi-step scan. + - maxBudget (int): The maximum token budget for the scan. + - datasets (list[dict[str, str]]): The datasets to scan. + - probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only). + - optimize (bool): Whether to enable optimization. + - secrets (dict[str, str], optional): A dictionary of secrets for authentication. + tools_inbox: Optional tools for additional processing (default: None). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). Returns: - Async generator of scan results + A function wrapped with `with_error_handling`, which executes either: + - `perform_many_shot_scan` for multi-step scanning. + - `perform_single_shot_scan` for single-shot scanning. + + The function ensures that the appropriate scanning method is chosen based on + the `enableMultiStepAttack` flag in `scan_parameters`. """ if scan_parameters.enableMultiStepAttack: return with_error_handling(