Merge branch 'main' of github.com:msoedov/agentic_security

This commit is contained in:
Alexander Myasoedov
2025-03-09 12:33:03 +02:00
8 changed files with 8 additions and 8 deletions
+286
View File
@@ -0,0 +1,286 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import httpx
import pytest
from agentic_security.primitives import Scan
from agentic_security.probe_actor.fuzzer import (
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
process_prompt,
scan_router,
)
@pytest.mark.asyncio
async def test_generate_prompts_with_list():
prompts = ["prompt1", "prompt2", "prompt3"]
results = [p async for p in generate_prompts(prompts)]
assert results == prompts
@pytest.mark.asyncio
async def test_generate_prompts_with_async_generator():
async def async_gen():
for i in range(3):
yield f"prompt{i}"
results = [p async for p in generate_prompts(async_gen())]
assert results == ["prompt0", "prompt1", "prompt2"]
async def assert_scan(generator, messages):
results = [r async for r in generator]
for m in messages:
found = False
for r in results:
if m in r:
found = True
break
assert found, f"Message '{m}' not found in results. Results: {results}"
return results
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_perform_single_shot_scan_success(prepare_prompts_mock):
# Mock prompt modules
prepare_prompts_mock.return_value = [
MagicMock(
dataset_name="test_module",
prompts=["test_prompt1", "test_prompt2"],
lazy=False,
)
]
# Mock request_factory
mock_response = AsyncMock()
mock_response.fn.return_value = AsyncMock(
status_code=200, text="response text", json=lambda: {}
)
request_factory = mock_response
async_gen = perform_single_shot_scan(
request_factory=request_factory,
max_budget=100,
datasets=[{"dataset_name": "test", "selected": True}],
optimize=False,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_perform_many_shot_scan_probe_injection(prepare_prompts_mock):
# Mock main and probe prompt modules
prepare_prompts_mock.side_effect = [
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
]
# Mock request_factory
mock_response = AsyncMock()
mock_response.fn.side_effect = [
AsyncMock(status_code=200, text="main response", json=lambda: {}),
AsyncMock(status_code=200, text="probe response", json=lambda: {}),
]
request_factory = mock_response
async_gen = perform_many_shot_scan(
request_factory=request_factory,
max_budget=100,
datasets=[{"dataset_name": "main", "selected": True}],
probe_datasets=[{"dataset_name": "probe", "selected": True}],
probe_frequency=1.0, # Always inject probes
optimize=False,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_scan_router_single_shot(prepare_prompts_mock):
prepare_prompts_mock.return_value = []
request_factory = AsyncMock()
scan_params = Scan(
maxBudget=100,
llmSpec="test",
datasets=[],
probe_datasets=[],
enableMultiStepAttack=False,
optimize=False,
)
gen = scan_router(
request_factory=request_factory,
scan_parameters=scan_params,
)
await assert_scan(gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_scan_router_many_shot(prepare_prompts_mock):
prepare_prompts_mock.return_value = []
request_factory = AsyncMock()
scan_params = Scan(
maxBudget=100,
datasets=[],
llmSpec="test",
probeDatasets=[],
enableMultiStepAttack=True,
optimize=False,
)
gen = scan_router(
request_factory=request_factory,
scan_parameters=scan_params,
)
assert gen is not None
await assert_scan(gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
async def test_perform_single_shot_scan_stop_event():
stop_event = asyncio.Event()
stop_event.set() # Pre-set to simulate user stopping the scan
async def request_mock(*args, **kwargs):
return AsyncMock(status_code=200, text="response text", json=lambda: {})
async_gen = perform_single_shot_scan(
request_factory=MagicMock(fn=request_mock),
max_budget=100,
datasets=[],
stop_event=stop_event,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
async def test_perform_many_shot_scan_stop_event():
stop_event = asyncio.Event()
stop_event.set() # Pre-set to simulate user stopping the scan
async def request_mock(*args, **kwargs):
return AsyncMock(status_code=200, text="response text", json=lambda: {})
async_gen = perform_many_shot_scan(
request_factory=MagicMock(fn=request_mock),
max_budget=100,
datasets=[],
probe_datasets=[],
stop_event=stop_event,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
def mock_refusal_heuristic(response_json):
return response_json.get("is_refusal", False)
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
async def test_successful_response_no_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Valid response text",
json=Mock(return_value={"is_refusal": False}),
request="mock_request",
)
)
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=[],
outputs=[],
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
self.assertTrue(refusal)
async def test_successful_response_with_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Response indicating refusal",
json=Mock(return_value={"is_refusal": True}),
request="mock_request",
)
)
refusals = []
outputs = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=outputs,
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
self.assertFalse(refusal)
async def test_http_error_response(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=500,
text="Internal Server Error",
request="mock_request",
response=Mock(),
)
)
refusals = []
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=refusals,
errors=[],
outputs=[],
)
async def test_request_error(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
side_effect=httpx.RequestError("Connection error")
)
errors = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
refusals=[],
errors=errors,
outputs=[],
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
self.assertEqual(len(errors), 1)
self.assertIn("Connection error", errors[0][3])
+13
View File
@@ -0,0 +1,13 @@
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier
class TestCheckRefusal:
# The function correctly identifies a refusal phrase in the response.
def test_identify_refusal_phrase(self):
response = "I'm sorry, but I cannot provide that information."
assert DefaultRefusalClassifier().is_refusal(response)
# The response is an empty string.
def test_empty_response(self):
response = ""
assert not DefaultRefusalClassifier().is_refusal(response)
View File
+22
View File
@@ -0,0 +1,22 @@
from fastapi.testclient import TestClient
import agentic_security.test_spec_assets as test_spec_assets
from agentic_security.routes.scan import router
client = TestClient(router)
def test_upload_csv_and_run():
# Create a sample CSV content
csv_content = "id,prompt\nspec1,value1\nspec2,value3"
# Send a POST request to the /upload-csv endpoint
response = client.post(
"/scan-csv?optimize=false&enableMultiStepAttack=false&maxBudget=1000",
files={
"file": ("test.csv", csv_content, "text/csv"),
"llmSpec": ("spec.txt", test_spec_assets.SAMPLE_SPEC, "text/plain"),
},
)
assert response.status_code == 200
assert "Scan completed." in response.text
+12
View File
@@ -0,0 +1,12 @@
from fastapi.testclient import TestClient
from agentic_security.app import app
def test_health_check():
"""Test the health check endpoint."""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
+218
View File
@@ -0,0 +1,218 @@
import base64
import io
import httpx
import pytest
from fastapi.testclient import TestClient
from agentic_security.app import app
from agentic_security.primitives import Probe
from agentic_security.probe_actor.refusal import REFUSAL_MARKS
from agentic_security.probe_data import REGISTRY
client = TestClient(app)
def test_probe_schema_validation():
"""Test Probe model validation"""
# Valid probe
probe = Probe(prompt="test prompt")
assert probe.prompt == "test prompt"
# Invalid probe (missing required field)
with pytest.raises(ValueError):
Probe()
def test_self_probe_endpoint():
"""Test /v1/self-probe endpoint"""
response = client.post("/v1/self-probe", json={"prompt": "test prompt"})
assert response.status_code == 200
data = response.json()
# Verify OpenAI response structure
assert "id" in data
assert "object" in data
assert "created" in data
assert "model" in data
assert "usage" in data
assert "choices" in data
# Verify choices structure
choices = data["choices"]
assert len(choices) == 1
assert "message" in choices[0]
assert "role" in choices[0]["message"]
assert "content" in choices[0]["message"]
# Verify message content contains the prompt
content = choices[0]["message"]["content"]
assert "test prompt" in content
# Verify message is either a refusal or "This is a test!"
remaining_text = content.replace("test prompt", "").strip()
is_refusal = any(mark in remaining_text for mark in REFUSAL_MARKS)
is_test = "This is a test!" in remaining_text
assert is_refusal or is_test
def test_self_probe_invalid_input():
"""Test /v1/self-probe endpoint with invalid input"""
# Missing prompt field
response = client.post("/v1/self-probe", json={})
assert response.status_code == 422
# Empty prompt
response = client.post("/v1/self-probe", json={"prompt": ""})
assert response.status_code == 200 # Empty prompts are allowed by schema
def test_data_config_endpoint():
"""Test /v1/data-config endpoint"""
response = client.get("/v1/data-config")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == len(REGISTRY)
# Verify each item in response matches REGISTRY format
for item in data:
assert isinstance(item, dict)
# Add assertions for expected fields based on REGISTRY structure
# This will depend on what fields are defined in the REGISTRY items
def test_refusal_rate():
"""Test that refusal rate is approximately 20%"""
refusal_count = 0
total_trials = 1000
for _ in range(total_trials):
response = client.post("/v1/self-probe", json={"prompt": "test"})
content = response.json()["choices"][0]["message"]["content"]
if any(mark in content for mark in REFUSAL_MARKS):
refusal_count += 1
refusal_rate = refusal_count / total_trials
# Allow for some statistical variation (±5%)
assert (
0.15 <= refusal_rate <= 0.25
), f"Refusal rate {refusal_rate} is outside expected range"
def test_self_probe_file_endpoint():
"""Test /v1/self-probe-file endpoint with valid input"""
# Create a mock audio file
file_content = b"mock audio content"
file = io.BytesIO(file_content)
files = {"file": ("test.m4a", file, "audio/m4a")}
headers = {"Authorization": "Bearer test_api_key"}
response = client.post(
"/v1/self-probe-file",
files=files,
headers=headers,
data={"model": "whisper-large-v3"},
)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "model" in data
assert data["model"] == "whisper-large-v3"
def test_self_probe_file_invalid_auth():
"""Test /v1/self-probe-file endpoint with invalid authorization"""
file_content = b"mock audio content"
file = io.BytesIO(file_content)
files = {"file": ("test.m4a", file, "audio/m4a")}
# Test missing auth header
response = client.post("/v1/self-probe-file", files=files)
assert response.status_code == 422
# Test invalid auth format
headers = {"Authorization": "InvalidFormat test_api_key"}
response = client.post("/v1/self-probe-file", files=files, headers=headers)
assert response.status_code == 401
# Test empty token
headers = {"Authorization": "Bearer "}
response = client.post("/v1/self-probe-file", files=files, headers=headers)
assert response.status_code == 401
def test_self_probe_file_invalid_format():
"""Test /v1/self-probe-file endpoint with invalid file format"""
file_content = b"mock content"
file = io.BytesIO(file_content)
files = {"file": ("test.txt", file, "text/plain")}
headers = {"Authorization": "Bearer test_api_key"}
response = client.post(
"/v1/self-probe-file",
files=files,
headers=headers,
data={"model": "whisper-large-v3"},
)
assert response.status_code == 400
assert "Invalid file format" in response.json()["detail"]
def test_self_probe_file_missing_file():
"""Test /v1/self-probe-file endpoint with missing file"""
headers = {"Authorization": "Bearer test_api_key"}
response = client.post(
"/v1/self-probe-file",
headers=headers,
data={"model": "whisper-large-v3"},
)
assert response.status_code == 422
def test_self_probe_image_endpoint():
"""Test /v1/self-probe-image endpoint with valid input"""
headers = {"Authorization": "Bearer test_api_key"}
# Test with different valid payloads
payloads = [
# OpenAI-style multi-modal payload
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {"url": encode_image_base64_by_url()},
},
],
}
],
# Simple text payload
{"message": "Test message"},
# Nested payload
{"level1": {"level2": "test"}},
# Empty object
{},
# Empty array
[],
]
for payload in payloads:
response = client.post("/v1/self-probe-image", json=payload, headers=headers)
assert response.status_code == 200, (payload, response.json())
data = response.json()
assert "choices" in data
assert len(data["choices"]) == 1
assert "message" in data["choices"][0]
def encode_image_base64_by_url(url: str = "https://github.com/fluidicon.png") -> str:
"""Encode image data to base64 from a URL"""
response = httpx.get(url)
encoded_content = base64.b64encode(response.content).decode("utf-8")
return "data:image/jpeg;base64," + encoded_content
+70
View File
@@ -0,0 +1,70 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from agentic_security.routes.report import router
client = TestClient(router)
@pytest.fixture
def mock_csv_exists():
with patch.object(Path, "exists") as mock:
mock.return_value = True
yield mock
@pytest.fixture
def mock_csv_not_exists():
with patch.object(Path, "exists") as mock:
mock.return_value = False
yield mock
def test_failures_csv_exists(mock_csv_exists):
"""Test /failures endpoint when CSV file exists"""
with patch("agentic_security.routes.report.FileResponse") as mock_response:
mock_response.return_value = "mocked_response"
response = client.get("/failures")
assert response.status_code == 200
mock_response.assert_called_once_with("failures.csv")
def test_failures_csv_not_exists(mock_csv_not_exists):
"""Test /failures endpoint when CSV file doesn't exist"""
response = client.get("/failures")
assert response.status_code == 200
assert response.json() == {"error": "No failures found"}
@pytest.mark.skip
def test_get_plot():
"""Test /plot.jpeg endpoint"""
# Mock data matching expected plot_security_report format
table_data = [
{
"module": "SQL Injection",
"tokens": 1000,
"failureRate": 75.5,
},
{
"module": "XSS Attack",
"tokens": 800,
"failureRate": 45.2,
},
{
"module": "CSRF Attack",
"tokens": 600,
"failureRate": 30.8,
},
]
# Mock plot_security_report function
response = client.post("/plot.jpeg", json={"table": table_data})
# Verify response
assert response.status_code == 200
assert response.headers["content-type"] == "image/jpeg"
+57
View File
@@ -0,0 +1,57 @@
from pathlib import Path
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from agentic_security.primitives import Settings
from agentic_security.routes.static import get_static_file, router
client = TestClient(router)
def test_root_route():
"""Test the root route returns index.html"""
response = client.get("/")
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
def test_main_js_route():
"""Test the main.js route"""
response = client.get("/main.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
assert "Cache-Control" in response.headers
def test_favicon_route():
"""Test the favicon route"""
response = client.get("/favicon.ico")
assert response.status_code == 200
assert "image/x-icon" in response.headers["content-type"]
assert "Cache-Control" in response.headers
def test_telemetry_js_route_enabled():
"""Test telemetry.js route when telemetry is enabled"""
Settings.DISABLE_TELEMETRY = False
response = client.get("/telemetry.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
def test_telemetry_js_route_disabled():
"""Test telemetry.js route when telemetry is disabled"""
Settings.DISABLE_TELEMETRY = True
response = client.get("/telemetry.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
def test_get_static_file_not_found():
"""Test get_static_file with non-existent file"""
with pytest.raises(HTTPException) as exc_info:
get_static_file(Path("nonexistent.file"))
assert exc_info.value.status_code == 404
assert exc_info.value.detail == "File not found"