feat(restruct tests):

This commit is contained in:
Alexander Myasoedov
2025-12-26 22:58:21 +02:00
parent 433c999600
commit ce7636fe9e
27 changed files with 0 additions and 1 deletions
View File
View File
+209
View File
@@ -0,0 +1,209 @@
"""Tests for CircuitBreaker."""
import time
from agentic_security.executor.circuit_breaker import CircuitBreaker
class TestCircuitBreaker:
"""Test CircuitBreaker functionality."""
def test_initialization(self):
"""Test circuit breaker initialization."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
assert breaker.failure_threshold == 0.5
assert breaker.recovery_timeout == 30
assert breaker.state == "closed"
assert breaker.failures == 0
assert breaker.successes == 0
def test_record_success(self):
"""Test recording successful requests."""
breaker = CircuitBreaker()
breaker.record_success()
assert breaker.successes == 1
assert breaker.failures == 0
assert breaker.state == "closed"
def test_record_failure(self):
"""Test recording failed requests."""
breaker = CircuitBreaker()
breaker.record_failure()
assert breaker.failures == 1
assert breaker.successes == 0
assert breaker.last_failure_time is not None
def test_circuit_opens_on_failure_threshold(self):
"""Test that circuit opens when failure threshold is exceeded."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
# Record 10 requests: 6 failures, 4 successes (60% failure rate)
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
# Circuit should be open (60% > 50% threshold)
assert breaker.state == "open"
assert breaker.is_open() is True
def test_circuit_stays_closed_below_threshold(self):
"""Test that circuit stays closed when below threshold."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
# Record 10 requests: 4 failures, 6 successes (40% failure rate)
for _ in range(6):
breaker.record_success()
for _ in range(4):
breaker.record_failure()
# Circuit should stay closed (40% < 50% threshold)
assert breaker.state == "closed"
assert breaker.is_open() is False
def test_minimum_sample_size_required(self):
"""Test that minimum sample size is required before opening."""
breaker = CircuitBreaker(failure_threshold=0.5)
# Only 5 failures (below minimum of 10 total requests)
for _ in range(5):
breaker.record_failure()
# Circuit should stay closed (not enough samples)
assert breaker.state == "closed"
assert breaker.is_open() is False
def test_circuit_recovery_after_timeout(self):
"""Test that circuit enters half-open state after recovery timeout."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
# Open the circuit
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
assert breaker.state == "open"
# Wait for recovery timeout
time.sleep(1.1)
# Check if circuit moves to half-open
is_open = breaker.is_open()
assert is_open is False
assert breaker.state == "half_open"
def test_half_open_to_closed_on_successes(self):
"""Test that circuit closes from half-open after enough successes."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
# Open the circuit
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
# Wait for recovery
time.sleep(1.1)
breaker.is_open() # Triggers transition to half-open
assert breaker.state == "half_open"
# Record 3 successes
breaker.record_success()
breaker.record_success()
breaker.record_success()
# Should transition to closed
assert breaker.state == "closed"
def test_get_state(self):
"""Test get_state method."""
breaker = CircuitBreaker()
assert breaker.get_state() == "closed"
# Open the circuit
for _ in range(10):
breaker.record_failure()
assert breaker.get_state() == "open"
def test_get_failure_rate(self):
"""Test get_failure_rate method."""
breaker = CircuitBreaker()
# No requests
assert breaker.get_failure_rate() == 0.0
# 3 failures, 7 successes (30% failure rate)
for _ in range(7):
breaker.record_success()
for _ in range(3):
breaker.record_failure()
assert breaker.get_failure_rate() == 0.3
def test_reset(self):
"""Test reset method."""
breaker = CircuitBreaker()
# Record some activity
breaker.record_success()
breaker.record_failure()
for _ in range(10):
breaker.record_failure()
# Reset
breaker.reset()
# Should be back to initial state
assert breaker.state == "closed"
assert breaker.failures == 0
assert breaker.successes == 0
assert breaker.last_failure_time is None
def test_exact_failure_threshold(self):
"""Test behavior at exact failure threshold."""
breaker = CircuitBreaker(failure_threshold=0.5)
# Exactly 50% failure rate (5 failures, 5 successes)
for _ in range(5):
breaker.record_success()
for _ in range(5):
breaker.record_failure()
# Should be open (>= threshold)
assert breaker.state == "open"
def test_high_failure_threshold(self):
"""Test with high failure threshold."""
breaker = CircuitBreaker(failure_threshold=0.9)
# 80% failure rate (8 failures, 2 successes)
for _ in range(2):
breaker.record_success()
for _ in range(8):
breaker.record_failure()
# Should stay closed (80% < 90%)
assert breaker.state == "closed"
def test_zero_recovery_timeout(self):
"""Test with zero recovery timeout."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=0)
# Open the circuit
for _ in range(10):
breaker.record_failure()
assert breaker.state == "open"
# Should immediately allow recovery attempt
time.sleep(0.01)
is_open = breaker.is_open()
assert is_open is False
assert breaker.state == "half_open"
+279
View File
@@ -0,0 +1,279 @@
"""Tests for ConcurrentExecutor."""
import pytest
import asyncio
from unittest.mock import Mock, patch
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
from agentic_security.probe_actor.state import FuzzerState
class TestExecutorMetrics:
"""Test ExecutorMetrics functionality."""
def test_initialization(self):
"""Test metrics initialization."""
metrics = ExecutorMetrics()
assert metrics.successful_requests == 0
assert metrics.failed_requests == 0
assert metrics.total_latency == 0.0
assert len(metrics.latencies) == 0
def test_record_success(self):
"""Test recording successful requests."""
metrics = ExecutorMetrics()
metrics.record_success(0.5)
metrics.record_success(0.3)
assert metrics.successful_requests == 2
assert metrics.total_latency == 0.8
assert len(metrics.latencies) == 2
def test_record_failure(self):
"""Test recording failed requests."""
metrics = ExecutorMetrics()
metrics.record_failure()
metrics.record_failure()
assert metrics.failed_requests == 2
assert metrics.successful_requests == 0
def test_get_stats_no_requests(self):
"""Test get_stats with no requests."""
metrics = ExecutorMetrics()
stats = metrics.get_stats()
assert stats["total_requests"] == 0
assert stats["success_rate"] == 0.0
assert stats["avg_latency_ms"] == 0.0
assert stats["p95_latency_ms"] == 0.0
def test_get_stats_with_requests(self):
"""Test get_stats with recorded requests."""
metrics = ExecutorMetrics()
# Record some requests
metrics.record_success(0.1) # 100ms
metrics.record_success(0.2) # 200ms
metrics.record_success(0.3) # 300ms
metrics.record_failure()
stats = metrics.get_stats()
assert stats["total_requests"] == 4
assert stats["successful_requests"] == 3
assert stats["failed_requests"] == 1
assert stats["success_rate"] == 0.75
assert stats["avg_latency_ms"] == pytest.approx(200.0, rel=0.01)
def test_get_stats_p95_latency(self):
"""Test p95 latency calculation."""
metrics = ExecutorMetrics()
# Add 100 requests with varying latencies
for i in range(100):
metrics.record_success(i * 0.001) # 0ms to 99ms
stats = metrics.get_stats()
# p95 should be around 95ms
assert stats["p95_latency_ms"] >= 90.0
assert stats["p95_latency_ms"] <= 100.0
class TestConcurrentExecutor:
"""Test ConcurrentExecutor functionality."""
def test_initialization(self):
"""Test executor initialization."""
executor = ConcurrentExecutor(
max_concurrent=20,
rate_limit=10,
burst=5,
failure_threshold=0.5,
recovery_timeout=30,
)
assert executor.semaphore._value == 20
assert executor.rate_limiter.rate == 10
assert executor.rate_limiter.burst == 5
assert executor.circuit_breaker.failure_threshold == 0.5
assert executor.circuit_breaker.recovery_timeout == 30
@pytest.mark.asyncio
async def test_execute_batch_success(self):
"""Test successful batch execution."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
# Mock request factory
request_factory = Mock()
# Mock process_prompt to return success
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False) # 10 tokens, not refused
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["prompt1", "prompt2", "prompt3"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
assert tokens == 30 # 3 prompts * 10 tokens
assert failures == 0
@pytest.mark.asyncio
async def test_execute_batch_with_failures(self):
"""Test batch execution with some failures."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Mock process_prompt to alternate success/failure
call_count = [0]
async def mock_process_prompt(rf, prompt, tokens, module, state):
call_count[0] += 1
if call_count[0] % 2 == 0:
return (10, True) # Refused
return (10, False) # Success
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["p1", "p2", "p3", "p4"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
assert tokens == 40 # 4 prompts * 10 tokens
assert failures == 2 # 2 refused
@pytest.mark.asyncio
async def test_execute_batch_respects_concurrency_limit(self):
"""Test that concurrency limit is respected."""
executor = ConcurrentExecutor(max_concurrent=2, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Track concurrent executions
concurrent_count = [0]
max_concurrent = [0]
async def mock_process_prompt(rf, prompt, tokens, module, state):
concurrent_count[0] += 1
max_concurrent[0] = max(max_concurrent[0], concurrent_count[0])
await asyncio.sleep(0.01) # Simulate work
concurrent_count[0] -= 1
return (10, False)
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["p1", "p2", "p3", "p4", "p5"]
await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
# Max concurrent should not exceed limit
assert max_concurrent[0] <= 2
@pytest.mark.asyncio
async def test_circuit_breaker_integration(self):
"""Test that circuit breaker opens on failures."""
executor = ConcurrentExecutor(
max_concurrent=10,
rate_limit=100,
burst=20,
failure_threshold=0.5,
recovery_timeout=1,
)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Mock process_prompt to always fail
async def mock_process_prompt_fail(rf, prompt, tokens, module, state):
raise Exception("Request failed")
# First batch - all failures
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt_fail,
):
prompts = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
# All should have failed
assert failures == 10
# Circuit should be open now
assert executor.circuit_breaker.state == "open"
@pytest.mark.asyncio
async def test_get_metrics(self):
"""Test getting executor metrics."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False)
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
await executor.execute_batch(
request_factory, ["p1", "p2"], "test_module", fuzzer_state
)
metrics = executor.get_metrics()
assert "total_requests" in metrics
assert "success_rate" in metrics
assert "circuit_breaker_state" in metrics
assert "available_tokens" in metrics
assert metrics["total_requests"] == 2
assert metrics["circuit_breaker_state"] == "closed"
@pytest.mark.asyncio
async def test_rate_limiting_applied(self):
"""Test that rate limiting is applied."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=5, burst=2)
fuzzer_state = FuzzerState()
request_factory = Mock()
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False)
import time
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
start = time.monotonic()
# 5 requests with rate=5/s and burst=2
# First 2 immediate, next 3 should take ~0.6s total
await executor.execute_batch(
request_factory,
["p1", "p2", "p3", "p4", "p5"],
"test_module",
fuzzer_state,
)
elapsed = time.monotonic() - start
# Should take at least 0.5s (3 requests / 5 per second)
assert elapsed >= 0.4
+145
View File
@@ -0,0 +1,145 @@
"""Tests for TokenBucketRateLimiter."""
import asyncio
import pytest
import time
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
class TestTokenBucketRateLimiter:
"""Test TokenBucketRateLimiter functionality."""
@pytest.mark.asyncio
async def test_initialization(self):
"""Test rate limiter initialization."""
limiter = TokenBucketRateLimiter(rate=10, burst=20)
assert limiter.rate == 10
assert limiter.burst == 20
assert limiter.tokens == 20 # Starts full
@pytest.mark.asyncio
async def test_acquire_with_available_tokens(self):
"""Test acquiring tokens when they're available."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
start = time.monotonic()
await limiter.acquire()
elapsed = time.monotonic() - start
# Should return immediately
assert elapsed < 0.1
assert limiter.tokens < 5 # One token consumed
@pytest.mark.asyncio
async def test_acquire_waits_when_no_tokens(self):
"""Test that acquire waits when no tokens available."""
limiter = TokenBucketRateLimiter(rate=10, burst=1)
# Consume the initial token
await limiter.acquire()
# Next acquire should wait
start = time.monotonic()
await limiter.acquire()
elapsed = time.monotonic() - start
# Should wait approximately 1/rate seconds (0.1s for rate=10)
assert elapsed >= 0.08 # Allow some tolerance
@pytest.mark.asyncio
async def test_rate_limiting(self):
"""Test that rate limiting actually limits request rate."""
limiter = TokenBucketRateLimiter(rate=10, burst=2)
# Make 5 requests
start = time.monotonic()
for _ in range(5):
await limiter.acquire()
elapsed = time.monotonic() - start
# With rate=10/s and burst=2:
# - First 2 requests are immediate (burst)
# - Next 3 requests require waiting: 3 * (1/10) = 0.3s
# Total should be around 0.3s
assert elapsed >= 0.25 # Allow some tolerance
assert elapsed < 0.5
@pytest.mark.asyncio
async def test_burst_capacity(self):
"""Test that burst capacity allows immediate requests."""
limiter = TokenBucketRateLimiter(rate=5, burst=10)
# Make burst number of requests immediately
start = time.monotonic()
for _ in range(10):
await limiter.acquire()
elapsed = time.monotonic() - start
# All 10 requests should be nearly immediate (using burst capacity)
assert elapsed < 0.2
@pytest.mark.asyncio
async def test_token_replenishment(self):
"""Test that tokens are replenished over time."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
# Consume all tokens
for _ in range(5):
await limiter.acquire()
assert limiter.tokens < 1
# Wait for tokens to replenish
await asyncio.sleep(0.3) # Should add 3 tokens at rate=10
# Should have tokens again (approximately 3)
available = limiter.get_available_tokens()
assert available >= 2.5
assert available <= 3.5
@pytest.mark.asyncio
async def test_get_available_tokens(self):
"""Test get_available_tokens method."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
# Initially full
assert limiter.get_available_tokens() == 5
# After consuming one
await limiter.acquire()
assert limiter.get_available_tokens() < 5
@pytest.mark.asyncio
async def test_concurrent_requests(self):
"""Test rate limiter with concurrent requests."""
limiter = TokenBucketRateLimiter(rate=10, burst=3)
async def make_request(limiter):
await limiter.acquire()
return time.monotonic()
# Make 5 concurrent requests
start = time.monotonic()
tasks = [make_request(limiter) for _ in range(5)]
timestamps = await asyncio.gather(*tasks)
total_elapsed = time.monotonic() - start
# First 3 should be immediate (burst=3)
# Next 2 should wait
# Total time should be around 0.2s (2 * 1/10)
assert total_elapsed >= 0.15
assert total_elapsed < 0.4
@pytest.mark.asyncio
async def test_max_burst_capacity(self):
"""Test that tokens don't exceed burst capacity."""
limiter = TokenBucketRateLimiter(rate=100, burst=5)
# Wait longer than needed to fill
await asyncio.sleep(0.2) # Would add 20 tokens, but capped at 5
# Check tokens don't exceed burst
available = limiter.get_available_tokens()
assert available <= 5
assert available >= 4.5 # Close to full
View File
+285
View File
@@ -0,0 +1,285 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import httpx
import pytest
from agentic_security.primitives import Scan
from agentic_security.probe_actor.fuzzer import (
FuzzerState,
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
process_prompt,
scan_router,
)
@pytest.mark.asyncio
async def test_generate_prompts_with_list():
prompts = ["prompt1", "prompt2", "prompt3"]
results = [p async for p in generate_prompts(prompts)]
assert results == prompts
@pytest.mark.asyncio
async def test_generate_prompts_with_async_generator():
async def async_gen():
for i in range(3):
yield f"prompt{i}"
results = [p async for p in generate_prompts(async_gen())]
assert results == ["prompt0", "prompt1", "prompt2"]
async def assert_scan(generator, messages):
results = [r async for r in generator]
for m in messages:
found = False
for r in results:
if m in r:
found = True
break
assert found, f"Message '{m}' not found in results. Results: {results}"
return results
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_perform_single_shot_scan_success(prepare_prompts_mock):
# Mock prompt modules
prepare_prompts_mock.return_value = [
MagicMock(
dataset_name="test_module",
prompts=["test_prompt1", "test_prompt2"],
lazy=False,
)
]
# Mock request_factory
mock_response = AsyncMock()
mock_response.fn.return_value = AsyncMock(
status_code=200, text="response text", json=lambda: {}
)
request_factory = mock_response
async_gen = perform_single_shot_scan(
request_factory=request_factory,
max_budget=100,
datasets=[{"dataset_name": "test", "selected": True}],
optimize=False,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_perform_many_shot_scan_probe_injection(
prepare_prompts_mock, msj_prepare_prompts_mock
):
# Mock main and probe prompt modules
prepare_prompts_mock.side_effect = [
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
]
msj_prepare_prompts_mock.return_value = [
MagicMock(
dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False
)
]
# Mock request_factory
mock_response = AsyncMock()
mock_response.fn.side_effect = [
AsyncMock(status_code=200, text="main response", json=lambda: {}),
AsyncMock(status_code=200, text="probe response", json=lambda: {}),
]
request_factory = mock_response
async_gen = perform_many_shot_scan(
request_factory=request_factory,
max_budget=100,
datasets=[{"dataset_name": "main", "selected": True}],
probe_datasets=[{"dataset_name": "probe", "selected": True}],
probe_frequency=1.0, # Always inject probes
optimize=False,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_scan_router_single_shot(prepare_prompts_mock):
prepare_prompts_mock.return_value = []
request_factory = AsyncMock()
scan_params = Scan(
maxBudget=100,
llmSpec="test",
datasets=[],
probe_datasets=[],
enableMultiStepAttack=False,
optimize=False,
)
gen = scan_router(
request_factory=request_factory,
scan_parameters=scan_params,
)
await assert_scan(gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_scan_router_many_shot(prepare_prompts_mock):
prepare_prompts_mock.return_value = []
request_factory = AsyncMock()
scan_params = Scan(
maxBudget=100,
datasets=[],
llmSpec="test",
probeDatasets=[],
enableMultiStepAttack=True,
optimize=False,
)
gen = scan_router(
request_factory=request_factory,
scan_parameters=scan_params,
)
assert gen is not None
await assert_scan(gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
async def test_perform_single_shot_scan_stop_event():
stop_event = asyncio.Event()
stop_event.set() # Pre-set to simulate user stopping the scan
async def request_mock(*args, **kwargs):
return AsyncMock(status_code=200, text="response text", json=lambda: {})
async_gen = perform_single_shot_scan(
request_factory=MagicMock(fn=request_mock),
max_budget=100,
datasets=[],
stop_event=stop_event,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
@pytest.mark.asyncio
async def test_perform_many_shot_scan_stop_event():
stop_event = asyncio.Event()
stop_event.set() # Pre-set to simulate user stopping the scan
async def request_mock(*args, **kwargs):
return AsyncMock(status_code=200, text="response text", json=lambda: {})
async_gen = perform_many_shot_scan(
request_factory=MagicMock(fn=request_mock),
max_budget=100,
datasets=[],
probe_datasets=[],
stop_event=stop_event,
)
await assert_scan(async_gen, ["Loading", "Scan completed."])
def mock_refusal_heuristic(response_json):
return response_json.get("is_refusal", False)
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
async def test_successful_response_no_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Valid response text",
json=Mock(return_value={"is_refusal": False}),
request="mock_request",
)
)
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=FuzzerState(),
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
self.assertTrue(refusal)
async def test_successful_response_with_refusal(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=200,
text="Response indicating refusal",
json=Mock(return_value={"is_refusal": True}),
request="mock_request",
)
)
fuzzer_state = FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
# self.assertFalse(fuzzer_state.refusals)
async def test_http_error_response(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
return_value=Mock(
status_code=500,
text="Internal Server Error",
request="mock_request",
response=Mock(),
)
)
fuzzer_state = FuzzerState()
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
)
async def test_request_error(self):
mock_request_factory = Mock()
mock_request_factory.fn = AsyncMock(
side_effect=httpx.RequestError("Connection error")
)
fuzzer_state = FuzzerState()
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
+13
View File
@@ -0,0 +1,13 @@
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier
class TestCheckRefusal:
# The function correctly identifies a refusal phrase in the response.
def test_identify_refusal_phrase(self):
response = "I'm sorry, but I cannot provide that information."
assert DefaultRefusalClassifier().is_refusal(response)
# The response is an empty string.
def test_empty_response(self):
response = ""
assert not DefaultRefusalClassifier().is_refusal(response)
View File
@@ -0,0 +1,360 @@
"""Tests for unified dataset loader."""
import pytest
from unittest.mock import patch
from agentic_security.probe_data.unified_loader import (
InputSourceConfig,
UnifiedDatasetLoader,
)
from agentic_security.probe_data.models import ProbeDataset
class TestInputSourceConfig:
"""Test InputSourceConfig validation."""
def test_csv_source_config(self):
"""Test CSV source configuration."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="./test.csv",
prompt_column="prompt",
weight=1.5,
)
assert config.source_type == "csv"
assert config.dataset_name == "test_csv"
assert config.path == "./test.csv"
assert config.weight == 1.5
def test_huggingface_source_config(self):
"""Test HuggingFace source configuration."""
config = InputSourceConfig(
source_type="huggingface",
dataset_name="test/dataset",
split="train",
max_samples=100,
)
assert config.source_type == "huggingface"
assert config.split == "train"
assert config.max_samples == 100
def test_proxy_source_config(self):
"""Test proxy source configuration."""
config = InputSourceConfig(
source_type="proxy",
dataset_name="proxy_test",
)
assert config.source_type == "proxy"
assert config.enabled is True # Default value
def test_disabled_source(self):
"""Test disabled source configuration."""
config = InputSourceConfig(
source_type="csv",
dataset_name="disabled_test",
enabled=False,
)
assert config.enabled is False
def test_weight_validation(self):
"""Test that weight must be non-negative."""
with pytest.raises(ValueError):
InputSourceConfig(
source_type="csv",
dataset_name="test",
weight=-1.0,
)
class TestUnifiedDatasetLoader:
"""Test UnifiedDatasetLoader functionality."""
@pytest.mark.asyncio
async def test_load_single_csv_source(self):
"""Test loading a single CSV source."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="test.csv",
)
loader = UnifiedDatasetLoader([config])
# Mock the load_csv function
mock_dataset = ProbeDataset(
dataset_name="test_csv",
prompts=["prompt1", "prompt2", "prompt3"],
tokens=10,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
):
result = await loader.load_all()
assert result.dataset_name == "unified"
assert len(result.prompts) == 3
assert result.prompts == ["prompt1", "prompt2", "prompt3"]
@pytest.mark.asyncio
async def test_load_single_huggingface_source(self):
"""Test loading a single HuggingFace source."""
config = InputSourceConfig(
source_type="huggingface",
dataset_name="test/dataset",
split="train",
)
loader = UnifiedDatasetLoader([config])
# Mock the load_dataset_generic function
mock_dataset = ProbeDataset(
dataset_name="test/dataset",
prompts=["hf_prompt1", "hf_prompt2"],
tokens=8,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_dataset_generic",
return_value=mock_dataset,
):
result = await loader.load_all()
assert result.dataset_name == "unified"
assert len(result.prompts) == 2
@pytest.mark.asyncio
async def test_merge_multiple_sources(self):
"""Test merging multiple sources."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="csv1",
path="test1.csv",
weight=1.0,
),
InputSourceConfig(
source_type="csv",
dataset_name="csv2",
path="test2.csv",
weight=2.0,
),
]
loader = UnifiedDatasetLoader(configs)
# Mock datasets
mock_dataset1 = ProbeDataset(
dataset_name="csv1",
prompts=["prompt1"],
tokens=5,
approx_cost=0.0,
metadata={},
)
mock_dataset2 = ProbeDataset(
dataset_name="csv2",
prompts=["prompt2", "prompt3"],
tokens=10,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=[mock_dataset1, mock_dataset2],
):
result = await loader.load_all()
assert result.dataset_name == "unified"
# Weight 1.0 = include once, weight 2.0 = include twice
# csv1: 1 prompt * 1 = 1
# csv2: 2 prompts * 2 = 4
assert len(result.prompts) == 5
assert "csv1" in result.metadata["sources"]
assert "csv2" in result.metadata["sources"]
@pytest.mark.asyncio
async def test_handle_disabled_sources(self):
"""Test that disabled sources are skipped."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="enabled_csv",
path="enabled.csv",
enabled=True,
),
InputSourceConfig(
source_type="csv",
dataset_name="disabled_csv",
path="disabled.csv",
enabled=False,
),
]
loader = UnifiedDatasetLoader(configs)
mock_dataset = ProbeDataset(
dataset_name="enabled_csv",
prompts=["prompt1"],
tokens=5,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
) as mock_load:
result = await loader.load_all()
# Should only be called once (for enabled source)
assert mock_load.call_count == 1
assert len(result.prompts) == 1
@pytest.mark.asyncio
async def test_max_samples_limit(self):
"""Test that max_samples limits the number of prompts."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="test.csv",
max_samples=2,
)
loader = UnifiedDatasetLoader([config])
# Mock dataset with more prompts than max_samples
mock_dataset = ProbeDataset(
dataset_name="test_csv",
prompts=["prompt1", "prompt2", "prompt3", "prompt4", "prompt5"],
tokens=20,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
):
result = await loader.load_all()
# Should be limited to 2 prompts
assert len(result.prompts) == 2
@pytest.mark.asyncio
async def test_error_handling(self):
"""Test that errors are handled gracefully."""
config = InputSourceConfig(
source_type="csv",
dataset_name="error_csv",
path="nonexistent.csv",
)
loader = UnifiedDatasetLoader([config])
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=Exception("File not found"),
):
result = await loader.load_all()
# Should return empty dataset on error
assert result.dataset_name == "unified_empty"
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_proxy_source_placeholder(self):
"""Test that proxy source returns empty dataset (not implemented in PoC)."""
config = InputSourceConfig(
source_type="proxy",
dataset_name="proxy_test",
)
loader = UnifiedDatasetLoader([config])
result = await loader.load_all()
# Proxy not implemented in PoC, should return empty
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_weighted_sampling(self):
"""Test weighted sampling behavior."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="low_weight",
path="low.csv",
weight=1.0,
),
InputSourceConfig(
source_type="csv",
dataset_name="high_weight",
path="high.csv",
weight=3.0,
),
]
loader = UnifiedDatasetLoader(configs)
mock_dataset1 = ProbeDataset(
dataset_name="low_weight",
prompts=["a"],
tokens=1,
approx_cost=0.0,
metadata={},
)
mock_dataset2 = ProbeDataset(
dataset_name="high_weight",
prompts=["b"],
tokens=1,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=[mock_dataset1, mock_dataset2],
):
result = await loader.load_all()
# Weight 1.0: 1 prompt * 1 = 1
# Weight 3.0: 1 prompt * 3 = 3
# Total: 4 prompts
assert len(result.prompts) == 4
assert result.prompts.count("a") == 1
assert result.prompts.count("b") == 3
@pytest.mark.asyncio
async def test_empty_configs_list(self):
"""Test loading with empty configs list."""
loader = UnifiedDatasetLoader([])
result = await loader.load_all()
assert result.dataset_name == "unified_empty"
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_csv_with_url(self):
"""Test CSV loading from URL."""
config = InputSourceConfig(
source_type="csv",
dataset_name="remote_csv",
url="https://example.com/data.csv",
prompt_column="text",
)
loader = UnifiedDatasetLoader([config])
mock_dataset = ProbeDataset(
dataset_name="remote_csv",
prompts=["remote_prompt"],
tokens=5,
approx_cost=0.0,
metadata={"source_type": "csv", "url": "https://example.com/data.csv"},
)
with patch(
"agentic_security.probe_data.unified_loader.load_dataset_generic",
return_value=mock_dataset,
):
result = await loader.load_all()
assert len(result.prompts) == 1
assert result.prompts[0] == "remote_prompt"
@@ -0,0 +1,53 @@
from unittest.mock import patch
import pandas as pd
import pytest
from agentic_security.refusal_classifier.model import RefusalClassifier
@pytest.fixture
def mock_training_data():
"""Create mock training data CSV content"""
data = {
"GPT4_response": ["I cannot help with that", "I must decline"],
"ChatGPT_response": ["I won't assist with that", "That's not appropriate"],
"Claude_response": ["I cannot comply", "That would be unethical"],
}
return pd.DataFrame(data)
@pytest.fixture
def classifier():
"""Create a RefusalClassifier instance with test paths"""
return RefusalClassifier(
model_path="test_model.joblib",
vectorizer_path="test_vectorizer.joblib",
scaler_path="test_scaler.joblib",
)
@pytest.fixture
def trained_classifier(classifier, mock_training_data):
"""Create a trained classifier with mock data"""
with patch("pandas.read_csv", return_value=mock_training_data):
classifier.train(["mock_data.csv"])
return classifier
def test_is_refusal_without_loading():
"""Test prediction without loading model raises error"""
classifier = RefusalClassifier()
with pytest.raises(ValueError, match="Model, vectorizer, or scaler not loaded"):
classifier.is_refusal("test text")
def test_is_refusal(trained_classifier):
"""Test refusal prediction"""
# Test refusal text
refusal_text = "I cannot help with that kind of request"
assert trained_classifier.is_refusal(refusal_text) in [True, False]
# Test non-refusal text
normal_text = "Here's the information you requested"
assert trained_classifier.is_refusal(normal_text) in [True, False]
+25
View File
@@ -0,0 +1,25 @@
import os
from pathlib import Path
from agentic_security.cache_config import ensure_cache_dir
def test_ensure_cache_dir_creates_dir_and_sets_env(tmp_path, monkeypatch):
monkeypatch.delenv("DISK_CACHE_DIR", raising=False)
target_dir = tmp_path / "cache_to_disk"
resolved = ensure_cache_dir(target_dir)
assert resolved == target_dir
assert resolved.is_dir()
assert Path(os.environ["DISK_CACHE_DIR"]) == resolved
def test_ensure_cache_dir_respects_existing_env(tmp_path, monkeypatch):
env_dir = tmp_path / "preconfigured"
monkeypatch.setenv("DISK_CACHE_DIR", str(env_dir))
resolved = ensure_cache_dir()
assert resolved == env_dir
assert resolved.exists()
+15
View File
@@ -0,0 +1,15 @@
from agentic_security.dependencies import InMemorySecrets, get_in_memory_secrets
def test_in_memory_secrets():
secrets = InMemorySecrets()
secrets.set_secret("api_key", "12345")
assert secrets.get_secret("api_key") == "12345"
assert secrets.get_secret("non_existent_key") is None
def test_get_in_memory_secrets():
secrets = get_in_memory_secrets()
assert isinstance(secrets, InMemorySecrets)
secrets.set_secret("token", "abcde")
assert secrets.get_secret("token") == "abcde"
+12
View File
@@ -0,0 +1,12 @@
import pytest
from agentic_security.mcp.client import run
@pytest.mark.asyncio
async def test_mcp_echo_tool():
"""Test the echo tool functionality"""
prompts, resources, tools = await run()
assert prompts
assert resources
assert tools
+89
View File
@@ -0,0 +1,89 @@
import pytest
from agentic_security.http_spec import (
InvalidHTTPSpecError,
LLMSpec,
parse_http_spec,
)
class TestParseHttpSpec:
# Should correctly parse a simple HTTP spec with headers and body
def test_parse_simple_http_spec(self):
http_spec = (
'GET http://example.com\nContent-Type: application/json\n\n{"key": "value"}'
)
expected_spec = LLMSpec(
method="GET",
url="http://example.com",
headers={"Content-Type": "application/json"},
body='{"key": "value"}',
)
assert parse_http_spec(http_spec) == expected_spec
# Should correctly parse a HTTP spec with headers containing special characters
def test_parse_http_spec_with_special_characters(self):
http_spec = 'POST http://example.com\nX-Auth-Token: abcdefg1234567890!@#$%^&*\n\n{"key": "value"}'
expected_spec = LLMSpec(
method="POST",
url="http://example.com",
headers={"X-Auth-Token": "abcdefg1234567890!@#$%^&*"},
body='{"key": "value"}',
)
assert parse_http_spec(http_spec) == expected_spec
# Should correctly parse a spec with no headers and no body
def test_parse_http_spec_with_no_headers_and_no_body(self):
# Arrange
http_spec = "GET http://example.com"
# Act
result = parse_http_spec(http_spec)
# Assert
assert result.method == "GET"
assert result.url == "http://example.com"
assert result.headers == {}
assert result.body == ""
def test_parse_http_spec_with_headers_no_body(self):
# Arrange
http_spec = "GET http://example.com\nContent-Type: application/json\n\n"
# Act
result = parse_http_spec(http_spec)
# Assert
assert result.method == "GET"
assert result.url == "http://example.com"
assert result.headers == {"Content-Type": "application/json"}
assert result.body == ""
def test_parse_http_spec_rejects_malformed_header(self):
http_spec = "GET http://example.com\nHeaderWithoutColon\n\n"
with pytest.raises(InvalidHTTPSpecError, match="Invalid header line"):
parse_http_spec(http_spec)
def test_parse_http_spec_trims_header_whitespace(self):
http_spec = "GET http://example.com\nAuthorization:Bearer token\n\n"
result = parse_http_spec(http_spec)
assert result.headers == {"Authorization": "Bearer token"}
class TestLLMSpec:
def test_validate_raises_error_for_missing_files(self):
spec = LLMSpec(
method="POST", url="http://example.com", headers={}, body="", has_files=True
)
with pytest.raises(ValueError, match="Files are required for this request."):
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})
def test_validate_raises_error_for_missing_image(self):
spec = LLMSpec(
method="POST", url="http://example.com", headers={}, body="", has_image=True
)
with pytest.raises(ValueError, match="An image is required for this request."):
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})