mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(process_prompt):
This commit is contained in:
@@ -25,6 +25,30 @@ async def generate_prompts(
|
||||
yield prompt
|
||||
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors
|
||||
):
|
||||
try:
|
||||
response = await request_factory.fn(prompt=prompt)
|
||||
if response.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
response_text = response.text
|
||||
tokens += len(response_text.split())
|
||||
|
||||
if not refusal_heuristic(response.json()):
|
||||
refusals.append((module_name, prompt, response.status_code, response_text))
|
||||
return tokens, False
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error(f"Request error: {exc}")
|
||||
errors.append((module_name, prompt, str(exc)))
|
||||
return tokens, True
|
||||
|
||||
|
||||
async def perform_single_shot_scan(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
@@ -73,34 +97,17 @@ async def perform_single_shot_scan(
|
||||
progress = (
|
||||
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
)
|
||||
prompt_tokens = len(prompt.split())
|
||||
tokens += prompt_tokens
|
||||
|
||||
try:
|
||||
r = await request_factory.fn(prompt=prompt)
|
||||
if r.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {r.status_code}",
|
||||
request=r.request,
|
||||
response=r,
|
||||
)
|
||||
|
||||
response_text = r.text
|
||||
response_tokens = len(response_text.split())
|
||||
tokens += response_tokens
|
||||
|
||||
if not refusal_heuristic(r.json()):
|
||||
refusals.append(
|
||||
(module.dataset_name, prompt, r.status_code, response_text)
|
||||
)
|
||||
module_failures += 1
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
errors.append((module.dataset_name, prompt, str(e)))
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
continue
|
||||
|
||||
failure_rate = module_failures / max(processed_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = round(tokens * 1.5 / 1000_000, 2)
|
||||
@@ -202,42 +209,23 @@ async def perform_many_shot_scan(
|
||||
|
||||
m_prompt = random.choice(msj_module.prompts)
|
||||
full_prompt += "\n" + m_prompt
|
||||
tokens += len(m_prompt.split())
|
||||
if tokens > max_ctx_length:
|
||||
full_prompt = "\n" + prompt
|
||||
injected = True
|
||||
try:
|
||||
# Main request
|
||||
r = await request_factory.fn(prompt=full_prompt)
|
||||
if r.status_code >= 400:
|
||||
logger.error(f"HTTP {r.status_code}")
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {r.status_code}",
|
||||
request=r.request,
|
||||
response=r,
|
||||
)
|
||||
|
||||
response_text = r.text
|
||||
response_tokens = len(response_text.split())
|
||||
tokens += response_tokens
|
||||
logger.info(f"Response: tokens: {full_prompt}")
|
||||
if injected and not refusal_heuristic(r.json()):
|
||||
logger.info(f"Refusal detected: {response_text}")
|
||||
refusals.append(
|
||||
(
|
||||
module.dataset_name,
|
||||
full_prompt,
|
||||
r.status_code,
|
||||
response_text,
|
||||
)
|
||||
)
|
||||
module_failures += 1
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
errors.append((module.dataset_name, full_prompt, str(e)))
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
full_prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
continue
|
||||
break
|
||||
if injected:
|
||||
break
|
||||
|
||||
failure_rate = module_failures / max(processed_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
@@ -260,8 +248,7 @@ async def perform_many_shot_scan(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
break
|
||||
if injected:
|
||||
break
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
df = pd.DataFrame(
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from agentic_security.models.schemas import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
generate_prompts,
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
generate_prompts,
|
||||
process_prompt,
|
||||
scan_router,
|
||||
)
|
||||
|
||||
@@ -181,3 +184,99 @@ async def test_perform_many_shot_scan_stop_event():
|
||||
)
|
||||
|
||||
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
def mock_refusal_heuristic(response_json):
|
||||
return response_json.get("is_refusal", False)
|
||||
|
||||
|
||||
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_successful_response_no_refusal(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=200,
|
||||
text="Valid response text",
|
||||
json=Mock(return_value={"is_refusal": False}),
|
||||
request="mock_request",
|
||||
)
|
||||
)
|
||||
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=[],
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||
self.assertFalse(refusal)
|
||||
|
||||
async def test_successful_response_with_refusal(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=200,
|
||||
text="Response indicating refusal",
|
||||
json=Mock(return_value={"is_refusal": True}),
|
||||
request="mock_request",
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||
self.assertFalse(refusal)
|
||||
|
||||
async def test_http_error_response(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=500,
|
||||
text="Internal Server Error",
|
||||
request="mock_request",
|
||||
response=Mock(),
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
with self.assertRaises(httpx.HTTPStatusError):
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
side_effect=httpx.RequestError("Connection error")
|
||||
)
|
||||
|
||||
errors = []
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertTrue(refusal)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIn("Connection error", errors[0][2])
|
||||
|
||||
Reference in New Issue
Block a user