From b1e2dc8cefd6a810dd7f7e7fce60d2e1f54006c9 Mon Sep 17 00:00:00 2001 From: nemanjaASE <93867316+nemanjaASE@users.noreply.github.com> Date: Thu, 13 Mar 2025 09:42:55 +0100 Subject: [PATCH] Add missing documentation in fuzzer.py --- agentic_security/probe_actor/fuzzer.py | 157 ++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 3 deletions(-) diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 882fc7f..a7b7180 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -29,6 +29,18 @@ FAILURE_RATE_THRESHOLD = 0.5 async def generate_prompts( prompts: list[str] | AsyncGenerator, ) -> AsyncGenerator[str, None]: + """ + Asynchronously generates and yields individual prompts. + + If the input is a list of strings, the function sequentially yields each string. + If the input is an asynchronous generator, it forwards each generated prompt. + + Args: + prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts. + + Yields: + str: An individual prompt from the list or the asynchronous generator. + """ if isinstance(prompts, list): for prompt in prompts: yield prompt @@ -38,6 +50,20 @@ async def generate_prompts( def multi_modality_spec(llm_spec): + """ + Returns the appropriate request adapter based on the modality of the LLM specification. + + Depending on the modality of `llm_spec`, the function selects the corresponding request adapter. + If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type. + If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is. + + Args: + llm_spec: An object containing modality information for the LLM. + + Returns: + RequestAdapter | llm_spec: An instance of the appropriate request adapter + or the original `llm_spec` if no adaptation is needed. + """ match llm_spec.modality: case Modality.IMAGE: return image_generator.RequestAdapter(llm_spec) @@ -53,7 +79,24 @@ async def process_prompt( request_factory, prompt, tokens, module_name, refusals, errors, outputs ) -> tuple[int, bool]: """ - Process a single prompt and update the token count and failure status. + Processes a single prompt using the provided request factory and updates tracking lists. + + This function sends the given `prompt` to the `request_factory`, checks for errors, and updates + the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or + the response indicates a refusal, the function records the issue and returns the updated token count + along with a boolean indicating whether the prompt was refused. + + Args: + request_factory: An object with a `fn` method used to send the prompt. + prompt (str): The input prompt to be processed. + tokens (int): The current token count, which will be updated. + module_name (str): The name of the module handling the request. + refusals (list): A list to store prompts that were refused. + errors (list): A list to store prompts that encountered errors. + outputs (list): A list to store processed prompt outputs. + + Returns: + tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused. """ try: response = await request_factory.fn(prompt=prompt) @@ -95,6 +138,27 @@ async def process_prompt_batch( errors, outputs, ) -> tuple[int, int]: + """ + Processes a batch of prompts asynchronously and aggregates the results. + + This function sends multiple prompts concurrently using `process_prompt`, + collects the token count and failure status for each prompt, and returns + the total number of tokens processed and the number of failed prompts. + + Args: + request_factory: An object with a `fn` method used to send the prompts. + prompts (list[str]): A list of input prompts to be processed. + tokens (int): The initial token count, which will be updated. + module_name (str): The name of the module handling the request. + refusals (list): A list to store prompts that were refused. + errors (list): A list to store prompts that encountered errors. + outputs (list): A list to store processed prompt outputs. + + Returns: + tuple[int, int]: + - Total number of tokens processed. + - Number of failed prompts. + """ tasks = [ process_prompt( request_factory, p, tokens, module_name, refusals, errors, outputs @@ -108,6 +172,20 @@ async def process_prompt_batch( async def with_error_handling(agen): + """ + Wraps an asynchronous generator with error handling. + + This function iterates over an asynchronous generator, yielding its values. + If an exception occurs, it logs the error and yields a failure message. + Finally, it ensures that a completion message is always yielded. + + Args: + agen: An asynchronous generator that produces scan results. + + Yields: + ScanResult: Either a successful result, an error message if an + exception occurs, or a completion message at the end. + """ try: async for t in agen: yield t @@ -127,7 +205,29 @@ async def perform_single_shot_scan( stop_event: asyncio.Event = None, secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: - """Perform a standard security scan.""" + """ + Perform a standard security scan using a given request factory. + + This function processes security scan prompts from selected datasets while + respecting a predefined token budget. It supports optimization, failure tracking, + and early stopping based on budget constraints or user intervention. + + Args: + request_factory: A factory function that generates requests for processing prompts. + max_budget (int): The maximum token budget for the scan. + datasets (list[dict[str, str]], optional): A list of datasets containing security prompts. + tools_inbox: Optional additional tools for processing (default: None). + optimize (bool, optional): Whether to enable failure rate optimization (default: False). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). + secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}). + + Yields: + str: JSON-encoded scan results or status messages. + + The function iterates over prompts, processes them asynchronously, and updates + failure statistics and token usage. If the scan exceeds the budget or failure rate is too high, + it stops execution. Results are saved to a CSV file upon completion. + """ max_budget = max_budget * BUDGET_MULTIPLIER selected_datasets = [m for m in datasets if m["selected"]] request_factory = multi_modality_spec(request_factory) @@ -256,7 +356,32 @@ async def perform_many_shot_scan( max_ctx_length: int = 10_000, secrets: dict[str, str] = {}, ) -> AsyncGenerator[str, None]: - """Perform a multi-step security scan with probe injection.""" + """ + Perform a multi-step security scan with probe injection. + + This function executes a security scan while periodically injecting probe datasets + to test system robustness. It tracks failures, optimizes scan efficiency, + and ensures adherence to a predefined token budget. + + Args: + request_factory: A factory function that generates requests for processing prompts. + max_budget (int): The maximum token budget for the scan. + datasets (list[dict[str, str]], optional): The main datasets for scanning. + probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection. + tools_inbox: Optional tools for additional processing (default: None). + optimize (bool, optional): Whether to enable failure rate optimization (default: False). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). + probe_frequency (float, optional): The probability of probe injection (default: 0.2). + max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens). + secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}). + + Yields: + str: JSON-encoded scan results or status messages. + + This function iterates over prompts, injects probe prompts at random intervals, + processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold + or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion. + """ request_factory = multi_modality_spec(request_factory) # Load main and probe datasets yield ScanResult.status_msg("Loading datasets...") @@ -367,6 +492,32 @@ def scan_router( tools_inbox=None, stop_event: asyncio.Event = None, ): + """ + Route scan requests to the appropriate scanning function. + + This function determines whether to perform a multi-step or single-shot + security scan based on the provided scan parameters. + + Args: + request_factory: A factory function to generate requests for processing prompts. + scan_parameters (Scan): An object containing the parameters for the scan, including: + - enableMultiStepAttack (bool): Whether to perform a multi-step scan. + - maxBudget (int): The maximum token budget for the scan. + - datasets (list[dict[str, str]]): The datasets to scan. + - probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only). + - optimize (bool): Whether to enable optimization. + - secrets (dict[str, str], optional): A dictionary of secrets for authentication. + tools_inbox: Optional tools for additional processing (default: None). + stop_event (asyncio.Event, optional): An event to signal early termination (default: None). + + Returns: + A function wrapped with `with_error_handling`, which executes either: + - `perform_many_shot_scan` for multi-step scanning. + - `perform_single_shot_scan` for single-shot scanning. + + The function ensures that the appropriate scanning method is chosen based on + the `enableMultiStepAttack` flag in `scan_parameters`. + """ if scan_parameters.enableMultiStepAttack: return with_error_handling( perform_many_shot_scan(