mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(fix network error handling in fuzzer):
This commit is contained in:
@@ -32,6 +32,9 @@ class ScanResult(BaseModel):
|
||||
progress: float
|
||||
status: bool = False
|
||||
failureRate: float = 0.0
|
||||
prompt: str = ""
|
||||
model: str = ""
|
||||
refused: bool = False
|
||||
|
||||
@classmethod
|
||||
def status_msg(cls, msg: str) -> str:
|
||||
@@ -42,6 +45,9 @@ class ScanResult(BaseModel):
|
||||
progress=0,
|
||||
failureRate=0,
|
||||
status=True,
|
||||
prompt="",
|
||||
model="",
|
||||
refused=False,
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
# TODO: full log file
|
||||
|
||||
MAX_PROMPT_LENGTH = 2048
|
||||
|
||||
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
@@ -43,7 +45,10 @@ def multi_modality_spec(llm_spec):
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors
|
||||
):
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Process a single prompt and update the token count and failure status.
|
||||
"""
|
||||
try:
|
||||
response = await request_factory.fn(prompt=prompt)
|
||||
if response.status_code == 422:
|
||||
@@ -52,11 +57,9 @@ async def process_prompt(
|
||||
return tokens, True
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {response.status_code} {response.content=}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
logger.error(f"HTTP {response.status_code} {response.content=}")
|
||||
errors.append((module_name, prompt, response.status_code, response.text))
|
||||
return tokens, True
|
||||
response_text = response.text
|
||||
tokens += len(response_text.split())
|
||||
|
||||
@@ -150,6 +153,7 @@ async def perform_single_shot_scan(
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
@@ -183,7 +187,9 @@ async def perform_single_shot_scan(
|
||||
except Exception as e:
|
||||
logger.exception("Scan failed")
|
||||
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
||||
raise e
|
||||
# raise e
|
||||
finally:
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
|
||||
async def perform_many_shot_scan(
|
||||
@@ -281,6 +287,7 @@ async def perform_many_shot_scan(
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
|
||||
@@ -250,15 +250,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
refusals = []
|
||||
with self.assertRaises(httpx.HTTPStatusError):
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
)
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
mock_request_factory = Mock()
|
||||
|
||||
Reference in New Issue
Block a user