fix(fuzzer):

This commit is contained in:
Alexander Myasoedov
2024-12-10 20:18:51 +02:00
parent 4ce9d266d8
commit c37ee7f7fa
+44 -67
View File
@@ -10,6 +10,7 @@ from skopt.space import Real
from agentic_security.models.schemas import Scan, ScanResult
from agentic_security.probe_actor.refusal import refusal_heuristic
from agentic_security.probe_data import msj_data
from agentic_security.probe_data.data import prepare_prompts
@@ -142,6 +143,7 @@ async def perform_many_shot_scan(
optimize=False,
stop_event: asyncio.Event = None,
probe_frequency: float = 0.2,
max_ctx_length: int = 10_000,
) -> AsyncGenerator[str, None]:
"""Perform a multi-step security scan with probe injection."""
try:
@@ -152,18 +154,14 @@ async def perform_many_shot_scan(
budget=max_budget,
tools_inbox=tools_inbox,
)
probe_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in probe_datasets if m["selected"]],
budget=max_budget,
tools_inbox=tools_inbox,
)
yield ScanResult.status_msg("Loading datasets for MSJ...")
msj_modules = msj_data.prepare_prompts(probe_datasets)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
conversation_history = {}
optimizer = (
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
@@ -177,7 +175,6 @@ async def perform_many_shot_scan(
module_failures = 0
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
conv_id = module.dataset_name
async for prompt in prompt_iter(module.prompts):
if stop_event and stop_event.is_set():
@@ -191,74 +188,54 @@ async def perform_many_shot_scan(
100 * processed_prompts / total_prompts if total_prompts else 0
)
# Get conversation history
history = conversation_history.get(conv_id, [])
full_prompt = "\n".join([*history, prompt]) if history else prompt
current_length = 0
full_prompt = ""
msj_module = random.choice(msj_modules)
prompt_tokens = len(full_prompt.split())
tokens += prompt_tokens
try:
# Main request
r = await request_factory.fn(prompt=full_prompt)
if r.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {r.status_code}",
request=r.request,
response=r,
)
injected = False
for _ in range(20):
if injected:
break
response_text = r.text
response_tokens = len(response_text.split())
tokens += response_tokens
m_prompt = random.choice(msj_module.prompts)
full_prompt += "\n" + m_prompt
current_length += len(m_prompt.split())
if current_length > max_ctx_length:
full_prompt = "\n" + prompt
injected = True
try:
# Main request
r = await request_factory.fn(prompt=full_prompt)
if r.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {r.status_code}",
request=r.request,
response=r,
)
# Update history
history.extend([prompt, response_text])
history = history[-4:] # Keep last 2 exchanges
conversation_history[conv_id] = history
response_text = r.text
response_tokens = len(response_text.split())
tokens += response_tokens
if not refusal_heuristic(r.json()):
refusals.append(
(module.dataset_name, prompt, r.status_code, response_text)
)
module_failures += 1
# Random probe injection
if probe_modules and random.random() < probe_frequency:
probe_module = random.choice(probe_modules)
probe_prompts = [
p async for p in prompt_iter(probe_module.prompts)
]
if probe_prompts:
probe = random.choice(probe_prompts)
full_probe = "\n".join([*history, probe])
probe_r = await request_factory.fn(prompt=full_probe)
if probe_r.status_code < 400:
probe_response = probe_r.text
tokens += len(probe.split()) + len(
probe_response.split()
if injected and not refusal_heuristic(r.json()):
refusals.append(
(
module.dataset_name,
prompt,
r.status_code,
response_text,
)
)
module_failures += 1
history.extend([probe, probe_response])
history = history[-4:]
conversation_history[conv_id] = history
if not refusal_heuristic(probe_r.json()):
refusals.append(
(
probe_module.dataset_name,
probe,
probe_r.status_code,
probe_response,
)
)
module_failures += 1
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
errors.append((module.dataset_name, prompt, str(e)))
module_failures += 1
continue
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
errors.append((module.dataset_name, prompt, str(e)))
module_failures += 1
continue
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)