mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(fix test gaps):
This commit is contained in:
@@ -4,13 +4,14 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from agentic_security.models.schemas import Scan, ScanResult
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
from loguru import logger
|
||||
from skopt import Optimizer
|
||||
from skopt.space import Real
|
||||
|
||||
from agentic_security.models.schemas import Scan, ScanResult
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
|
||||
async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str, None]:
|
||||
if isinstance(prompts, list):
|
||||
@@ -21,7 +22,7 @@ async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str
|
||||
yield p
|
||||
|
||||
|
||||
async def perform_scan(
|
||||
async def perform_single_shot_scan(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] = [],
|
||||
@@ -132,7 +133,7 @@ async def perform_scan(
|
||||
raise e
|
||||
|
||||
|
||||
async def perform_multi_step_scan(
|
||||
async def perform_many_shot_scan(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] = [],
|
||||
@@ -300,9 +301,8 @@ def scan_router(
|
||||
tools_inbox=None,
|
||||
stop_event: asyncio.Event = None,
|
||||
):
|
||||
|
||||
if scan_parameters.enableMultiStepAttack:
|
||||
return perform_multi_step_scan(
|
||||
return perform_many_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
@@ -312,7 +312,7 @@ def scan_router(
|
||||
stop_event=stop_event,
|
||||
)
|
||||
else:
|
||||
return perform_scan(
|
||||
return perform_single_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..app import app
|
||||
from ..models.schemas import Probe
|
||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||
from ..probe_data import REGISTRY
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_probe_schema_validation():
|
||||
"""Test Probe model validation"""
|
||||
# Valid probe
|
||||
probe = Probe(prompt="test prompt")
|
||||
assert probe.prompt == "test prompt"
|
||||
|
||||
# Invalid probe (missing required field)
|
||||
with pytest.raises(ValueError):
|
||||
Probe()
|
||||
|
||||
|
||||
def test_self_probe_endpoint():
|
||||
"""Test /v1/self-probe endpoint"""
|
||||
response = client.post("/v1/self-probe", json={"prompt": "test prompt"})
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
# Verify OpenAI response structure
|
||||
assert "id" in data
|
||||
assert "object" in data
|
||||
assert "created" in data
|
||||
assert "model" in data
|
||||
assert "usage" in data
|
||||
assert "choices" in data
|
||||
|
||||
# Verify choices structure
|
||||
choices = data["choices"]
|
||||
assert len(choices) == 1
|
||||
assert "message" in choices[0]
|
||||
assert "role" in choices[0]["message"]
|
||||
assert "content" in choices[0]["message"]
|
||||
|
||||
# Verify message content contains the prompt
|
||||
content = choices[0]["message"]["content"]
|
||||
assert "test prompt" in content
|
||||
|
||||
# Verify message is either a refusal or "This is a test!"
|
||||
remaining_text = content.replace("test prompt", "").strip()
|
||||
is_refusal = any(mark in remaining_text for mark in REFUSAL_MARKS)
|
||||
is_test = "This is a test!" in remaining_text
|
||||
assert is_refusal or is_test
|
||||
|
||||
|
||||
def test_self_probe_invalid_input():
|
||||
"""Test /v1/self-probe endpoint with invalid input"""
|
||||
# Missing prompt field
|
||||
response = client.post("/v1/self-probe", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Empty prompt
|
||||
response = client.post("/v1/self-probe", json={"prompt": ""})
|
||||
assert response.status_code == 200 # Empty prompts are allowed by schema
|
||||
|
||||
|
||||
def test_data_config_endpoint():
|
||||
"""Test /v1/data-config endpoint"""
|
||||
response = client.get("/v1/data-config")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == len(REGISTRY)
|
||||
|
||||
# Verify each item in response matches REGISTRY format
|
||||
for item in data:
|
||||
assert isinstance(item, dict)
|
||||
# Add assertions for expected fields based on REGISTRY structure
|
||||
# This will depend on what fields are defined in the REGISTRY items
|
||||
|
||||
|
||||
def test_refusal_rate():
|
||||
"""Test that refusal rate is approximately 20%"""
|
||||
refusal_count = 0
|
||||
total_trials = 1000
|
||||
|
||||
for _ in range(total_trials):
|
||||
response = client.post("/v1/self-probe", json={"prompt": "test"})
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
if any(mark in content for mark in REFUSAL_MARKS):
|
||||
refusal_count += 1
|
||||
|
||||
refusal_rate = refusal_count / total_trials
|
||||
# Allow for some statistical variation (±5%)
|
||||
assert (
|
||||
0.15 <= refusal_rate <= 0.25
|
||||
), f"Refusal rate {refusal_rate} is outside expected range"
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, Mock
|
||||
from pathlib import Path
|
||||
from .report import router
|
||||
|
||||
client = TestClient(router)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csv_exists():
|
||||
with patch.object(Path, "exists") as mock:
|
||||
mock.return_value = True
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csv_not_exists():
|
||||
with patch.object(Path, "exists") as mock:
|
||||
mock.return_value = False
|
||||
yield mock
|
||||
|
||||
|
||||
def test_failures_csv_exists(mock_csv_exists):
|
||||
"""Test /failures endpoint when CSV file exists"""
|
||||
with patch("agentic_security.routes.report.FileResponse") as mock_response:
|
||||
mock_response.return_value = "mocked_response"
|
||||
response = client.get("/failures")
|
||||
assert response.status_code == 200
|
||||
mock_response.assert_called_once_with("failures.csv")
|
||||
|
||||
|
||||
def test_failures_csv_not_exists(mock_csv_not_exists):
|
||||
"""Test /failures endpoint when CSV file doesn't exist"""
|
||||
response = client.get("/failures")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"error": "No failures found"}
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_get_plot():
|
||||
"""Test /plot.jpeg endpoint"""
|
||||
# Mock data matching expected plot_security_report format
|
||||
table_data = [
|
||||
{
|
||||
"module": "SQL Injection",
|
||||
"tokens": 1000,
|
||||
"failureRate": 75.5,
|
||||
},
|
||||
{
|
||||
"module": "XSS Attack",
|
||||
"tokens": 800,
|
||||
"failureRate": 45.2,
|
||||
},
|
||||
{
|
||||
"module": "CSRF Attack",
|
||||
"tokens": 600,
|
||||
"failureRate": 30.8,
|
||||
},
|
||||
]
|
||||
|
||||
# Mock plot_security_report function
|
||||
|
||||
response = client.post("/plot.jpeg", json={"table": table_data})
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/jpeg"
|
||||
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
from ..models.schemas import Settings
|
||||
from .static import router, get_static_file
|
||||
from fastapi import HTTPException
|
||||
|
||||
client = TestClient(router)
|
||||
|
||||
|
||||
def test_root_route():
|
||||
"""Test the root route returns index.html"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
|
||||
def test_main_js_route():
|
||||
"""Test the main.js route"""
|
||||
response = client.get("/main.js")
|
||||
assert response.status_code == 200
|
||||
assert "application/javascript" in response.headers["content-type"]
|
||||
assert "Cache-Control" in response.headers
|
||||
|
||||
|
||||
def test_favicon_route():
|
||||
"""Test the favicon route"""
|
||||
response = client.get("/favicon.ico")
|
||||
assert response.status_code == 200
|
||||
assert "image/x-icon" in response.headers["content-type"]
|
||||
assert "Cache-Control" in response.headers
|
||||
|
||||
|
||||
def test_telemetry_js_route_enabled():
|
||||
"""Test telemetry.js route when telemetry is enabled"""
|
||||
Settings.DISABLE_TELEMETRY = False
|
||||
response = client.get("/telemetry.js")
|
||||
assert response.status_code == 200
|
||||
assert "application/javascript" in response.headers["content-type"]
|
||||
|
||||
|
||||
def test_telemetry_js_route_disabled():
|
||||
"""Test telemetry.js route when telemetry is disabled"""
|
||||
Settings.DISABLE_TELEMETRY = True
|
||||
response = client.get("/telemetry.js")
|
||||
assert response.status_code == 200
|
||||
assert "application/javascript" in response.headers["content-type"]
|
||||
|
||||
|
||||
def test_get_static_file_not_found():
|
||||
"""Test get_static_file with non-existent file"""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_static_file(Path("nonexistent.file"))
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.detail == "File not found"
|
||||
Reference in New Issue
Block a user