Merge pull request #182 from nemanjaASE/issue-166-missing-documentation

Add missing documentation in fuzzer.py
This commit is contained in:
Alexander Myasoedov
2025-03-13 13:47:33 +02:00
committed by GitHub
+119 -42
View File
@@ -29,7 +29,18 @@ FAILURE_RATE_THRESHOLD = 0.5
async def generate_prompts(
prompts: list[str] | AsyncGenerator,
) -> AsyncGenerator[str, None]:
"""Convert list of prompts or async generator to a unified async generator."""
"""
Asynchronously generates and yields individual prompts.
If the input is a list of strings, the function sequentially yields each string.
If the input is an asynchronous generator, it forwards each generated prompt.
Args:
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
Yields:
str: An individual prompt from the list or the asynchronous generator.
"""
if isinstance(prompts, list):
for prompt in prompts:
yield prompt
@@ -39,7 +50,20 @@ async def generate_prompts(
def get_modality_adapter(llm_spec):
"""Get the appropriate modality adapter based on the LLM spec."""
"""
Returns the appropriate request adapter based on the modality of the LLM specification.
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
Args:
llm_spec: An object containing modality information for the LLM.
Returns:
RequestAdapter | llm_spec: An instance of the appropriate request adapter
or the original `llm_spec` if no adaptation is needed.
"""
match llm_spec.modality:
case Modality.IMAGE:
return image_generator.RequestAdapter(llm_spec)
@@ -59,17 +83,22 @@ async def process_prompt(
fuzzer_state: FuzzerState,
) -> tuple[int, bool]:
"""
Process a single prompt and update the token count and failure status.
Processes a single prompt using the provided request factory and updates tracking lists.
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
the response indicates a refusal, the function records the issue and returns the updated token count
along with a boolean indicating whether the prompt was refused.
Args:
request_factory: The factory for creating requests
prompt: The prompt to process
tokens: Current token count
module_name: Name of the module being processed
request_factory: An object with a `fn` method used to send the prompt.
prompt (str): The input prompt to be processed.
tokens (int): The current token count, which will be updated.
module_name (str): The name of the module handling the request.
fuzzer_state: State tracking object for the fuzzer
Returns:
Tuple of (updated token count, whether the prompt resulted in a failure)
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
"""
try:
response = await request_factory.fn(prompt=prompt)
@@ -122,17 +151,23 @@ async def process_prompt_batch(
fuzzer_state: FuzzerState,
) -> tuple[int, int]:
"""
Process a batch of prompts in parallel.
Processes a batch of prompts asynchronously and aggregates the results.
This function sends multiple prompts concurrently using `process_prompt`,
collects the token count and failure status for each prompt, and returns
the total number of tokens processed and the number of failed prompts.
Args:
request_factory: The factory for creating requests
prompts: List of prompts to process
tokens: Current token count
module_name: Name of the module being processed
request_factory: An object with a `fn` method used to send the prompts.
prompts (list[str]): A list of input prompts to be processed.
tokens (int): The initial token count, which will be updated.
module_name (str): The name of the module handling the request.
fuzzer_state: State tracking object for the fuzzer
Returns:
Tuple of (updated token count, number of failures)
tuple[int, int]:
- Total number of tokens processed.
- Number of failed prompts.
"""
tasks = [
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
@@ -268,7 +303,20 @@ async def scan_module(
async def with_error_handling(agen):
"""Wrapper to handle errors in async generators."""
"""
Wraps an asynchronous generator with error handling.
This function iterates over an asynchronous generator, yielding its values.
If an exception occurs, it logs the error and yields a failure message.
Finally, it ensures that a completion message is always yielded.
Args:
agen: An asynchronous generator that produces scan results.
Yields:
ScanResult: Either a successful result, an error message if an
exception occurs, or a completion message at the end.
"""
try:
async for t in agen:
yield t
@@ -289,19 +337,27 @@ async def perform_single_shot_scan(
secrets: dict[str, str] = {},
) -> AsyncGenerator[str, None]:
"""
Perform a standard security scan across all selected datasets.
Perform a standard security scan using a given request factory.
This function processes security scan prompts from selected datasets while
respecting a predefined token budget. It supports optimization, failure tracking,
and early stopping based on budget constraints or user intervention.
Args:
request_factory: The factory for creating requests
max_budget: Maximum token budget
datasets: List of datasets to scan
tools_inbox: Tools inbox
optimize: Whether to use optimization
stop_event: Event to stop scanning
secrets: Secrets to use in the scan
request_factory: A factory function that generates requests for processing prompts.
max_budget (int): The maximum token budget for the scan.
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
tools_inbox: Optional additional tools for processing (default: None).
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
Yields:
ScanResult objects as the scan progresses
str: JSON-encoded scan results or status messages.
The function iterates over prompts, processes them asynchronously, and updates
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
it stops execution. Results are saved to a CSV file upon completion.
"""
max_budget = max_budget * BUDGET_MULTIPLIER
selected_datasets = [m for m in datasets if m["selected"]]
@@ -362,23 +418,30 @@ async def perform_many_shot_scan(
"""
Perform a multi-step security scan with probe injection.
This function executes a security scan while periodically injecting probe datasets
to test system robustness. It tracks failures, optimizes scan efficiency,
and ensures adherence to a predefined token budget.
Args:
request_factory: The factory for creating requests
max_budget: Maximum token budget
datasets: List of datasets to scan
probe_datasets: List of probe datasets to inject
tools_inbox: Tools inbox
optimize: Whether to use optimization
stop_event: Event to stop scanning
probe_frequency: Frequency of probe injection
max_ctx_length: Maximum context length
secrets: Secrets to use in the scan
request_factory: A factory function that generates requests for processing prompts.
max_budget (int): The maximum token budget for the scan.
datasets (list[dict[str, str]], optional): The main datasets for scanning.
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
tools_inbox: Optional tools for additional processing (default: None).
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
Yields:
ScanResult objects as the scan progresses
str: JSON-encoded scan results or status messages.
This function iterates over prompts, injects probe prompts at random intervals,
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
"""
request_factory = get_modality_adapter(request_factory)
# Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
@@ -473,16 +536,30 @@ def scan_router(
stop_event: asyncio.Event | None = None,
):
"""
Route to the appropriate scan function based on scan parameters.
Route scan requests to the appropriate scanning function.
This function determines whether to perform a multi-step or single-shot
security scan based on the provided scan parameters.
Args:
request_factory: The factory for creating requests
scan_parameters: Scan parameters
tools_inbox: Tools inbox
stop_event: Event to stop scanning
request_factory: A factory function to generate requests for processing prompts.
scan_parameters (Scan): An object containing the parameters for the scan, including:
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
- maxBudget (int): The maximum token budget for the scan.
- datasets (list[dict[str, str]]): The datasets to scan.
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
- optimize (bool): Whether to enable optimization.
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
tools_inbox: Optional tools for additional processing (default: None).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
Returns:
Async generator of scan results
A function wrapped with `with_error_handling`, which executes either:
- `perform_many_shot_scan` for multi-step scanning.
- `perform_single_shot_scan` for single-shot scanning.
The function ensures that the appropriate scanning method is chosen based on
the `enableMultiStepAttack` flag in `scan_parameters`.
"""
if scan_parameters.enableMultiStepAttack:
return with_error_handling(