From 314617651f30bb8d753ae685cb1e0d8c85ff1b31 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Wed, 11 Dec 2024 18:02:26 +0200 Subject: [PATCH] feat(process_prompt): --- agentic_security/probe_actor/fuzzer.py | 105 +++++++++----------- agentic_security/probe_actor/test_fuzzer.py | 103 ++++++++++++++++++- 2 files changed, 147 insertions(+), 61 deletions(-) diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 9a5e697..cfb7d9b 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -25,6 +25,30 @@ async def generate_prompts( yield prompt +async def process_prompt( + request_factory, prompt, tokens, module_name, refusals, errors +): + try: + response = await request_factory.fn(prompt=prompt) + if response.status_code >= 400: + raise httpx.HTTPStatusError( + f"HTTP {response.status_code}", + request=response.request, + response=response, + ) + response_text = response.text + tokens += len(response_text.split()) + + if not refusal_heuristic(response.json()): + refusals.append((module_name, prompt, response.status_code, response_text)) + return tokens, False + + except httpx.RequestError as exc: + logger.error(f"Request error: {exc}") + errors.append((module_name, prompt, str(exc))) + return tokens, True + + async def perform_single_shot_scan( request_factory, max_budget: int, @@ -73,34 +97,17 @@ async def perform_single_shot_scan( progress = ( 100 * processed_prompts / total_prompts if total_prompts else 0 ) - prompt_tokens = len(prompt.split()) - tokens += prompt_tokens - try: - r = await request_factory.fn(prompt=prompt) - if r.status_code >= 400: - raise httpx.HTTPStatusError( - f"HTTP {r.status_code}", - request=r.request, - response=r, - ) - - response_text = r.text - response_tokens = len(response_text.split()) - tokens += response_tokens - - if not refusal_heuristic(r.json()): - refusals.append( - (module.dataset_name, prompt, r.status_code, response_text) - ) - module_failures += 1 - - except httpx.RequestError as e: - logger.error(f"Request error: {e}") - errors.append((module.dataset_name, prompt, str(e))) + tokens, failed = await process_prompt( + request_factory, + prompt, + tokens, + module.dataset_name, + refusals, + errors, + ) + if failed: module_failures += 1 - continue - failure_rate = module_failures / max(processed_prompts, 1) failure_rates.append(failure_rate) cost = round(tokens * 1.5 / 1000_000, 2) @@ -202,42 +209,23 @@ async def perform_many_shot_scan( m_prompt = random.choice(msj_module.prompts) full_prompt += "\n" + m_prompt - tokens += len(m_prompt.split()) if tokens > max_ctx_length: full_prompt = "\n" + prompt injected = True - try: - # Main request - r = await request_factory.fn(prompt=full_prompt) - if r.status_code >= 400: - logger.error(f"HTTP {r.status_code}") - raise httpx.HTTPStatusError( - f"HTTP {r.status_code}", - request=r.request, - response=r, - ) - response_text = r.text - response_tokens = len(response_text.split()) - tokens += response_tokens - logger.info(f"Response: tokens: {full_prompt}") - if injected and not refusal_heuristic(r.json()): - logger.info(f"Refusal detected: {response_text}") - refusals.append( - ( - module.dataset_name, - full_prompt, - r.status_code, - response_text, - ) - ) - module_failures += 1 - - except httpx.RequestError as e: - logger.error(f"Request error: {e}") - errors.append((module.dataset_name, full_prompt, str(e))) + tokens, failed = await process_prompt( + request_factory, + full_prompt, + tokens, + module.dataset_name, + refusals, + errors, + ) + if failed: module_failures += 1 - continue + break + if injected: + break failure_rate = module_failures / max(processed_prompts, 1) failure_rates.append(failure_rate) @@ -260,8 +248,7 @@ async def perform_many_shot_scan( f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..." ) break - if injected: - break + yield ScanResult.status_msg("Scan completed.") df = pd.DataFrame( diff --git a/agentic_security/probe_actor/test_fuzzer.py b/agentic_security/probe_actor/test_fuzzer.py index 6db487f..c92ef30 100644 --- a/agentic_security/probe_actor/test_fuzzer.py +++ b/agentic_security/probe_actor/test_fuzzer.py @@ -1,13 +1,16 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, patch +import httpx import pytest from agentic_security.models.schemas import Scan from agentic_security.probe_actor.fuzzer import ( + generate_prompts, perform_many_shot_scan, perform_single_shot_scan, - generate_prompts, + process_prompt, scan_router, ) @@ -181,3 +184,99 @@ async def test_perform_many_shot_scan_stop_event(): ) await assert_scan(async_gen, ["Loading", "Scan completed."]) + + +def mock_refusal_heuristic(response_json): + return response_json.get("is_refusal", False) + + +class TestProcessPrompt(unittest.IsolatedAsyncioTestCase): + async def test_successful_response_no_refusal(self): + mock_request_factory = Mock() + mock_request_factory.fn = AsyncMock( + return_value=Mock( + status_code=200, + text="Valid response text", + json=Mock(return_value={"is_refusal": False}), + request="mock_request", + ) + ) + + tokens, refusal = await process_prompt( + request_factory=mock_request_factory, + prompt="test prompt", + tokens=0, + module_name="module_a", + refusals=[], + errors=[], + ) + + self.assertEqual(tokens, 3) # Tokens from "Valid response text" + self.assertFalse(refusal) + + async def test_successful_response_with_refusal(self): + mock_request_factory = Mock() + mock_request_factory.fn = AsyncMock( + return_value=Mock( + status_code=200, + text="Response indicating refusal", + json=Mock(return_value={"is_refusal": True}), + request="mock_request", + ) + ) + + refusals = [] + tokens, refusal = await process_prompt( + request_factory=mock_request_factory, + prompt="test prompt", + tokens=0, + module_name="module_a", + refusals=refusals, + errors=[], + ) + + self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal" + self.assertFalse(refusal) + + async def test_http_error_response(self): + mock_request_factory = Mock() + mock_request_factory.fn = AsyncMock( + return_value=Mock( + status_code=500, + text="Internal Server Error", + request="mock_request", + response=Mock(), + ) + ) + + refusals = [] + with self.assertRaises(httpx.HTTPStatusError): + await process_prompt( + request_factory=mock_request_factory, + prompt="test prompt", + tokens=0, + module_name="module_a", + refusals=refusals, + errors=[], + ) + + async def test_request_error(self): + mock_request_factory = Mock() + mock_request_factory.fn = AsyncMock( + side_effect=httpx.RequestError("Connection error") + ) + + errors = [] + tokens, refusal = await process_prompt( + request_factory=mock_request_factory, + prompt="test prompt", + tokens=0, + module_name="module_a", + refusals=[], + errors=errors, + ) + + self.assertEqual(tokens, 0) + self.assertTrue(refusal) + self.assertEqual(len(errors), 1) + self.assertIn("Connection error", errors[0][2])