diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 1d93d17..8795d39 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -4,13 +4,14 @@ from collections.abc import AsyncGenerator import httpx import pandas as pd -from agentic_security.models.schemas import Scan, ScanResult -from agentic_security.probe_actor.refusal import refusal_heuristic -from agentic_security.probe_data.data import prepare_prompts from loguru import logger from skopt import Optimizer from skopt.space import Real +from agentic_security.models.schemas import Scan, ScanResult +from agentic_security.probe_actor.refusal import refusal_heuristic +from agentic_security.probe_data.data import prepare_prompts + async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str, None]: if isinstance(prompts, list): @@ -21,7 +22,7 @@ async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str yield p -async def perform_scan( +async def perform_single_shot_scan( request_factory, max_budget: int, datasets: list[dict[str, str]] = [], @@ -132,7 +133,7 @@ async def perform_scan( raise e -async def perform_multi_step_scan( +async def perform_many_shot_scan( request_factory, max_budget: int, datasets: list[dict[str, str]] = [], @@ -300,9 +301,8 @@ def scan_router( tools_inbox=None, stop_event: asyncio.Event = None, ): - if scan_parameters.enableMultiStepAttack: - return perform_multi_step_scan( + return perform_many_shot_scan( request_factory=request_factory, max_budget=scan_parameters.maxBudget, datasets=scan_parameters.datasets, @@ -312,7 +312,7 @@ def scan_router( stop_event=stop_event, ) else: - return perform_scan( + return perform_single_shot_scan( request_factory=request_factory, max_budget=scan_parameters.maxBudget, datasets=scan_parameters.datasets, diff --git a/agentic_security/routes/test_probe.py b/agentic_security/routes/test_probe.py new file mode 100644 index 0000000..476b01c --- /dev/null +++ b/agentic_security/routes/test_probe.py @@ -0,0 +1,97 @@ +import pytest +from fastapi.testclient import TestClient + +from ..app import app +from ..models.schemas import Probe +from ..probe_actor.refusal import REFUSAL_MARKS +from ..probe_data import REGISTRY + +client = TestClient(app) + + +def test_probe_schema_validation(): + """Test Probe model validation""" + # Valid probe + probe = Probe(prompt="test prompt") + assert probe.prompt == "test prompt" + + # Invalid probe (missing required field) + with pytest.raises(ValueError): + Probe() + + +def test_self_probe_endpoint(): + """Test /v1/self-probe endpoint""" + response = client.post("/v1/self-probe", json={"prompt": "test prompt"}) + assert response.status_code == 200 + + data = response.json() + # Verify OpenAI response structure + assert "id" in data + assert "object" in data + assert "created" in data + assert "model" in data + assert "usage" in data + assert "choices" in data + + # Verify choices structure + choices = data["choices"] + assert len(choices) == 1 + assert "message" in choices[0] + assert "role" in choices[0]["message"] + assert "content" in choices[0]["message"] + + # Verify message content contains the prompt + content = choices[0]["message"]["content"] + assert "test prompt" in content + + # Verify message is either a refusal or "This is a test!" + remaining_text = content.replace("test prompt", "").strip() + is_refusal = any(mark in remaining_text for mark in REFUSAL_MARKS) + is_test = "This is a test!" in remaining_text + assert is_refusal or is_test + + +def test_self_probe_invalid_input(): + """Test /v1/self-probe endpoint with invalid input""" + # Missing prompt field + response = client.post("/v1/self-probe", json={}) + assert response.status_code == 422 + + # Empty prompt + response = client.post("/v1/self-probe", json={"prompt": ""}) + assert response.status_code == 200 # Empty prompts are allowed by schema + + +def test_data_config_endpoint(): + """Test /v1/data-config endpoint""" + response = client.get("/v1/data-config") + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, list) + assert len(data) == len(REGISTRY) + + # Verify each item in response matches REGISTRY format + for item in data: + assert isinstance(item, dict) + # Add assertions for expected fields based on REGISTRY structure + # This will depend on what fields are defined in the REGISTRY items + + +def test_refusal_rate(): + """Test that refusal rate is approximately 20%""" + refusal_count = 0 + total_trials = 1000 + + for _ in range(total_trials): + response = client.post("/v1/self-probe", json={"prompt": "test"}) + content = response.json()["choices"][0]["message"]["content"] + if any(mark in content for mark in REFUSAL_MARKS): + refusal_count += 1 + + refusal_rate = refusal_count / total_trials + # Allow for some statistical variation (±5%) + assert ( + 0.15 <= refusal_rate <= 0.25 + ), f"Refusal rate {refusal_rate} is outside expected range" diff --git a/agentic_security/routes/test_report.py b/agentic_security/routes/test_report.py new file mode 100644 index 0000000..7166376 --- /dev/null +++ b/agentic_security/routes/test_report.py @@ -0,0 +1,68 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock +from pathlib import Path +from .report import router + +client = TestClient(router) + + +@pytest.fixture +def mock_csv_exists(): + with patch.object(Path, "exists") as mock: + mock.return_value = True + yield mock + + +@pytest.fixture +def mock_csv_not_exists(): + with patch.object(Path, "exists") as mock: + mock.return_value = False + yield mock + + +def test_failures_csv_exists(mock_csv_exists): + """Test /failures endpoint when CSV file exists""" + with patch("agentic_security.routes.report.FileResponse") as mock_response: + mock_response.return_value = "mocked_response" + response = client.get("/failures") + assert response.status_code == 200 + mock_response.assert_called_once_with("failures.csv") + + +def test_failures_csv_not_exists(mock_csv_not_exists): + """Test /failures endpoint when CSV file doesn't exist""" + response = client.get("/failures") + assert response.status_code == 200 + assert response.json() == {"error": "No failures found"} + + +@pytest.mark.skip +def test_get_plot(): + """Test /plot.jpeg endpoint""" + # Mock data matching expected plot_security_report format + table_data = [ + { + "module": "SQL Injection", + "tokens": 1000, + "failureRate": 75.5, + }, + { + "module": "XSS Attack", + "tokens": 800, + "failureRate": 45.2, + }, + { + "module": "CSRF Attack", + "tokens": 600, + "failureRate": 30.8, + }, + ] + + # Mock plot_security_report function + + response = client.post("/plot.jpeg", json={"table": table_data}) + + # Verify response + assert response.status_code == 200 + assert response.headers["content-type"] == "image/jpeg" diff --git a/agentic_security/routes/test_static.py b/agentic_security/routes/test_static.py new file mode 100644 index 0000000..91700f3 --- /dev/null +++ b/agentic_security/routes/test_static.py @@ -0,0 +1,55 @@ +import pytest +from fastapi.testclient import TestClient +from pathlib import Path +from ..models.schemas import Settings +from .static import router, get_static_file +from fastapi import HTTPException + +client = TestClient(router) + + +def test_root_route(): + """Test the root route returns index.html""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_main_js_route(): + """Test the main.js route""" + response = client.get("/main.js") + assert response.status_code == 200 + assert "application/javascript" in response.headers["content-type"] + assert "Cache-Control" in response.headers + + +def test_favicon_route(): + """Test the favicon route""" + response = client.get("/favicon.ico") + assert response.status_code == 200 + assert "image/x-icon" in response.headers["content-type"] + assert "Cache-Control" in response.headers + + +def test_telemetry_js_route_enabled(): + """Test telemetry.js route when telemetry is enabled""" + Settings.DISABLE_TELEMETRY = False + response = client.get("/telemetry.js") + assert response.status_code == 200 + assert "application/javascript" in response.headers["content-type"] + + +def test_telemetry_js_route_disabled(): + """Test telemetry.js route when telemetry is disabled""" + Settings.DISABLE_TELEMETRY = True + response = client.get("/telemetry.js") + assert response.status_code == 200 + assert "application/javascript" in response.headers["content-type"] + + +def test_get_static_file_not_found(): + """Test get_static_file with non-existent file""" + with pytest.raises(HTTPException) as exc_info: + get_static_file(Path("nonexistent.file")) + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "File not found"