fix(_FuzzerState nt):

This commit is contained in:
Alexander Myasoedov
2025-03-12 19:18:01 +02:00
parent b4857a5f36
commit 839c1af9d7
2 changed files with 38 additions and 48 deletions
+29 -29
View File
@@ -1,6 +1,7 @@
import asyncio
import random
import time
from collections import namedtuple
from collections.abc import AsyncGenerator
from json import JSONDecodeError
@@ -26,6 +27,12 @@ MIN_FAILURE_SAMPLES = 5
FAILURE_RATE_THRESHOLD = 0.5
def _FuzzerState():
return namedtuple(
"_FuzzerState", ["errors", "refusals", "outputs"], defaults=([], [], [])
)()
async def generate_prompts(
prompts: list[str] | AsyncGenerator,
) -> AsyncGenerator[str, None]:
@@ -50,8 +57,8 @@ def multi_modality_spec(llm_spec):
async def process_prompt(
request_factory, prompt, tokens, module_name, refusals, errors, outputs
) -> tuple[int, bool]:
request_factory, prompt, tokens, module_name, fuzzer_state: _FuzzerState
):
"""
Process a single prompt and update the token count and failure status.
"""
@@ -59,30 +66,34 @@ async def process_prompt(
response = await request_factory.fn(prompt=prompt)
if response.status_code == 422:
logger.error(f"Invalid prompt: {prompt}, error=422")
errors.append((module_name, prompt, 422, "Invalid prompt"))
fuzzer_state.errors.append((module_name, prompt, 422, "Invalid prompt"))
return tokens, True
if response.status_code >= 400:
logger.error(f"HTTP {response.status_code} {response.content=}")
errors.append((module_name, prompt, response.status_code, response.text))
fuzzer_state.errors.append(
(module_name, prompt, response.status_code, response.text)
)
return tokens, True
response_text = response.text
tokens += len(response_text.split())
refused = refusal_heuristic(response.json())
if refused:
refusals.append((module_name, prompt, response.status_code, response_text))
fuzzer_state.refusals.append(
(module_name, prompt, response.status_code, response_text)
)
outputs.append((module_name, prompt, response_text, refused))
fuzzer_state.outputs.append((module_name, prompt, response_text, refused))
return tokens, refused
except httpx.RequestError as exc:
logger.error(f"Request error: {exc}")
errors.append((module_name, prompt, "?", str(exc)))
fuzzer_state.errors.append((module_name, prompt, "?", str(exc)))
return tokens, True
except JSONDecodeError as json_decode_error:
logger.error(f"Jason error: {json_decode_error}")
errors.append((module_name, prompt, "?", str(json_decode_error)))
fuzzer_state.errors.append((module_name, prompt, "?", str(json_decode_error)))
return tokens, True
except Exception:
logger.exception("Oups")
@@ -94,14 +105,10 @@ async def process_prompt_batch(
prompts: list[str],
tokens: int,
module_name: str,
refusals,
errors,
outputs,
fuzzer_state: _FuzzerState,
) -> tuple[int, int]:
tasks = [
process_prompt(
request_factory, p, tokens, module_name, refusals, errors, outputs
)
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
for p in prompts
]
results = await asyncio.gather(*tasks)
@@ -143,9 +150,7 @@ async def perform_single_shot_scan(
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
fuzzer_state = _FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
@@ -188,9 +193,7 @@ async def perform_single_shot_scan(
prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
fuzzer_state=fuzzer_state,
)
end = time.time()
total_tokens += tokens
@@ -201,7 +204,7 @@ async def perform_single_shot_scan(
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
last_output = outputs[-1] if outputs else None
last_output = fuzzer_state.outputs[-1] if fuzzer_state.outputs else None
if last_output and last_output[1] == prompt:
response_text = last_output[2]
else:
@@ -240,7 +243,7 @@ async def perform_single_shot_scan(
yield ScanResult.status_msg("Scan completed.")
failure_data = errors + refusals
failure_data = fuzzer_state.errors + fuzzer_state.refusals
df = pd.DataFrame(
failure_data, columns=["module", "prompt", "status_code", "content"]
)
@@ -272,9 +275,7 @@ async def perform_many_shot_scan(
msj_modules = msj_data.prepare_prompts(probe_datasets)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
fuzzer_state = _FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
@@ -323,9 +324,7 @@ async def perform_many_shot_scan(
full_prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
fuzzer_state=fuzzer_state,
)
if failed:
module_failures += 1
@@ -359,7 +358,8 @@ async def perform_many_shot_scan(
yield ScanResult.status_msg("Scan completed.")
df = pd.DataFrame(
errors + refusals, columns=["module", "prompt", "status_code", "content"]
fuzzer_state.errors + fuzzer_state.refusals,
columns=["module", "prompt", "status_code", "content"],
)
df.to_csv("failures.csv", index=False)
+9 -19
View File
@@ -7,6 +7,7 @@ import pytest
from agentic_security.primitives import Scan
from agentic_security.probe_actor.fuzzer import (
_FuzzerState,
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
@@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=[],
outputs=[],
fuzzer_state=_FuzzerState(),
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
@@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
refusals = []
outputs = []
fuzzer_state = _FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=outputs,
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
self.assertFalse(refusal)
# self.assertFalse(fuzzer_state.refusals)
async def test_http_error_response(self):
mock_request_factory = Mock()
@@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
refusals = []
fuzzer_state = _FuzzerState()
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=[],
fuzzer_state=fuzzer_state,
)
async def test_request_error(self):
@@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
side_effect=httpx.RequestError("Connection error")
)
errors = []
fuzzer_state = _FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=errors,
outputs=[],
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
self.assertEqual(len(errors), 1)
self.assertIn("Connection error", errors[0][3])