feat(fix test gaps):

This commit is contained in:
Alexander Myasoedov
2024-12-02 20:58:57 +02:00
parent b2a12a3a62
commit 10dc91060f
4 changed files with 228 additions and 8 deletions
+8 -8
View File
@@ -4,13 +4,14 @@ from collections.abc import AsyncGenerator
import httpx
import pandas as pd
from agentic_security.models.schemas import Scan, ScanResult
from agentic_security.probe_actor.refusal import refusal_heuristic
from agentic_security.probe_data.data import prepare_prompts
from loguru import logger
from skopt import Optimizer
from skopt.space import Real
from agentic_security.models.schemas import Scan, ScanResult
from agentic_security.probe_actor.refusal import refusal_heuristic
from agentic_security.probe_data.data import prepare_prompts
async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str, None]:
if isinstance(prompts, list):
@@ -21,7 +22,7 @@ async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str
yield p
async def perform_scan(
async def perform_single_shot_scan(
request_factory,
max_budget: int,
datasets: list[dict[str, str]] = [],
@@ -132,7 +133,7 @@ async def perform_scan(
raise e
async def perform_multi_step_scan(
async def perform_many_shot_scan(
request_factory,
max_budget: int,
datasets: list[dict[str, str]] = [],
@@ -300,9 +301,8 @@ def scan_router(
tools_inbox=None,
stop_event: asyncio.Event = None,
):
if scan_parameters.enableMultiStepAttack:
return perform_multi_step_scan(
return perform_many_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
@@ -312,7 +312,7 @@ def scan_router(
stop_event=stop_event,
)
else:
return perform_scan(
return perform_single_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
+97
View File
@@ -0,0 +1,97 @@
import pytest
from fastapi.testclient import TestClient
from ..app import app
from ..models.schemas import Probe
from ..probe_actor.refusal import REFUSAL_MARKS
from ..probe_data import REGISTRY
client = TestClient(app)
def test_probe_schema_validation():
"""Test Probe model validation"""
# Valid probe
probe = Probe(prompt="test prompt")
assert probe.prompt == "test prompt"
# Invalid probe (missing required field)
with pytest.raises(ValueError):
Probe()
def test_self_probe_endpoint():
"""Test /v1/self-probe endpoint"""
response = client.post("/v1/self-probe", json={"prompt": "test prompt"})
assert response.status_code == 200
data = response.json()
# Verify OpenAI response structure
assert "id" in data
assert "object" in data
assert "created" in data
assert "model" in data
assert "usage" in data
assert "choices" in data
# Verify choices structure
choices = data["choices"]
assert len(choices) == 1
assert "message" in choices[0]
assert "role" in choices[0]["message"]
assert "content" in choices[0]["message"]
# Verify message content contains the prompt
content = choices[0]["message"]["content"]
assert "test prompt" in content
# Verify message is either a refusal or "This is a test!"
remaining_text = content.replace("test prompt", "").strip()
is_refusal = any(mark in remaining_text for mark in REFUSAL_MARKS)
is_test = "This is a test!" in remaining_text
assert is_refusal or is_test
def test_self_probe_invalid_input():
"""Test /v1/self-probe endpoint with invalid input"""
# Missing prompt field
response = client.post("/v1/self-probe", json={})
assert response.status_code == 422
# Empty prompt
response = client.post("/v1/self-probe", json={"prompt": ""})
assert response.status_code == 200 # Empty prompts are allowed by schema
def test_data_config_endpoint():
"""Test /v1/data-config endpoint"""
response = client.get("/v1/data-config")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == len(REGISTRY)
# Verify each item in response matches REGISTRY format
for item in data:
assert isinstance(item, dict)
# Add assertions for expected fields based on REGISTRY structure
# This will depend on what fields are defined in the REGISTRY items
def test_refusal_rate():
"""Test that refusal rate is approximately 20%"""
refusal_count = 0
total_trials = 1000
for _ in range(total_trials):
response = client.post("/v1/self-probe", json={"prompt": "test"})
content = response.json()["choices"][0]["message"]["content"]
if any(mark in content for mark in REFUSAL_MARKS):
refusal_count += 1
refusal_rate = refusal_count / total_trials
# Allow for some statistical variation (±5%)
assert (
0.15 <= refusal_rate <= 0.25
), f"Refusal rate {refusal_rate} is outside expected range"
+68
View File
@@ -0,0 +1,68 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch, Mock
from pathlib import Path
from .report import router
client = TestClient(router)
@pytest.fixture
def mock_csv_exists():
with patch.object(Path, "exists") as mock:
mock.return_value = True
yield mock
@pytest.fixture
def mock_csv_not_exists():
with patch.object(Path, "exists") as mock:
mock.return_value = False
yield mock
def test_failures_csv_exists(mock_csv_exists):
"""Test /failures endpoint when CSV file exists"""
with patch("agentic_security.routes.report.FileResponse") as mock_response:
mock_response.return_value = "mocked_response"
response = client.get("/failures")
assert response.status_code == 200
mock_response.assert_called_once_with("failures.csv")
def test_failures_csv_not_exists(mock_csv_not_exists):
"""Test /failures endpoint when CSV file doesn't exist"""
response = client.get("/failures")
assert response.status_code == 200
assert response.json() == {"error": "No failures found"}
@pytest.mark.skip
def test_get_plot():
"""Test /plot.jpeg endpoint"""
# Mock data matching expected plot_security_report format
table_data = [
{
"module": "SQL Injection",
"tokens": 1000,
"failureRate": 75.5,
},
{
"module": "XSS Attack",
"tokens": 800,
"failureRate": 45.2,
},
{
"module": "CSRF Attack",
"tokens": 600,
"failureRate": 30.8,
},
]
# Mock plot_security_report function
response = client.post("/plot.jpeg", json={"table": table_data})
# Verify response
assert response.status_code == 200
assert response.headers["content-type"] == "image/jpeg"
+55
View File
@@ -0,0 +1,55 @@
import pytest
from fastapi.testclient import TestClient
from pathlib import Path
from ..models.schemas import Settings
from .static import router, get_static_file
from fastapi import HTTPException
client = TestClient(router)
def test_root_route():
"""Test the root route returns index.html"""
response = client.get("/")
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
def test_main_js_route():
"""Test the main.js route"""
response = client.get("/main.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
assert "Cache-Control" in response.headers
def test_favicon_route():
"""Test the favicon route"""
response = client.get("/favicon.ico")
assert response.status_code == 200
assert "image/x-icon" in response.headers["content-type"]
assert "Cache-Control" in response.headers
def test_telemetry_js_route_enabled():
"""Test telemetry.js route when telemetry is enabled"""
Settings.DISABLE_TELEMETRY = False
response = client.get("/telemetry.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
def test_telemetry_js_route_disabled():
"""Test telemetry.js route when telemetry is disabled"""
Settings.DISABLE_TELEMETRY = True
response = client.get("/telemetry.js")
assert response.status_code == 200
assert "application/javascript" in response.headers["content-type"]
def test_get_static_file_not_found():
"""Test get_static_file with non-existent file"""
with pytest.raises(HTTPException) as exc_info:
get_static_file(Path("nonexistent.file"))
assert exc_info.value.status_code == 404
assert exc_info.value.detail == "File not found"