Merge branch 'main' into issue-166-missing-documentation

This commit is contained in:
nemanjaASE
2025-03-13 09:52:25 +01:00
committed by GitHub
6 changed files with 252 additions and 172 deletions
+192 -149
View File
@@ -3,9 +3,9 @@ import random
import time
from collections.abc import AsyncGenerator
from json import JSONDecodeError
from typing import Any
import httpx
import pandas as pd
from skopt import Optimizer
from skopt.space import Real
@@ -14,11 +14,11 @@ from agentic_security.logutils import logger
from agentic_security.primitives import Scan, ScanResult
from agentic_security.probe_actor.cost_module import calculate_cost
from agentic_security.probe_actor.refusal import refusal_heuristic
from agentic_security.probe_actor.state import FuzzerState
from agentic_security.probe_data import audio_generator, image_generator, msj_data
from agentic_security.probe_data.data import prepare_prompts
# TODO: full log file
# Constants
MAX_PROMPT_LENGTH = 2048
BUDGET_MULTIPLIER = 100_000_000
INITIAL_OPTIMIZER_POINTS = 25
@@ -76,7 +76,11 @@ def multi_modality_spec(llm_spec):
async def process_prompt(
request_factory, prompt, tokens, module_name, refusals, errors, outputs
request_factory,
prompt: str,
tokens: int,
module_name: str,
fuzzer_state: FuzzerState,
) -> tuple[int, bool]:
"""
Processes a single prompt using the provided request factory and updates tracking lists.
@@ -91,42 +95,52 @@ async def process_prompt(
prompt (str): The input prompt to be processed.
tokens (int): The current token count, which will be updated.
module_name (str): The name of the module handling the request.
refusals (list): A list to store prompts that were refused.
errors (list): A list to store prompts that encountered errors.
outputs (list): A list to store processed prompt outputs.
fuzzer_state: State tracking object for the fuzzer
Returns:
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
"""
try:
response = await request_factory.fn(prompt=prompt)
# Handle HTTP errors
if response.status_code == 422:
logger.error(f"Invalid prompt: {prompt}, error=422")
errors.append((module_name, prompt, 422, "Invalid prompt"))
fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt")
return tokens, True
if response.status_code >= 400:
logger.error(f"HTTP {response.status_code} {response.content=}")
errors.append((module_name, prompt, response.status_code, response.text))
fuzzer_state.add_error(
module_name, prompt, response.status_code, response.text
)
return tokens, True
# Process successful response
response_text = response.text
tokens += len(response_text.split())
# Check if the response indicates a refusal
refused = refusal_heuristic(response.json())
if refused:
refusals.append((module_name, prompt, response.status_code, response_text))
fuzzer_state.add_refusal(
module_name, prompt, response.status_code, response_text
)
outputs.append((module_name, prompt, response_text, refused))
fuzzer_state.add_output(module_name, prompt, response_text, refused)
return tokens, refused
except httpx.RequestError as exc:
logger.error(f"Request error: {exc}")
errors.append((module_name, prompt, "?", str(exc)))
fuzzer_state.add_error(module_name, prompt, "?", str(exc))
return tokens, True
except JSONDecodeError as json_decode_error:
logger.error(f"Jason error: {json_decode_error}")
errors.append((module_name, prompt, "?", str(json_decode_error)))
logger.error(f"JSON error: {json_decode_error}")
fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error))
return tokens, True
except Exception as e:
logger.exception(f"Unexpected error: {e}")
return tokens, False
async def process_prompt_batch(
@@ -134,9 +148,7 @@ async def process_prompt_batch(
prompts: list[str],
tokens: int,
module_name: str,
refusals,
errors,
outputs,
fuzzer_state: FuzzerState,
) -> tuple[int, int]:
"""
Processes a batch of prompts asynchronously and aggregates the results.
@@ -150,9 +162,7 @@ async def process_prompt_batch(
prompts (list[str]): A list of input prompts to be processed.
tokens (int): The initial token count, which will be updated.
module_name (str): The name of the module handling the request.
refusals (list): A list to store prompts that were refused.
errors (list): A list to store prompts that encountered errors.
outputs (list): A list to store processed prompt outputs.
fuzzer_state: State tracking object for the fuzzer
Returns:
tuple[int, int]:
@@ -160,9 +170,7 @@ async def process_prompt_batch(
- Number of failed prompts.
"""
tasks = [
process_prompt(
request_factory, p, tokens, module_name, refusals, errors, outputs
)
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
for p in prompts
]
results = await asyncio.gather(*tasks)
@@ -171,6 +179,129 @@ async def process_prompt_batch(
return total_tokens, failures
async def scan_module(
request_factory,
module,
fuzzer_state: FuzzerState,
processed_prompts: int = 0,
total_prompts: int = 0,
max_budget: int = 0,
total_tokens: int = 0,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
) -> AsyncGenerator[dict[str, Any], None]:
"""
Scan a single module.
Args:
request_factory: The factory for creating requests
module: The prompt module to scan
fuzzer_state: State tracking object for the fuzzer
processed_prompts: Number of prompts processed so far
total_prompts: Total number of prompts to process
max_budget: Maximum token budget
total_tokens: Current token count
optimize: Whether to use optimization
stop_event: Event to stop scanning
Yields:
ScanResult objects as the scan progresses
"""
tokens = 0
module_failures = 0
module_prompts = 0
failure_rates = []
should_stop = False
# Initialize optimizer if optimization is enabled
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
)
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
processed_prompts += 1
module_prompts += 1
# Calculate progress based on total processed prompts
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
progress = progress % 100
total_tokens -= tokens
start = time.time()
tokens, failed = await process_prompt(
request_factory,
prompt,
tokens,
module.dataset_name,
fuzzer_state=fuzzer_state,
)
end = time.time()
total_tokens += tokens
if failed:
module_failures += 1
failure_rate = module_failures / max(module_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
response_text = fuzzer_state.get_last_output(prompt) or ""
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
latency=end - start,
model=response_text,
).model_dump_json()
# Optimization logic
if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > FAILURE_RATE_THRESHOLD:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
should_stop = True
break
# Budget check
if total_tokens > max_budget:
logger.info(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
yield ScanResult.status_msg(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
should_stop = True
break
if should_stop:
break
return
async def with_error_handling(agen):
"""
Wraps an asynchronous generator with error handling.
@@ -201,8 +332,8 @@ async def perform_single_shot_scan(
max_budget: int,
datasets: list[dict[str, str]] = [],
tools_inbox=None,
optimize=False,
stop_event: asyncio.Event = None,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
secrets: dict[str, str] = {},
) -> AsyncGenerator[str, None]:
"""
@@ -230,7 +361,8 @@ async def perform_single_shot_scan(
"""
max_budget = max_budget * BUDGET_MULTIPLIER
selected_datasets = [m for m in datasets if m["selected"]]
request_factory = multi_modality_spec(request_factory)
request_factory = get_modality_adapter(request_factory)
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in selected_datasets],
@@ -240,108 +372,35 @@ async def perform_single_shot_scan(
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
fuzzer_state = FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
if optimize
else None
)
failure_rates = []
total_tokens = 0
tokens = 0
should_stop = False
for module in prompt_modules:
if should_stop:
break
tokens = 0
module_failures = 0
module_gen = scan_module(
request_factory=request_factory,
module=module,
fuzzer_state=fuzzer_state,
processed_prompts=processed_prompts,
total_prompts=total_prompts,
max_budget=max_budget,
total_tokens=total_tokens,
optimize=optimize,
stop_event=stop_event,
)
try:
async for result in module_gen:
yield result
except Exception:
logger.error("Module exception")
continue
# Update processed_prompts count
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
module_prompts = 0 # Reset for each module
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
processed_prompts += 1
module_prompts += 1 # Fixed increment syntax
# Calculate progress based on total processed prompts
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
progress = progress % 100
total_tokens -= tokens
start = time.time()
tokens, failed = await process_prompt(
request_factory,
prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
)
end = time.time()
total_tokens += tokens
if failed:
module_failures += 1
failure_rate = module_failures / max(module_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
last_output = outputs[-1] if outputs else None
if last_output and last_output[1] == prompt:
response_text = last_output[2]
else:
response_text = ""
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
latency=end - start,
model=response_text,
).model_dump_json()
if optimize and len(failure_rates) >= 5:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > 0.5:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
should_stop = True
break
if total_tokens > max_budget:
logger.info(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
yield ScanResult.status_msg(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
should_stop = True
break
processed_prompts += module_size
yield ScanResult.status_msg("Scan completed.")
failure_data = errors + refusals
df = pd.DataFrame(
failure_data, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv("failures.csv", index=False)
fuzzer_state.export_failures("failures.csv")
async def perform_many_shot_scan(
@@ -350,8 +409,8 @@ async def perform_many_shot_scan(
datasets: list[dict[str, str]] = [],
probe_datasets: list[dict[str, str]] = [],
tools_inbox=None,
optimize=False,
stop_event: asyncio.Event = None,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
probe_frequency: float = 0.2,
max_ctx_length: int = 10_000,
secrets: dict[str, str] = {},
@@ -382,7 +441,7 @@ async def perform_many_shot_scan(
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
"""
request_factory = multi_modality_spec(request_factory)
request_factory = get_modality_adapter(request_factory)
# Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
@@ -394,17 +453,10 @@ async def perform_many_shot_scan(
msj_modules = msj_data.prepare_prompts(probe_datasets)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
fuzzer_state = FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
if optimize
else None
)
failure_rates = []
for module in prompt_modules:
@@ -418,6 +470,7 @@ async def perform_many_shot_scan(
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
tokens = 0
processed_prompts += 1
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
@@ -445,9 +498,7 @@ async def perform_many_shot_scan(
full_prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
fuzzer_state=fuzzer_state,
)
if failed:
module_failures += 1
@@ -468,29 +519,21 @@ async def perform_many_shot_scan(
prompt=prompt[:MAX_PROMPT_LENGTH],
).model_dump_json()
if optimize and len(failure_rates) >= 5:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > 0.5:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
break
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
yield ScanResult.status_msg(
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
)
break
yield ScanResult.status_msg("Scan completed.")
df = pd.DataFrame(
errors + refusals, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv("failures.csv", index=False)
fuzzer_state.export_failures("failures.csv")
def scan_router(
request_factory,
scan_parameters: Scan,
tools_inbox=None,
stop_event: asyncio.Event = None,
stop_event: asyncio.Event | None = None,
):
"""
Route scan requests to the appropriate scanning function.
+47
View File
@@ -0,0 +1,47 @@
import pandas as pd
class FuzzerState:
"""Container for tracking scan results"""
def __init__(self):
self.errors = []
self.refusals = []
self.outputs = []
def add_error(
self,
module_name: str,
prompt: str,
status_code: int | str,
error_msg: str,
):
"""Add an error to the state"""
self.errors.append((module_name, prompt, status_code, error_msg))
def add_refusal(
self, module_name: str, prompt: str, status_code: int, response_text: str
):
"""Add a refusal to the state"""
self.refusals.append((module_name, prompt, status_code, response_text))
def add_output(
self, module_name: str, prompt: str, response_text: str, refused: bool
):
"""Add an output to the state"""
self.outputs.append((module_name, prompt, response_text, refused))
def get_last_output(self, prompt: str) -> str | None:
"""Get the last output for a given prompt"""
for output in reversed(self.outputs):
if output[1] == prompt:
return output[2]
return None
def export_failures(self, filename: str = "failures.csv"):
"""Export failures to a CSV file"""
failure_data = self.errors + self.refusals
df = pd.DataFrame(
failure_data, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv(filename, index=False)
Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 MiB

Generated
+3 -3
View File
@@ -1283,14 +1283,14 @@ files = [
[[package]]
name = "inline-snapshot"
version = "0.20.3"
version = "0.20.5"
description = "golden master/snapshot/approval testing library which puts the values right into your source code"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "inline_snapshot-0.20.3-py3-none-any.whl", hash = "sha256:1ea999fbf38dd11cc72d0e1a0b9303c63d496b77bdc406a394fe2424ae842f70"},
{file = "inline_snapshot-0.20.3.tar.gz", hash = "sha256:7a353170b7e42aa89086c7ba790a973c9645523acf985532648dabd7ee2d71f2"},
{file = "inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a"},
{file = "inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5"},
]
[package.dependencies]
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "agentic_security"
version = "0.6.0"
version = "0.7.0"
description = "Agentic LLM vulnerability scanner"
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
+9 -19
View File
@@ -7,6 +7,7 @@ import pytest
from agentic_security.primitives import Scan
from agentic_security.probe_actor.fuzzer import (
FuzzerState,
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
@@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=[],
outputs=[],
fuzzer_state=FuzzerState(),
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
@@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
refusals = []
outputs = []
fuzzer_state = FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=outputs,
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
self.assertFalse(refusal)
# self.assertFalse(fuzzer_state.refusals)
async def test_http_error_response(self):
mock_request_factory = Mock()
@@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
refusals = []
fuzzer_state = FuzzerState()
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=[],
fuzzer_state=fuzzer_state,
)
async def test_request_error(self):
@@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
side_effect=httpx.RequestError("Connection error")
)
errors = []
fuzzer_state = FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=errors,
outputs=[],
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
self.assertEqual(len(errors), 1)
self.assertIn("Connection error", errors[0][3])