From 48125bd1063ea3d698ca185954a8423085541cd9 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Wed, 24 Dec 2025 08:10:08 +0200 Subject: [PATCH] feat(add executor): --- .gitignore | 3 + agentic_security/executor/__init__.py | 12 + agentic_security/executor/circuit_breaker.py | 109 ++++++ agentic_security/executor/concurrent.py | 232 +++++++++++ agentic_security/executor/rate_limiter.py | 63 +++ agentic_security/probe_data/data.py | 44 +++ agentic_security/probe_data/unified_loader.py | 250 ++++++++++++ auto_loop.sh | 2 +- pyproject.toml | 3 +- tests/executor/__init__.py | 1 + tests/executor/test_circuit_breaker.py | 211 ++++++++++ tests/executor/test_concurrent.py | 276 ++++++++++++++ tests/executor/test_rate_limiter.py | 146 +++++++ tests/probe_data/test_unified_loader.py | 360 ++++++++++++++++++ 14 files changed, 1710 insertions(+), 2 deletions(-) create mode 100644 agentic_security/executor/__init__.py create mode 100644 agentic_security/executor/circuit_breaker.py create mode 100644 agentic_security/executor/concurrent.py create mode 100644 agentic_security/executor/rate_limiter.py create mode 100644 agentic_security/probe_data/unified_loader.py create mode 100644 tests/executor/__init__.py create mode 100644 tests/executor/test_circuit_breaker.py create mode 100644 tests/executor/test_concurrent.py create mode 100644 tests/executor/test_rate_limiter.py create mode 100644 tests/probe_data/test_unified_loader.py diff --git a/.gitignore b/.gitignore index 93ad0cc..4dfe1a2 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,6 @@ agentic_security.toml /venv *.csv agentic_security/agents/operator_agno.py +.claude/ +plan.md +auto_loop.sh diff --git a/agentic_security/executor/__init__.py b/agentic_security/executor/__init__.py new file mode 100644 index 0000000..0240cc7 --- /dev/null +++ b/agentic_security/executor/__init__.py @@ -0,0 +1,12 @@ +"""Advanced concurrent execution package for security scanning.""" + +from agentic_security.executor.rate_limiter import TokenBucketRateLimiter +from agentic_security.executor.circuit_breaker import CircuitBreaker +from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics + +__all__ = [ + "TokenBucketRateLimiter", + "CircuitBreaker", + "ConcurrentExecutor", + "ExecutorMetrics", +] diff --git a/agentic_security/executor/circuit_breaker.py b/agentic_security/executor/circuit_breaker.py new file mode 100644 index 0000000..fe8a0b4 --- /dev/null +++ b/agentic_security/executor/circuit_breaker.py @@ -0,0 +1,109 @@ +"""Circuit breaker pattern for fault tolerance.""" + +import time +from typing import Literal + + +CircuitState = Literal["closed", "open", "half_open"] + + +class CircuitBreaker: + """Circuit breaker to prevent cascading failures. + + Implements the circuit breaker pattern with three states: + - closed: Normal operation, requests pass through + - open: Failure threshold exceeded, requests fail fast + - half_open: Recovery attempt, limited requests allowed + + Example: + >>> breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30) + >>> if breaker.is_open(): + ... raise Exception("Circuit breaker is open") + >>> try: + ... result = make_request() + ... breaker.record_success() + >>> except Exception: + ... breaker.record_failure() + """ + + def __init__(self, failure_threshold: float = 0.5, recovery_timeout: int = 30): + """Initialize circuit breaker. + + Args: + failure_threshold: Failure rate (0.0-1.0) that triggers open state + recovery_timeout: Seconds to wait before attempting recovery + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.failures = 0 + self.successes = 0 + self.state: CircuitState = "closed" + self.last_failure_time: float | None = None + + def record_success(self): + """Record a successful request.""" + self.successes += 1 + + # If in half_open state and we have enough successes, close the circuit + if self.state == "half_open" and self.successes >= 3: + self.state = "closed" + self.failures = 0 + self.successes = 0 + + def record_failure(self): + """Record a failed request.""" + self.failures += 1 + self.last_failure_time = time.monotonic() + + total = self.failures + self.successes + + # Need minimum sample size before opening circuit + if total >= 10: + failure_rate = self.failures / total + if failure_rate >= self.failure_threshold: + self.state = "open" + + def is_open(self) -> bool: + """Check if circuit breaker is open. + + Returns: + bool: True if circuit is open and requests should be blocked + """ + if self.state == "open": + # Check if we should attempt recovery + if self.last_failure_time is not None: + if time.monotonic() - self.last_failure_time > self.recovery_timeout: + self.state = "half_open" + # Reset counters for half-open state + self.failures = 0 + self.successes = 0 + return False + return True + + return False + + def get_state(self) -> CircuitState: + """Get current circuit breaker state. + + Returns: + CircuitState: Current state (closed, open, or half_open) + """ + return self.state + + def get_failure_rate(self) -> float: + """Get current failure rate. + + Returns: + float: Failure rate (0.0-1.0), or 0.0 if no requests recorded + """ + total = self.failures + self.successes + if total == 0: + return 0.0 + return self.failures / total + + def reset(self): + """Reset circuit breaker to initial state.""" + self.failures = 0 + self.successes = 0 + self.state = "closed" + self.last_failure_time = None diff --git a/agentic_security/executor/concurrent.py b/agentic_security/executor/concurrent.py new file mode 100644 index 0000000..821b418 --- /dev/null +++ b/agentic_security/executor/concurrent.py @@ -0,0 +1,232 @@ +"""Concurrent executor with rate limiting and circuit breaking.""" + +import asyncio +import time +from typing import Any + +from agentic_security.executor.rate_limiter import TokenBucketRateLimiter +from agentic_security.executor.circuit_breaker import CircuitBreaker +from agentic_security.logutils import logger +from agentic_security.probe_actor.state import FuzzerState + + +class ExecutorMetrics: + """Track executor performance metrics.""" + + def __init__(self): + """Initialize metrics tracking.""" + self.successful_requests = 0 + self.failed_requests = 0 + self.total_latency = 0.0 + self.latencies: list[float] = [] + + def record_success(self, latency: float): + """Record a successful request. + + Args: + latency: Request latency in seconds + """ + self.successful_requests += 1 + self.total_latency += latency + self.latencies.append(latency) + + def record_failure(self): + """Record a failed request.""" + self.failed_requests += 1 + + def get_stats(self) -> dict[str, Any]: + """Get current statistics. + + Returns: + dict: Statistics including total requests, success rate, latency metrics + """ + total_requests = self.successful_requests + self.failed_requests + + if total_requests == 0: + return { + "total_requests": 0, + "success_rate": 0.0, + "avg_latency_ms": 0.0, + "p95_latency_ms": 0.0, + } + + success_rate = self.successful_requests / total_requests + avg_latency_ms = ( + (self.total_latency / self.successful_requests * 1000) + if self.successful_requests > 0 + else 0.0 + ) + + # Calculate p95 latency + if self.latencies: + sorted_latencies = sorted(self.latencies) + p95_index = int(len(sorted_latencies) * 0.95) + p95_latency_ms = sorted_latencies[p95_index] * 1000 if p95_index < len(sorted_latencies) else 0.0 + else: + p95_latency_ms = 0.0 + + return { + "total_requests": total_requests, + "successful_requests": self.successful_requests, + "failed_requests": self.failed_requests, + "success_rate": success_rate, + "avg_latency_ms": avg_latency_ms, + "p95_latency_ms": p95_latency_ms, + } + + +class ConcurrentExecutor: + """Enhanced concurrent executor with rate limiting and circuit breaking. + + Provides advanced concurrency control for security scanning with: + - Token bucket rate limiting + - Circuit breaker for fault tolerance + - Metrics collection + - Semaphore-based concurrency limits + + Example: + >>> executor = ConcurrentExecutor(max_concurrent=20, rate_limit=10, burst=5) + >>> tokens, failures = await executor.execute_batch( + ... request_factory, prompts, "module_name", fuzzer_state + ... ) + >>> print(executor.metrics.get_stats()) + """ + + def __init__( + self, + max_concurrent: int = 50, + rate_limit: float = 100, + burst: int = 20, + failure_threshold: float = 0.5, + recovery_timeout: int = 30, + ): + """Initialize concurrent executor. + + Args: + max_concurrent: Maximum number of concurrent requests + rate_limit: Requests per second limit + burst: Maximum burst size for rate limiter + failure_threshold: Failure rate that triggers circuit breaker + recovery_timeout: Seconds before attempting circuit recovery + """ + self.semaphore = asyncio.Semaphore(max_concurrent) + self.rate_limiter = TokenBucketRateLimiter(rate_limit, burst) + self.circuit_breaker = CircuitBreaker(failure_threshold, recovery_timeout) + self.metrics = ExecutorMetrics() + + logger.info( + f"ConcurrentExecutor initialized: max_concurrent={max_concurrent}, " + f"rate_limit={rate_limit}/s, burst={burst}" + ) + + async def execute_batch( + self, + request_factory, + prompts: list[str], + module_name: str, + fuzzer_state: FuzzerState, + ) -> tuple[int, int]: + """Execute a batch of prompts with rate limiting and circuit breaking. + + This is compatible with the existing process_prompt_batch signature. + + Args: + request_factory: Request factory with fn() method + prompts: List of prompts to process + module_name: Name of the module being scanned + fuzzer_state: State tracking object + + Returns: + tuple[int, int]: (total_tokens, failures) + """ + tasks = [ + self._execute_single(request_factory, prompt, module_name, fuzzer_state) + for prompt in prompts + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Aggregate results + total_tokens = 0 + failures = 0 + + for result in results: + if isinstance(result, Exception): + failures += 1 + logger.error(f"Task failed with exception: {result}") + else: + tokens, refused = result + total_tokens += tokens + if refused: + failures += 1 + + return total_tokens, failures + + async def _execute_single( + self, + request_factory, + prompt: str, + module_name: str, + fuzzer_state: FuzzerState, + ) -> tuple[int, bool]: + """Execute a single prompt with rate limiting and circuit breaking. + + Args: + request_factory: Request factory with fn() method + prompt: Prompt to process + module_name: Name of the module being scanned + fuzzer_state: State tracking object + + Returns: + tuple[int, bool]: (tokens, refused) + + Raises: + Exception: If circuit breaker is open + """ + # Rate limiting + await self.rate_limiter.acquire() + + # Circuit breaker check + if self.circuit_breaker.is_open(): + self.metrics.record_failure() + raise Exception("Circuit breaker is open - too many failures") + + # Concurrency control + async with self.semaphore: + start_time = time.monotonic() + + try: + # Import here to avoid circular dependency + from agentic_security.probe_actor.fuzzer import process_prompt + + tokens = 0 # Initial token count for this prompt + result = await process_prompt( + request_factory, prompt, tokens, module_name, fuzzer_state + ) + + # Record success + self.circuit_breaker.record_success() + latency = time.monotonic() - start_time + self.metrics.record_success(latency) + + return result + + except Exception as e: + # Record failure + self.circuit_breaker.record_failure() + self.metrics.record_failure() + logger.error(f"Error executing prompt: {e}") + raise + + def get_metrics(self) -> dict[str, Any]: + """Get current executor metrics. + + Returns: + dict: Metrics including request stats, latency, and circuit breaker state + """ + stats = self.metrics.get_stats() + stats["circuit_breaker_state"] = self.circuit_breaker.get_state() + stats["circuit_breaker_failure_rate"] = self.circuit_breaker.get_failure_rate() + stats["available_tokens"] = self.rate_limiter.get_available_tokens() + + return stats diff --git a/agentic_security/executor/rate_limiter.py b/agentic_security/executor/rate_limiter.py new file mode 100644 index 0000000..ec8e5fa --- /dev/null +++ b/agentic_security/executor/rate_limiter.py @@ -0,0 +1,63 @@ +"""Token bucket rate limiter for controlling request rate.""" + +import asyncio +import time + + +class TokenBucketRateLimiter: + """Token bucket rate limiter with configurable rate and burst capacity. + + This implements the token bucket algorithm where tokens are added at a fixed + rate and consumed for each request. Supports bursting up to the bucket capacity. + + Example: + >>> limiter = TokenBucketRateLimiter(rate=10, burst=20) + >>> await limiter.acquire() # Will wait if no tokens available + """ + + def __init__(self, rate: float, burst: int): + """Initialize rate limiter. + + Args: + rate: Tokens added per second (requests/sec) + burst: Maximum bucket capacity (max concurrent burst) + """ + self.rate = rate + self.burst = burst + self.tokens = float(burst) + self.last_update = time.monotonic() + self._lock = asyncio.Lock() + + async def acquire(self): + """Acquire a token, waiting if necessary. + + This method will block until a token is available. + """ + async with self._lock: + now = time.monotonic() + elapsed = now - self.last_update + + # Add tokens based on elapsed time + self.tokens = min(self.burst, self.tokens + elapsed * self.rate) + self.last_update = now + + if self.tokens >= 1: + # Token available, consume it + self.tokens -= 1 + return + + # Need to wait for next token + wait_time = (1 - self.tokens) / self.rate + await asyncio.sleep(wait_time) + self.tokens = 0 + self.last_update = time.monotonic() + + def get_available_tokens(self) -> float: + """Get current number of available tokens (non-blocking). + + Returns: + float: Number of tokens currently available + """ + now = time.monotonic() + elapsed = now - self.last_update + return min(self.burst, self.tokens + elapsed * self.rate) diff --git a/agentic_security/probe_data/data.py b/agentic_security/probe_data/data.py index 5d7351e..023151b 100644 --- a/agentic_security/probe_data/data.py +++ b/agentic_security/probe_data/data.py @@ -475,3 +475,47 @@ def prepare_prompts( datasets.append(load_csv(name)) return datasets + + +async def prepare_prompts_unified(configs: list) -> list[ProbeDataset]: + """Prepare datasets using unified loader configuration. + + This is an alternative to prepare_prompts() that uses the UnifiedDatasetLoader + for streamlined configuration and merging of multiple sources. + + Args: + configs: List of InputSourceConfig objects or dicts + + Returns: + list[ProbeDataset]: List containing the merged dataset + + Example: + >>> from agentic_security.probe_data.unified_loader import InputSourceConfig + >>> configs = [ + ... InputSourceConfig( + ... source_type="huggingface", + ... dataset_name="deepset/prompt-injections", + ... enabled=True, + ... weight=1.0 + ... ) + ... ] + >>> datasets = await prepare_prompts_unified(configs) + """ + from agentic_security.probe_data.unified_loader import ( + UnifiedDatasetLoader, + InputSourceConfig, + ) + + # Convert dicts to InputSourceConfig if needed + config_objects = [] + for config in configs: + if isinstance(config, dict): + config_objects.append(InputSourceConfig(**config)) + else: + config_objects.append(config) + + loader = UnifiedDatasetLoader(config_objects) + merged_dataset = await loader.load_all() + + # Return as list for compatibility with existing code + return [merged_dataset] if merged_dataset.prompts else [] diff --git a/agentic_security/probe_data/unified_loader.py b/agentic_security/probe_data/unified_loader.py new file mode 100644 index 0000000..1f077d3 --- /dev/null +++ b/agentic_security/probe_data/unified_loader.py @@ -0,0 +1,250 @@ +"""Unified dataset loader for CSV, HuggingFace, and proxy sources.""" + +from typing import Any, Literal, Optional +from pydantic import BaseModel, Field + +from agentic_security.logutils import logger +from agentic_security.probe_data.data import ( + load_dataset_generic, + load_csv, + create_probe_dataset, +) +from agentic_security.probe_data.models import ProbeDataset + + +class InputSourceConfig(BaseModel): + """Configuration for a single input source.""" + + source_type: Literal["csv", "huggingface", "proxy"] = Field( + description="Type of input source" + ) + enabled: bool = Field(default=True, description="Whether this source is enabled") + dataset_name: str = Field(description="Name/identifier of the dataset") + weight: float = Field(default=1.0, ge=0.0, description="Sampling weight for merging") + + # CSV-specific fields + path: Optional[str] = Field( + default=None, description="File path for CSV sources" + ) + prompt_column: Optional[str] = Field( + default="prompt", description="Column name containing prompts" + ) + + # HuggingFace-specific fields + split: Optional[str] = Field( + default="train", description="Dataset split to load (train/test/validation)" + ) + max_samples: Optional[int] = Field( + default=None, ge=1, description="Maximum number of samples to load" + ) + + # URL for custom sources + url: Optional[str] = Field( + default=None, description="URL for remote CSV files" + ) + + +class UnifiedDatasetLoader: + """Loads and merges datasets from multiple sources.""" + + def __init__(self, configs: list[InputSourceConfig]): + """Initialize with list of input source configurations. + + Args: + configs: List of InputSourceConfig objects defining data sources + """ + self.configs = configs + logger.info(f"Initialized UnifiedDatasetLoader with {len(configs)} sources") + + async def load_all(self) -> ProbeDataset: + """Load all enabled sources and merge into a single dataset. + + Returns: + ProbeDataset: Merged dataset from all enabled sources + """ + datasets = [] + + for config in self.configs: + if not config.enabled: + logger.debug(f"Skipping disabled source: {config.dataset_name}") + continue + + try: + dataset = await self._load_single(config) + if dataset and dataset.prompts: + datasets.append((dataset, config.weight)) + logger.info( + f"Loaded {len(dataset.prompts)} prompts from {config.dataset_name} " + f"(weight={config.weight})" + ) + else: + logger.warning(f"No prompts loaded from {config.dataset_name}") + except Exception as e: + logger.error(f"Error loading {config.dataset_name}: {e}") + + if not datasets: + logger.warning("No datasets loaded successfully") + return create_probe_dataset("unified_empty", [], {"sources": []}) + + return self._merge_weighted(datasets) + + async def _load_single(self, config: InputSourceConfig) -> ProbeDataset: + """Load a single dataset based on its configuration. + + Args: + config: Configuration for the source to load + + Returns: + ProbeDataset: Loaded dataset + """ + if config.source_type == "csv": + return self._load_csv_source(config) + elif config.source_type == "huggingface": + return self._load_huggingface_source(config) + elif config.source_type == "proxy": + return self._load_proxy_source(config) + else: + raise ValueError(f"Unknown source type: {config.source_type}") + + def _load_csv_source(self, config: InputSourceConfig) -> ProbeDataset: + """Load dataset from CSV file. + + Args: + config: CSV source configuration + + Returns: + ProbeDataset: Dataset loaded from CSV + """ + if config.path: + # Local CSV file + logger.info(f"Loading CSV from path: {config.path}") + dataset = load_csv(config.path) + elif config.url: + # Remote CSV file + logger.info(f"Loading CSV from URL: {config.url}") + mappings = {config.prompt_column: "prompt"} if config.prompt_column else None + dataset = load_dataset_generic( + name=config.dataset_name, + url=config.url, + mappings=mappings, + metadata={"source_type": "csv", "url": config.url} + ) + else: + raise ValueError(f"CSV source {config.dataset_name} requires either path or url") + + # Apply max_samples limit if specified + if config.max_samples and len(dataset.prompts) > config.max_samples: + logger.info( + f"Limiting {config.dataset_name} from {len(dataset.prompts)} " + f"to {config.max_samples} samples" + ) + dataset.prompts = dataset.prompts[:config.max_samples] + + return dataset + + def _load_huggingface_source(self, config: InputSourceConfig) -> ProbeDataset: + """Load dataset from HuggingFace. + + Args: + config: HuggingFace source configuration + + Returns: + ProbeDataset: Dataset loaded from HuggingFace + """ + logger.info( + f"Loading HuggingFace dataset: {config.dataset_name} " + f"(split={config.split})" + ) + + # Build column mappings + mappings = None + if config.prompt_column and config.prompt_column != "prompt": + mappings = {config.prompt_column: "prompt"} + + dataset = load_dataset_generic( + name=config.dataset_name, + mappings=mappings, + metadata={ + "source_type": "huggingface", + "split": config.split, + } + ) + + # Apply max_samples limit if specified + if config.max_samples and len(dataset.prompts) > config.max_samples: + logger.info( + f"Limiting {config.dataset_name} from {len(dataset.prompts)} " + f"to {config.max_samples} samples" + ) + dataset.prompts = dataset.prompts[:config.max_samples] + + return dataset + + def _load_proxy_source(self, config: InputSourceConfig) -> ProbeDataset: + """Load dataset from proxy queue (placeholder for PoC). + + Args: + config: Proxy source configuration + + Returns: + ProbeDataset: Empty dataset (proxy integration not implemented in PoC) + """ + logger.warning( + f"Proxy source {config.dataset_name} not implemented in PoC - returning empty dataset" + ) + return create_probe_dataset( + config.dataset_name, + [], + {"source_type": "proxy", "status": "not_implemented"} + ) + + def _merge_weighted( + self, datasets: list[tuple[ProbeDataset, float]] + ) -> ProbeDataset: + """Merge multiple datasets with weighted sampling. + + For PoC, this implements simple concatenation with optional weighting. + Production version would implement proper stratified sampling. + + Args: + datasets: List of (ProbeDataset, weight) tuples + + Returns: + ProbeDataset: Merged dataset + """ + if not datasets: + return create_probe_dataset("unified_empty", [], {"sources": []}) + + # For PoC: simple concatenation, repeat prompts based on weight + all_prompts = [] + source_names = [] + total_tokens = 0 + + for dataset, weight in datasets: + source_names.append(dataset.dataset_name) + + # Calculate how many times to include this dataset based on weight + # Weight of 1.0 = include once, 2.0 = include twice, etc. + repeat_count = max(1, int(weight)) + + for _ in range(repeat_count): + all_prompts.extend(dataset.prompts) + + total_tokens += dataset.tokens * repeat_count + + logger.info( + f"Merged {len(datasets)} datasets into {len(all_prompts)} total prompts " + f"from sources: {source_names}" + ) + + return ProbeDataset( + dataset_name="unified", + metadata={ + "sources": source_names, + "source_count": len(datasets), + "weights": {ds.dataset_name: w for ds, w in datasets}, + }, + prompts=all_prompts, + tokens=total_tokens, + approx_cost=0.0, + ) diff --git a/auto_loop.sh b/auto_loop.sh index 353a13d..ff215c5 100644 --- a/auto_loop.sh +++ b/auto_loop.sh @@ -3,7 +3,7 @@ set -euo pipefail PROMPT="Ultrathink. You're a principal engineer. Do not ask me any questions. We need to improve the quality of this codebase. Implement improvements to codebase quality." -MAX_ITERS=5 +MAX_ITERS=50 MAX_EMPTY_RUNS=5 CODEX_MODEL="gpt-5.1-codex-max" diff --git a/pyproject.toml b/pyproject.toml index fb1b28e..d4cf93a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,8 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] -addopts = "--durations=5 -m 'not slow' -n 3" +addopts = "-m 'not slow'" +# addopts = "--durations=5 -m 'not slow' -n 3" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" markers = "slow: marks tests as slow" diff --git a/tests/executor/__init__.py b/tests/executor/__init__.py new file mode 100644 index 0000000..8339d5d --- /dev/null +++ b/tests/executor/__init__.py @@ -0,0 +1 @@ +"""Tests for executor package.""" diff --git a/tests/executor/test_circuit_breaker.py b/tests/executor/test_circuit_breaker.py new file mode 100644 index 0000000..4cf4d5b --- /dev/null +++ b/tests/executor/test_circuit_breaker.py @@ -0,0 +1,211 @@ +"""Tests for CircuitBreaker.""" + +import pytest +import time +from unittest.mock import patch +from agentic_security.executor.circuit_breaker import CircuitBreaker + + +class TestCircuitBreaker: + """Test CircuitBreaker functionality.""" + + def test_initialization(self): + """Test circuit breaker initialization.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30) + + assert breaker.failure_threshold == 0.5 + assert breaker.recovery_timeout == 30 + assert breaker.state == "closed" + assert breaker.failures == 0 + assert breaker.successes == 0 + + def test_record_success(self): + """Test recording successful requests.""" + breaker = CircuitBreaker() + + breaker.record_success() + assert breaker.successes == 1 + assert breaker.failures == 0 + assert breaker.state == "closed" + + def test_record_failure(self): + """Test recording failed requests.""" + breaker = CircuitBreaker() + + breaker.record_failure() + assert breaker.failures == 1 + assert breaker.successes == 0 + assert breaker.last_failure_time is not None + + def test_circuit_opens_on_failure_threshold(self): + """Test that circuit opens when failure threshold is exceeded.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30) + + # Record 10 requests: 6 failures, 4 successes (60% failure rate) + for _ in range(4): + breaker.record_success() + for _ in range(6): + breaker.record_failure() + + # Circuit should be open (60% > 50% threshold) + assert breaker.state == "open" + assert breaker.is_open() is True + + def test_circuit_stays_closed_below_threshold(self): + """Test that circuit stays closed when below threshold.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30) + + # Record 10 requests: 4 failures, 6 successes (40% failure rate) + for _ in range(6): + breaker.record_success() + for _ in range(4): + breaker.record_failure() + + # Circuit should stay closed (40% < 50% threshold) + assert breaker.state == "closed" + assert breaker.is_open() is False + + def test_minimum_sample_size_required(self): + """Test that minimum sample size is required before opening.""" + breaker = CircuitBreaker(failure_threshold=0.5) + + # Only 5 failures (below minimum of 10 total requests) + for _ in range(5): + breaker.record_failure() + + # Circuit should stay closed (not enough samples) + assert breaker.state == "closed" + assert breaker.is_open() is False + + def test_circuit_recovery_after_timeout(self): + """Test that circuit enters half-open state after recovery timeout.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1) + + # Open the circuit + for _ in range(4): + breaker.record_success() + for _ in range(6): + breaker.record_failure() + + assert breaker.state == "open" + + # Wait for recovery timeout + time.sleep(1.1) + + # Check if circuit moves to half-open + is_open = breaker.is_open() + assert is_open is False + assert breaker.state == "half_open" + + def test_half_open_to_closed_on_successes(self): + """Test that circuit closes from half-open after enough successes.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1) + + # Open the circuit + for _ in range(4): + breaker.record_success() + for _ in range(6): + breaker.record_failure() + + # Wait for recovery + time.sleep(1.1) + breaker.is_open() # Triggers transition to half-open + + assert breaker.state == "half_open" + + # Record 3 successes + breaker.record_success() + breaker.record_success() + breaker.record_success() + + # Should transition to closed + assert breaker.state == "closed" + + def test_get_state(self): + """Test get_state method.""" + breaker = CircuitBreaker() + + assert breaker.get_state() == "closed" + + # Open the circuit + for _ in range(10): + breaker.record_failure() + + assert breaker.get_state() == "open" + + def test_get_failure_rate(self): + """Test get_failure_rate method.""" + breaker = CircuitBreaker() + + # No requests + assert breaker.get_failure_rate() == 0.0 + + # 3 failures, 7 successes (30% failure rate) + for _ in range(7): + breaker.record_success() + for _ in range(3): + breaker.record_failure() + + assert breaker.get_failure_rate() == 0.3 + + def test_reset(self): + """Test reset method.""" + breaker = CircuitBreaker() + + # Record some activity + breaker.record_success() + breaker.record_failure() + for _ in range(10): + breaker.record_failure() + + # Reset + breaker.reset() + + # Should be back to initial state + assert breaker.state == "closed" + assert breaker.failures == 0 + assert breaker.successes == 0 + assert breaker.last_failure_time is None + + def test_exact_failure_threshold(self): + """Test behavior at exact failure threshold.""" + breaker = CircuitBreaker(failure_threshold=0.5) + + # Exactly 50% failure rate (5 failures, 5 successes) + for _ in range(5): + breaker.record_success() + for _ in range(5): + breaker.record_failure() + + # Should be open (>= threshold) + assert breaker.state == "open" + + def test_high_failure_threshold(self): + """Test with high failure threshold.""" + breaker = CircuitBreaker(failure_threshold=0.9) + + # 80% failure rate (8 failures, 2 successes) + for _ in range(2): + breaker.record_success() + for _ in range(8): + breaker.record_failure() + + # Should stay closed (80% < 90%) + assert breaker.state == "closed" + + def test_zero_recovery_timeout(self): + """Test with zero recovery timeout.""" + breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=0) + + # Open the circuit + for _ in range(10): + breaker.record_failure() + + assert breaker.state == "open" + + # Should immediately allow recovery attempt + time.sleep(0.01) + is_open = breaker.is_open() + + assert is_open is False + assert breaker.state == "half_open" diff --git a/tests/executor/test_concurrent.py b/tests/executor/test_concurrent.py new file mode 100644 index 0000000..993e26a --- /dev/null +++ b/tests/executor/test_concurrent.py @@ -0,0 +1,276 @@ +"""Tests for ConcurrentExecutor.""" + +import pytest +import asyncio +from unittest.mock import Mock, patch, AsyncMock +from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics +from agentic_security.probe_actor.state import FuzzerState + + +class TestExecutorMetrics: + """Test ExecutorMetrics functionality.""" + + def test_initialization(self): + """Test metrics initialization.""" + metrics = ExecutorMetrics() + + assert metrics.successful_requests == 0 + assert metrics.failed_requests == 0 + assert metrics.total_latency == 0.0 + assert len(metrics.latencies) == 0 + + def test_record_success(self): + """Test recording successful requests.""" + metrics = ExecutorMetrics() + + metrics.record_success(0.5) + metrics.record_success(0.3) + + assert metrics.successful_requests == 2 + assert metrics.total_latency == 0.8 + assert len(metrics.latencies) == 2 + + def test_record_failure(self): + """Test recording failed requests.""" + metrics = ExecutorMetrics() + + metrics.record_failure() + metrics.record_failure() + + assert metrics.failed_requests == 2 + assert metrics.successful_requests == 0 + + def test_get_stats_no_requests(self): + """Test get_stats with no requests.""" + metrics = ExecutorMetrics() + + stats = metrics.get_stats() + + assert stats["total_requests"] == 0 + assert stats["success_rate"] == 0.0 + assert stats["avg_latency_ms"] == 0.0 + assert stats["p95_latency_ms"] == 0.0 + + def test_get_stats_with_requests(self): + """Test get_stats with recorded requests.""" + metrics = ExecutorMetrics() + + # Record some requests + metrics.record_success(0.1) # 100ms + metrics.record_success(0.2) # 200ms + metrics.record_success(0.3) # 300ms + metrics.record_failure() + + stats = metrics.get_stats() + + assert stats["total_requests"] == 4 + assert stats["successful_requests"] == 3 + assert stats["failed_requests"] == 1 + assert stats["success_rate"] == 0.75 + assert stats["avg_latency_ms"] == pytest.approx(200.0, rel=0.01) + + def test_get_stats_p95_latency(self): + """Test p95 latency calculation.""" + metrics = ExecutorMetrics() + + # Add 100 requests with varying latencies + for i in range(100): + metrics.record_success(i * 0.001) # 0ms to 99ms + + stats = metrics.get_stats() + + # p95 should be around 95ms + assert stats["p95_latency_ms"] >= 90.0 + assert stats["p95_latency_ms"] <= 100.0 + + +class TestConcurrentExecutor: + """Test ConcurrentExecutor functionality.""" + + def test_initialization(self): + """Test executor initialization.""" + executor = ConcurrentExecutor( + max_concurrent=20, + rate_limit=10, + burst=5, + failure_threshold=0.5, + recovery_timeout=30, + ) + + assert executor.semaphore._value == 20 + assert executor.rate_limiter.rate == 10 + assert executor.rate_limiter.burst == 5 + assert executor.circuit_breaker.failure_threshold == 0.5 + assert executor.circuit_breaker.recovery_timeout == 30 + + @pytest.mark.asyncio + async def test_execute_batch_success(self): + """Test successful batch execution.""" + executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10) + fuzzer_state = FuzzerState() + + # Mock request factory + request_factory = Mock() + + # Mock process_prompt to return success + async def mock_process_prompt(rf, prompt, tokens, module, state): + return (10, False) # 10 tokens, not refused + + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt, + ): + prompts = ["prompt1", "prompt2", "prompt3"] + tokens, failures = await executor.execute_batch( + request_factory, prompts, "test_module", fuzzer_state + ) + + assert tokens == 30 # 3 prompts * 10 tokens + assert failures == 0 + + @pytest.mark.asyncio + async def test_execute_batch_with_failures(self): + """Test batch execution with some failures.""" + executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10) + fuzzer_state = FuzzerState() + + request_factory = Mock() + + # Mock process_prompt to alternate success/failure + call_count = [0] + + async def mock_process_prompt(rf, prompt, tokens, module, state): + call_count[0] += 1 + if call_count[0] % 2 == 0: + return (10, True) # Refused + return (10, False) # Success + + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt, + ): + prompts = ["p1", "p2", "p3", "p4"] + tokens, failures = await executor.execute_batch( + request_factory, prompts, "test_module", fuzzer_state + ) + + assert tokens == 40 # 4 prompts * 10 tokens + assert failures == 2 # 2 refused + + @pytest.mark.asyncio + async def test_execute_batch_respects_concurrency_limit(self): + """Test that concurrency limit is respected.""" + executor = ConcurrentExecutor(max_concurrent=2, rate_limit=100, burst=10) + fuzzer_state = FuzzerState() + + request_factory = Mock() + + # Track concurrent executions + concurrent_count = [0] + max_concurrent = [0] + + async def mock_process_prompt(rf, prompt, tokens, module, state): + concurrent_count[0] += 1 + max_concurrent[0] = max(max_concurrent[0], concurrent_count[0]) + await asyncio.sleep(0.01) # Simulate work + concurrent_count[0] -= 1 + return (10, False) + + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt, + ): + prompts = ["p1", "p2", "p3", "p4", "p5"] + await executor.execute_batch( + request_factory, prompts, "test_module", fuzzer_state + ) + + # Max concurrent should not exceed limit + assert max_concurrent[0] <= 2 + + @pytest.mark.asyncio + async def test_circuit_breaker_integration(self): + """Test that circuit breaker opens on failures.""" + executor = ConcurrentExecutor( + max_concurrent=10, + rate_limit=100, + burst=20, + failure_threshold=0.5, + recovery_timeout=1, + ) + fuzzer_state = FuzzerState() + request_factory = Mock() + + # Mock process_prompt to always fail + async def mock_process_prompt_fail(rf, prompt, tokens, module, state): + raise Exception("Request failed") + + # First batch - all failures + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt_fail, + ): + prompts = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10"] + tokens, failures = await executor.execute_batch( + request_factory, prompts, "test_module", fuzzer_state + ) + + # All should have failed + assert failures == 10 + + # Circuit should be open now + assert executor.circuit_breaker.state == "open" + + @pytest.mark.asyncio + async def test_get_metrics(self): + """Test getting executor metrics.""" + executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10) + fuzzer_state = FuzzerState() + request_factory = Mock() + + async def mock_process_prompt(rf, prompt, tokens, module, state): + return (10, False) + + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt, + ): + await executor.execute_batch( + request_factory, ["p1", "p2"], "test_module", fuzzer_state + ) + + metrics = executor.get_metrics() + + assert "total_requests" in metrics + assert "success_rate" in metrics + assert "circuit_breaker_state" in metrics + assert "available_tokens" in metrics + assert metrics["total_requests"] == 2 + assert metrics["circuit_breaker_state"] == "closed" + + @pytest.mark.asyncio + async def test_rate_limiting_applied(self): + """Test that rate limiting is applied.""" + executor = ConcurrentExecutor(max_concurrent=10, rate_limit=5, burst=2) + fuzzer_state = FuzzerState() + request_factory = Mock() + + async def mock_process_prompt(rf, prompt, tokens, module, state): + return (10, False) + + import time + + with patch( + "agentic_security.probe_actor.fuzzer.process_prompt", + side_effect=mock_process_prompt, + ): + start = time.monotonic() + # 5 requests with rate=5/s and burst=2 + # First 2 immediate, next 3 should take ~0.6s total + await executor.execute_batch( + request_factory, ["p1", "p2", "p3", "p4", "p5"], "test_module", fuzzer_state + ) + elapsed = time.monotonic() - start + + # Should take at least 0.5s (3 requests / 5 per second) + assert elapsed >= 0.4 diff --git a/tests/executor/test_rate_limiter.py b/tests/executor/test_rate_limiter.py new file mode 100644 index 0000000..0a36607 --- /dev/null +++ b/tests/executor/test_rate_limiter.py @@ -0,0 +1,146 @@ +"""Tests for TokenBucketRateLimiter.""" + +import asyncio +import pytest +import time +from unittest.mock import patch +from agentic_security.executor.rate_limiter import TokenBucketRateLimiter + + +class TestTokenBucketRateLimiter: + """Test TokenBucketRateLimiter functionality.""" + + @pytest.mark.asyncio + async def test_initialization(self): + """Test rate limiter initialization.""" + limiter = TokenBucketRateLimiter(rate=10, burst=20) + + assert limiter.rate == 10 + assert limiter.burst == 20 + assert limiter.tokens == 20 # Starts full + + @pytest.mark.asyncio + async def test_acquire_with_available_tokens(self): + """Test acquiring tokens when they're available.""" + limiter = TokenBucketRateLimiter(rate=10, burst=5) + + start = time.monotonic() + await limiter.acquire() + elapsed = time.monotonic() - start + + # Should return immediately + assert elapsed < 0.1 + assert limiter.tokens < 5 # One token consumed + + @pytest.mark.asyncio + async def test_acquire_waits_when_no_tokens(self): + """Test that acquire waits when no tokens available.""" + limiter = TokenBucketRateLimiter(rate=10, burst=1) + + # Consume the initial token + await limiter.acquire() + + # Next acquire should wait + start = time.monotonic() + await limiter.acquire() + elapsed = time.monotonic() - start + + # Should wait approximately 1/rate seconds (0.1s for rate=10) + assert elapsed >= 0.08 # Allow some tolerance + + @pytest.mark.asyncio + async def test_rate_limiting(self): + """Test that rate limiting actually limits request rate.""" + limiter = TokenBucketRateLimiter(rate=10, burst=2) + + # Make 5 requests + start = time.monotonic() + for _ in range(5): + await limiter.acquire() + elapsed = time.monotonic() - start + + # With rate=10/s and burst=2: + # - First 2 requests are immediate (burst) + # - Next 3 requests require waiting: 3 * (1/10) = 0.3s + # Total should be around 0.3s + assert elapsed >= 0.25 # Allow some tolerance + assert elapsed < 0.5 + + @pytest.mark.asyncio + async def test_burst_capacity(self): + """Test that burst capacity allows immediate requests.""" + limiter = TokenBucketRateLimiter(rate=5, burst=10) + + # Make burst number of requests immediately + start = time.monotonic() + for _ in range(10): + await limiter.acquire() + elapsed = time.monotonic() - start + + # All 10 requests should be nearly immediate (using burst capacity) + assert elapsed < 0.2 + + @pytest.mark.asyncio + async def test_token_replenishment(self): + """Test that tokens are replenished over time.""" + limiter = TokenBucketRateLimiter(rate=10, burst=5) + + # Consume all tokens + for _ in range(5): + await limiter.acquire() + + assert limiter.tokens < 1 + + # Wait for tokens to replenish + await asyncio.sleep(0.3) # Should add 3 tokens at rate=10 + + # Should have tokens again (approximately 3) + available = limiter.get_available_tokens() + assert available >= 2.5 + assert available <= 3.5 + + @pytest.mark.asyncio + async def test_get_available_tokens(self): + """Test get_available_tokens method.""" + limiter = TokenBucketRateLimiter(rate=10, burst=5) + + # Initially full + assert limiter.get_available_tokens() == 5 + + # After consuming one + await limiter.acquire() + assert limiter.get_available_tokens() < 5 + + @pytest.mark.asyncio + async def test_concurrent_requests(self): + """Test rate limiter with concurrent requests.""" + limiter = TokenBucketRateLimiter(rate=10, burst=3) + + async def make_request(limiter): + await limiter.acquire() + return time.monotonic() + + # Make 5 concurrent requests + start = time.monotonic() + tasks = [make_request(limiter) for _ in range(5)] + timestamps = await asyncio.gather(*tasks) + total_elapsed = time.monotonic() - start + + # First 3 should be immediate (burst=3) + # Next 2 should wait + # Total time should be around 0.2s (2 * 1/10) + assert total_elapsed >= 0.15 + assert total_elapsed < 0.4 + + @pytest.mark.asyncio + async def test_max_burst_capacity(self): + """Test that tokens don't exceed burst capacity.""" + limiter = TokenBucketRateLimiter(rate=100, burst=5) + + # Wait longer than needed to fill + await asyncio.sleep(0.2) # Would add 20 tokens, but capped at 5 + + # Check tokens don't exceed burst + available = limiter.get_available_tokens() + assert available <= 5 + assert available >= 4.5 # Close to full diff --git a/tests/probe_data/test_unified_loader.py b/tests/probe_data/test_unified_loader.py new file mode 100644 index 0000000..8d19ce9 --- /dev/null +++ b/tests/probe_data/test_unified_loader.py @@ -0,0 +1,360 @@ +"""Tests for unified dataset loader.""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from agentic_security.probe_data.unified_loader import ( + InputSourceConfig, + UnifiedDatasetLoader, +) +from agentic_security.probe_data.models import ProbeDataset + + +class TestInputSourceConfig: + """Test InputSourceConfig validation.""" + + def test_csv_source_config(self): + """Test CSV source configuration.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="test_csv", + path="./test.csv", + prompt_column="prompt", + weight=1.5, + ) + assert config.source_type == "csv" + assert config.dataset_name == "test_csv" + assert config.path == "./test.csv" + assert config.weight == 1.5 + + def test_huggingface_source_config(self): + """Test HuggingFace source configuration.""" + config = InputSourceConfig( + source_type="huggingface", + dataset_name="test/dataset", + split="train", + max_samples=100, + ) + assert config.source_type == "huggingface" + assert config.split == "train" + assert config.max_samples == 100 + + def test_proxy_source_config(self): + """Test proxy source configuration.""" + config = InputSourceConfig( + source_type="proxy", + dataset_name="proxy_test", + ) + assert config.source_type == "proxy" + assert config.enabled is True # Default value + + def test_disabled_source(self): + """Test disabled source configuration.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="disabled_test", + enabled=False, + ) + assert config.enabled is False + + def test_weight_validation(self): + """Test that weight must be non-negative.""" + with pytest.raises(ValueError): + InputSourceConfig( + source_type="csv", + dataset_name="test", + weight=-1.0, + ) + + +class TestUnifiedDatasetLoader: + """Test UnifiedDatasetLoader functionality.""" + + @pytest.mark.asyncio + async def test_load_single_csv_source(self): + """Test loading a single CSV source.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="test_csv", + path="test.csv", + ) + loader = UnifiedDatasetLoader([config]) + + # Mock the load_csv function + mock_dataset = ProbeDataset( + dataset_name="test_csv", + prompts=["prompt1", "prompt2", "prompt3"], + tokens=10, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + return_value=mock_dataset + ): + result = await loader.load_all() + + assert result.dataset_name == "unified" + assert len(result.prompts) == 3 + assert result.prompts == ["prompt1", "prompt2", "prompt3"] + + @pytest.mark.asyncio + async def test_load_single_huggingface_source(self): + """Test loading a single HuggingFace source.""" + config = InputSourceConfig( + source_type="huggingface", + dataset_name="test/dataset", + split="train", + ) + loader = UnifiedDatasetLoader([config]) + + # Mock the load_dataset_generic function + mock_dataset = ProbeDataset( + dataset_name="test/dataset", + prompts=["hf_prompt1", "hf_prompt2"], + tokens=8, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_dataset_generic", + return_value=mock_dataset + ): + result = await loader.load_all() + + assert result.dataset_name == "unified" + assert len(result.prompts) == 2 + + @pytest.mark.asyncio + async def test_merge_multiple_sources(self): + """Test merging multiple sources.""" + configs = [ + InputSourceConfig( + source_type="csv", + dataset_name="csv1", + path="test1.csv", + weight=1.0, + ), + InputSourceConfig( + source_type="csv", + dataset_name="csv2", + path="test2.csv", + weight=2.0, + ), + ] + loader = UnifiedDatasetLoader(configs) + + # Mock datasets + mock_dataset1 = ProbeDataset( + dataset_name="csv1", + prompts=["prompt1"], + tokens=5, + approx_cost=0.0, + metadata={} + ) + mock_dataset2 = ProbeDataset( + dataset_name="csv2", + prompts=["prompt2", "prompt3"], + tokens=10, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + side_effect=[mock_dataset1, mock_dataset2] + ): + result = await loader.load_all() + + assert result.dataset_name == "unified" + # Weight 1.0 = include once, weight 2.0 = include twice + # csv1: 1 prompt * 1 = 1 + # csv2: 2 prompts * 2 = 4 + assert len(result.prompts) == 5 + assert "csv1" in result.metadata["sources"] + assert "csv2" in result.metadata["sources"] + + @pytest.mark.asyncio + async def test_handle_disabled_sources(self): + """Test that disabled sources are skipped.""" + configs = [ + InputSourceConfig( + source_type="csv", + dataset_name="enabled_csv", + path="enabled.csv", + enabled=True, + ), + InputSourceConfig( + source_type="csv", + dataset_name="disabled_csv", + path="disabled.csv", + enabled=False, + ), + ] + loader = UnifiedDatasetLoader(configs) + + mock_dataset = ProbeDataset( + dataset_name="enabled_csv", + prompts=["prompt1"], + tokens=5, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + return_value=mock_dataset + ) as mock_load: + result = await loader.load_all() + + # Should only be called once (for enabled source) + assert mock_load.call_count == 1 + assert len(result.prompts) == 1 + + @pytest.mark.asyncio + async def test_max_samples_limit(self): + """Test that max_samples limits the number of prompts.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="test_csv", + path="test.csv", + max_samples=2, + ) + loader = UnifiedDatasetLoader([config]) + + # Mock dataset with more prompts than max_samples + mock_dataset = ProbeDataset( + dataset_name="test_csv", + prompts=["prompt1", "prompt2", "prompt3", "prompt4", "prompt5"], + tokens=20, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + return_value=mock_dataset + ): + result = await loader.load_all() + + # Should be limited to 2 prompts + assert len(result.prompts) == 2 + + @pytest.mark.asyncio + async def test_error_handling(self): + """Test that errors are handled gracefully.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="error_csv", + path="nonexistent.csv", + ) + loader = UnifiedDatasetLoader([config]) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + side_effect=Exception("File not found") + ): + result = await loader.load_all() + + # Should return empty dataset on error + assert result.dataset_name == "unified_empty" + assert len(result.prompts) == 0 + + @pytest.mark.asyncio + async def test_proxy_source_placeholder(self): + """Test that proxy source returns empty dataset (not implemented in PoC).""" + config = InputSourceConfig( + source_type="proxy", + dataset_name="proxy_test", + ) + loader = UnifiedDatasetLoader([config]) + + result = await loader.load_all() + + # Proxy not implemented in PoC, should return empty + assert len(result.prompts) == 0 + + @pytest.mark.asyncio + async def test_weighted_sampling(self): + """Test weighted sampling behavior.""" + configs = [ + InputSourceConfig( + source_type="csv", + dataset_name="low_weight", + path="low.csv", + weight=1.0, + ), + InputSourceConfig( + source_type="csv", + dataset_name="high_weight", + path="high.csv", + weight=3.0, + ), + ] + loader = UnifiedDatasetLoader(configs) + + mock_dataset1 = ProbeDataset( + dataset_name="low_weight", + prompts=["a"], + tokens=1, + approx_cost=0.0, + metadata={} + ) + mock_dataset2 = ProbeDataset( + dataset_name="high_weight", + prompts=["b"], + tokens=1, + approx_cost=0.0, + metadata={} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_csv", + side_effect=[mock_dataset1, mock_dataset2] + ): + result = await loader.load_all() + + # Weight 1.0: 1 prompt * 1 = 1 + # Weight 3.0: 1 prompt * 3 = 3 + # Total: 4 prompts + assert len(result.prompts) == 4 + assert result.prompts.count("a") == 1 + assert result.prompts.count("b") == 3 + + @pytest.mark.asyncio + async def test_empty_configs_list(self): + """Test loading with empty configs list.""" + loader = UnifiedDatasetLoader([]) + result = await loader.load_all() + + assert result.dataset_name == "unified_empty" + assert len(result.prompts) == 0 + + @pytest.mark.asyncio + async def test_csv_with_url(self): + """Test CSV loading from URL.""" + config = InputSourceConfig( + source_type="csv", + dataset_name="remote_csv", + url="https://example.com/data.csv", + prompt_column="text", + ) + loader = UnifiedDatasetLoader([config]) + + mock_dataset = ProbeDataset( + dataset_name="remote_csv", + prompts=["remote_prompt"], + tokens=5, + approx_cost=0.0, + metadata={"source_type": "csv", "url": "https://example.com/data.csv"} + ) + + with patch( + "agentic_security.probe_data.unified_loader.load_dataset_generic", + return_value=mock_dataset + ): + result = await loader.load_all() + + assert len(result.prompts) == 1 + assert result.prompts[0] == "remote_prompt"