mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(restruct tests):
This commit is contained in:
@@ -0,0 +1,209 @@
|
||||
"""Tests for CircuitBreaker."""
|
||||
|
||||
import time
|
||||
from agentic_security.executor.circuit_breaker import CircuitBreaker
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test CircuitBreaker functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test circuit breaker initialization."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
assert breaker.failure_threshold == 0.5
|
||||
assert breaker.recovery_timeout == 30
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.failures == 0
|
||||
assert breaker.successes == 0
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful requests."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
breaker.record_success()
|
||||
assert breaker.successes == 1
|
||||
assert breaker.failures == 0
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed requests."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failures == 1
|
||||
assert breaker.successes == 0
|
||||
assert breaker.last_failure_time is not None
|
||||
|
||||
def test_circuit_opens_on_failure_threshold(self):
|
||||
"""Test that circuit opens when failure threshold is exceeded."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
# Record 10 requests: 6 failures, 4 successes (60% failure rate)
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should be open (60% > 50% threshold)
|
||||
assert breaker.state == "open"
|
||||
assert breaker.is_open() is True
|
||||
|
||||
def test_circuit_stays_closed_below_threshold(self):
|
||||
"""Test that circuit stays closed when below threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
# Record 10 requests: 4 failures, 6 successes (40% failure rate)
|
||||
for _ in range(6):
|
||||
breaker.record_success()
|
||||
for _ in range(4):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should stay closed (40% < 50% threshold)
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.is_open() is False
|
||||
|
||||
def test_minimum_sample_size_required(self):
|
||||
"""Test that minimum sample size is required before opening."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5)
|
||||
|
||||
# Only 5 failures (below minimum of 10 total requests)
|
||||
for _ in range(5):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should stay closed (not enough samples)
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.is_open() is False
|
||||
|
||||
def test_circuit_recovery_after_timeout(self):
|
||||
"""Test that circuit enters half-open state after recovery timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "open"
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
# Check if circuit moves to half-open
|
||||
is_open = breaker.is_open()
|
||||
assert is_open is False
|
||||
assert breaker.state == "half_open"
|
||||
|
||||
def test_half_open_to_closed_on_successes(self):
|
||||
"""Test that circuit closes from half-open after enough successes."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
time.sleep(1.1)
|
||||
breaker.is_open() # Triggers transition to half-open
|
||||
|
||||
assert breaker.state == "half_open"
|
||||
|
||||
# Record 3 successes
|
||||
breaker.record_success()
|
||||
breaker.record_success()
|
||||
breaker.record_success()
|
||||
|
||||
# Should transition to closed
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_get_state(self):
|
||||
"""Test get_state method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
assert breaker.get_state() == "closed"
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.get_state() == "open"
|
||||
|
||||
def test_get_failure_rate(self):
|
||||
"""Test get_failure_rate method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# No requests
|
||||
assert breaker.get_failure_rate() == 0.0
|
||||
|
||||
# 3 failures, 7 successes (30% failure rate)
|
||||
for _ in range(7):
|
||||
breaker.record_success()
|
||||
for _ in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.get_failure_rate() == 0.3
|
||||
|
||||
def test_reset(self):
|
||||
"""Test reset method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# Record some activity
|
||||
breaker.record_success()
|
||||
breaker.record_failure()
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
# Reset
|
||||
breaker.reset()
|
||||
|
||||
# Should be back to initial state
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.failures == 0
|
||||
assert breaker.successes == 0
|
||||
assert breaker.last_failure_time is None
|
||||
|
||||
def test_exact_failure_threshold(self):
|
||||
"""Test behavior at exact failure threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5)
|
||||
|
||||
# Exactly 50% failure rate (5 failures, 5 successes)
|
||||
for _ in range(5):
|
||||
breaker.record_success()
|
||||
for _ in range(5):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should be open (>= threshold)
|
||||
assert breaker.state == "open"
|
||||
|
||||
def test_high_failure_threshold(self):
|
||||
"""Test with high failure threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.9)
|
||||
|
||||
# 80% failure rate (8 failures, 2 successes)
|
||||
for _ in range(2):
|
||||
breaker.record_success()
|
||||
for _ in range(8):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should stay closed (80% < 90%)
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_zero_recovery_timeout(self):
|
||||
"""Test with zero recovery timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=0)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "open"
|
||||
|
||||
# Should immediately allow recovery attempt
|
||||
time.sleep(0.01)
|
||||
is_open = breaker.is_open()
|
||||
|
||||
assert is_open is False
|
||||
assert breaker.state == "half_open"
|
||||
@@ -0,0 +1,279 @@
|
||||
"""Tests for ConcurrentExecutor."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
|
||||
from agentic_security.probe_actor.state import FuzzerState
|
||||
|
||||
|
||||
class TestExecutorMetrics:
|
||||
"""Test ExecutorMetrics functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test metrics initialization."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
assert metrics.successful_requests == 0
|
||||
assert metrics.failed_requests == 0
|
||||
assert metrics.total_latency == 0.0
|
||||
assert len(metrics.latencies) == 0
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
metrics.record_success(0.5)
|
||||
metrics.record_success(0.3)
|
||||
|
||||
assert metrics.successful_requests == 2
|
||||
assert metrics.total_latency == 0.8
|
||||
assert len(metrics.latencies) == 2
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
metrics.record_failure()
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.failed_requests == 2
|
||||
assert metrics.successful_requests == 0
|
||||
|
||||
def test_get_stats_no_requests(self):
|
||||
"""Test get_stats with no requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
assert stats["total_requests"] == 0
|
||||
assert stats["success_rate"] == 0.0
|
||||
assert stats["avg_latency_ms"] == 0.0
|
||||
assert stats["p95_latency_ms"] == 0.0
|
||||
|
||||
def test_get_stats_with_requests(self):
|
||||
"""Test get_stats with recorded requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
# Record some requests
|
||||
metrics.record_success(0.1) # 100ms
|
||||
metrics.record_success(0.2) # 200ms
|
||||
metrics.record_success(0.3) # 300ms
|
||||
metrics.record_failure()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
assert stats["total_requests"] == 4
|
||||
assert stats["successful_requests"] == 3
|
||||
assert stats["failed_requests"] == 1
|
||||
assert stats["success_rate"] == 0.75
|
||||
assert stats["avg_latency_ms"] == pytest.approx(200.0, rel=0.01)
|
||||
|
||||
def test_get_stats_p95_latency(self):
|
||||
"""Test p95 latency calculation."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
# Add 100 requests with varying latencies
|
||||
for i in range(100):
|
||||
metrics.record_success(i * 0.001) # 0ms to 99ms
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
# p95 should be around 95ms
|
||||
assert stats["p95_latency_ms"] >= 90.0
|
||||
assert stats["p95_latency_ms"] <= 100.0
|
||||
|
||||
|
||||
class TestConcurrentExecutor:
|
||||
"""Test ConcurrentExecutor functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test executor initialization."""
|
||||
executor = ConcurrentExecutor(
|
||||
max_concurrent=20,
|
||||
rate_limit=10,
|
||||
burst=5,
|
||||
failure_threshold=0.5,
|
||||
recovery_timeout=30,
|
||||
)
|
||||
|
||||
assert executor.semaphore._value == 20
|
||||
assert executor.rate_limiter.rate == 10
|
||||
assert executor.rate_limiter.burst == 5
|
||||
assert executor.circuit_breaker.failure_threshold == 0.5
|
||||
assert executor.circuit_breaker.recovery_timeout == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_success(self):
|
||||
"""Test successful batch execution."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
# Mock request factory
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to return success
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False) # 10 tokens, not refused
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
assert tokens == 30 # 3 prompts * 10 tokens
|
||||
assert failures == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_with_failures(self):
|
||||
"""Test batch execution with some failures."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to alternate success/failure
|
||||
call_count = [0]
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
call_count[0] += 1
|
||||
if call_count[0] % 2 == 0:
|
||||
return (10, True) # Refused
|
||||
return (10, False) # Success
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
assert tokens == 40 # 4 prompts * 10 tokens
|
||||
assert failures == 2 # 2 refused
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_respects_concurrency_limit(self):
|
||||
"""Test that concurrency limit is respected."""
|
||||
executor = ConcurrentExecutor(max_concurrent=2, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
request_factory = Mock()
|
||||
|
||||
# Track concurrent executions
|
||||
concurrent_count = [0]
|
||||
max_concurrent = [0]
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
concurrent_count[0] += 1
|
||||
max_concurrent[0] = max(max_concurrent[0], concurrent_count[0])
|
||||
await asyncio.sleep(0.01) # Simulate work
|
||||
concurrent_count[0] -= 1
|
||||
return (10, False)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4", "p5"]
|
||||
await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
# Max concurrent should not exceed limit
|
||||
assert max_concurrent[0] <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_integration(self):
|
||||
"""Test that circuit breaker opens on failures."""
|
||||
executor = ConcurrentExecutor(
|
||||
max_concurrent=10,
|
||||
rate_limit=100,
|
||||
burst=20,
|
||||
failure_threshold=0.5,
|
||||
recovery_timeout=1,
|
||||
)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to always fail
|
||||
async def mock_process_prompt_fail(rf, prompt, tokens, module, state):
|
||||
raise Exception("Request failed")
|
||||
|
||||
# First batch - all failures
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt_fail,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
# All should have failed
|
||||
assert failures == 10
|
||||
|
||||
# Circuit should be open now
|
||||
assert executor.circuit_breaker.state == "open"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metrics(self):
|
||||
"""Test getting executor metrics."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
await executor.execute_batch(
|
||||
request_factory, ["p1", "p2"], "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
metrics = executor.get_metrics()
|
||||
|
||||
assert "total_requests" in metrics
|
||||
assert "success_rate" in metrics
|
||||
assert "circuit_breaker_state" in metrics
|
||||
assert "available_tokens" in metrics
|
||||
assert metrics["total_requests"] == 2
|
||||
assert metrics["circuit_breaker_state"] == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting_applied(self):
|
||||
"""Test that rate limiting is applied."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=5, burst=2)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False)
|
||||
|
||||
import time
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
start = time.monotonic()
|
||||
# 5 requests with rate=5/s and burst=2
|
||||
# First 2 immediate, next 3 should take ~0.6s total
|
||||
await executor.execute_batch(
|
||||
request_factory,
|
||||
["p1", "p2", "p3", "p4", "p5"],
|
||||
"test_module",
|
||||
fuzzer_state,
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should take at least 0.5s (3 requests / 5 per second)
|
||||
assert elapsed >= 0.4
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Tests for TokenBucketRateLimiter."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
|
||||
|
||||
|
||||
class TestTokenBucketRateLimiter:
|
||||
"""Test TokenBucketRateLimiter functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
"""Test rate limiter initialization."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=20)
|
||||
|
||||
assert limiter.rate == 10
|
||||
assert limiter.burst == 20
|
||||
assert limiter.tokens == 20 # Starts full
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_with_available_tokens(self):
|
||||
"""Test acquiring tokens when they're available."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
start = time.monotonic()
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should return immediately
|
||||
assert elapsed < 0.1
|
||||
assert limiter.tokens < 5 # One token consumed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_waits_when_no_tokens(self):
|
||||
"""Test that acquire waits when no tokens available."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=1)
|
||||
|
||||
# Consume the initial token
|
||||
await limiter.acquire()
|
||||
|
||||
# Next acquire should wait
|
||||
start = time.monotonic()
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should wait approximately 1/rate seconds (0.1s for rate=10)
|
||||
assert elapsed >= 0.08 # Allow some tolerance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting(self):
|
||||
"""Test that rate limiting actually limits request rate."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=2)
|
||||
|
||||
# Make 5 requests
|
||||
start = time.monotonic()
|
||||
for _ in range(5):
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# With rate=10/s and burst=2:
|
||||
# - First 2 requests are immediate (burst)
|
||||
# - Next 3 requests require waiting: 3 * (1/10) = 0.3s
|
||||
# Total should be around 0.3s
|
||||
assert elapsed >= 0.25 # Allow some tolerance
|
||||
assert elapsed < 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_burst_capacity(self):
|
||||
"""Test that burst capacity allows immediate requests."""
|
||||
limiter = TokenBucketRateLimiter(rate=5, burst=10)
|
||||
|
||||
# Make burst number of requests immediately
|
||||
start = time.monotonic()
|
||||
for _ in range(10):
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# All 10 requests should be nearly immediate (using burst capacity)
|
||||
assert elapsed < 0.2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_replenishment(self):
|
||||
"""Test that tokens are replenished over time."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
# Consume all tokens
|
||||
for _ in range(5):
|
||||
await limiter.acquire()
|
||||
|
||||
assert limiter.tokens < 1
|
||||
|
||||
# Wait for tokens to replenish
|
||||
await asyncio.sleep(0.3) # Should add 3 tokens at rate=10
|
||||
|
||||
# Should have tokens again (approximately 3)
|
||||
available = limiter.get_available_tokens()
|
||||
assert available >= 2.5
|
||||
assert available <= 3.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_tokens(self):
|
||||
"""Test get_available_tokens method."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
# Initially full
|
||||
assert limiter.get_available_tokens() == 5
|
||||
|
||||
# After consuming one
|
||||
await limiter.acquire()
|
||||
assert limiter.get_available_tokens() < 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self):
|
||||
"""Test rate limiter with concurrent requests."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=3)
|
||||
|
||||
async def make_request(limiter):
|
||||
await limiter.acquire()
|
||||
return time.monotonic()
|
||||
|
||||
# Make 5 concurrent requests
|
||||
start = time.monotonic()
|
||||
tasks = [make_request(limiter) for _ in range(5)]
|
||||
timestamps = await asyncio.gather(*tasks)
|
||||
total_elapsed = time.monotonic() - start
|
||||
|
||||
# First 3 should be immediate (burst=3)
|
||||
# Next 2 should wait
|
||||
# Total time should be around 0.2s (2 * 1/10)
|
||||
assert total_elapsed >= 0.15
|
||||
assert total_elapsed < 0.4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_burst_capacity(self):
|
||||
"""Test that tokens don't exceed burst capacity."""
|
||||
limiter = TokenBucketRateLimiter(rate=100, burst=5)
|
||||
|
||||
# Wait longer than needed to fill
|
||||
await asyncio.sleep(0.2) # Would add 20 tokens, but capped at 5
|
||||
|
||||
# Check tokens don't exceed burst
|
||||
available = limiter.get_available_tokens()
|
||||
assert available <= 5
|
||||
assert available >= 4.5 # Close to full
|
||||
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from agentic_security.primitives import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
FuzzerState,
|
||||
generate_prompts,
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
process_prompt,
|
||||
scan_router,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_prompts_with_list():
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
results = [p async for p in generate_prompts(prompts)]
|
||||
assert results == prompts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_prompts_with_async_generator():
|
||||
async def async_gen():
|
||||
for i in range(3):
|
||||
yield f"prompt{i}"
|
||||
|
||||
results = [p async for p in generate_prompts(async_gen())]
|
||||
assert results == ["prompt0", "prompt1", "prompt2"]
|
||||
|
||||
|
||||
async def assert_scan(generator, messages):
|
||||
results = [r async for r in generator]
|
||||
|
||||
for m in messages:
|
||||
found = False
|
||||
for r in results:
|
||||
if m in r:
|
||||
found = True
|
||||
break
|
||||
assert found, f"Message '{m}' not found in results. Results: {results}"
|
||||
return results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agentic_security.probe_data.data.prepare_prompts")
|
||||
async def test_perform_single_shot_scan_success(prepare_prompts_mock):
|
||||
# Mock prompt modules
|
||||
prepare_prompts_mock.return_value = [
|
||||
MagicMock(
|
||||
dataset_name="test_module",
|
||||
prompts=["test_prompt1", "test_prompt2"],
|
||||
lazy=False,
|
||||
)
|
||||
]
|
||||
|
||||
# Mock request_factory
|
||||
mock_response = AsyncMock()
|
||||
mock_response.fn.return_value = AsyncMock(
|
||||
status_code=200, text="response text", json=lambda: {}
|
||||
)
|
||||
request_factory = mock_response
|
||||
|
||||
async_gen = perform_single_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=100,
|
||||
datasets=[{"dataset_name": "test", "selected": True}],
|
||||
optimize=False,
|
||||
)
|
||||
|
||||
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
|
||||
@patch("agentic_security.probe_data.data.prepare_prompts")
|
||||
async def test_perform_many_shot_scan_probe_injection(
|
||||
prepare_prompts_mock, msj_prepare_prompts_mock
|
||||
):
|
||||
# Mock main and probe prompt modules
|
||||
prepare_prompts_mock.side_effect = [
|
||||
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
|
||||
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
|
||||
]
|
||||
|
||||
msj_prepare_prompts_mock.return_value = [
|
||||
MagicMock(
|
||||
dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False
|
||||
)
|
||||
]
|
||||
|
||||
# Mock request_factory
|
||||
mock_response = AsyncMock()
|
||||
mock_response.fn.side_effect = [
|
||||
AsyncMock(status_code=200, text="main response", json=lambda: {}),
|
||||
AsyncMock(status_code=200, text="probe response", json=lambda: {}),
|
||||
]
|
||||
request_factory = mock_response
|
||||
|
||||
async_gen = perform_many_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=100,
|
||||
datasets=[{"dataset_name": "main", "selected": True}],
|
||||
probe_datasets=[{"dataset_name": "probe", "selected": True}],
|
||||
probe_frequency=1.0, # Always inject probes
|
||||
optimize=False,
|
||||
)
|
||||
|
||||
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agentic_security.probe_data.data.prepare_prompts")
|
||||
async def test_scan_router_single_shot(prepare_prompts_mock):
|
||||
prepare_prompts_mock.return_value = []
|
||||
|
||||
request_factory = AsyncMock()
|
||||
scan_params = Scan(
|
||||
maxBudget=100,
|
||||
llmSpec="test",
|
||||
datasets=[],
|
||||
probe_datasets=[],
|
||||
enableMultiStepAttack=False,
|
||||
optimize=False,
|
||||
)
|
||||
|
||||
gen = scan_router(
|
||||
request_factory=request_factory,
|
||||
scan_parameters=scan_params,
|
||||
)
|
||||
await assert_scan(gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agentic_security.probe_data.data.prepare_prompts")
|
||||
async def test_scan_router_many_shot(prepare_prompts_mock):
|
||||
prepare_prompts_mock.return_value = []
|
||||
|
||||
request_factory = AsyncMock()
|
||||
scan_params = Scan(
|
||||
maxBudget=100,
|
||||
datasets=[],
|
||||
llmSpec="test",
|
||||
probeDatasets=[],
|
||||
enableMultiStepAttack=True,
|
||||
optimize=False,
|
||||
)
|
||||
|
||||
gen = scan_router(
|
||||
request_factory=request_factory,
|
||||
scan_parameters=scan_params,
|
||||
)
|
||||
assert gen is not None
|
||||
|
||||
await assert_scan(gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_single_shot_scan_stop_event():
|
||||
stop_event = asyncio.Event()
|
||||
stop_event.set() # Pre-set to simulate user stopping the scan
|
||||
|
||||
async def request_mock(*args, **kwargs):
|
||||
return AsyncMock(status_code=200, text="response text", json=lambda: {})
|
||||
|
||||
async_gen = perform_single_shot_scan(
|
||||
request_factory=MagicMock(fn=request_mock),
|
||||
max_budget=100,
|
||||
datasets=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_many_shot_scan_stop_event():
|
||||
stop_event = asyncio.Event()
|
||||
stop_event.set() # Pre-set to simulate user stopping the scan
|
||||
|
||||
async def request_mock(*args, **kwargs):
|
||||
return AsyncMock(status_code=200, text="response text", json=lambda: {})
|
||||
|
||||
async_gen = perform_many_shot_scan(
|
||||
request_factory=MagicMock(fn=request_mock),
|
||||
max_budget=100,
|
||||
datasets=[],
|
||||
probe_datasets=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
await assert_scan(async_gen, ["Loading", "Scan completed."])
|
||||
|
||||
|
||||
def mock_refusal_heuristic(response_json):
|
||||
return response_json.get("is_refusal", False)
|
||||
|
||||
|
||||
class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_successful_response_no_refusal(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=200,
|
||||
text="Valid response text",
|
||||
json=Mock(return_value={"is_refusal": False}),
|
||||
request="mock_request",
|
||||
)
|
||||
)
|
||||
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
fuzzer_state=FuzzerState(),
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||
self.assertTrue(refusal)
|
||||
|
||||
async def test_successful_response_with_refusal(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=200,
|
||||
text="Response indicating refusal",
|
||||
json=Mock(return_value={"is_refusal": True}),
|
||||
request="mock_request",
|
||||
)
|
||||
)
|
||||
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||
# self.assertFalse(fuzzer_state.refusals)
|
||||
|
||||
async def test_http_error_response(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
return_value=Mock(
|
||||
status_code=500,
|
||||
text="Internal Server Error",
|
||||
request="mock_request",
|
||||
response=Mock(),
|
||||
)
|
||||
)
|
||||
|
||||
fuzzer_state = FuzzerState()
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
mock_request_factory = Mock()
|
||||
mock_request_factory.fn = AsyncMock(
|
||||
side_effect=httpx.RequestError("Connection error")
|
||||
)
|
||||
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertTrue(refusal)
|
||||
@@ -0,0 +1,13 @@
|
||||
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier
|
||||
|
||||
|
||||
class TestCheckRefusal:
|
||||
# The function correctly identifies a refusal phrase in the response.
|
||||
def test_identify_refusal_phrase(self):
|
||||
response = "I'm sorry, but I cannot provide that information."
|
||||
assert DefaultRefusalClassifier().is_refusal(response)
|
||||
|
||||
# The response is an empty string.
|
||||
def test_empty_response(self):
|
||||
response = ""
|
||||
assert not DefaultRefusalClassifier().is_refusal(response)
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Tests for unified dataset loader."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from agentic_security.probe_data.unified_loader import (
|
||||
InputSourceConfig,
|
||||
UnifiedDatasetLoader,
|
||||
)
|
||||
from agentic_security.probe_data.models import ProbeDataset
|
||||
|
||||
|
||||
class TestInputSourceConfig:
|
||||
"""Test InputSourceConfig validation."""
|
||||
|
||||
def test_csv_source_config(self):
|
||||
"""Test CSV source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="./test.csv",
|
||||
prompt_column="prompt",
|
||||
weight=1.5,
|
||||
)
|
||||
assert config.source_type == "csv"
|
||||
assert config.dataset_name == "test_csv"
|
||||
assert config.path == "./test.csv"
|
||||
assert config.weight == 1.5
|
||||
|
||||
def test_huggingface_source_config(self):
|
||||
"""Test HuggingFace source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="huggingface",
|
||||
dataset_name="test/dataset",
|
||||
split="train",
|
||||
max_samples=100,
|
||||
)
|
||||
assert config.source_type == "huggingface"
|
||||
assert config.split == "train"
|
||||
assert config.max_samples == 100
|
||||
|
||||
def test_proxy_source_config(self):
|
||||
"""Test proxy source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="proxy",
|
||||
dataset_name="proxy_test",
|
||||
)
|
||||
assert config.source_type == "proxy"
|
||||
assert config.enabled is True # Default value
|
||||
|
||||
def test_disabled_source(self):
|
||||
"""Test disabled source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="disabled_test",
|
||||
enabled=False,
|
||||
)
|
||||
assert config.enabled is False
|
||||
|
||||
def test_weight_validation(self):
|
||||
"""Test that weight must be non-negative."""
|
||||
with pytest.raises(ValueError):
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test",
|
||||
weight=-1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedDatasetLoader:
|
||||
"""Test UnifiedDatasetLoader functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_single_csv_source(self):
|
||||
"""Test loading a single CSV source."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="test.csv",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock the load_csv function
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test_csv",
|
||||
prompts=["prompt1", "prompt2", "prompt3"],
|
||||
tokens=10,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
assert len(result.prompts) == 3
|
||||
assert result.prompts == ["prompt1", "prompt2", "prompt3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_single_huggingface_source(self):
|
||||
"""Test loading a single HuggingFace source."""
|
||||
config = InputSourceConfig(
|
||||
source_type="huggingface",
|
||||
dataset_name="test/dataset",
|
||||
split="train",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock the load_dataset_generic function
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test/dataset",
|
||||
prompts=["hf_prompt1", "hf_prompt2"],
|
||||
tokens=8,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_dataset_generic",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
assert len(result.prompts) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_multiple_sources(self):
|
||||
"""Test merging multiple sources."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="csv1",
|
||||
path="test1.csv",
|
||||
weight=1.0,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="csv2",
|
||||
path="test2.csv",
|
||||
weight=2.0,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
# Mock datasets
|
||||
mock_dataset1 = ProbeDataset(
|
||||
dataset_name="csv1",
|
||||
prompts=["prompt1"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
mock_dataset2 = ProbeDataset(
|
||||
dataset_name="csv2",
|
||||
prompts=["prompt2", "prompt3"],
|
||||
tokens=10,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=[mock_dataset1, mock_dataset2],
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
# Weight 1.0 = include once, weight 2.0 = include twice
|
||||
# csv1: 1 prompt * 1 = 1
|
||||
# csv2: 2 prompts * 2 = 4
|
||||
assert len(result.prompts) == 5
|
||||
assert "csv1" in result.metadata["sources"]
|
||||
assert "csv2" in result.metadata["sources"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disabled_sources(self):
|
||||
"""Test that disabled sources are skipped."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="enabled_csv",
|
||||
path="enabled.csv",
|
||||
enabled=True,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="disabled_csv",
|
||||
path="disabled.csv",
|
||||
enabled=False,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="enabled_csv",
|
||||
prompts=["prompt1"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
) as mock_load:
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should only be called once (for enabled source)
|
||||
assert mock_load.call_count == 1
|
||||
assert len(result.prompts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_samples_limit(self):
|
||||
"""Test that max_samples limits the number of prompts."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="test.csv",
|
||||
max_samples=2,
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock dataset with more prompts than max_samples
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test_csv",
|
||||
prompts=["prompt1", "prompt2", "prompt3", "prompt4", "prompt5"],
|
||||
tokens=20,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should be limited to 2 prompts
|
||||
assert len(result.prompts) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(self):
|
||||
"""Test that errors are handled gracefully."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="error_csv",
|
||||
path="nonexistent.csv",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=Exception("File not found"),
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should return empty dataset on error
|
||||
assert result.dataset_name == "unified_empty"
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_source_placeholder(self):
|
||||
"""Test that proxy source returns empty dataset (not implemented in PoC)."""
|
||||
config = InputSourceConfig(
|
||||
source_type="proxy",
|
||||
dataset_name="proxy_test",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
result = await loader.load_all()
|
||||
|
||||
# Proxy not implemented in PoC, should return empty
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weighted_sampling(self):
|
||||
"""Test weighted sampling behavior."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="low_weight",
|
||||
path="low.csv",
|
||||
weight=1.0,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="high_weight",
|
||||
path="high.csv",
|
||||
weight=3.0,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
mock_dataset1 = ProbeDataset(
|
||||
dataset_name="low_weight",
|
||||
prompts=["a"],
|
||||
tokens=1,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
mock_dataset2 = ProbeDataset(
|
||||
dataset_name="high_weight",
|
||||
prompts=["b"],
|
||||
tokens=1,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=[mock_dataset1, mock_dataset2],
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Weight 1.0: 1 prompt * 1 = 1
|
||||
# Weight 3.0: 1 prompt * 3 = 3
|
||||
# Total: 4 prompts
|
||||
assert len(result.prompts) == 4
|
||||
assert result.prompts.count("a") == 1
|
||||
assert result.prompts.count("b") == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_configs_list(self):
|
||||
"""Test loading with empty configs list."""
|
||||
loader = UnifiedDatasetLoader([])
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified_empty"
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csv_with_url(self):
|
||||
"""Test CSV loading from URL."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="remote_csv",
|
||||
url="https://example.com/data.csv",
|
||||
prompt_column="text",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="remote_csv",
|
||||
prompts=["remote_prompt"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={"source_type": "csv", "url": "https://example.com/data.csv"},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_dataset_generic",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert len(result.prompts) == 1
|
||||
assert result.prompts[0] == "remote_prompt"
|
||||
@@ -0,0 +1,53 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_training_data():
|
||||
"""Create mock training data CSV content"""
|
||||
data = {
|
||||
"GPT4_response": ["I cannot help with that", "I must decline"],
|
||||
"ChatGPT_response": ["I won't assist with that", "That's not appropriate"],
|
||||
"Claude_response": ["I cannot comply", "That would be unethical"],
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def classifier():
|
||||
"""Create a RefusalClassifier instance with test paths"""
|
||||
return RefusalClassifier(
|
||||
model_path="test_model.joblib",
|
||||
vectorizer_path="test_vectorizer.joblib",
|
||||
scaler_path="test_scaler.joblib",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trained_classifier(classifier, mock_training_data):
|
||||
"""Create a trained classifier with mock data"""
|
||||
with patch("pandas.read_csv", return_value=mock_training_data):
|
||||
classifier.train(["mock_data.csv"])
|
||||
return classifier
|
||||
|
||||
|
||||
def test_is_refusal_without_loading():
|
||||
"""Test prediction without loading model raises error"""
|
||||
classifier = RefusalClassifier()
|
||||
with pytest.raises(ValueError, match="Model, vectorizer, or scaler not loaded"):
|
||||
classifier.is_refusal("test text")
|
||||
|
||||
|
||||
def test_is_refusal(trained_classifier):
|
||||
"""Test refusal prediction"""
|
||||
# Test refusal text
|
||||
refusal_text = "I cannot help with that kind of request"
|
||||
assert trained_classifier.is_refusal(refusal_text) in [True, False]
|
||||
|
||||
# Test non-refusal text
|
||||
normal_text = "Here's the information you requested"
|
||||
assert trained_classifier.is_refusal(normal_text) in [True, False]
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from agentic_security.cache_config import ensure_cache_dir
|
||||
|
||||
|
||||
def test_ensure_cache_dir_creates_dir_and_sets_env(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("DISK_CACHE_DIR", raising=False)
|
||||
target_dir = tmp_path / "cache_to_disk"
|
||||
|
||||
resolved = ensure_cache_dir(target_dir)
|
||||
|
||||
assert resolved == target_dir
|
||||
assert resolved.is_dir()
|
||||
assert Path(os.environ["DISK_CACHE_DIR"]) == resolved
|
||||
|
||||
|
||||
def test_ensure_cache_dir_respects_existing_env(tmp_path, monkeypatch):
|
||||
env_dir = tmp_path / "preconfigured"
|
||||
monkeypatch.setenv("DISK_CACHE_DIR", str(env_dir))
|
||||
|
||||
resolved = ensure_cache_dir()
|
||||
|
||||
assert resolved == env_dir
|
||||
assert resolved.exists()
|
||||
@@ -0,0 +1,15 @@
|
||||
from agentic_security.dependencies import InMemorySecrets, get_in_memory_secrets
|
||||
|
||||
|
||||
def test_in_memory_secrets():
|
||||
secrets = InMemorySecrets()
|
||||
secrets.set_secret("api_key", "12345")
|
||||
assert secrets.get_secret("api_key") == "12345"
|
||||
assert secrets.get_secret("non_existent_key") is None
|
||||
|
||||
|
||||
def test_get_in_memory_secrets():
|
||||
secrets = get_in_memory_secrets()
|
||||
assert isinstance(secrets, InMemorySecrets)
|
||||
secrets.set_secret("token", "abcde")
|
||||
assert secrets.get_secret("token") == "abcde"
|
||||
@@ -0,0 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from agentic_security.mcp.client import run
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_echo_tool():
|
||||
"""Test the echo tool functionality"""
|
||||
prompts, resources, tools = await run()
|
||||
assert prompts
|
||||
assert resources
|
||||
assert tools
|
||||
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
|
||||
from agentic_security.http_spec import (
|
||||
InvalidHTTPSpecError,
|
||||
LLMSpec,
|
||||
parse_http_spec,
|
||||
)
|
||||
|
||||
|
||||
class TestParseHttpSpec:
|
||||
# Should correctly parse a simple HTTP spec with headers and body
|
||||
def test_parse_simple_http_spec(self):
|
||||
http_spec = (
|
||||
'GET http://example.com\nContent-Type: application/json\n\n{"key": "value"}'
|
||||
)
|
||||
expected_spec = LLMSpec(
|
||||
method="GET",
|
||||
url="http://example.com",
|
||||
headers={"Content-Type": "application/json"},
|
||||
body='{"key": "value"}',
|
||||
)
|
||||
assert parse_http_spec(http_spec) == expected_spec
|
||||
|
||||
# Should correctly parse a HTTP spec with headers containing special characters
|
||||
def test_parse_http_spec_with_special_characters(self):
|
||||
http_spec = 'POST http://example.com\nX-Auth-Token: abcdefg1234567890!@#$%^&*\n\n{"key": "value"}'
|
||||
expected_spec = LLMSpec(
|
||||
method="POST",
|
||||
url="http://example.com",
|
||||
headers={"X-Auth-Token": "abcdefg1234567890!@#$%^&*"},
|
||||
body='{"key": "value"}',
|
||||
)
|
||||
assert parse_http_spec(http_spec) == expected_spec
|
||||
|
||||
# Should correctly parse a spec with no headers and no body
|
||||
def test_parse_http_spec_with_no_headers_and_no_body(self):
|
||||
# Arrange
|
||||
http_spec = "GET http://example.com"
|
||||
|
||||
# Act
|
||||
result = parse_http_spec(http_spec)
|
||||
|
||||
# Assert
|
||||
assert result.method == "GET"
|
||||
assert result.url == "http://example.com"
|
||||
assert result.headers == {}
|
||||
assert result.body == ""
|
||||
|
||||
def test_parse_http_spec_with_headers_no_body(self):
|
||||
# Arrange
|
||||
http_spec = "GET http://example.com\nContent-Type: application/json\n\n"
|
||||
|
||||
# Act
|
||||
result = parse_http_spec(http_spec)
|
||||
|
||||
# Assert
|
||||
assert result.method == "GET"
|
||||
assert result.url == "http://example.com"
|
||||
assert result.headers == {"Content-Type": "application/json"}
|
||||
assert result.body == ""
|
||||
|
||||
def test_parse_http_spec_rejects_malformed_header(self):
|
||||
http_spec = "GET http://example.com\nHeaderWithoutColon\n\n"
|
||||
|
||||
with pytest.raises(InvalidHTTPSpecError, match="Invalid header line"):
|
||||
parse_http_spec(http_spec)
|
||||
|
||||
def test_parse_http_spec_trims_header_whitespace(self):
|
||||
http_spec = "GET http://example.com\nAuthorization:Bearer token\n\n"
|
||||
|
||||
result = parse_http_spec(http_spec)
|
||||
|
||||
assert result.headers == {"Authorization": "Bearer token"}
|
||||
|
||||
|
||||
class TestLLMSpec:
|
||||
def test_validate_raises_error_for_missing_files(self):
|
||||
spec = LLMSpec(
|
||||
method="POST", url="http://example.com", headers={}, body="", has_files=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="Files are required for this request."):
|
||||
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})
|
||||
|
||||
def test_validate_raises_error_for_missing_image(self):
|
||||
spec = LLMSpec(
|
||||
method="POST", url="http://example.com", headers={}, body="", has_image=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="An image is required for this request."):
|
||||
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})
|
||||
Reference in New Issue
Block a user