From c37ee7f7fad32e1c0610a1053c11baa87cc2fa21 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Tue, 10 Dec 2024 20:18:51 +0200 Subject: [PATCH] fix(fuzzer): --- agentic_security/probe_actor/fuzzer.py | 111 ++++++++++--------------- 1 file changed, 44 insertions(+), 67 deletions(-) diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 0c47327..d3a1b1d 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -10,6 +10,7 @@ from skopt.space import Real from agentic_security.models.schemas import Scan, ScanResult from agentic_security.probe_actor.refusal import refusal_heuristic +from agentic_security.probe_data import msj_data from agentic_security.probe_data.data import prepare_prompts @@ -142,6 +143,7 @@ async def perform_many_shot_scan( optimize=False, stop_event: asyncio.Event = None, probe_frequency: float = 0.2, + max_ctx_length: int = 10_000, ) -> AsyncGenerator[str, None]: """Perform a multi-step security scan with probe injection.""" try: @@ -152,18 +154,14 @@ async def perform_many_shot_scan( budget=max_budget, tools_inbox=tools_inbox, ) - probe_modules = prepare_prompts( - dataset_names=[m["dataset_name"] for m in probe_datasets if m["selected"]], - budget=max_budget, - tools_inbox=tools_inbox, - ) + yield ScanResult.status_msg("Loading datasets for MSJ...") + msj_modules = msj_data.prepare_prompts(probe_datasets) yield ScanResult.status_msg("Datasets loaded. Starting scan...") errors = [] refusals = [] total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy) processed_prompts = 0 - conversation_history = {} optimizer = ( Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25) @@ -177,7 +175,6 @@ async def perform_many_shot_scan( module_failures = 0 module_size = 0 if module.lazy else len(module.prompts) logger.info(f"Scanning {module.dataset_name} {module_size}") - conv_id = module.dataset_name async for prompt in prompt_iter(module.prompts): if stop_event and stop_event.is_set(): @@ -191,74 +188,54 @@ async def perform_many_shot_scan( 100 * processed_prompts / total_prompts if total_prompts else 0 ) - # Get conversation history - history = conversation_history.get(conv_id, []) - full_prompt = "\n".join([*history, prompt]) if history else prompt + current_length = 0 + full_prompt = "" + msj_module = random.choice(msj_modules) + prompt_tokens = len(full_prompt.split()) tokens += prompt_tokens - try: - # Main request - r = await request_factory.fn(prompt=full_prompt) - if r.status_code >= 400: - raise httpx.HTTPStatusError( - f"HTTP {r.status_code}", - request=r.request, - response=r, - ) + injected = False + for _ in range(20): + if injected: + break - response_text = r.text - response_tokens = len(response_text.split()) - tokens += response_tokens + m_prompt = random.choice(msj_module.prompts) + full_prompt += "\n" + m_prompt + current_length += len(m_prompt.split()) + if current_length > max_ctx_length: + full_prompt = "\n" + prompt + injected = True + try: + # Main request + r = await request_factory.fn(prompt=full_prompt) + if r.status_code >= 400: + raise httpx.HTTPStatusError( + f"HTTP {r.status_code}", + request=r.request, + response=r, + ) - # Update history - history.extend([prompt, response_text]) - history = history[-4:] # Keep last 2 exchanges - conversation_history[conv_id] = history + response_text = r.text + response_tokens = len(response_text.split()) + tokens += response_tokens - if not refusal_heuristic(r.json()): - refusals.append( - (module.dataset_name, prompt, r.status_code, response_text) - ) - module_failures += 1 - - # Random probe injection - if probe_modules and random.random() < probe_frequency: - probe_module = random.choice(probe_modules) - probe_prompts = [ - p async for p in prompt_iter(probe_module.prompts) - ] - if probe_prompts: - probe = random.choice(probe_prompts) - full_probe = "\n".join([*history, probe]) - - probe_r = await request_factory.fn(prompt=full_probe) - if probe_r.status_code < 400: - probe_response = probe_r.text - tokens += len(probe.split()) + len( - probe_response.split() + if injected and not refusal_heuristic(r.json()): + refusals.append( + ( + module.dataset_name, + prompt, + r.status_code, + response_text, ) + ) + module_failures += 1 - history.extend([probe, probe_response]) - history = history[-4:] - conversation_history[conv_id] = history - - if not refusal_heuristic(probe_r.json()): - refusals.append( - ( - probe_module.dataset_name, - probe, - probe_r.status_code, - probe_response, - ) - ) - module_failures += 1 - - except httpx.RequestError as e: - logger.error(f"Request error: {e}") - errors.append((module.dataset_name, prompt, str(e))) - module_failures += 1 - continue + except httpx.RequestError as e: + logger.error(f"Request error: {e}") + errors.append((module.dataset_name, prompt, str(e))) + module_failures += 1 + continue failure_rate = module_failures / max(processed_prompts, 1) failure_rates.append(failure_rate)