mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(_FuzzerState nt):
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from collections.abc import AsyncGenerator
|
||||
from json import JSONDecodeError
|
||||
|
||||
@@ -26,6 +27,12 @@ MIN_FAILURE_SAMPLES = 5
|
||||
FAILURE_RATE_THRESHOLD = 0.5
|
||||
|
||||
|
||||
def _FuzzerState():
|
||||
return namedtuple(
|
||||
"_FuzzerState", ["errors", "refusals", "outputs"], defaults=([], [], [])
|
||||
)()
|
||||
|
||||
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
@@ -50,8 +57,8 @@ def multi_modality_spec(llm_spec):
|
||||
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors, outputs
|
||||
) -> tuple[int, bool]:
|
||||
request_factory, prompt, tokens, module_name, fuzzer_state: _FuzzerState
|
||||
):
|
||||
"""
|
||||
Process a single prompt and update the token count and failure status.
|
||||
"""
|
||||
@@ -59,30 +66,34 @@ async def process_prompt(
|
||||
response = await request_factory.fn(prompt=prompt)
|
||||
if response.status_code == 422:
|
||||
logger.error(f"Invalid prompt: {prompt}, error=422")
|
||||
errors.append((module_name, prompt, 422, "Invalid prompt"))
|
||||
fuzzer_state.errors.append((module_name, prompt, 422, "Invalid prompt"))
|
||||
return tokens, True
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.error(f"HTTP {response.status_code} {response.content=}")
|
||||
errors.append((module_name, prompt, response.status_code, response.text))
|
||||
fuzzer_state.errors.append(
|
||||
(module_name, prompt, response.status_code, response.text)
|
||||
)
|
||||
return tokens, True
|
||||
response_text = response.text
|
||||
tokens += len(response_text.split())
|
||||
|
||||
refused = refusal_heuristic(response.json())
|
||||
if refused:
|
||||
refusals.append((module_name, prompt, response.status_code, response_text))
|
||||
fuzzer_state.refusals.append(
|
||||
(module_name, prompt, response.status_code, response_text)
|
||||
)
|
||||
|
||||
outputs.append((module_name, prompt, response_text, refused))
|
||||
fuzzer_state.outputs.append((module_name, prompt, response_text, refused))
|
||||
return tokens, refused
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error(f"Request error: {exc}")
|
||||
errors.append((module_name, prompt, "?", str(exc)))
|
||||
fuzzer_state.errors.append((module_name, prompt, "?", str(exc)))
|
||||
return tokens, True
|
||||
except JSONDecodeError as json_decode_error:
|
||||
logger.error(f"Jason error: {json_decode_error}")
|
||||
errors.append((module_name, prompt, "?", str(json_decode_error)))
|
||||
fuzzer_state.errors.append((module_name, prompt, "?", str(json_decode_error)))
|
||||
return tokens, True
|
||||
except Exception:
|
||||
logger.exception("Oups")
|
||||
@@ -94,14 +105,10 @@ async def process_prompt_batch(
|
||||
prompts: list[str],
|
||||
tokens: int,
|
||||
module_name: str,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state: _FuzzerState,
|
||||
) -> tuple[int, int]:
|
||||
tasks = [
|
||||
process_prompt(
|
||||
request_factory, p, tokens, module_name, refusals, errors, outputs
|
||||
)
|
||||
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
|
||||
for p in prompts
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -143,9 +150,7 @@ async def perform_single_shot_scan(
|
||||
)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = _FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
@@ -188,9 +193,7 @@ async def perform_single_shot_scan(
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
end = time.time()
|
||||
total_tokens += tokens
|
||||
@@ -201,7 +204,7 @@ async def perform_single_shot_scan(
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
last_output = outputs[-1] if outputs else None
|
||||
last_output = fuzzer_state.outputs[-1] if fuzzer_state.outputs else None
|
||||
if last_output and last_output[1] == prompt:
|
||||
response_text = last_output[2]
|
||||
else:
|
||||
@@ -240,7 +243,7 @@ async def perform_single_shot_scan(
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
failure_data = errors + refusals
|
||||
failure_data = fuzzer_state.errors + fuzzer_state.refusals
|
||||
df = pd.DataFrame(
|
||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
@@ -272,9 +275,7 @@ async def perform_many_shot_scan(
|
||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = _FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
@@ -323,9 +324,7 @@ async def perform_many_shot_scan(
|
||||
full_prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
@@ -359,7 +358,8 @@ async def perform_many_shot_scan(
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
df = pd.DataFrame(
|
||||
errors + refusals, columns=["module", "prompt", "status_code", "content"]
|
||||
fuzzer_state.errors + fuzzer_state.refusals,
|
||||
columns=["module", "prompt", "status_code", "content"],
|
||||
)
|
||||
df.to_csv("failures.csv", index=False)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from agentic_security.primitives import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
_FuzzerState,
|
||||
generate_prompts,
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
@@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=_FuzzerState(),
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||
@@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = _FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||
self.assertFalse(refusal)
|
||||
# self.assertFalse(fuzzer_state.refusals)
|
||||
|
||||
async def test_http_error_response(self):
|
||||
mock_request_factory = Mock()
|
||||
@@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
fuzzer_state = _FuzzerState()
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
@@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
side_effect=httpx.RequestError("Connection error")
|
||||
)
|
||||
|
||||
errors = []
|
||||
fuzzer_state = _FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=errors,
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertTrue(refusal)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIn("Connection error", errors[0][3])
|
||||
|
||||
Reference in New Issue
Block a user