fix(fix network error handling in fuzzer):

This commit is contained in:
Alexander Myasoedov
2025-02-17 18:01:38 +02:00
parent 46ef89355b
commit 99fc8cb2e7
3 changed files with 28 additions and 16 deletions
+6
View File
@@ -32,6 +32,9 @@ class ScanResult(BaseModel):
progress: float
status: bool = False
failureRate: float = 0.0
prompt: str = ""
model: str = ""
refused: bool = False
@classmethod
def status_msg(cls, msg: str) -> str:
@@ -42,6 +45,9 @@ class ScanResult(BaseModel):
progress=0,
failureRate=0,
status=True,
prompt="",
model="",
refused=False,
).model_dump_json()
+14 -7
View File
@@ -17,6 +17,8 @@ from agentic_security.probe_data.data import prepare_prompts
# TODO: full log file
MAX_PROMPT_LENGTH = 2048
async def generate_prompts(
prompts: list[str] | AsyncGenerator,
@@ -43,7 +45,10 @@ def multi_modality_spec(llm_spec):
async def process_prompt(
request_factory, prompt, tokens, module_name, refusals, errors
):
) -> tuple[int, bool]:
"""
Process a single prompt and update the token count and failure status.
"""
try:
response = await request_factory.fn(prompt=prompt)
if response.status_code == 422:
@@ -52,11 +57,9 @@ async def process_prompt(
return tokens, True
if response.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {response.status_code} {response.content=}",
request=response.request,
response=response,
)
logger.error(f"HTTP {response.status_code} {response.content=}")
errors.append((module_name, prompt, response.status_code, response.text))
return tokens, True
response_text = response.text
tokens += len(response_text.split())
@@ -150,6 +153,7 @@ async def perform_single_shot_scan(
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
).model_dump_json()
if optimize and len(failure_rates) >= 5:
@@ -183,7 +187,9 @@ async def perform_single_shot_scan(
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
raise e
# raise e
finally:
yield ScanResult.status_msg("Scan completed.")
async def perform_many_shot_scan(
@@ -281,6 +287,7 @@ async def perform_many_shot_scan(
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
).model_dump_json()
if optimize and len(failure_rates) >= 5:
+8 -9
View File
@@ -250,15 +250,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
refusals = []
with self.assertRaises(httpx.HTTPStatusError):
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
)
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
)
async def test_request_error(self):
mock_request_factory = Mock()