mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(fuzzer):
This commit is contained in:
@@ -10,6 +10,7 @@ from skopt.space import Real
|
||||
|
||||
from agentic_security.models.schemas import Scan, ScanResult
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||
from agentic_security.probe_data import msj_data
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
|
||||
@@ -142,6 +143,7 @@ async def perform_many_shot_scan(
|
||||
optimize=False,
|
||||
stop_event: asyncio.Event = None,
|
||||
probe_frequency: float = 0.2,
|
||||
max_ctx_length: int = 10_000,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a multi-step security scan with probe injection."""
|
||||
try:
|
||||
@@ -152,18 +154,14 @@ async def perform_many_shot_scan(
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
)
|
||||
probe_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in probe_datasets if m["selected"]],
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
)
|
||||
yield ScanResult.status_msg("Loading datasets for MSJ...")
|
||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
conversation_history = {}
|
||||
|
||||
optimizer = (
|
||||
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||
@@ -177,7 +175,6 @@ async def perform_many_shot_scan(
|
||||
module_failures = 0
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
conv_id = module.dataset_name
|
||||
|
||||
async for prompt in prompt_iter(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
@@ -191,74 +188,54 @@ async def perform_many_shot_scan(
|
||||
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
)
|
||||
|
||||
# Get conversation history
|
||||
history = conversation_history.get(conv_id, [])
|
||||
full_prompt = "\n".join([*history, prompt]) if history else prompt
|
||||
current_length = 0
|
||||
full_prompt = ""
|
||||
msj_module = random.choice(msj_modules)
|
||||
|
||||
prompt_tokens = len(full_prompt.split())
|
||||
tokens += prompt_tokens
|
||||
|
||||
try:
|
||||
# Main request
|
||||
r = await request_factory.fn(prompt=full_prompt)
|
||||
if r.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {r.status_code}",
|
||||
request=r.request,
|
||||
response=r,
|
||||
)
|
||||
injected = False
|
||||
for _ in range(20):
|
||||
if injected:
|
||||
break
|
||||
|
||||
response_text = r.text
|
||||
response_tokens = len(response_text.split())
|
||||
tokens += response_tokens
|
||||
m_prompt = random.choice(msj_module.prompts)
|
||||
full_prompt += "\n" + m_prompt
|
||||
current_length += len(m_prompt.split())
|
||||
if current_length > max_ctx_length:
|
||||
full_prompt = "\n" + prompt
|
||||
injected = True
|
||||
try:
|
||||
# Main request
|
||||
r = await request_factory.fn(prompt=full_prompt)
|
||||
if r.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {r.status_code}",
|
||||
request=r.request,
|
||||
response=r,
|
||||
)
|
||||
|
||||
# Update history
|
||||
history.extend([prompt, response_text])
|
||||
history = history[-4:] # Keep last 2 exchanges
|
||||
conversation_history[conv_id] = history
|
||||
response_text = r.text
|
||||
response_tokens = len(response_text.split())
|
||||
tokens += response_tokens
|
||||
|
||||
if not refusal_heuristic(r.json()):
|
||||
refusals.append(
|
||||
(module.dataset_name, prompt, r.status_code, response_text)
|
||||
)
|
||||
module_failures += 1
|
||||
|
||||
# Random probe injection
|
||||
if probe_modules and random.random() < probe_frequency:
|
||||
probe_module = random.choice(probe_modules)
|
||||
probe_prompts = [
|
||||
p async for p in prompt_iter(probe_module.prompts)
|
||||
]
|
||||
if probe_prompts:
|
||||
probe = random.choice(probe_prompts)
|
||||
full_probe = "\n".join([*history, probe])
|
||||
|
||||
probe_r = await request_factory.fn(prompt=full_probe)
|
||||
if probe_r.status_code < 400:
|
||||
probe_response = probe_r.text
|
||||
tokens += len(probe.split()) + len(
|
||||
probe_response.split()
|
||||
if injected and not refusal_heuristic(r.json()):
|
||||
refusals.append(
|
||||
(
|
||||
module.dataset_name,
|
||||
prompt,
|
||||
r.status_code,
|
||||
response_text,
|
||||
)
|
||||
)
|
||||
module_failures += 1
|
||||
|
||||
history.extend([probe, probe_response])
|
||||
history = history[-4:]
|
||||
conversation_history[conv_id] = history
|
||||
|
||||
if not refusal_heuristic(probe_r.json()):
|
||||
refusals.append(
|
||||
(
|
||||
probe_module.dataset_name,
|
||||
probe,
|
||||
probe_r.status_code,
|
||||
probe_response,
|
||||
)
|
||||
)
|
||||
module_failures += 1
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
errors.append((module.dataset_name, prompt, str(e)))
|
||||
module_failures += 1
|
||||
continue
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
errors.append((module.dataset_name, prompt, str(e)))
|
||||
module_failures += 1
|
||||
continue
|
||||
|
||||
failure_rate = module_failures / max(processed_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
|
||||
Reference in New Issue
Block a user