From a26b5dd4482ec5e871fc6917d5c7576aa9fd1b9b Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Tue, 4 Mar 2025 09:34:24 +0200 Subject: [PATCH] feat(error handling in fuzzer): --- agentic_security/probe_actor/fuzzer.py | 430 ++++++++++++------------- 1 file changed, 214 insertions(+), 216 deletions(-) diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 76e4842..142bed6 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -107,6 +107,17 @@ async def process_prompt_batch( return total_tokens, failures +async def with_error_handling(agen): + try: + async for t in agen: + yield t + except Exception as e: + logger.exception("Scan failed") + yield ScanResult.status_msg(f"Scan failed: {str(e)}") + finally: + yield ScanResult.status_msg("Scan completed.") + + async def perform_single_shot_scan( request_factory, max_budget: int, @@ -120,125 +131,116 @@ async def perform_single_shot_scan( max_budget = max_budget * BUDGET_MULTIPLIER selected_datasets = [m for m in datasets if m["selected"]] request_factory = multi_modality_spec(request_factory) - try: - yield ScanResult.status_msg("Loading datasets...") - prompt_modules = prepare_prompts( - dataset_names=[m["dataset_name"] for m in selected_datasets], - budget=max_budget, - tools_inbox=tools_inbox, - options=[m.get("opts", {}) for m in selected_datasets], - ) - yield ScanResult.status_msg("Datasets loaded. Starting scan...") + yield ScanResult.status_msg("Loading datasets...") + prompt_modules = prepare_prompts( + dataset_names=[m["dataset_name"] for m in selected_datasets], + budget=max_budget, + tools_inbox=tools_inbox, + options=[m.get("opts", {}) for m in selected_datasets], + ) + yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] - total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) - processed_prompts = 0 + errors = [] + refusals = [] + outputs = [] + total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) + processed_prompts = 0 - optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) - if optimize - else None - ) - failure_rates = [] + optimizer = ( + Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) + if optimize + else None + ) + failure_rates = [] - total_tokens = 0 + total_tokens = 0 + tokens = 0 + should_stop = False + for module in prompt_modules: + if should_stop: + break tokens = 0 - should_stop = False - for module in prompt_modules: - if should_stop: - break - tokens = 0 - module_failures = 0 - module_size = 0 if module.lazy else len(module.prompts) - logger.info(f"Scanning {module.dataset_name} {module_size}") - module_prompts = 0 # Reset for each module + module_failures = 0 + module_size = 0 if module.lazy else len(module.prompts) + logger.info(f"Scanning {module.dataset_name} {module_size}") + module_prompts = 0 # Reset for each module - async for prompt in generate_prompts(module.prompts): - if stop_event and stop_event.is_set(): - stop_event.clear() - logger.info("Scan stopped by user.") - yield ScanResult.status_msg("Scan stopped by user.") - return + async for prompt in generate_prompts(module.prompts): + if stop_event and stop_event.is_set(): + stop_event.clear() + logger.info("Scan stopped by user.") + yield ScanResult.status_msg("Scan stopped by user.") + return - processed_prompts += 1 - module_prompts += 1 # Fixed increment syntax - # Calculate progress based on total processed prompts - progress = ( - 100 * processed_prompts / total_prompts if total_prompts else 0 - ) + processed_prompts += 1 + module_prompts += 1 # Fixed increment syntax + # Calculate progress based on total processed prompts + progress = 100 * processed_prompts / total_prompts if total_prompts else 0 - total_tokens -= tokens - start = time.time() - tokens, failed = await process_prompt( - request_factory, - prompt, - tokens, - module.dataset_name, - refusals, - errors, - outputs, - ) - end = time.time() - total_tokens += tokens + total_tokens -= tokens + start = time.time() + tokens, failed = await process_prompt( + request_factory, + prompt, + tokens, + module.dataset_name, + refusals, + errors, + outputs, + ) + end = time.time() + total_tokens += tokens - if failed: - module_failures += 1 - failure_rate = module_failures / max(module_prompts, 1) - failure_rates.append(failure_rate) - cost = calculate_cost(tokens) + if failed: + module_failures += 1 + failure_rate = module_failures / max(module_prompts, 1) + failure_rates.append(failure_rate) + cost = calculate_cost(tokens) - last_output = outputs[-1] if outputs else None - if last_output and last_output[1] == prompt: - response_text = last_output[2] - else: - response_text = "" + last_output = outputs[-1] if outputs else None + if last_output and last_output[1] == prompt: + response_text = last_output[2] + else: + response_text = "" - yield ScanResult( - module=module.dataset_name, - tokens=round(tokens / 1000, 1), - cost=cost, - progress=round(progress, 2), - failureRate=round(failure_rate * 100, 2), - prompt=prompt[:MAX_PROMPT_LENGTH], - latency=end - start, - model=response_text, - ).model_dump_json() + yield ScanResult( + module=module.dataset_name, + tokens=round(tokens / 1000, 1), + cost=cost, + progress=round(progress, 2), + failureRate=round(failure_rate * 100, 2), + prompt=prompt[:MAX_PROMPT_LENGTH], + latency=end - start, + model=response_text, + ).model_dump_json() - if optimize and len(failure_rates) >= 5: - next_point = optimizer.ask() - optimizer.tell(next_point, -failure_rate) - best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: - yield ScanResult.status_msg( - f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." - ) - should_stop = True - break - if total_tokens > max_budget: - logger.info( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" - ) + if optimize and len(failure_rates) >= 5: + next_point = optimizer.ask() + optimizer.tell(next_point, -failure_rate) + best_failure_rate = -optimizer.get_result().fun + if best_failure_rate > 0.5: yield ScanResult.status_msg( - f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." ) should_stop = True break + if total_tokens > max_budget: + logger.info( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + yield ScanResult.status_msg( + f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}" + ) + should_stop = True + break - yield ScanResult.status_msg("Scan completed.") + yield ScanResult.status_msg("Scan completed.") - failure_data = errors + refusals - df = pd.DataFrame( - failure_data, columns=["module", "prompt", "status_code", "content"] - ) - df.to_csv("failures.csv", index=False) - - except Exception as e: - logger.exception("Scan failed") - yield ScanResult.status_msg(f"Scan failed: {str(e)}") - finally: - yield ScanResult.status_msg("Scan completed.") + failure_data = errors + refusals + df = pd.DataFrame( + failure_data, columns=["module", "prompt", "status_code", "content"] + ) + df.to_csv("failures.csv", index=False) async def perform_many_shot_scan( @@ -255,114 +257,106 @@ async def perform_many_shot_scan( ) -> AsyncGenerator[str, None]: """Perform a multi-step security scan with probe injection.""" request_factory = multi_modality_spec(request_factory) - try: - # Load main and probe datasets - yield ScanResult.status_msg("Loading datasets...") - prompt_modules = prepare_prompts( - dataset_names=[m["dataset_name"] for m in datasets if m["selected"]], - budget=max_budget, - tools_inbox=tools_inbox, - ) - yield ScanResult.status_msg("Loading datasets for MSJ...") - msj_modules = msj_data.prepare_prompts(probe_datasets) - yield ScanResult.status_msg("Datasets loaded. Starting scan...") + # Load main and probe datasets + yield ScanResult.status_msg("Loading datasets...") + prompt_modules = prepare_prompts( + dataset_names=[m["dataset_name"] for m in datasets if m["selected"]], + budget=max_budget, + tools_inbox=tools_inbox, + ) + yield ScanResult.status_msg("Loading datasets for MSJ...") + msj_modules = msj_data.prepare_prompts(probe_datasets) + yield ScanResult.status_msg("Datasets loaded. Starting scan...") - errors = [] - refusals = [] - outputs = [] - total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) - processed_prompts = 0 + errors = [] + refusals = [] + outputs = [] + total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) + processed_prompts = 0 - optimizer = ( - Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) - if optimize - else None - ) - failure_rates = [] + optimizer = ( + Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) + if optimize + else None + ) + failure_rates = [] - for module in prompt_modules: - module_failures = 0 - module_size = 0 if module.lazy else len(module.prompts) - logger.info(f"Scanning {module.dataset_name} {module_size}") + for module in prompt_modules: + module_failures = 0 + module_size = 0 if module.lazy else len(module.prompts) + logger.info(f"Scanning {module.dataset_name} {module_size}") - async for prompt in generate_prompts(module.prompts): - if stop_event and stop_event.is_set(): - stop_event.clear() - logger.info("Scan stopped by user.") - yield ScanResult.status_msg("Scan stopped by user.") - return - tokens = 0 - processed_prompts += 1 - progress = ( - 100 * processed_prompts / total_prompts if total_prompts else 0 + async for prompt in generate_prompts(module.prompts): + if stop_event and stop_event.is_set(): + stop_event.clear() + logger.info("Scan stopped by user.") + yield ScanResult.status_msg("Scan stopped by user.") + return + tokens = 0 + processed_prompts += 1 + progress = 100 * processed_prompts / total_prompts if total_prompts else 0 + + full_prompt = "" + msj_module = random.choice(msj_modules) + + prompt_tokens = len(full_prompt.split()) + tokens += prompt_tokens + + injected = False + for _ in range(20): + if injected: + break + + m_prompt = random.choice(msj_module.prompts) + full_prompt += "\n" + m_prompt + if tokens > max_ctx_length: + full_prompt = "\n" + prompt + injected = True + + tokens, failed = await process_prompt( + request_factory, + full_prompt, + tokens, + module.dataset_name, + refusals, + errors, + outputs, ) + if failed: + module_failures += 1 + break + if injected: + break - full_prompt = "" - msj_module = random.choice(msj_modules) + failure_rate = module_failures / max(processed_prompts, 1) + failure_rates.append(failure_rate) + cost = calculate_cost(tokens) - prompt_tokens = len(full_prompt.split()) - tokens += prompt_tokens + yield ScanResult( + module=module.dataset_name, + tokens=round(tokens / 1000, 1), + cost=cost, + progress=round(progress, 2), + failureRate=round(failure_rate * 100, 2), + prompt=prompt[:MAX_PROMPT_LENGTH], + ).model_dump_json() - injected = False - for _ in range(20): - if injected: - break - - m_prompt = random.choice(msj_module.prompts) - full_prompt += "\n" + m_prompt - if tokens > max_ctx_length: - full_prompt = "\n" + prompt - injected = True - - tokens, failed = await process_prompt( - request_factory, - full_prompt, - tokens, - module.dataset_name, - refusals, - errors, - outputs, + if optimize and len(failure_rates) >= 5: + next_point = optimizer.ask() + optimizer.tell(next_point, -failure_rate) + best_failure_rate = -optimizer.get_result().fun + if best_failure_rate > 0.5: + yield ScanResult.status_msg( + f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." ) - if failed: - module_failures += 1 - break - if injected: - break + break - failure_rate = module_failures / max(processed_prompts, 1) - failure_rates.append(failure_rate) - cost = calculate_cost(tokens) + yield ScanResult.status_msg("Scan completed.") - yield ScanResult( - module=module.dataset_name, - tokens=round(tokens / 1000, 1), - cost=cost, - progress=round(progress, 2), - failureRate=round(failure_rate * 100, 2), - prompt=prompt[:MAX_PROMPT_LENGTH], - ).model_dump_json() - - if optimize and len(failure_rates) >= 5: - next_point = optimizer.ask() - optimizer.tell(next_point, -failure_rate) - best_failure_rate = -optimizer.get_result().fun - if best_failure_rate > 0.5: - yield ScanResult.status_msg( - f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." - ) - break - - yield ScanResult.status_msg("Scan completed.") - - df = pd.DataFrame( - errors + refusals, columns=["module", "prompt", "status_code", "content"] - ) - df.to_csv("failures.csv", index=False) - - except Exception as e: - logger.exception("Scan failed") - yield ScanResult.status_msg(f"Scan failed: {str(e)}") - raise e + df = pd.DataFrame( + errors + refusals, columns=["module", "prompt", "status_code", "content"] + ) + df.to_csv("failures.csv", index=False) def scan_router( @@ -372,23 +366,27 @@ def scan_router( stop_event: asyncio.Event = None, ): if scan_parameters.enableMultiStepAttack: - return perform_many_shot_scan( - request_factory=request_factory, - max_budget=scan_parameters.maxBudget, - datasets=scan_parameters.datasets, - probe_datasets=scan_parameters.probe_datasets, - tools_inbox=tools_inbox, - optimize=scan_parameters.optimize, - stop_event=stop_event, - secrets=scan_parameters.secrets, + return with_error_handling( + perform_many_shot_scan( + request_factory=request_factory, + max_budget=scan_parameters.maxBudget, + datasets=scan_parameters.datasets, + probe_datasets=scan_parameters.probe_datasets, + tools_inbox=tools_inbox, + optimize=scan_parameters.optimize, + stop_event=stop_event, + secrets=scan_parameters.secrets, + ) ) else: - return perform_single_shot_scan( - request_factory=request_factory, - max_budget=scan_parameters.maxBudget, - datasets=scan_parameters.datasets, - tools_inbox=tools_inbox, - optimize=scan_parameters.optimize, - stop_event=stop_event, - secrets=scan_parameters.secrets, + return with_error_handling( + perform_single_shot_scan( + request_factory=request_factory, + max_budget=scan_parameters.maxBudget, + datasets=scan_parameters.datasets, + tools_inbox=tools_inbox, + optimize=scan_parameters.optimize, + stop_event=stop_event, + secrets=scan_parameters.secrets, + ) )