mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 22:29:56 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 317fd33480 | |||
| 56e3c91af2 | |||
| 594f8960e8 | |||
| 51a9b5de5f | |||
| 0a555b8427 | |||
| aa27817f94 | |||
| 8bd76b9f05 | |||
| 6f3c522d59 | |||
| 896ca95ae2 | |||
| f85c77d622 | |||
| 684ba0b70d | |||
| 21b43b18e7 | |||
| d20c1a3d0d | |||
| ebac62e21a |
@@ -17,3 +17,4 @@ inv/
|
||||
scripts/
|
||||
docx/
|
||||
agentic_security.toml
|
||||
/venv
|
||||
@@ -6,12 +6,30 @@ from agentic_security.core.app import expand_secrets
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals():
|
||||
"""
|
||||
Reset globals (_secrets, current_run, tools_inbox, stop_event) before each test.
|
||||
This ensures tests run in a clean state.
|
||||
"""
|
||||
from agentic_security.core.app import _secrets, current_run, get_tools_inbox, get_stop_event
|
||||
_secrets.clear()
|
||||
current_run["spec"] = ""
|
||||
current_run["id"] = ""
|
||||
# Clear tools_inbox queue
|
||||
queue = get_tools_inbox()
|
||||
while not queue.empty():
|
||||
queue.get_nowait()
|
||||
# Reset stop_event if it is set
|
||||
event = get_stop_event()
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
def setup_env_vars():
|
||||
# Set up environment variables for testing
|
||||
os.environ["TEST_ENV_VAR"] = "test_value"
|
||||
|
||||
|
||||
def test_expand_secrets_with_env_var():
|
||||
os.environ["TEST_ENV_VAR"] = "test_value"
|
||||
secrets = {"secret_key": "$TEST_ENV_VAR"}
|
||||
expand_secrets(secrets)
|
||||
assert secrets["secret_key"] == "test_value"
|
||||
@@ -27,3 +45,180 @@ def test_expand_secrets_without_dollar_sign():
|
||||
secrets = {"secret_key": "plain_value"}
|
||||
expand_secrets(secrets)
|
||||
assert secrets["secret_key"] == "plain_value"
|
||||
|
||||
import asyncio
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from agentic_security.core.app import create_app, get_tools_inbox, get_stop_event, get_current_run, set_current_run, get_secrets, set_secrets, expand_secrets
|
||||
|
||||
class DummyLLMSpec:
|
||||
"""A dummy LLMSpec for testing purposes."""
|
||||
pass
|
||||
|
||||
def test_create_app():
|
||||
"""Test that create_app returns a FastAPI app with ORJSONResponse."""
|
||||
app = create_app()
|
||||
assert isinstance(app, FastAPI)
|
||||
assert app.router.default_response_class == ORJSONResponse
|
||||
|
||||
def test_get_tools_inbox():
|
||||
"""Test that get_tools_inbox returns a Queue instance."""
|
||||
queue = get_tools_inbox()
|
||||
from asyncio import Queue
|
||||
assert isinstance(queue, Queue)
|
||||
|
||||
def test_get_stop_event():
|
||||
"""Test that get_stop_event returns an Event instance."""
|
||||
event = get_stop_event()
|
||||
from asyncio import Event
|
||||
assert isinstance(event, Event)
|
||||
|
||||
def test_get_current_run_initial():
|
||||
"""Test that get_current_run returns the initial current run dictionary."""
|
||||
current = get_current_run()
|
||||
# The initial dictionary should have an empty spec and id.
|
||||
assert current["spec"] == ""
|
||||
assert current["id"] == ""
|
||||
|
||||
def test_set_current_run():
|
||||
"""Test that set_current_run updates the current run with the dummy LLMSpec."""
|
||||
dummy_spec = DummyLLMSpec()
|
||||
updated = set_current_run(dummy_spec)
|
||||
assert updated["spec"] is dummy_spec
|
||||
# Ensure that the id is computed as hash(id(dummy_spec))
|
||||
expected_id = hash(id(dummy_spec))
|
||||
assert updated["id"] == expected_id
|
||||
|
||||
def test_get_and_set_secrets():
|
||||
"""Test that set_secrets updates the secrets dictionary and get_secrets returns the updated values."""
|
||||
# Clear any previously set secrets
|
||||
secrets_before = get_secrets().copy()
|
||||
os.environ["MY_SECRET"] = "secret_value"
|
||||
new_secrets = {"key1": "$MY_SECRET", "key2": "plain"}
|
||||
updated = set_secrets(new_secrets)
|
||||
assert updated["key1"] == "secret_value"
|
||||
assert updated["key2"] == "plain"
|
||||
|
||||
def test_expand_secrets_multiple_keys():
|
||||
"""Test expand_secrets with multiple keys, including one with an environment variable,
|
||||
one with a non-existent variable, and one that is plain."""
|
||||
os.environ["TEST_ENV_VAR"] = "test_value"
|
||||
secrets = {"env_key": "$TEST_ENV_VAR", "nonexistent_key": "$NON_EXISTENT", "plain_key": "value"}
|
||||
expand_secrets(secrets)
|
||||
assert secrets["env_key"] == "test_value"
|
||||
# For a non-existent environment variable, os.getenv returns None
|
||||
assert secrets["nonexistent_key"] is None
|
||||
# Plain values should not be changed.
|
||||
assert secrets["plain_key"] == "value"
|
||||
def test_expand_secrets_with_space_after_dollar():
|
||||
"""Test expand_secrets when the value has a dollar sign followed by a space.
|
||||
Since the value does not start strictly with "$", the secret remains unchanged.
|
||||
Also verifies that the stripping in expand_secrets (via strip("$"))
|
||||
will remove both dollar and any whitespace if the value actually started with '$'.
|
||||
"""
|
||||
os.environ["SPACED_VAR"] = "spaced_value"
|
||||
secrets = {"key": "$ SPACED_VAR"}
|
||||
expand_secrets(secrets)
|
||||
# " $ SPACED_VAR" after strip("$") becomes " SPACED_VAR" which is not a valid env key so returns None.
|
||||
assert secrets["key"] is None
|
||||
|
||||
def test_set_secrets_update_existing():
|
||||
"""Test that set_secrets updates an existing secret and retains previously set keys."""
|
||||
os.environ["VAR1"] = "value1"
|
||||
os.environ["VAR2"] = "value2"
|
||||
result_first = set_secrets({"a": "$VAR1", "b": "b_val"})
|
||||
assert result_first["a"] == "value1"
|
||||
# Change VAR1 in environment and update secret "a", and add secret "c"
|
||||
os.environ["VAR1"] = "new_value1"
|
||||
result_second = set_secrets({"a": "$VAR1", "c": "$VAR2"})
|
||||
assert result_second["a"] == "new_value1"
|
||||
assert result_second["b"] == "b_val"
|
||||
assert result_second["c"] == "value2"
|
||||
|
||||
def test_tools_inbox_state():
|
||||
"""Test that get_tools_inbox returns the same queue instance
|
||||
and that the queue state persists across multiple calls.
|
||||
"""
|
||||
from asyncio import Queue
|
||||
inbox1 = get_tools_inbox()
|
||||
inbox1.put_nowait("message")
|
||||
inbox2 = get_tools_inbox()
|
||||
# inbox2 should contain the "message" from inbox1
|
||||
msg = inbox2.get_nowait()
|
||||
assert msg == "message"
|
||||
|
||||
def test_stop_event_state():
|
||||
"""Test that stop_event can be set and cleared, and its state persists."""
|
||||
event = get_stop_event()
|
||||
# Initially the event should not be set
|
||||
assert not event.is_set()
|
||||
event.set()
|
||||
assert event.is_set()
|
||||
event.clear()
|
||||
assert not event.is_set()
|
||||
|
||||
def test_set_current_run_returns_global_dict():
|
||||
"""Test that set_current_run returns the same global current_run dictionary
|
||||
as returned by get_current_run.
|
||||
"""
|
||||
dummy_spec = DummyLLMSpec()
|
||||
updated = set_current_run(dummy_spec)
|
||||
current = get_current_run()
|
||||
assert updated is current
|
||||
def test_get_secrets_initial():
|
||||
"""Test that get_secrets returns an empty dictionary initially."""
|
||||
assert get_secrets() == {}
|
||||
|
||||
def test_set_secrets_empty():
|
||||
"""Test that setting an empty secrets dictionary does not modify existing secrets."""
|
||||
# first set initial secrets
|
||||
initial = {"key": "value"}
|
||||
set_secrets(initial)
|
||||
# update with an empty dict – the existing keys remain
|
||||
result = set_secrets({})
|
||||
assert result == initial
|
||||
|
||||
def test_update_current_run_twice():
|
||||
"""Test updating current run twice with different LLMSpec values."""
|
||||
dummy1 = DummyLLMSpec()
|
||||
dummy2 = DummyLLMSpec()
|
||||
set_current_run(dummy1)
|
||||
first = get_current_run().copy()
|
||||
set_current_run(dummy2)
|
||||
second = get_current_run().copy()
|
||||
# first update should hold dummy1, second should hold dummy2
|
||||
assert first["spec"] is dummy1
|
||||
assert second["spec"] is dummy2
|
||||
# Ensure that id has changed (using hash(id(dummy_spec)))
|
||||
assert first["id"] != second["id"]
|
||||
|
||||
def test_expand_secrets_trailing_whitespace():
|
||||
"""Test expand_secrets when the secret value has trailing whitespace after the dollar sign.
|
||||
The trailing whitespace remains after stripping only the dollar sign, so the looked-up environment variable key will not match.
|
||||
"""
|
||||
os.environ["TRIM_TEST"] = "trimmed"
|
||||
secrets = {"key": "$TRIM_TEST "}
|
||||
expand_secrets(secrets)
|
||||
# Since "TRIM_TEST " (with trailing space) is not set in the environment, the secret should be None.
|
||||
assert secrets["key"] is None
|
||||
def test_expand_secrets_empty_dict():
|
||||
"""Test expand_secrets with an empty dictionary does nothing."""
|
||||
secrets = {}
|
||||
expand_secrets(secrets)
|
||||
assert secrets == {}
|
||||
|
||||
def test_expand_secrets_with_non_string_value():
|
||||
"""Test that expand_secrets raises an AttributeError when a secret value is not a string."""
|
||||
secrets = {"key": 123}
|
||||
with pytest.raises(AttributeError):
|
||||
expand_secrets(secrets)
|
||||
|
||||
def test_expand_secrets_multiple_dollar_signs():
|
||||
"""Test expand_secrets with a value that contains multiple leading dollar signs.
|
||||
The extra dollar signs are removed by the strip method.
|
||||
"""
|
||||
os.environ["MULTI_DOLLAR_VAR"] = "multi_value"
|
||||
secrets = {"key": "$$MULTI_DOLLAR_VAR"}
|
||||
expand_secrets(secrets)
|
||||
# After stripping, "$$MULTI_DOLLAR_VAR".strip("$") returns "MULTI_DOLLAR_VAR"
|
||||
assert secrets["key"] == "multi_value"
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from agentic_security.refusal_classifier import RefusalClassifier
|
||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
||||
|
||||
classifier = RefusalClassifier()
|
||||
classifier.load_model()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic, RefusalClassifierManager, RefusalClassifierPlugin
|
||||
|
||||
|
||||
class TestCheckRefusal:
|
||||
@@ -11,3 +12,40 @@ class TestCheckRefusal:
|
||||
def test_empty_response(self):
|
||||
response = ""
|
||||
assert not DefaultRefusalClassifier().is_refusal(response)
|
||||
def test_non_refusal_phrase(self):
|
||||
"""Test that a response without refusal phrases returns False."""
|
||||
response = "Here is some helpful information."
|
||||
assert not DefaultRefusalClassifier().is_refusal(response)
|
||||
|
||||
def test_refusal_heuristic(self):
|
||||
"""Test that the refusal_heuristic function correctly identifies refusal phrases."""
|
||||
response = "Sorry, I cannot assist with your request."
|
||||
# refusal_heuristic will cast the argument to a string and use the plugin manager to check for refusal markers
|
||||
assert refusal_heuristic(response)
|
||||
|
||||
def test_plugin_manager_with_dummy_plugin(self):
|
||||
"""Test that RefusalClassifierManager returns True if any registered plugin returns True."""
|
||||
class AlwaysFalsePlugin(RefusalClassifierPlugin):
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
return False
|
||||
|
||||
class AlwaysTruePlugin(RefusalClassifierPlugin):
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
return True
|
||||
|
||||
manager = RefusalClassifierManager()
|
||||
manager.register_plugin("false", AlwaysFalsePlugin())
|
||||
manager.register_plugin("true", AlwaysTruePlugin())
|
||||
response = "Any response text"
|
||||
assert manager.is_refusal(response)
|
||||
|
||||
def test_default_classifier_with_empty_phrases(self):
|
||||
"""Test that DefaultRefusalClassifier returns False when provided with an empty refusal phrase list."""
|
||||
classifier = DefaultRefusalClassifier(refusal_phrases=[])
|
||||
response = "I do not have any info."
|
||||
assert not classifier.is_refusal(response)
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test that string matching is case-sensitive."""
|
||||
response = "i'm sorry, but can you help me?" # lower-case "i'm sorry" does not match "I'm sorry" and no extra refusal phrases are present
|
||||
assert not DefaultRefusalClassifier().is_refusal(response)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from: python-pytest-poetry
|
||||
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
|
||||
Generated
+36
-2
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -786,6 +786,20 @@ files = [
|
||||
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "execnet"
|
||||
version = "2.1.1"
|
||||
description = "execnet: rapid multi-Python deployment"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
|
||||
{file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
testing = ["hatch", "pre-commit", "pytest", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "executing"
|
||||
version = "2.2.0"
|
||||
@@ -3244,6 +3258,26 @@ pytest = ">=6.2.5"
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-xdist"
|
||||
version = "3.6.1"
|
||||
description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"},
|
||||
{file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
execnet = ">=2.1"
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
psutil = ["psutil (>=3.0)"]
|
||||
setproctitle = ["setproctitle"]
|
||||
testing = ["filelock"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -4447,4 +4481,4 @@ propcache = ">=0.2.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "28a2b74bfafa9f93d14d2f8d1fcaffa340db212acce6469d6714d342203ad77f"
|
||||
content-hash = "35e03dba41d30cf6129a4a4f3107eca560f779205b21d3ffb2871eeffc5d5a64"
|
||||
|
||||
+2
-1
@@ -54,6 +54,7 @@ pyfiglet = "^1.0.2"
|
||||
termcolor = "^2.4.0"
|
||||
|
||||
# garak = { version = "*", optional = true }
|
||||
pytest-xdist = "3.6.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -86,7 +87,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--durations=5 -m 'not slow'"
|
||||
addopts = "--durations=5 -m 'not slow' -n auto"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
markers = "slow: marks tests as slow"
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier
|
||||
|
||||
|
||||
class TestCheckRefusal:
|
||||
# The function correctly identifies a refusal phrase in the response.
|
||||
def test_identify_refusal_phrase(self):
|
||||
response = "I'm sorry, but I cannot provide that information."
|
||||
assert DefaultRefusalClassifier().is_refusal(response)
|
||||
|
||||
# The response is an empty string.
|
||||
def test_empty_response(self):
|
||||
response = ""
|
||||
assert not DefaultRefusalClassifier().is_refusal(response)
|
||||
+1
-1
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from .model import RefusalClassifier
|
||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..app import app
|
||||
from agentic_security.app import app
|
||||
|
||||
|
||||
def test_health_check():
|
||||
@@ -5,10 +5,10 @@ import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..app import app
|
||||
from ..primitives import Probe
|
||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||
from ..probe_data import REGISTRY
|
||||
from agentic_security.app import app
|
||||
from agentic_security.primitives import Probe
|
||||
from agentic_security.probe_actor.refusal import REFUSAL_MARKS
|
||||
from agentic_security.probe_data import REGISTRY
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from .report import router
|
||||
from agentic_security.routes.report import router
|
||||
|
||||
client = TestClient(router)
|
||||
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..primitives import Settings
|
||||
from .static import get_static_file, router
|
||||
from agentic_security.primitives import Settings
|
||||
from agentic_security.routes.static import get_static_file, router
|
||||
|
||||
client = TestClient(router)
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import io
|
||||
import string
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from agentic_security.report_chart import plot_security_report, generate_identifiers
|
||||
|
||||
class TestReportChart:
|
||||
"""Test suite for agentic_security.report_chart module."""
|
||||
|
||||
def test_generate_identifiers_short(self):
|
||||
"""Test generate_identifiers with a small dataset."""
|
||||
df = pd.DataFrame([{'dummy': i} for i in range(5)])
|
||||
identifiers = generate_identifiers(df)
|
||||
expected = ['A1', 'A2', 'A3', 'A4', 'A5']
|
||||
assert identifiers == expected
|
||||
|
||||
def test_generate_identifiers_edge(self):
|
||||
"""Test generate_identifiers with more than 26 items to cover cycling over the alphabet."""
|
||||
n = 30
|
||||
df = pd.DataFrame([{'dummy': i} for i in range(n)])
|
||||
identifiers = generate_identifiers(df)
|
||||
# For i=25, identifier should be A26, and for i=26, identifier should be B1
|
||||
assert identifiers[25] == 'A26'
|
||||
assert identifiers[26] == 'B1'
|
||||
assert len(identifiers) == n
|
||||
|
||||
def test_generate_identifiers_empty(self):
|
||||
"""Test generate_identifiers with an empty dataframe."""
|
||||
df = pd.DataFrame([])
|
||||
identifiers = generate_identifiers(df)
|
||||
assert identifiers == []
|
||||
|
||||
def test_plot_security_report_png_output(self):
|
||||
"""Test plot_security_report returns valid PNG output."""
|
||||
# Create a sample table with required columns
|
||||
table = [
|
||||
{"failureRate": 10, "tokens": 100, "module": "Module1"},
|
||||
{"failureRate": 30, "tokens": 200, "module": "Module2"},
|
||||
{"failureRate": 20, "tokens": 150, "module": "Module3"},
|
||||
]
|
||||
buf = plot_security_report(table)
|
||||
# Check that buf is a BytesIO object and starts with PNG header bytes
|
||||
assert isinstance(buf, io.BytesIO)
|
||||
buf.seek(0)
|
||||
header = buf.read(8)
|
||||
assert header.startswith(b'\x89PNG')
|
||||
|
||||
def test_plot_security_report_ordering(self, monkeypatch):
|
||||
"""Test that the table embedded in the plot contains correctly sorted order by descending failure rate."""
|
||||
table = [
|
||||
{"failureRate": 15, "tokens": 110, "module": "ModuleA"},
|
||||
{"failureRate": 25, "tokens": 210, "module": "ModuleB"},
|
||||
{"failureRate": 5, "tokens": 90, "module": "ModuleC"},
|
||||
]
|
||||
result_holder = {}
|
||||
from matplotlib.axes import Axes
|
||||
original_table = Axes.table
|
||||
def fake_table(self, *args, **kwargs):
|
||||
result_holder['cellText'] = kwargs.get('cellText')
|
||||
return original_table(self, *args, **kwargs)
|
||||
monkeypatch.setattr(Axes, "table", fake_table)
|
||||
plot_security_report(table)
|
||||
cell_text = result_holder.get('cellText')
|
||||
assert cell_text is not None
|
||||
# Verify header row in the table
|
||||
assert cell_text[0] == ["Threat"]
|
||||
# Since the data are sorted (highest failure rate first), ModuleB (25.0%) should appear in one of the rows.
|
||||
found = any("ModuleB (25.0%)" in row[0] for row in cell_text[1:])
|
||||
assert found
|
||||
|
||||
def test_plot_security_report_one_entry(self):
|
||||
"""Test plot_security_report with a single entry."""
|
||||
table = [{"failureRate": 50, "tokens": 300, "module": "OnlyModule"}]
|
||||
buf = plot_security_report(table)
|
||||
assert isinstance(buf, io.BytesIO)
|
||||
buf.seek(0)
|
||||
content = buf.read()
|
||||
assert content.startswith(b'\x89PNG')
|
||||
def test_generate_identifiers_many(self):
|
||||
"""Test generate_identifiers with 52 items to verify identifier sequence."""
|
||||
n = 52
|
||||
df = pd.DataFrame([{'dummy': i} for i in range(n)])
|
||||
identifiers = generate_identifiers(df)
|
||||
assert identifiers[0] == "A1"
|
||||
assert identifiers[25] == "A26"
|
||||
assert identifiers[26] == "B1"
|
||||
assert identifiers[51] == "B26"
|
||||
|
||||
def test_plot_security_report_missing_failureRate(self):
|
||||
"""Test plot_security_report raises KeyError when 'failureRate' column is missing."""
|
||||
table = [{"tokens": 100, "module": "Mod1"}] # Missing 'failureRate'
|
||||
with pytest.raises(KeyError):
|
||||
plot_security_report(table)
|
||||
|
||||
def test_plot_security_report_missing_tokens(self):
|
||||
"""Test plot_security_report raises KeyError when 'tokens' column is missing."""
|
||||
table = [{"failureRate": 10, "module": "Mod1"}] # Missing 'tokens'
|
||||
with pytest.raises(KeyError):
|
||||
plot_security_report(table)
|
||||
|
||||
def test_plot_security_report_empty_table(self):
|
||||
"""Test plot_security_report raises KeyError when the table is empty."""
|
||||
table = []
|
||||
with pytest.raises(KeyError):
|
||||
plot_security_report(table)
|
||||
def test_plot_security_report_missing_module(self):
|
||||
"""Test plot_security_report raises KeyError when 'module' column is missing."""
|
||||
table = [{"failureRate": 10, "tokens": 100}] # Missing 'module'
|
||||
with pytest.raises(KeyError):
|
||||
plot_security_report(table)
|
||||
|
||||
def test_plot_security_report_failure_rate_labels(self, monkeypatch):
|
||||
"""Test that plot_security_report calls ax.text for each failure rate bar label."""
|
||||
table = [
|
||||
{"failureRate": 10, "tokens": 100, "module": "Mod1"},
|
||||
{"failureRate": 20, "tokens": 150, "module": "Mod2"},
|
||||
{"failureRate": 30, "tokens": 200, "module": "Mod3"},
|
||||
]
|
||||
# Count the number of times ax.text is called for drawing failure rate labels.
|
||||
call_count = [0]
|
||||
from matplotlib.axes import Axes
|
||||
original_text = Axes.text
|
||||
def fake_text(self, *args, **kwargs):
|
||||
call_count[0] += 1
|
||||
return original_text(self, *args, **kwargs)
|
||||
monkeypatch.setattr(Axes, "text", fake_text)
|
||||
plot_security_report(table)
|
||||
# The loop inside plot_security_report calls ax.text once for each data point.
|
||||
assert call_count[0] == len(table)
|
||||
|
||||
def test_plot_security_report_non_numeric_failureRate(self):
|
||||
"""Test that plot_security_report raises an exception when failureRate is non-numeric."""
|
||||
table = [{"failureRate": "invalid", "tokens": 100, "module": "ModX"}]
|
||||
with pytest.raises(Exception):
|
||||
plot_security_report(table)
|
||||
@@ -0,0 +1,126 @@
|
||||
import io
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Event
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from agentic_security.routes import scan
|
||||
|
||||
# Dummy LLMSpec for success tests
|
||||
class DummyLLMSpec:
|
||||
def __init__(self, spec_string):
|
||||
self.spec_string = spec_string
|
||||
async def verify(self):
|
||||
class DummyResponse:
|
||||
status_code = 200
|
||||
text = "verification succeeded"
|
||||
elapsed = timedelta(seconds=0.5)
|
||||
return DummyResponse()
|
||||
@classmethod
|
||||
def from_string(cls, spec_string):
|
||||
return DummyLLMSpec(spec_string)
|
||||
|
||||
# Dummy scan_router generator to simulate streaming responses
|
||||
async def dummy_scan_router(request_factory, scan_parameters, tools_inbox, stop_event):
|
||||
for i in range(2):
|
||||
yield f"result {i}"
|
||||
|
||||
# Define a dummy Secrets class for testing purposes.
|
||||
class DummySecrets:
|
||||
def __init__(self):
|
||||
self.secrets = {}
|
||||
|
||||
# Create FastAPI app for testing and include the scan router.
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = FastAPI()
|
||||
app.include_router(scan.router)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_dependencies(monkeypatch):
|
||||
# Patch LLMSpec used in the routes with our dummy implementation.
|
||||
monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpec)
|
||||
# Patch fuzzer.scan_router to use our dummy scanning generator.
|
||||
monkeypatch.setattr(scan.fuzzer, "scan_router", dummy_scan_router)
|
||||
# Patch get_stop_event to return a dummy Event.
|
||||
dummy_event = Event()
|
||||
monkeypatch.setattr(scan, "get_stop_event", lambda: dummy_event)
|
||||
# Patch get_tools_inbox to return None.
|
||||
monkeypatch.setattr(scan, "get_tools_inbox", lambda: None)
|
||||
# Patch set_current_run to be a no-op.
|
||||
monkeypatch.setattr(scan, "set_current_run", lambda x: None)
|
||||
# Patch get_in_memory_secrets to return a DummySecrets instance.
|
||||
monkeypatch.setattr(scan, "get_in_memory_secrets", lambda: DummySecrets())
|
||||
# Ensure Scan.with_secrets is a no-op if not already implemented.
|
||||
if not hasattr(scan.Scan, "with_secrets"):
|
||||
monkeypatch.setattr(scan.Scan, "with_secrets", lambda self, secrets: None)
|
||||
|
||||
def test_verify_success(client):
|
||||
"""Test /verify endpoint for a successful verification."""
|
||||
data = {"spec": "dummy"}
|
||||
response = client.post("/verify", json=data)
|
||||
res_json = response.json()
|
||||
assert response.status_code == 200
|
||||
assert res_json["status_code"] == 200
|
||||
assert res_json["body"] == "verification succeeded"
|
||||
assert "elapsed" in res_json
|
||||
assert "timestamp" in res_json
|
||||
|
||||
def test_verify_failure(client, monkeypatch):
|
||||
"""Test /verify endpoint when verification fails."""
|
||||
class DummyLLMSpecFailure:
|
||||
def __init__(self, spec_string):
|
||||
self.spec_string = spec_string
|
||||
async def verify(self):
|
||||
raise Exception("verification error")
|
||||
@classmethod
|
||||
def from_string(cls, spec_string):
|
||||
return DummyLLMSpecFailure(spec_string)
|
||||
monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpecFailure)
|
||||
data = {"spec": "bad"}
|
||||
response = client.post("/verify", json=data)
|
||||
assert response.status_code == 400
|
||||
assert "verification error" in response.text
|
||||
|
||||
def test_scan(client):
|
||||
"""Test /scan endpoint to ensure streaming response works."""
|
||||
data = {"llmSpec": "dummy", "optimize": False, "maxBudget": 10, "enableMultiStepAttack": False}
|
||||
response = client.post("/scan", json=data)
|
||||
assert response.status_code == 200
|
||||
content = list(response.iter_lines())
|
||||
expected = ["result 0", "result 1"]
|
||||
assert content == expected
|
||||
|
||||
def test_stop_scan(client):
|
||||
"""Test /stop endpoint to ensure scan stopping functionality."""
|
||||
dummy_event = scan.get_stop_event()
|
||||
dummy_event.clear()
|
||||
response = client.post("/stop")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "Scan stopped"}
|
||||
assert dummy_event.is_set()
|
||||
|
||||
def test_scan_csv(client):
|
||||
"""Test /scan-csv endpoint with CSV file and llmSpec upload."""
|
||||
csv_content = b"col1,col2\nvalue1,value2"
|
||||
llm_spec_content = b"dummy"
|
||||
files = {
|
||||
"file": ("dummy.csv", csv_content, "text/csv"),
|
||||
"llmSpec": ("spec.txt", llm_spec_content, "text/plain"),
|
||||
}
|
||||
response = client.post(
|
||||
"/scan-csv",
|
||||
files=files,
|
||||
data={"optimize": "false", "maxBudget": "10", "enableMultiStepAttack": "false"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = list(response.iter_lines())
|
||||
expected = ["result 0", "result 1"]
|
||||
assert content == expected
|
||||
Reference in New Issue
Block a user