feat(process_prompt):

This commit is contained in:
Alexander Myasoedov
2024-12-11 18:02:26 +02:00
parent b5ecc28ab6
commit 314617651f
2 changed files with 147 additions and 61 deletions
+46 -59
View File
@@ -25,6 +25,30 @@ async def generate_prompts(
yield prompt
async def process_prompt(
request_factory, prompt, tokens, module_name, refusals, errors
):
try:
response = await request_factory.fn(prompt=prompt)
if response.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {response.status_code}",
request=response.request,
response=response,
)
response_text = response.text
tokens += len(response_text.split())
if not refusal_heuristic(response.json()):
refusals.append((module_name, prompt, response.status_code, response_text))
return tokens, False
except httpx.RequestError as exc:
logger.error(f"Request error: {exc}")
errors.append((module_name, prompt, str(exc)))
return tokens, True
async def perform_single_shot_scan(
request_factory,
max_budget: int,
@@ -73,34 +97,17 @@ async def perform_single_shot_scan(
progress = (
100 * processed_prompts / total_prompts if total_prompts else 0
)
prompt_tokens = len(prompt.split())
tokens += prompt_tokens
try:
r = await request_factory.fn(prompt=prompt)
if r.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {r.status_code}",
request=r.request,
response=r,
)
response_text = r.text
response_tokens = len(response_text.split())
tokens += response_tokens
if not refusal_heuristic(r.json()):
refusals.append(
(module.dataset_name, prompt, r.status_code, response_text)
)
module_failures += 1
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
errors.append((module.dataset_name, prompt, str(e)))
tokens, failed = await process_prompt(
request_factory,
prompt,
tokens,
module.dataset_name,
refusals,
errors,
)
if failed:
module_failures += 1
continue
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)
cost = round(tokens * 1.5 / 1000_000, 2)
@@ -202,42 +209,23 @@ async def perform_many_shot_scan(
m_prompt = random.choice(msj_module.prompts)
full_prompt += "\n" + m_prompt
tokens += len(m_prompt.split())
if tokens > max_ctx_length:
full_prompt = "\n" + prompt
injected = True
try:
# Main request
r = await request_factory.fn(prompt=full_prompt)
if r.status_code >= 400:
logger.error(f"HTTP {r.status_code}")
raise httpx.HTTPStatusError(
f"HTTP {r.status_code}",
request=r.request,
response=r,
)
response_text = r.text
response_tokens = len(response_text.split())
tokens += response_tokens
logger.info(f"Response: tokens: {full_prompt}")
if injected and not refusal_heuristic(r.json()):
logger.info(f"Refusal detected: {response_text}")
refusals.append(
(
module.dataset_name,
full_prompt,
r.status_code,
response_text,
)
)
module_failures += 1
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
errors.append((module.dataset_name, full_prompt, str(e)))
tokens, failed = await process_prompt(
request_factory,
full_prompt,
tokens,
module.dataset_name,
refusals,
errors,
)
if failed:
module_failures += 1
continue
break
if injected:
break
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)
@@ -260,8 +248,7 @@ async def perform_many_shot_scan(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
break
if injected:
break
yield ScanResult.status_msg("Scan completed.")
df = pd.DataFrame(
+101 -2
View File
@@ -1,13 +1,16 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import unittest
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import httpx
import pytest
from agentic_security.models.schemas import Scan
from agentic_security.probe_actor.fuzzer import (
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
generate_prompts,
process_prompt,
scan_router,
)
@@ -181,3 +184,99 @@ async def test_perform_many_shot_scan_stop_event():
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
def mock_refusal_heuristic(response_json):
return response_json.get("is_refusal", False)
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
async def test_successful_response_no_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Valid response text",
json=Mock(return_value={"is_refusal": False}),
request="mock_request",
)
)
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=[],
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
self.assertFalse(refusal)
async def test_successful_response_with_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Response indicating refusal",
json=Mock(return_value={"is_refusal": True}),
request="mock_request",
)
)
refusals = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
self.assertFalse(refusal)
async def test_http_error_response(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=500,
text="Internal Server Error",
request="mock_request",
response=Mock(),
)
)
refusals = []
with self.assertRaises(httpx.HTTPStatusError):
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
)
async def test_request_error(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
side_effect=httpx.RequestError("Connection error")
)
errors = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=errors,
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
self.assertEqual(len(errors), 1)
self.assertIn("Connection error", errors[0][2])