Compare commits

..

3 Commits

Author SHA1 Message Date
codebeaver-ai[bot] e752ebaeeb Adding codebeaver.yml 2025-03-03 18:41:09 +00:00
codebeaver-ai[bot] 2549194bd1 test: Add coverage improvement test for tests/test_http_spec.py 2025-03-03 18:41:08 +00:00
codebeaver-ai[bot] 4c580ea1b8 test: Add coverage improvement test for tests/test_app.py 2025-03-03 18:41:06 +00:00
91 changed files with 3871 additions and 7106 deletions
+2 -8
View File
@@ -1,9 +1,5 @@
name: Pre-Commit Checks
env:
POETRY_VERSION: "1.8.5"
on:
push:
branches: [main]
@@ -19,9 +15,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install poetry
run: pipx install poetry==$POETRY_VERSION
- name: Install pre-commit
run: poetry install
run: pip install pre-commit
- name: Run pre-commit
run: poetry run pre-commit run --all-files
run: pre-commit run --all-files
+1 -1
View File
@@ -9,7 +9,7 @@ on:
- 0.*
env:
POETRY_VERSION: "1.8.5"
POETRY_VERSION: "1.7.1"
jobs:
if_release:
+37
View File
@@ -0,0 +1,37 @@
name: Security Scan
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
schedule:
- cron: '0 0 * * 1' # Run weekly on Mondays
workflow_dispatch: # Allow manual trigger
jobs:
security_scan:
runs-on: ubuntu-latest
env:
API_KEY: PLACEHOLDER
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install agentic-security colorama tabulate tqdm python-multipart
- name: Run security scan
id: scan
run: |
agentic_security init
# agentic_security ci
+14
View File
@@ -0,0 +1,14 @@
name: PyCharm Python Security Scanner
on:
schedule:
- cron: "0 0 * * *"
jobs:
security_checks:
runs-on: ubuntu-latest
name: Execute the pycharm-security action
steps:
- uses: actions/checkout@v1
- name: PyCharm Python Security Scanner
uses: tonybaloney/pycharm-security@1.19.0
+1 -1
View File
@@ -7,7 +7,7 @@ on:
branches: [main]
env:
POETRY_VERSION: "1.8.5"
POETRY_VERSION: "1.7.1"
OPENAI_API_KEY: "sk-fake"
jobs:
-6
View File
@@ -17,9 +17,3 @@ inv/
scripts/
docx/
agentic_security.toml
/venv
*.csv
agentic_security/agents/operator_agno.py
.claude/
plan.md
auto_loop.sh
+6 -7
View File
@@ -9,7 +9,7 @@ repos:
args: [--py311-plus]
- repo: https://github.com/psf/black
rev: 25.11.0
rev: 23.11.0
hooks:
- id: black
language_version: python3.11
@@ -20,13 +20,12 @@ repos:
- id: flake8
language_version: python3.11
additional_dependencies: [flake8-docstrings]
exclude: '^(tests)/'
# - repo: https://github.com/PyCQA/isort
# rev: 7.0.0
# hooks:
# - id: isort
# args: [--profile, black]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: [--profile, black]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
+1 -8
View File
@@ -1,5 +1,5 @@
# Build stage
FROM python:3.11-slim AS builder
FROM python:3.11-slim as builder
WORKDIR /app
@@ -14,15 +14,8 @@ RUN poetry self add "poetry-plugin-export"
# Copy only dependency files to leverage Docker layer caching
COPY pyproject.toml poetry.lock ./
# update lock file to avoid failure
RUN poetry lock
# Install dependencies
RUN poetry export -f requirements.txt --without-hashes -o requirements.txt
# Install wheel (required to build packages like fire)
RUN pip install --upgrade pip setuptools wheel
RUN pip install --no-cache-dir -r requirements.txt
# Runtime stage
+8 -27
View File
@@ -21,7 +21,9 @@
<a href="https://pypi.org/project/agentic-security/">
<img alt="PyPI Version" src="https://img.shields.io/pypi/v/agentic-security?style=for-the-badge&logo=pypi&labelColor=000000&color=00CCFF" />
</a>
<a href="https://discord.gg/stw3DfZQ">
<img alt="Join Discord" src="https://img.shields.io/badge/Discord-Join%20Us-black?style=for-the-badge&logo=discord&labelColor=000000&color=DD55FF" />
</a>
</p>
@@ -81,7 +83,7 @@ agentic_security --port=PORT --host=HOST
## UI 🧙
<img width="100%" alt="booking-screen" src="https://raw.githubusercontent.com/msoedov/agentic_security/refs/heads/main/docs/images/demo.gif">
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
## LLM kwargs
@@ -400,31 +402,16 @@ This setup ensures a continuous integration approach towards maintaining securit
The `Module` class is designed to manage prompt processing and interaction with external AI models and tools. It supports fetching, processing, and posting prompts asynchronously for model vulnerabilities. Check out [module.md](https://github.com/msoedov/agentic_security/blob/main/docs/module.md) for details.
## MCP server
```shell
pip install -U mcp
# From cloned directory
mcp install agentic_security/mcp/main.py
```
## Documentation
For more detailed information on how to use Agentic Security, including advanced features and customization options, please refer to the official documentation.
## Roadmap and Future Goals
Were just getting started! Heres whats on the horizon:
- **RL-Powered Attacks**: An attacker LLM trained with reinforcement learning to dynamically evolve jailbreaks and outsmart defenses.
- **Massive Dataset Expansion**: Scaling to 100,000+ prompts across text, image, and audio modalities—curated for real-world threats.
- **Daily Attack Updates**: Fresh attack vectors delivered daily, keeping your scans ahead of the curve.
- **Community Modules**: A plug-and-play ecosystem where you can share and deploy custom probes, datasets, and integrations.
- \[ \] Expand dataset variety
- \[ \] Introduce two new attack vectors
- \[ \] Develop initial attacker LLM
- \[ \] Complete integration of OWASP Top 10 classification
| Tool | Source | Integrated |
|-------------------------|-------------------------------------------------------------------------------|------------|
@@ -436,7 +423,6 @@ Were just getting started! Heres whats on the horizon:
Note: All dates are tentative and subject to change based on project progress and priorities.
## 👋 Contributing
Contributions to Agentic Security are welcome! If you'd like to contribute, please follow these steps:
@@ -453,9 +439,4 @@ Before contributing, please read the contributing guidelines.
Agentic Security is released under the Apache License v2.
## 🚫 No Cryptocurrency Affiliation
Agentic Security is focused solely on AI security and has no affiliation with cryptocurrency projects, blockchain technologies, or related initiatives. Our mission is to advance the safety and reliability of AI systems—no tokens, no coins, just code.
## Contact us
+2 -6
View File
@@ -1,7 +1,3 @@
from agentic_security.cache_config import ensure_cache_dir
from .lib import AgenticSecurity
ensure_cache_dir()
from .lib import SecurityScanner # noqa: E402
__all__ = ["SecurityScanner", "ensure_cache_dir"]
__all__ = ["AgenticSecurity"]
+4 -4
View File
@@ -5,7 +5,7 @@ import fire
import uvicorn
from agentic_security.app import app
from agentic_security.lib import SecurityScanner
from agentic_security.lib import AgenticSecurity
from agentic_security.misc.banner import init_banner
@@ -32,14 +32,14 @@ class CLI:
Run Agentic Security in CI mode.
"""
sys.path.append(os.path.dirname("."))
SecurityScanner().entrypoint()
AgenticSecurity().entrypoint()
def init(self, host: str = "0.0.0.0", port: int = 8718):
"""
Generate the default CI configuration file.
"""
sys.path.append(os.path.dirname("."))
SecurityScanner().generate_default_settings(host, port)
AgenticSecurity().generate_default_settings(host, port)
i = init
@@ -48,7 +48,7 @@ class CLI:
List all available security checks.
"""
sys.path.append(os.path.dirname("."))
SecurityScanner().list_checks()
AgenticSecurity().list_checks()
def main():
+6 -4
View File
@@ -1,4 +1,5 @@
import asyncio
import logging
import os
from typing import Any
@@ -9,11 +10,12 @@ from pydantic import BaseModel, ConfigDict, Field
# Assuming LLMSpec is defined elsewhere; placeholder import
from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = [] # Populate with LLM spec strings if needed
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define AgentSpecification model
@@ -246,9 +248,9 @@ async def run_crew():
os.environ["OPENAI_API_KEY"] = os.environ.get(
"DEEPSEEK_API_KEY", ""
) # CrewAI uses OPENAI_API_KEY
os.environ["OPENAI_MODEL_NAME"] = (
"deepseek:chat" # Specify DeepSeek model (adjust if needed)
)
os.environ[
"OPENAI_MODEL_NAME"
] = "deepseek:chat" # Specify DeepSeek model (adjust if needed)
if __name__ == "__main__":
asyncio.run(run_crew())
+5 -1
View File
@@ -1,4 +1,5 @@
import asyncio
import logging
from typing import Any
import httpx
@@ -7,10 +8,13 @@ from pydantic_ai import Agent, RunContext, Tool
# Assuming LLMSpec is defined elsewhere; placeholder import
from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = [] # Populate this list with LLM spec strings if needed
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define AgentSpecification model
class AgentSpecification(BaseModel):
-23
View File
@@ -1,23 +0,0 @@
"""Utilities to keep cache-to-disk storage in a writable, predictable location."""
from __future__ import annotations
import os
from pathlib import Path
def ensure_cache_dir(base_dir: Path | None = None) -> Path:
"""Ensure ``DISK_CACHE_DIR`` points to a writable directory and create it if needed."""
env_var = "DISK_CACHE_DIR"
configured_path = os.environ.get(env_var) or os.environ.get(
"AGENTIC_SECURITY_CACHE_DIR"
)
cache_dir = Path(
configured_path or base_dir or Path.cwd() / ".cache" / "agentic_security"
).expanduser()
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ[env_var] = str(cache_dir)
return cache_dir
__all__ = ["ensure_cache_dir"]
+2 -10
View File
@@ -1,10 +1,9 @@
from functools import lru_cache
import tomli
from loguru import logger
from agentic_security.logutils import logger
SETTINGS_VERSION = 2
SETTINGS_VERSION = 1
@lru_cache(maxsize=1)
@@ -143,13 +142,6 @@ use_disk_cache = false
retry = 3
timeout_connect = 30
timeout_response = 90
[fuzzer]
max_prompt_lenght = 2048
budget_multiplier = 100000000
initial_optimizer_points = 25
min_failure_samples = 5
failure_rate_threshold = 0.5
""".replace(
"$HOST", host
)
+9 -22
View File
@@ -1,22 +1,13 @@
import os
from asyncio import Event, Queue
from typing import TypedDict
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
from agentic_security.http_spec import LLMSpec
class CurrentRun(TypedDict):
id: int | None
spec: LLMSpec | None
tools_inbox: Queue = Queue()
stop_event: Event = Event()
current_run: CurrentRun = {"spec": None, "id": None}
_secrets: dict[str, str] = {}
current_run: str = {"spec": "", "id": ""}
_secrets = {}
def create_app() -> FastAPI:
@@ -35,34 +26,30 @@ def get_stop_event() -> Event:
return stop_event
def get_current_run() -> CurrentRun:
def get_current_run() -> str:
"""Get the current run id."""
return current_run
def set_current_run(spec: LLMSpec) -> CurrentRun:
"""Set the current run metadata based on a spec instance."""
def set_current_run(spec):
"""Set the current run id."""
current_run["id"] = hash(id(spec))
current_run["spec"] = spec
return current_run
def get_secrets() -> dict[str, str]:
def get_secrets():
return _secrets
def set_secrets(secrets: dict[str, str]) -> dict[str, str]:
def set_secrets(secrets):
_secrets.update(secrets)
expand_secrets(_secrets)
return _secrets
def expand_secrets(secrets: dict[str, str]) -> None:
def expand_secrets(secrets):
for key in secrets:
val = secrets[key]
if val.startswith("$"):
env_value = os.getenv(val.strip("$"))
if env_value is not None:
secrets[key] = env_value
else:
secrets[key] = None
secrets[key] = os.getenv(val.strip("$"))
+23 -2
View File
@@ -1,5 +1,26 @@
from agentic_security.logutils import set_log_level_to_info
from logging import config
def setup_logging():
return set_log_level_to_info()
config.dictConfig(
{
"version": 1,
"disable_existing_loggers": True,
"handlers": {
"console": {
"class": "logging.StreamHandler",
},
},
"root": {
"handlers": ["console"],
"level": "INFO",
},
"loggers": {
"uvicorn.access": {
"level": "ERROR", # Set higher log level to suppress info logs globally
"handlers": ["console"],
"propagate": False,
}
},
}
)
-12
View File
@@ -1,12 +0,0 @@
"""Advanced concurrent execution package for security scanning."""
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
from agentic_security.executor.circuit_breaker import CircuitBreaker
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
__all__ = [
"TokenBucketRateLimiter",
"CircuitBreaker",
"ConcurrentExecutor",
"ExecutorMetrics",
]
@@ -1,109 +0,0 @@
"""Circuit breaker pattern for fault tolerance."""
import time
from typing import Literal
CircuitState = Literal["closed", "open", "half_open"]
class CircuitBreaker:
"""Circuit breaker to prevent cascading failures.
Implements the circuit breaker pattern with three states:
- closed: Normal operation, requests pass through
- open: Failure threshold exceeded, requests fail fast
- half_open: Recovery attempt, limited requests allowed
Example:
>>> breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
>>> if breaker.is_open():
... raise Exception("Circuit breaker is open")
>>> try:
... result = make_request()
... breaker.record_success()
>>> except Exception:
... breaker.record_failure()
"""
def __init__(self, failure_threshold: float = 0.5, recovery_timeout: int = 30):
"""Initialize circuit breaker.
Args:
failure_threshold: Failure rate (0.0-1.0) that triggers open state
recovery_timeout: Seconds to wait before attempting recovery
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.failures = 0
self.successes = 0
self.state: CircuitState = "closed"
self.last_failure_time: float | None = None
def record_success(self):
"""Record a successful request."""
self.successes += 1
# If in half_open state and we have enough successes, close the circuit
if self.state == "half_open" and self.successes >= 3:
self.state = "closed"
self.failures = 0
self.successes = 0
def record_failure(self):
"""Record a failed request."""
self.failures += 1
self.last_failure_time = time.monotonic()
total = self.failures + self.successes
# Need minimum sample size before opening circuit
if total >= 10:
failure_rate = self.failures / total
if failure_rate >= self.failure_threshold:
self.state = "open"
def is_open(self) -> bool:
"""Check if circuit breaker is open.
Returns:
bool: True if circuit is open and requests should be blocked
"""
if self.state == "open":
# Check if we should attempt recovery
if self.last_failure_time is not None:
if time.monotonic() - self.last_failure_time > self.recovery_timeout:
self.state = "half_open"
# Reset counters for half-open state
self.failures = 0
self.successes = 0
return False
return True
return False
def get_state(self) -> CircuitState:
"""Get current circuit breaker state.
Returns:
CircuitState: Current state (closed, open, or half_open)
"""
return self.state
def get_failure_rate(self) -> float:
"""Get current failure rate.
Returns:
float: Failure rate (0.0-1.0), or 0.0 if no requests recorded
"""
total = self.failures + self.successes
if total == 0:
return 0.0
return self.failures / total
def reset(self):
"""Reset circuit breaker to initial state."""
self.failures = 0
self.successes = 0
self.state = "closed"
self.last_failure_time = None
-236
View File
@@ -1,236 +0,0 @@
"""Concurrent executor with rate limiting and circuit breaking."""
import asyncio
import time
from typing import Any
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
from agentic_security.executor.circuit_breaker import CircuitBreaker
from agentic_security.logutils import logger
from agentic_security.probe_actor.state import FuzzerState
class ExecutorMetrics:
"""Track executor performance metrics."""
def __init__(self):
"""Initialize metrics tracking."""
self.successful_requests = 0
self.failed_requests = 0
self.total_latency = 0.0
self.latencies: list[float] = []
def record_success(self, latency: float):
"""Record a successful request.
Args:
latency: Request latency in seconds
"""
self.successful_requests += 1
self.total_latency += latency
self.latencies.append(latency)
def record_failure(self):
"""Record a failed request."""
self.failed_requests += 1
def get_stats(self) -> dict[str, Any]:
"""Get current statistics.
Returns:
dict: Statistics including total requests, success rate, latency metrics
"""
total_requests = self.successful_requests + self.failed_requests
if total_requests == 0:
return {
"total_requests": 0,
"success_rate": 0.0,
"avg_latency_ms": 0.0,
"p95_latency_ms": 0.0,
}
success_rate = self.successful_requests / total_requests
avg_latency_ms = (
(self.total_latency / self.successful_requests * 1000)
if self.successful_requests > 0
else 0.0
)
# Calculate p95 latency
if self.latencies:
sorted_latencies = sorted(self.latencies)
p95_index = int(len(sorted_latencies) * 0.95)
p95_latency_ms = (
sorted_latencies[p95_index] * 1000
if p95_index < len(sorted_latencies)
else 0.0
)
else:
p95_latency_ms = 0.0
return {
"total_requests": total_requests,
"successful_requests": self.successful_requests,
"failed_requests": self.failed_requests,
"success_rate": success_rate,
"avg_latency_ms": avg_latency_ms,
"p95_latency_ms": p95_latency_ms,
}
class ConcurrentExecutor:
"""Enhanced concurrent executor with rate limiting and circuit breaking.
Provides advanced concurrency control for security scanning with:
- Token bucket rate limiting
- Circuit breaker for fault tolerance
- Metrics collection
- Semaphore-based concurrency limits
Example:
>>> executor = ConcurrentExecutor(max_concurrent=20, rate_limit=10, burst=5)
>>> tokens, failures = await executor.execute_batch(
... request_factory, prompts, "module_name", fuzzer_state
... )
>>> print(executor.metrics.get_stats())
"""
def __init__(
self,
max_concurrent: int = 50,
rate_limit: float = 100,
burst: int = 20,
failure_threshold: float = 0.5,
recovery_timeout: int = 30,
):
"""Initialize concurrent executor.
Args:
max_concurrent: Maximum number of concurrent requests
rate_limit: Requests per second limit
burst: Maximum burst size for rate limiter
failure_threshold: Failure rate that triggers circuit breaker
recovery_timeout: Seconds before attempting circuit recovery
"""
self.semaphore = asyncio.Semaphore(max_concurrent)
self.rate_limiter = TokenBucketRateLimiter(rate_limit, burst)
self.circuit_breaker = CircuitBreaker(failure_threshold, recovery_timeout)
self.metrics = ExecutorMetrics()
logger.info(
f"ConcurrentExecutor initialized: max_concurrent={max_concurrent}, "
f"rate_limit={rate_limit}/s, burst={burst}"
)
async def execute_batch(
self,
request_factory,
prompts: list[str],
module_name: str,
fuzzer_state: FuzzerState,
) -> tuple[int, int]:
"""Execute a batch of prompts with rate limiting and circuit breaking.
This is compatible with the existing process_prompt_batch signature.
Args:
request_factory: Request factory with fn() method
prompts: List of prompts to process
module_name: Name of the module being scanned
fuzzer_state: State tracking object
Returns:
tuple[int, int]: (total_tokens, failures)
"""
tasks = [
self._execute_single(request_factory, prompt, module_name, fuzzer_state)
for prompt in prompts
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Aggregate results
total_tokens = 0
failures = 0
for result in results:
if isinstance(result, Exception):
failures += 1
logger.error(f"Task failed with exception: {result}")
else:
tokens, refused = result
total_tokens += tokens
if refused:
failures += 1
return total_tokens, failures
async def _execute_single(
self,
request_factory,
prompt: str,
module_name: str,
fuzzer_state: FuzzerState,
) -> tuple[int, bool]:
"""Execute a single prompt with rate limiting and circuit breaking.
Args:
request_factory: Request factory with fn() method
prompt: Prompt to process
module_name: Name of the module being scanned
fuzzer_state: State tracking object
Returns:
tuple[int, bool]: (tokens, refused)
Raises:
Exception: If circuit breaker is open
"""
# Rate limiting
await self.rate_limiter.acquire()
# Circuit breaker check
if self.circuit_breaker.is_open():
self.metrics.record_failure()
raise Exception("Circuit breaker is open - too many failures")
# Concurrency control
async with self.semaphore:
start_time = time.monotonic()
try:
# Import here to avoid circular dependency
from agentic_security.probe_actor.fuzzer import process_prompt
tokens = 0 # Initial token count for this prompt
result = await process_prompt(
request_factory, prompt, tokens, module_name, fuzzer_state
)
# Record success
self.circuit_breaker.record_success()
latency = time.monotonic() - start_time
self.metrics.record_success(latency)
return result
except Exception as e:
# Record failure
self.circuit_breaker.record_failure()
self.metrics.record_failure()
logger.error(f"Error executing prompt: {e}")
raise
def get_metrics(self) -> dict[str, Any]:
"""Get current executor metrics.
Returns:
dict: Metrics including request stats, latency, and circuit breaker state
"""
stats = self.metrics.get_stats()
stats["circuit_breaker_state"] = self.circuit_breaker.get_state()
stats["circuit_breaker_failure_rate"] = self.circuit_breaker.get_failure_rate()
stats["available_tokens"] = self.rate_limiter.get_available_tokens()
return stats
-63
View File
@@ -1,63 +0,0 @@
"""Token bucket rate limiter for controlling request rate."""
import asyncio
import time
class TokenBucketRateLimiter:
"""Token bucket rate limiter with configurable rate and burst capacity.
This implements the token bucket algorithm where tokens are added at a fixed
rate and consumed for each request. Supports bursting up to the bucket capacity.
Example:
>>> limiter = TokenBucketRateLimiter(rate=10, burst=20)
>>> await limiter.acquire() # Will wait if no tokens available
"""
def __init__(self, rate: float, burst: int):
"""Initialize rate limiter.
Args:
rate: Tokens added per second (requests/sec)
burst: Maximum bucket capacity (max concurrent burst)
"""
self.rate = rate
self.burst = burst
self.tokens = float(burst)
self.last_update = time.monotonic()
self._lock = asyncio.Lock()
async def acquire(self):
"""Acquire a token, waiting if necessary.
This method will block until a token is available.
"""
async with self._lock:
now = time.monotonic()
elapsed = now - self.last_update
# Add tokens based on elapsed time
self.tokens = min(self.burst, self.tokens + elapsed * self.rate)
self.last_update = now
if self.tokens >= 1:
# Token available, consume it
self.tokens -= 1
return
# Need to wait for next token
wait_time = (1 - self.tokens) / self.rate
await asyncio.sleep(wait_time)
self.tokens = 0
self.last_update = time.monotonic()
def get_available_tokens(self) -> float:
"""Get current number of available tokens (non-blocking).
Returns:
float: Number of tokens currently available
"""
now = time.monotonic()
elapsed = now - self.last_update
return min(self.burst, self.tokens + elapsed * self.rate)
+9 -41
View File
@@ -1,6 +1,5 @@
import base64
from enum import Enum
from urllib.parse import urlparse
import httpx
from pydantic import BaseModel
@@ -69,9 +68,7 @@ class LLMSpec(BaseModel):
return response
def validate(
self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None
) -> None:
def validate(self, prompt, encoded_image, encoded_audio, files) -> None:
if self.has_files and not files:
raise ValueError("Files are required for this request.")
@@ -82,11 +79,7 @@ class LLMSpec(BaseModel):
raise ValueError("Audio is required for this request.")
async def probe(
self,
prompt: str,
encoded_image: str = "",
encoded_audio: str = "",
files: dict | None = None,
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
) -> httpx.Response:
"""Sends an HTTP request using the `httpx` library.
@@ -161,52 +154,27 @@ def parse_http_spec(http_spec: str) -> LLMSpec:
secrets = get_secrets()
# Split the spec by lines
lines = http_spec.strip("\n").splitlines()
if not lines:
raise InvalidHTTPSpecError("HTTP spec is empty.")
lines = http_spec.strip().split("\n")
# Extract the method and URL from the first line
request_line_parts = lines[0].split()
if len(request_line_parts) < 2:
raise InvalidHTTPSpecError(
"First line of HTTP spec must include the method and URL."
)
method, url = request_line_parts[0], request_line_parts[1]
# Check url validity
valid_url = urlparse(url)
# if missing the correct formatting ://, urlparse.netloc will be empty
if valid_url.scheme not in ("http", "https") or not valid_url.netloc:
raise InvalidHTTPSpecError(
f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'"
)
method, url = lines[0].split(" ")[0:2]
# Initialize headers and body
headers = {}
body_lines: list[str] = []
body = ""
# Iterate over the remaining lines
reading_headers = True
for line in lines[1:]:
if line.strip() == "":
if reading_headers:
reading_headers = False
continue
body_lines.append("")
if line == "":
reading_headers = False
continue
if reading_headers:
if ":" not in line:
raise InvalidHTTPSpecError(f"Invalid header line: '{line}'")
key, value = line.split(":", maxsplit=1)
key = key.strip()
value = value.strip()
if not key:
raise InvalidHTTPSpecError("Header name cannot be empty.")
key, value = line.split(": ")
headers[key] = value
else:
body_lines.append(line)
body = "\n".join(body_lines)
body += line
has_files = "multipart/form-data" in headers.get("Content-Type", "")
has_image = "<<BASE64_IMAGE>>" in body
has_audio = "<<BASE64_AUDIO>>" in body
+4 -2
View File
@@ -5,6 +5,8 @@ from typing import Protocol
class IntegrationProto(Protocol):
def __init__(
self, prompt_groups: list, tools_inbox: asyncio.Queue, opts: dict = {}
): ...
):
...
async def apply(self) -> list: ...
async def apply(self) -> list:
...
+10 -16
View File
@@ -1,16 +1,15 @@
import asyncio
import copy
import json
from datetime import datetime
import colorama
import tqdm.asyncio
from loguru import logger
from rich.console import Console
from rich.table import Table
from tabulate import tabulate
from agentic_security.config import SettingsMixin # Importing the configuration mixin
from agentic_security.logutils import logger
from agentic_security.primitives import Scan
from agentic_security.probe_data import REGISTRY
from agentic_security.routes.scan import streaming_response_generator
@@ -24,20 +23,18 @@ YELLOW = colorama.Fore.YELLOW
BLUE = colorama.Fore.BLUE
class SecurityScanner(SettingsMixin):
class AgenticSecurity(SettingsMixin):
@classmethod
async def async_scan(
cls,
llmSpec: str,
maxBudget: int,
datasets: list[dict] | None,
datasets: list[dict],
max_th: float,
optimize: bool = False,
enableMultiStepAttack: bool = False,
probe_datasets: list[dict] | None = None,
probe_datasets: list[dict] = [],
):
datasets = copy.deepcopy(datasets) if datasets is not None else []
probe_datasets = copy.deepcopy(probe_datasets or [])
start_time = datetime.now()
total_modules = len(datasets)
completed_modules = 0
@@ -173,18 +170,15 @@ class SecurityScanner(SettingsMixin):
cls,
llmSpec: str,
maxBudget: int = 1_000_000,
datasets: list[dict] | None = None,
datasets: list[dict] = REGISTRY,
max_th: float = 0.3,
optimize: bool = False,
enableMultiStepAttack: bool = False,
probe_datasets: list[dict] | None = None,
only: list[str] | None = None,
probe_datasets: list[dict] = [],
only: list[str] = [],
):
datasets = copy.deepcopy(datasets or REGISTRY)
probe_datasets = copy.deepcopy(probe_datasets or [])
only_set = set(only) if only else None
if only_set is not None:
datasets = [d for d in datasets if d.get("dataset_name") in only_set]
if only:
datasets = [d for d in datasets if d["dataset_name"] in only]
for d in datasets:
d["selected"] = True
return asyncio.run(
@@ -208,7 +202,7 @@ class SecurityScanner(SettingsMixin):
self.load_config(self.default_path)
logger.info("Configuration loaded successfully.")
print(self.config)
datasets = list(self.get_config_value("modules", {}).values())
datasets = list(self.get_config_value("modules").values())
for d in datasets:
d["selected"] = True
self.scan(
-146
View File
@@ -1,146 +0,0 @@
# import sys
# from loguru import logger
# # Define custom colors
# BLUE = "#89CFF0"
# BROWN = "#8B4513" # Brown for DEBUG
# # Define custom log level colors
# logger.level("DEBUG", color=f"<fg {BROWN}>")
# logger.level("INFO", color=f"<fg {BLUE}>")
# # Define custom log format with aligned messages and colored levels
# LOG_FORMAT = (
# "<level>{level:<8}</level> " # Properly formatted and colored log level
# "<level>{message:<100}</level> " # Left-aligned message for readability
# "<cyan>{file.name}</cyan>:<cyan>{line}</cyan>" # File name and line number in cyan
# )
# # Remove default handlers and add a new one with custom formatting
# logger.remove()
# logger.add(sys.stdout, format=LOG_FORMAT, level="DEBUG", colorize=True)
import logging
import logging.config
import time
from collections.abc import Callable, Coroutine
from functools import wraps
from os import getenv
from typing import Any, ParamSpec, TypeVar
LOGGER_NAME = None
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {"format": "%(message)s", "datefmt": "[%X]"},
},
"handlers": {
"rich": {
"class": "rich.logging.RichHandler",
"level": "INFO",
"formatter": "rich",
"show_time": False,
"rich_tracebacks": False,
"show_path": lambda: True if getenv("API_RUNTIME") == "dev" else False,
"tracebacks_show_locals": False,
},
},
"loggers": {
"": { # Root logger configuration
"level": "INFO",
"handlers": ["rich"],
"propagate": True,
},
"httpx": { # Disable httpx logging
"level": "WARNING", # Suppress DEBUG and INFO messages from httpx
"handlers": [],
"propagate": False,
},
"uvicorn.access": { # Disable uvicorn.access logging
"level": "WARNING", # Suppress DEBUG and INFO messages from uvicorn.access
"handlers": [],
"propagate": False,
},
},
}
def configure_logging():
# Apply the dictionary configuration
logging.config.dictConfig(LOGGING_CONFIG)
# Get and return the logger
logger = logging.getLogger(LOGGER_NAME)
return logger
logger: logging.Logger = configure_logging()
def set_log_level_to_debug():
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.DEBUG)
# Update handler level as well
for handler in logger.handlers:
handler.setLevel(logging.DEBUG)
def set_log_level_to_info():
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.INFO)
# Update handler level as well
for handler in logger.handlers:
handler.setLevel(logging.INFO)
# Set initial log level
set_log_level_to_info()
# Define generic type variables for return type and parameters
R = TypeVar("R")
P = ParamSpec("P")
def time_execution_sync(
additional_text: str = "",
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
logger.debug(
f"{additional_text} Execution time: {execution_time:.2f} seconds"
)
return result
return wrapper
return decorator
def time_execution_async(
additional_text: str = "",
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
def decorator(
func: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
logger.debug(
f"{additional_text} Execution time: {execution_time:.2f} seconds"
)
return result
return wrapper
return decorator
View File
-68
View File
@@ -1,68 +0,0 @@
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from agentic_security.logutils import logger
# Create server parameters for stdio connection
server_params = StdioServerParameters(
command="python", # Executable
args=["agentic_security/mcp/main.py"], # Your server script
env=None, # Optional environment variables
)
async def run() -> None:
try:
logger.info(
"Starting stdio client session with server parameters: %s", server_params
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection --> connection does not work
logger.info("Initializing client session...")
await session.initialize()
# List available prompts, resources, and tools --> no avalialbe tools
logger.info("Listing available prompts...")
prompts = await session.list_prompts()
logger.info(f"Available prompts: {prompts}")
logger.info("Listing available resources...")
resources = await session.list_resources()
logger.info(f"Available resources: {resources}")
logger.info("Listing available tools...")
tools = await session.list_tools()
logger.info(f"Available tools: {tools}")
# Call the echo tool --> echo tool issue
logger.info("Calling echo_tool with message...")
echo_result = await session.call_tool(
"echo_tool", arguments={"message": "Hello from client!"}
)
logger.info(f"Tool result: {echo_result}")
# # Read the echo resource
# echo_content, mime_type = await session.read_resource(
# "echo://Hello_resource"
# )
# logger.info(f"Resource content: {echo_content}")
# logger.info(f"Resource MIME type: {mime_type}")
# # Get and use the echo prompt
# prompt_result = await session.get_prompt(
# "echo_prompt", arguments={"message": "Hello prompt!"}
# )
# logger.info(f"Prompt result: {prompt_result}")
logger.info("Client operations completed successfully.")
return prompts, resources, tools
except Exception as e:
logger.error(f"An error occurred during client operations: {e}", exc_info=True)
raise
if __name__ == "__main__":
asyncio.run(run())
-108
View File
@@ -1,108 +0,0 @@
import httpx
from mcp.server.fastmcp import FastMCP
# Initialize MCP server
mcp = FastMCP(
name="Agentic Security MCP Server",
dependencies=["httpx"],
)
# FastAPI Server Configuration
AGENTIC_SECURITY = "http://0.0.0.0:8718"
@mcp.tool()
async def verify_llm(spec: str) -> dict:
"""
Verify an LLM model specification using the FastAPI server
Returns:
dict: containing the verification result form the FastAPI server
Args: spect(str): The specification of the LLM model to verify.
"""
url = f"{AGENTIC_SECURITY}/verify"
async with httpx.AsyncClient() as client:
response = await client.post(url, json={"spec": spec})
return response.json()
@mcp.tool()
async def start_scan(
llmSpec: str,
maxBudget: int,
optimize: bool = False,
enableMultiStepAttack: bool = False,
) -> dict:
"""
Start an LLM security scan via the FastAPI server.
Returns:
dict: The scan initiation result from the FastAPI server.
Args:
llmSpec (str): The specification of the LLM model.
maxBudget (int): The maximum budget for the scan.
optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False.
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack
"""
url = f"{AGENTIC_SECURITY}/scan"
payload = {
"llmSpec": llmSpec,
"maxBudget": maxBudget,
"datasets": [],
"optimize": optimize,
"enableMultiStepAttack": enableMultiStepAttack,
"probe_datasets": [],
"secrets": {},
}
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload)
return response.json()
@mcp.tool()
async def stop_scan() -> dict:
"""Stop an ongoing scan via the FastAPI server.
Returns:
dict: The confirmation from the FastAPI server that the scan has been stopped.
"""
url = f"{AGENTIC_SECURITY}/stop"
async with httpx.AsyncClient() as client:
response = await client.post(url)
return response.json()
@mcp.tool()
async def get_data_config() -> list:
"""
Retrieve data configuration from the FastAPI server.
Returns:
list: The response from the FastAPI server, confirming the scan has been stopped.
"""
url = f"{AGENTIC_SECURITY}/v1/data-config"
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
@mcp.tool()
async def get_spec_templates() -> list:
"""
Retrieve data configuration from the FastAPI server.
Returns:
list: The LLM specification templates from the FastAPI server.
"""
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
# Run the MCP server
if __name__ == "__main__":
mcp.run()
+1 -2
View File
@@ -1,8 +1,7 @@
from fastapi import Request
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from agentic_security.logutils import logger
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
+9 -14
View File
@@ -8,14 +8,14 @@ except ImportError:
def generate_banner(
title: str = "Agentic Security",
font: str = "slant",
version: str = "v2.1.0",
tagline: str = "Proactive Threat Detection & Automated Security Protocols",
author: str = "Developed by: [Security Team]",
website: str = "Website: https://github.com/msoedov/agentic_security",
warning: str | None = "", # Using Optional for warning since it might be None
) -> str:
title="Agentic Security",
font="slant",
version="v2.1.0",
tagline="Proactive Threat Detection & Automated Security Protocols",
author="Developed by: [Security Team]",
website="Website: https://github.com/msoedov/agentic_security",
warning="",
):
"""Generate a visually enhanced banner with dynamic width and borders."""
# Define the text elements
@@ -80,13 +80,8 @@ def generate_banner(
def init_banner():
return
ver = version("agentic_security")
try:
print(generate_banner(version=ver))
except Exception:
# UnicodeEncodeError with codec on some systems
pass
print(generate_banner(version=ver))
if __name__ == "__main__":
+1 -2
View File
@@ -1,6 +1,5 @@
# noqa
from agentic_security.primitives.models import CompletionRequest # noqa
from agentic_security.primitives.models import ( # noqa
CompletionRequest,
FileProbeResponse,
LLMInfo,
Message,
+3 -3
View File
@@ -18,13 +18,13 @@ class LLMInfo(BaseModel):
class Scan(BaseModel):
llmSpec: str
maxBudget: int
datasets: list[dict] = Field(default_factory=list)
datasets: list[dict] = []
optimize: bool = False
enableMultiStepAttack: bool = False
# MSJ only mode
probe_datasets: list[dict] = Field(default_factory=list)
probe_datasets: list[dict] = []
# Set and managed by the backend
secrets: dict[str, str] = Field(default_factory=dict)
secrets: dict[str, str] = {}
def with_secrets(self, secrets) -> "Scan":
match secrets:
+279 -504
View File
@@ -3,44 +3,32 @@ import random
import time
from collections.abc import AsyncGenerator
from json import JSONDecodeError
from typing import Any
import httpx
import pandas as pd
from loguru import logger
from skopt import Optimizer
from skopt.space import Real
from agentic_security.config import settings_var
from agentic_security.http_spec import Modality
from agentic_security.logutils import logger
from agentic_security.primitives import Scan, ScanResult
from agentic_security.probe_actor.cost_module import calculate_cost
from agentic_security.probe_actor.refusal import refusal_heuristic
from agentic_security.probe_actor.state import FuzzerState
from agentic_security.probe_data import audio_generator, image_generator, msj_data
from agentic_security.probe_data.data import prepare_prompts
MAX_PROMPT_LENGTH = settings_var("fuzzer.max_prompt_lenght", 2048)
BUDGET_MULTIPLIER = settings_var("fuzzer.budget_multiplier", 100000000)
INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25)
MIN_FAILURE_SAMPLES = settings_var("fuzzer.min_failure_samples", 5)
FAILURE_RATE_THRESHOLD = settings_var("fuzzer.failure_rate_threshold", 0.5)
# TODO: full log file
MAX_PROMPT_LENGTH = 2048
BUDGET_MULTIPLIER = 100_000_000
INITIAL_OPTIMIZER_POINTS = 25
MIN_FAILURE_SAMPLES = 5
FAILURE_RATE_THRESHOLD = 0.5
async def generate_prompts(
prompts: list[str] | AsyncGenerator,
) -> AsyncGenerator[str, None]:
"""
Asynchronously generates and yields individual prompts.
If the input is a list of strings, the function sequentially yields each string.
If the input is an asynchronous generator, it forwards each generated prompt.
Args:
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
Yields:
str: An individual prompt from the list or the asynchronous generator.
"""
if isinstance(prompts, list):
for prompt in prompts:
yield prompt
@@ -49,21 +37,7 @@ async def generate_prompts(
yield prompt
def get_modality_adapter(llm_spec):
"""
Returns the appropriate request adapter based on the modality of the LLM specification.
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
Args:
llm_spec: An object containing modality information for the LLM.
Returns:
RequestAdapter | llm_spec: An instance of the appropriate request adapter
or the original `llm_spec` if no adaptation is needed.
"""
def multi_modality_spec(llm_spec):
match llm_spec.modality:
case Modality.IMAGE:
return image_generator.RequestAdapter(llm_spec)
@@ -76,71 +50,40 @@ def get_modality_adapter(llm_spec):
async def process_prompt(
request_factory,
prompt: str,
tokens: int,
module_name: str,
fuzzer_state: FuzzerState,
request_factory, prompt, tokens, module_name, refusals, errors, outputs
) -> tuple[int, bool]:
"""
Processes a single prompt using the provided request factory and updates tracking lists.
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
the response indicates a refusal, the function records the issue and returns the updated token count
along with a boolean indicating whether the prompt was refused.
Args:
request_factory: An object with a `fn` method used to send the prompt.
prompt (str): The input prompt to be processed.
tokens (int): The current token count, which will be updated.
module_name (str): The name of the module handling the request.
fuzzer_state: State tracking object for the fuzzer
Returns:
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
Process a single prompt and update the token count and failure status.
"""
try:
response = await request_factory.fn(prompt=prompt)
# Handle HTTP errors
if response.status_code == 422:
logger.error(f"Invalid prompt: {prompt}, error=422")
fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt")
errors.append((module_name, prompt, 422, "Invalid prompt"))
return tokens, True
if response.status_code >= 400:
logger.error(f"HTTP {response.status_code} {response.content=}")
fuzzer_state.add_error(
module_name, prompt, response.status_code, response.text
)
errors.append((module_name, prompt, response.status_code, response.text))
return tokens, True
# Process successful response
response_text = response.text
tokens += len(response_text.split())
# Check if the response indicates a refusal
refused = refusal_heuristic(response.json())
if refused:
fuzzer_state.add_refusal(
module_name, prompt, response.status_code, response_text
)
refusals.append((module_name, prompt, response.status_code, response_text))
fuzzer_state.add_output(module_name, prompt, response_text, refused)
outputs.append((module_name, prompt, response_text, refused))
return tokens, refused
except httpx.RequestError as exc:
logger.error(f"Request error: {exc}")
fuzzer_state.add_error(module_name, prompt, "?", str(exc))
errors.append((module_name, prompt, "?", str(exc)))
return tokens, True
except JSONDecodeError as json_decode_error:
logger.error(f"JSON error: {json_decode_error}")
fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error))
logger.error(f"Jason error: {json_decode_error}")
errors.append((module_name, prompt, "?", str(json_decode_error)))
return tokens, True
except Exception as e:
logger.exception(f"Unexpected error: {e}")
return tokens, False
async def process_prompt_batch(
@@ -148,29 +91,14 @@ async def process_prompt_batch(
prompts: list[str],
tokens: int,
module_name: str,
fuzzer_state: FuzzerState,
refusals,
errors,
outputs,
) -> tuple[int, int]:
"""
Processes a batch of prompts asynchronously and aggregates the results.
This function sends multiple prompts concurrently using `process_prompt`,
collects the token count and failure status for each prompt, and returns
the total number of tokens processed and the number of failed prompts.
Args:
request_factory: An object with a `fn` method used to send the prompts.
prompts (list[str]): A list of input prompts to be processed.
tokens (int): The initial token count, which will be updated.
module_name (str): The name of the module handling the request.
fuzzer_state: State tracking object for the fuzzer
Returns:
tuple[int, int]:
- Total number of tokens processed.
- Number of failed prompts.
"""
tasks = [
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
process_prompt(
request_factory, p, tokens, module_name, refusals, errors, outputs
)
for p in prompts
]
results = await asyncio.gather(*tasks)
@@ -179,161 +107,133 @@ async def process_prompt_batch(
return total_tokens, failures
async def scan_module(
async def perform_single_shot_scan(
request_factory,
module,
fuzzer_state: FuzzerState,
processed_prompts: int = 0,
total_prompts: int = 0,
max_budget: int = 0,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
token_counter: dict[str, int] | None = None,
) -> AsyncGenerator[dict[str, Any], None]:
"""
Scan a single module.
Args:
request_factory: The factory for creating requests
module: The prompt module to scan
fuzzer_state: State tracking object for the fuzzer
processed_prompts: Number of prompts processed so far
total_prompts: Total number of prompts to process
max_budget: Maximum token budget
token_counter: Shared token counter to enforce global budget
optimize: Whether to use optimization
stop_event: Event to stop scanning
Yields:
ScanResult objects as the scan progresses
"""
tokens = 0
token_counter = token_counter or {"total": 0}
module_failures = 0
module_prompts = 0
failure_rates = []
should_stop = False
# Initialize optimizer if optimization is enabled
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
)
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
yield ScanResult(
module=module.dataset_name,
tokens=0,
cost=0,
progress=0,
failureRate=0,
prompt="",
latency=0,
model="",
).model_dump_json()
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
processed_prompts += 1
module_prompts += 1
# Calculate progress based on total processed prompts
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
progress = progress % 100
start = time.time()
previous_tokens = tokens
tokens, failed = await process_prompt(
request_factory,
prompt,
tokens,
module.dataset_name,
fuzzer_state=fuzzer_state,
)
end = time.time()
token_delta = max(tokens - previous_tokens, 0)
token_counter["total"] += token_delta
if failed:
module_failures += 1
failure_rate = module_failures / max(module_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
response_text = fuzzer_state.get_last_output(prompt) or ""
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
latency=end - start,
model=response_text,
).model_dump_json()
# Optimization logic
if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > FAILURE_RATE_THRESHOLD:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
should_stop = True
break
# Budget check
if token_counter["total"] > max_budget:
logger.info(
"Scan ran out of budget and stopped. %s %s",
token_counter["total"],
max_budget,
)
yield ScanResult.status_msg(
f"Scan ran out of budget and stopped. total_tokens={token_counter['total']} max_budget={max_budget}"
)
should_stop = True
break
if should_stop:
break
return
async def with_error_handling(agen):
"""
Wraps an asynchronous generator with error handling.
This function iterates over an asynchronous generator, yielding its values.
If an exception occurs, it logs the error and yields a failure message.
Finally, it ensures that a completion message is always yielded.
Args:
agen: An asynchronous generator that produces scan results.
Yields:
ScanResult: Either a successful result, an error message if an
exception occurs, or a completion message at the end.
"""
max_budget: int,
datasets: list[dict[str, str]] = [],
tools_inbox=None,
optimize=False,
stop_event: asyncio.Event = None,
secrets: dict[str, str] = {},
) -> AsyncGenerator[str, None]:
"""Perform a standard security scan."""
max_budget = max_budget * BUDGET_MULTIPLIER
selected_datasets = [m for m in datasets if m["selected"]]
request_factory = multi_modality_spec(request_factory)
try:
async for t in agen:
yield t
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in selected_datasets],
budget=max_budget,
tools_inbox=tools_inbox,
options=[m.get("opts", {}) for m in selected_datasets],
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
if optimize
else None
)
failure_rates = []
total_tokens = 0
tokens = 0
should_stop = False
for module in prompt_modules:
if should_stop:
break
tokens = 0
module_failures = 0
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
module_prompts = 0 # Reset for each module
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
processed_prompts += 1
module_prompts += 1 # Fixed increment syntax
# Calculate progress based on total processed prompts
progress = (
100 * processed_prompts / total_prompts if total_prompts else 0
)
total_tokens -= tokens
start = time.time()
tokens, failed = await process_prompt(
request_factory,
prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
)
end = time.time()
total_tokens += tokens
if failed:
module_failures += 1
failure_rate = module_failures / max(module_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
last_output = outputs[-1] if outputs else None
if last_output and last_output[1] == prompt:
response_text = last_output[2]
else:
response_text = ""
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
latency=end - start,
model=response_text,
).model_dump_json()
if optimize and len(failure_rates) >= 5:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > 0.5:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
should_stop = True
break
if total_tokens > max_budget:
logger.info(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
yield ScanResult.status_msg(
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
)
should_stop = True
break
yield ScanResult.status_msg("Scan completed.")
failure_data = errors + refusals
df = pd.DataFrame(
failure_data, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv("failures.csv", index=False)
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
@@ -341,279 +241,154 @@ async def with_error_handling(agen):
yield ScanResult.status_msg("Scan completed.")
async def perform_single_shot_scan(
request_factory,
max_budget: int,
datasets: list[dict[str, str]] | None = None,
tools_inbox=None,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
secrets: dict[str, str] | None = None,
) -> AsyncGenerator[str, None]:
"""
Perform a standard security scan using a given request factory.
This function processes security scan prompts from selected datasets while
respecting a predefined token budget. It supports optimization, failure tracking,
and early stopping based on budget constraints or user intervention.
Args:
request_factory: A factory function that generates requests for processing prompts.
max_budget (int): The maximum token budget for the scan.
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
tools_inbox: Optional additional tools for processing (default: None).
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
Yields:
str: JSON-encoded scan results or status messages.
The function iterates over prompts, processes them asynchronously, and updates
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
it stops execution. Results are saved to a CSV file upon completion.
"""
datasets = datasets or []
secrets = secrets or {}
if stop_event and stop_event.is_set():
stop_event.clear()
yield ScanResult.status_msg("Loading datasets...")
yield ScanResult.status_msg("Scan stopped by user.")
yield ScanResult.status_msg("Scan completed.")
return
max_budget = max_budget * BUDGET_MULTIPLIER
selected_datasets = [m for m in datasets if m.get("selected")]
request_factory = get_modality_adapter(request_factory)
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in selected_datasets],
budget=max_budget,
tools_inbox=tools_inbox,
options=[m.get("opts", {}) for m in selected_datasets],
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
fuzzer_state = FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
token_counter = {"total": 0}
for module in prompt_modules:
module_gen = scan_module(
request_factory=request_factory,
module=module,
fuzzer_state=fuzzer_state,
processed_prompts=processed_prompts,
total_prompts=total_prompts,
max_budget=max_budget,
optimize=optimize,
stop_event=stop_event,
token_counter=token_counter,
)
try:
async for result in module_gen:
yield result
except Exception:
logger.error("Module exception")
continue
# Update processed_prompts count
module_size = 0 if module.lazy else len(module.prompts)
processed_prompts += module_size
yield ScanResult.status_msg("Scan completed.")
fuzzer_state.export_failures("failures.csv")
async def perform_many_shot_scan(
request_factory,
max_budget: int,
datasets: list[dict[str, str]] | None = None,
probe_datasets: list[dict[str, str]] | None = None,
datasets: list[dict[str, str]] = [],
probe_datasets: list[dict[str, str]] = [],
tools_inbox=None,
optimize: bool = False,
stop_event: asyncio.Event | None = None,
optimize=False,
stop_event: asyncio.Event = None,
probe_frequency: float = 0.2,
max_ctx_length: int = 10_000,
secrets: dict[str, str] | None = None,
secrets: dict[str, str] = {},
) -> AsyncGenerator[str, None]:
"""
Perform a multi-step security scan with probe injection.
This function executes a security scan while periodically injecting probe datasets
to test system robustness. It tracks failures, optimizes scan efficiency,
and ensures adherence to a predefined token budget.
Args:
request_factory: A factory function that generates requests for processing prompts.
max_budget (int): The maximum token budget for the scan.
datasets (list[dict[str, str]], optional): The main datasets for scanning.
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
tools_inbox: Optional tools for additional processing (default: None).
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
Yields:
str: JSON-encoded scan results or status messages.
This function iterates over prompts, injects probe prompts at random intervals,
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
"""
datasets = datasets or []
probe_datasets = probe_datasets or []
secrets = secrets or {}
if stop_event and stop_event.is_set():
stop_event.clear()
"""Perform a multi-step security scan with probe injection."""
request_factory = multi_modality_spec(request_factory)
try:
# Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...")
yield ScanResult.status_msg("Scan stopped by user.")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
budget=max_budget,
tools_inbox=tools_inbox,
)
yield ScanResult.status_msg("Loading datasets for MSJ...")
msj_modules = msj_data.prepare_prompts(probe_datasets)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
outputs = []
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
if optimize
else None
)
failure_rates = []
for module in prompt_modules:
module_failures = 0
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
tokens = 0
processed_prompts += 1
progress = (
100 * processed_prompts / total_prompts if total_prompts else 0
)
full_prompt = ""
msj_module = random.choice(msj_modules)
prompt_tokens = len(full_prompt.split())
tokens += prompt_tokens
injected = False
for _ in range(20):
if injected:
break
m_prompt = random.choice(msj_module.prompts)
full_prompt += "\n" + m_prompt
if tokens > max_ctx_length:
full_prompt = "\n" + prompt
injected = True
tokens, failed = await process_prompt(
request_factory,
full_prompt,
tokens,
module.dataset_name,
refusals,
errors,
outputs,
)
if failed:
module_failures += 1
break
if injected:
break
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
).model_dump_json()
if optimize and len(failure_rates) >= 5:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > 0.5:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
break
yield ScanResult.status_msg("Scan completed.")
return
request_factory = get_modality_adapter(request_factory)
# Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
budget=max_budget,
tools_inbox=tools_inbox,
)
yield ScanResult.status_msg("Loading datasets for MSJ...")
msj_modules = msj_data.prepare_prompts(probe_datasets)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
fuzzer_state = FuzzerState()
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
df = pd.DataFrame(
errors + refusals, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv("failures.csv", index=False)
failure_rates = []
for module in prompt_modules:
module_failures = 0
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
yield ScanResult.status_msg("Scan stopped by user.")
return
tokens = 0
processed_prompts += 1
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
progress = progress % 100
full_prompt = ""
msj_module = random.choice(msj_modules)
prompt_tokens = len(full_prompt.split())
tokens += prompt_tokens
injected = False
for _ in range(20):
if injected:
break
m_prompt = random.choice(msj_module.prompts)
full_prompt += "\n" + m_prompt
if tokens > max_ctx_length:
full_prompt = "\n" + prompt
injected = True
tokens, failed = await process_prompt(
request_factory,
full_prompt,
tokens,
module.dataset_name,
fuzzer_state=fuzzer_state,
)
if failed:
module_failures += 1
break
if injected:
break
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=cost,
progress=round(progress, 2),
failureRate=round(failure_rate * 100, 2),
prompt=prompt[:MAX_PROMPT_LENGTH],
).model_dump_json()
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
yield ScanResult.status_msg(
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
)
break
yield ScanResult.status_msg("Scan completed.")
fuzzer_state.export_failures("failures.csv")
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
raise e
def scan_router(
request_factory,
scan_parameters: Scan,
tools_inbox=None,
stop_event: asyncio.Event | None = None,
stop_event: asyncio.Event = None,
):
"""
Route scan requests to the appropriate scanning function.
This function determines whether to perform a multi-step or single-shot
security scan based on the provided scan parameters.
Args:
request_factory: A factory function to generate requests for processing prompts.
scan_parameters (Scan): An object containing the parameters for the scan, including:
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
- maxBudget (int): The maximum token budget for the scan.
- datasets (list[dict[str, str]]): The datasets to scan.
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
- optimize (bool): Whether to enable optimization.
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
tools_inbox: Optional tools for additional processing (default: None).
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
Returns:
A function wrapped with `with_error_handling`, which executes either:
- `perform_many_shot_scan` for multi-step scanning.
- `perform_single_shot_scan` for single-shot scanning.
The function ensures that the appropriate scanning method is chosen based on
the `enableMultiStepAttack` flag in `scan_parameters`.
"""
if scan_parameters.enableMultiStepAttack:
return with_error_handling(
perform_many_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
probe_datasets=scan_parameters.probe_datasets,
tools_inbox=tools_inbox,
optimize=scan_parameters.optimize,
stop_event=stop_event,
secrets=scan_parameters.secrets,
)
return perform_many_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
probe_datasets=scan_parameters.probe_datasets,
tools_inbox=tools_inbox,
optimize=scan_parameters.optimize,
stop_event=stop_event,
secrets=scan_parameters.secrets,
)
else:
return with_error_handling(
perform_single_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
tools_inbox=tools_inbox,
optimize=scan_parameters.optimize,
stop_event=stop_event,
secrets=scan_parameters.secrets,
)
return perform_single_shot_scan(
request_factory=request_factory,
max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets,
tools_inbox=tools_inbox,
optimize=scan_parameters.optimize,
stop_event=stop_event,
secrets=scan_parameters.secrets,
)
+5 -1
View File
@@ -1,4 +1,5 @@
import asyncio
import logging
from typing import Any
import httpx
@@ -6,10 +7,13 @@ from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = []
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgentSpecification(BaseModel):
name: str | None = Field(None, description="Name of the LLM/agent")
+2 -1
View File
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from agentic_security.refusal_classifier.model import RefusalClassifier
from agentic_security.refusal_classifier import RefusalClassifier
classifier = RefusalClassifier()
classifier.load_model()
@@ -50,6 +50,7 @@ class RefusalClassifierPlugin(ABC):
Returns:
bool: True if the response contains a refusal, False otherwise.
"""
pass
class DefaultRefusalClassifier(RefusalClassifierPlugin):
-47
View File
@@ -1,47 +0,0 @@
import pandas as pd
class FuzzerState:
"""Container for tracking scan results"""
def __init__(self):
self.errors = []
self.refusals = []
self.outputs = []
def add_error(
self,
module_name: str,
prompt: str,
status_code: int | str,
error_msg: str,
):
"""Add an error to the state"""
self.errors.append((module_name, prompt, status_code, error_msg))
def add_refusal(
self, module_name: str, prompt: str, status_code: int, response_text: str
):
"""Add a refusal to the state"""
self.refusals.append((module_name, prompt, status_code, response_text))
def add_output(
self, module_name: str, prompt: str, response_text: str, refused: bool
):
"""Add an output to the state"""
self.outputs.append((module_name, prompt, response_text, refused))
def get_last_output(self, prompt: str) -> str | None:
"""Get the last output for a given prompt"""
for output in reversed(self.outputs):
if output[1] == prompt:
return output[2]
return None
def export_failures(self, filename: str = "failures.csv"):
"""Export failures to a CSV file"""
failure_data = self.errors + self.refusals
df = pd.DataFrame(
failure_data, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv(filename, index=False)
@@ -7,7 +7,6 @@ import pytest
from agentic_security.primitives import Scan
from agentic_security.probe_actor.fuzzer import (
FuzzerState,
generate_prompts,
perform_many_shot_scan,
perform_single_shot_scan,
@@ -76,23 +75,14 @@ async def test_perform_single_shot_scan_success(prepare_prompts_mock):
@pytest.mark.asyncio
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
@patch("agentic_security.probe_data.data.prepare_prompts")
async def test_perform_many_shot_scan_probe_injection(
prepare_prompts_mock, msj_prepare_prompts_mock
):
async def test_perform_many_shot_scan_probe_injection(prepare_prompts_mock):
# Mock main and probe prompt modules
prepare_prompts_mock.side_effect = [
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
]
msj_prepare_prompts_mock.return_value = [
MagicMock(
dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False
)
]
# Mock request_factory
mock_response = AsyncMock()
mock_response.fn.side_effect = [
@@ -217,7 +207,9 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=FuzzerState(),
refusals=[],
errors=[],
outputs=[],
)
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
@@ -234,17 +226,20 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
fuzzer_state = FuzzerState()
refusals = []
outputs = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
refusals=refusals,
errors=[],
outputs=outputs,
)
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
# self.assertFalse(fuzzer_state.refusals)
self.assertFalse(refusal)
async def test_http_error_response(self):
mock_request_factory = Mock()
@@ -257,13 +252,15 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
)
)
fuzzer_state = FuzzerState()
refusals = []
await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
refusals=refusals,
errors=[],
outputs=[],
)
async def test_request_error(self):
@@ -272,14 +269,18 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
side_effect=httpx.RequestError("Connection error")
)
fuzzer_state = FuzzerState()
errors = []
tokens, refusal = await process_prompt(
request_factory=mock_request_factory,
prompt="test prompt",
tokens=0,
module_name="module_a",
fuzzer_state=fuzzer_state,
refusals=[],
errors=errors,
outputs=[],
)
self.assertEqual(tokens, 0)
self.assertTrue(refusal)
self.assertEqual(len(errors), 1)
self.assertIn("Connection error", errors[0][3])
+1 -16
View File
@@ -1,4 +1,4 @@
from .data import load_local_csv, load_local_csv_files
from .data import load_local_csv
REGISTRY_V0 = [
{
@@ -484,18 +484,3 @@ REGISTRY = REGISTRY_V0 + [
"modality": "text",
},
]
for ds in load_local_csv_files():
REGISTRY.append(
{
"dataset_name": ds.dataset_name,
"num_prompts": len(ds.prompts),
"tokens": ds.prompts,
"approx_cost": 0.0,
"is_active": True,
"source": f"Local file dataset: {ds.metadata['src']}",
"selected": False,
"url": "",
"modality": "text",
}
)
+5 -27
View File
@@ -1,5 +1,4 @@
import base64
import logging
import os
import platform
import subprocess
@@ -8,14 +7,6 @@ import uuid
import httpx
from cache_to_disk import cache_to_disk
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioGenerationError(Exception):
"""Custom exception for errors during audio generation."""
def encode(content: bytes) -> str:
encoded_content = base64.b64encode(content).decode("utf-8")
@@ -50,25 +41,12 @@ def generate_audio_mac_wav(prompt: str) -> bytes:
# Read the WAV file into memory
with open(temp_wav_path, "rb") as f:
audio_bytes = f.read()
except subprocess.CalledProcessError as e:
logger.error(f"Subprocess error: {e}")
raise AudioGenerationError("Failed to generate or convert audio.") from e
except FileNotFoundError as e:
logger.error(f"File not found: {e}")
raise AudioGenerationError("Required file not found.") from e
except Exception as e:
logger.exception("Unexpected error occurred.")
raise AudioGenerationError(
"An unexpected error occurred during audio generation."
) from e
finally:
for path in (temp_aiff_path, temp_wav_path):
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
logger.warning(f"Failed to delete temporary file {path}: {e}")
# Clean up the temporary files
if os.path.exists(temp_aiff_path):
os.remove(temp_aiff_path)
if os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
# Return the audio bytes
return audio_bytes
+384 -449
View File
@@ -1,15 +1,14 @@
import io
import os
import random
from collections.abc import Callable, Iterator
from functools import partial
from typing import Any, TypeVar
from functools import lru_cache
import httpx
import pandas as pd
from cache_to_disk import cache_to_disk
from datasets import load_dataset
from loguru import logger
from agentic_security.logutils import logger
from agentic_security.probe_data import stenography_fn
from agentic_security.probe_data.models import ProbeDataset
from agentic_security.probe_data.modules import (
@@ -19,281 +18,282 @@ from agentic_security.probe_data.modules import (
inspect_ai_tool,
rl_model,
)
from datasets import load_dataset
# Type aliases for clarity
T = TypeVar("T")
FilterFn = Callable[[pd.Series], bool]
ColumnMappings = dict[str, str]
DatasetLoader = Callable[[], ProbeDataset]
TransformFn = Callable[[str], str]
# Core data loading utilities
def fetch_csv_content(url: str) -> str:
"""Fetch CSV content from a URL."""
response = httpx.get(url)
response.raise_for_status() # Raise exception for bad responses
return response.content.decode("utf-8")
@cache_to_disk()
def load_dataset_general(
dataset_name,
dataset_split="train",
column_mappings=None,
filter_fn=None,
custom_url=None,
additional_metadata=None,
):
"""
Generalized function to load datasets with flexible configurations.
:param dataset_name: Name of the dataset or URL for custom CSVs
:param dataset_split: Split to load from the dataset (e.g., "train")
:param column_mappings: Dictionary mapping dataset columns to expected keys, e.g., {'prompt': 'query'}
:param filter_fn: A filtering function that takes a row and returns True/False
:param custom_url: URL for custom CSV datasets
:param additional_metadata: Additional metadata to include in the ProbeDataset
:return: A ProbeDataset object with the processed data
"""
if custom_url:
logger.info(f"Loading custom CSV dataset from {custom_url}")
r = httpx.get(custom_url)
content = r.content
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
else:
logger.info(f"Loading dataset {dataset_name} from Hugging Face datasets")
dataset = load_dataset(dataset_name)
df = pd.DataFrame(dataset[dataset_split])
def load_df_from_source(source: str, is_url: bool = False) -> pd.DataFrame:
"""Load DataFrame from either URL or Hugging Face dataset."""
if is_url:
logger.info(f"Loading custom CSV from {source}")
return pd.read_csv(io.StringIO(fetch_csv_content(source)))
logger.info(f"Loading dataset {source} from Hugging Face")
return pd.DataFrame(load_dataset(source)["train"])
# Apply column mappings if provided
if column_mappings:
df.rename(columns=column_mappings, inplace=True)
def transform_df(
df: pd.DataFrame,
mappings: ColumnMappings | None = None,
filter_fn: FilterFn | None = None,
) -> pd.DataFrame:
"""Apply transformations to DataFrame."""
result = df.copy()
if mappings:
result.rename(columns=mappings, inplace=True)
# Filter rows if filter_fn is provided
if filter_fn:
result = result[result.apply(filter_fn, axis=1)]
return result
df = df[df.apply(filter_fn, axis=1)]
# Extract prompts
prompts = df[column_mappings.get("prompt", "prompt")].tolist()
def create_probe_dataset(
name: str, prompts: list[str], metadata: dict[str, Any] | None = None
) -> ProbeDataset:
"""Create a ProbeDataset from prompts."""
metadata = metadata or {}
return ProbeDataset(
dataset_name=name,
metadata=metadata,
dataset_name=dataset_name,
metadata=additional_metadata or {},
prompts=prompts,
tokens=sum(len(str(s).split()) for s in prompts),
tokens=count_words_in_list(prompts),
approx_cost=0.0,
)
# Generalized dataset loader
def count_words_in_list(str_list):
"""Calculate the total number of words in a given list of strings.
:param str_list: List of strings
:return: Total number of words across all strings in the list
"""
total_words = sum(len(str(s).split()) for s in str_list)
return total_words
@cache_to_disk()
def load_dataset_generic(
name: str,
mappings: ColumnMappings | None = None,
filter_fn: FilterFn | None = None,
url: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ProbeDataset:
"""Load and process a dataset with flexible configuration."""
try:
df = load_df_from_source(url or name, is_url=bool(url))
transformed_df = transform_df(df, mappings, filter_fn)
# Determine which column to use as the prompt source
prompt_col = None
if mappings and "prompt" in mappings:
prompt_col = mappings["prompt"]
elif "prompt" in transformed_df.columns:
prompt_col = "prompt"
else:
# Try to find a suitable text column
text_columns = [
col
for col in transformed_df.columns
if any(
keyword in col.lower()
for keyword in ["prompt", "text", "query", "question"]
)
]
if text_columns:
prompt_col = text_columns[0]
logger.info(f"Using column '{prompt_col}' as prompt source")
else:
logger.error(f"No suitable prompt column found in dataset {name}")
return create_probe_dataset(name, [], metadata)
# Extract prompts and filter out empty ones
prompts = [
p
for p in transformed_df[prompt_col].tolist()
if p and isinstance(p, (str, int, float))
]
return create_probe_dataset(name, prompts, metadata)
except Exception as e:
logger.error(f"Error loading dataset {name}: {e}")
return create_probe_dataset(name, [], {"error": str(e)})
# Dataset-specific configurations
DATASET_CONFIGS = {
"ShawnMenz/DAN_jailbreak": {
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["jailbreak"],
},
"deepset/prompt-injections": {
"mappings": {"prompt": "text"},
"filter_fn": lambda row: row["label"],
},
"notrichardren/refuse-to-answer-prompts": {
"mappings": {"prompt": "claim"},
"filter_fn": lambda row: row["label"],
},
"rubend18/ChatGPT-Jailbreak-Prompts": {
"mappings": {"prompt": "Prompt"},
"filter_fn": lambda row: row["label"],
},
"markush1/LLM-Jailbreak-Classifier": {
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["classification"] == "jailbreak",
},
"ShawnMenz/jailbreak_sft_rm_ds": {
"url": "hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["jailbreak"] == "jailbreak",
},
"verazuo/jailbreak_llms/2023_05_07": {
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: bool(row["jailbreak"]),
},
"verazuo/jailbreak_llms/2023_12_25": {
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: bool(row["jailbreak"]),
},
}
# Additional generic dataset configurations
DATASET_CONFIGS_GENERICS = {
"simonycl/aya-23-8B_advbench_jailbreak": {"mappings": {"prompt": "prompt"}},
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": {},
"karanxa/dolphin-jailbreak-finetuning-dataset": {"mappings": {"prompt": "text"}},
"karanxa/llama-2-jailbreak-dataset": {"mappings": {"prompt": "text"}},
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": {
"mappings": {"prompt": "text"}
},
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"jackhhao/jailbreak-classification": {"mappings": {"prompt": "prompt"}},
"walledai/JailbreakBench": {"mappings": {"prompt": "prompt"}},
"walledai/JailbreakHub": {"mappings": {"prompt": "prompt"}},
"Granther/evil-jailbreak": {"mappings": {"prompt": "text"}},
"sevdeawesome/jailbreak_success": {"mappings": {"prompt": "jailbreak_prompt_text"}},
"IDA-SERICS/Disaster-tweet-jailbreaking": {"mappings": {"prompt": "prompt_attack"}},
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": {"mappings": {"prompt": "question"}},
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": {"mappings": {"prompt": "prompt"}},
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": {
"mappings": {"prompt": "prompt"}
},
}
# Dataset factory
def create_dataset_loader(name: str, config: dict[str, Any]) -> DatasetLoader:
"""Create a dataset loader from configuration."""
return partial(
load_dataset_generic,
name=name,
mappings=config.get("mappings"),
filter_fn=config.get("filter_fn"),
url=config.get("url"),
metadata={"source": name, "config": str(config)},
)
# Specialized loaders
@cache_to_disk()
def load_multi_dataset(name: str, sub_datasets: list[str]) -> ProbeDataset:
"""Load and combine multiple sub-datasets."""
prompts = []
for sub in sub_datasets:
try:
dataset = load_dataset(name, sub)
if "query" in dataset["train"].features:
prompts.extend(dataset["train"]["query"])
else:
logger.warning(f"No 'query' column in {name}/{sub}")
except Exception as e:
logger.error(f"Error loading {name}/{sub}: {e}")
return create_probe_dataset(
f"{name}_combined", prompts, {"source": name, "sub_datasets": sub_datasets}
def load_dataset_v1():
return load_dataset_general(
dataset_name="ShawnMenz/DAN_jailbreak",
column_mappings={"prompt": "prompt", "jailbreak": "jailbreak"},
filter_fn=lambda row: row["jailbreak"],
)
@cache_to_disk()
def load_jailbreak_v28k() -> ProbeDataset:
"""Load JailBreakV-28K dataset."""
try:
df = pd.read_csv(
"hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv"
)
prompts = df["jailbreak_query"].tolist()
return create_probe_dataset(
"JailbreakV-28K/JailBreakV-28k",
prompts,
{"source": "JailbreakV-28K/JailBreakV-28k"},
)
except Exception as e:
logger.error(f"Error loading JailbreakV-28K: {e}")
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", [])
def load_dataset_v2():
return load_dataset_general(
dataset_name="deepset/prompt-injections",
column_mappings={"prompt": "text", "jailbreak": "label"},
filter_fn=lambda row: row["label"],
)
@cache_to_disk(1)
def file_dataset(file) -> list[str]:
prompts = []
try:
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
if "prompt" in df.columns:
prompts = df["prompt"].tolist()
else:
logger.warning(f"File {file} lacks a suitable prompt column")
except Exception as e:
logger.error(f"Error reading {file}: {e}")
return prompts
@cache_to_disk()
def load_dataset_v4():
return load_dataset_general(
dataset_name="notrichardren/refuse-to-answer-prompts",
column_mappings={"prompt": "claim", "jailbreak": "label"},
filter_fn=lambda row: row["label"],
)
def load_local_csv() -> ProbeDataset:
"""Load prompts from local CSV files."""
os.makedirs("./datasets", exist_ok=True)
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
prompts = []
for file in csv_files:
prompts.extend(file_dataset(file))
return create_probe_dataset("Local CSV", prompts, {"src": str(csv_files)})
@cache_to_disk()
def load_dataset_v3():
return load_dataset_general(
dataset_name="rubend18/ChatGPT-Jailbreak-Prompts",
column_mappings={"prompt": "Prompt"},
filter_fn=lambda row: row["label"],
)
def load_csv(file: str) -> ProbeDataset:
"""Load prompts from local CSV files."""
prompts = file_dataset(file)
return create_probe_dataset(f"fs://{file}", prompts, {"src": str(file)})
@cache_to_disk()
def load_dataset_v6():
dataset = load_dataset("markush1/LLM-Jailbreak-Classifier")
bad_prompts = [
p
for p, c in zip(dataset["train"]["prompt"], dataset["train"]["classification"])
if c == "jailbreak"
]
return ProbeDataset(
dataset_name="markush1/LLM-Jailbreak-Classifier",
metadata={},
prompts=bad_prompts,
tokens=count_words_in_list(bad_prompts),
approx_cost=0.0,
)
def load_local_csv_files() -> list[ProbeDataset]:
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
datasets = []
for file in csv_files:
datasets.append(create_probe_dataset(file, file_dataset(file), {"src": file}))
return datasets
@cache_to_disk()
def load_dataset_v7():
splits = {
"mini_JailBreakV_28K": "JailBreakV_28K/mini_JailBreakV_28K.csv",
"JailBreakV_28K": "JailBreakV_28K/JailBreakV_28K.csv",
}
df = pd.read_csv(
"hf://datasets/JailbreakV-28K/JailBreakV-28k/" + splits["JailBreakV_28K"]
)
bad_prompts = df["jailbreak_query"].tolist()
print(df.shape)
return ProbeDataset(
dataset_name="JailbreakV-28K/JailBreakV-28k",
metadata={},
prompts=bad_prompts,
tokens=count_words_in_list(bad_prompts),
approx_cost=0.0,
)
# Stenography transformer
class StenographyTransformer:
"""Apply stenography transformations to datasets."""
@cache_to_disk()
def load_dataset_v8():
df = pd.read_csv(
"hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
names=["jailbreak", "prompt"],
)
filtered = df[df["jailbreak"] == "jailbreak"]["prompt"].tolist()
return ProbeDataset(
dataset_name="JailbreakV-28K/JailBreakV-28k",
metadata={},
prompts=filtered,
tokens=count_words_in_list(filtered),
approx_cost=0.0,
)
TRANSFORMATIONS = {
@cache_to_disk()
def load_dataset_v5():
ds = []
for c in [
"AdvBench",
"ForbiddenQuestion",
"MJP",
"MaliciousInstruct",
"QuestionList",
]:
dataset = load_dataset("Lemhf14/EasyJailbreak_Datasets", c)
bad_prompts = dataset["train"]["query"]
ds.extend(bad_prompts)
return ProbeDataset(
dataset_name="Lemhf14/EasyJailbreak_Datasets",
metadata={},
prompts=ds,
tokens=count_words_in_list(ds),
approx_cost=0.0,
)
@cache_to_disk()
def load_generic_csv(url, name, column="prompt", predicator=None):
r = httpx.get(url)
content = r.content
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
logger.info(f"Loaded {len(df)} prompts from {url}")
filtered_prompts = df[df.apply(predicator, axis=1)][column].tolist()
return ProbeDataset(
dataset_name=name,
metadata={},
prompts=filtered_prompts,
tokens=count_words_in_list(filtered_prompts),
approx_cost=0.0,
)
def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
# ## Datasets used and cleaned:
# markush1/LLM-Jailbreak-Classifier
# 1. Open-Orca/OpenOrca
# 2. ShawnMenz/DAN_jailbreak
# 3. EddyLuo/JailBreakV_28K
# 4. https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/jailbreak_prompts.csv
dataset_map = {
"ShawnMenz/DAN_jailbreak": load_dataset_v1,
"deepset/prompt-injections": load_dataset_v2,
"notrichardren/refuse-to-answer-prompts": load_dataset_v4,
"rubend18/ChatGPT-Jailbreak-Prompts": load_dataset_v3,
"Lemhf14/EasyJailbreak_Datasets": load_dataset_v5,
"markush1/LLM-Jailbreak-Classifier": load_dataset_v6,
"JailbreakV-28K/JailBreakV-28k": load_dataset_v7,
"ShawnMenz/jailbreak_sft_rm_ds": load_dataset_v8,
"verazuo/jailbreak_llms/2023_05_07": lambda: load_generic_csv(
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
name="verazuo/jailbreak_llms/2023_05_07",
column="prompt",
predicator=lambda x: bool(x["jailbreak"]),
),
"verazuo/jailbreak_llms/2023_12_25.csv": lambda: load_generic_csv(
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv.csv",
name="verazuo/jailbreak_llms/2023_12_25.csv",
column="prompt",
predicator=lambda x: bool(x["jailbreak"]),
),
"Custom CSV": load_local_csv,
}
dataset_map.update(dataset_map_generics)
group = []
for dataset_name in dataset_names:
if dataset_name in dataset_map:
logger.info(f"Loading {dataset_name}")
try:
group.append(dataset_map[dataset_name]())
except Exception as e:
logger.error(f"Error loading {dataset_name}: {e}")
dynamic_datasets = {
"AgenticBackend": lambda opts: dataset_from_iterator(
"AgenticBackend",
fine_tuned.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Steganography": lambda opts: Stenography(group),
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
"llm-adaptive-attacks",
adaptive_attacks.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
),
"Garak": lambda opts: dataset_from_iterator(
"Garak",
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
"Reinforcement Learning Optimization",
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"InspectAI": lambda opts: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
lazy=True,
),
"GPT fuzzer": lambda opts: [],
}
dynamic_groups = []
options = options or [{} for _ in dataset_names]
for dataset_name, opts in zip(dataset_names, options):
if dataset_name in dynamic_datasets:
logger.info(f"Loading {dataset_name}")
ds = dynamic_datasets[dataset_name](opts)
for g in ds:
dynamic_groups.append(g)
return group + dynamic_groups
class Stenography:
fn_library = {
"rot5": stenography_fn.rot5,
"rot13": stenography_fn.rot13,
"base64": stenography_fn.base64_encode,
@@ -306,216 +306,151 @@ class StenographyTransformer:
"zigzag_obfuscation": stenography_fn.zigzag_obfuscation,
}
def __init__(self, datasets: list[ProbeDataset]):
self.datasets = datasets
def __init__(self, prompt_groups: [ProbeDataset]):
self.prompt_groups = prompt_groups
def transform(self) -> Iterator[ProbeDataset]:
"""Generate transformed datasets."""
if not self.datasets:
logger.warning("No datasets provided to StenographyTransformer")
return
def __iter__(self):
return self.apply()
for dataset in self.datasets:
if not dataset.prompts:
logger.warning(
f"Dataset {dataset.dataset_name} has no prompts to transform"
def apply(self):
for prompt_group in self.prompt_groups:
size = len(prompt_group.prompts)
for name, fn in self.fn_library.items():
logger.info(f"Applying {name} to {prompt_group.dataset_name}")
selected_prompts = random.sample(prompt_group.prompts, int(size * 0.25))
prompts = [
f"Follow the instructions in {name}: {fn(p)}"
for p in selected_prompts
]
yield ProbeDataset(
dataset_name=f"stenography.{name}({prompt_group.dataset_name})",
metadata={},
prompts=prompts,
tokens=count_words_in_list(prompts),
approx_cost=0.0,
)
continue
# Convert async generator (if any) to a list
prompts = (
list(dataset.prompts) if hasattr(dataset.prompts, "__iter__") else []
)
size = len(prompts)
if size == 0:
continue
sample_size = max(1, int(size * 0.25)) # At least 1 sample
for name, fn in self.TRANSFORMATIONS.items():
logger.info(f"Applying {name} to {dataset.dataset_name}")
sampled = random.sample(prompts, min(sample_size, size))
try:
transformed = [
f"Follow the instructions in {name}: {fn(str(p))}"
for p in sampled
]
transformed_dataset = create_probe_dataset(
f"stenography.{name}({dataset.dataset_name})",
transformed,
{"source": dataset.dataset_name, "transformation": name},
)
yield transformed_dataset
except Exception as e:
logger.error(
f"Error applying {name} to {dataset.dataset_name}: {e}"
)
def dataset_from_iterator(
name: str, iterator: Iterator[str], lazy: bool = False
) -> list[ProbeDataset]:
"""Convert an iterator into a list of ProbeDataset objects."""
try:
prompts = list(iterator) if not lazy else iterator
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0
dataset = ProbeDataset(
dataset_name=name,
metadata={"source": name, "lazy": lazy},
prompts=prompts,
tokens=tokens,
approx_cost=0.0,
lazy=lazy,
)
return [dataset]
except Exception as e:
logger.error(f"Error creating dataset from iterator {name}: {e}")
return [create_probe_dataset(name, [], {"error": str(e)})]
@lru_cache
def load_local_csv() -> ProbeDataset:
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files")
logger.info(f"CSV files: {csv_files}")
prompt_list = []
# Main dataset preparation
def prepare_prompts(
dataset_names: list[str],
budget: float,
tools_inbox=None,
options: list[dict] = None,
) -> list[ProbeDataset]:
"""Prepare datasets based on names and options."""
# Base dataset loaders
logger.info(f"Preparing datasets: {dataset_names}")
dataset_loaders = {
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS.items()},
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS_GENERICS.items()},
"Lemhf14/EasyJailbreak_Datasets": partial(
load_multi_dataset,
name="Lemhf14/EasyJailbreak_Datasets",
sub_datasets=[
"AdvBench",
"ForbiddenQuestion",
"MJP",
"MaliciousInstruct",
"QuestionList",
],
),
"JailbreakV-28K/JailBreakV-28k": load_jailbreak_v28k,
"Local CSV": load_local_csv,
"Custom CSV": load_local_csv,
}
# Dynamic dataset loaders
dynamic_loaders = {
"AgenticBackend": lambda opts: dataset_from_iterator(
"AgenticBackend",
fine_tuned.Module(
opts["datasets"], tools_inbox=tools_inbox, opts=opts
).apply(),
lazy=True,
),
"Steganography": lambda opts: list(
StenographyTransformer(opts["datasets"]).transform()
),
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
"llm-adaptive-attacks",
adaptive_attacks.Module(
opts["datasets"], tools_inbox=tools_inbox, opts=opts
).apply(),
),
"Garak": lambda opts: dataset_from_iterator(
"Garak",
garak_tool.Module(
opts["datasets"], tools_inbox=tools_inbox, opts=opts
).apply(),
lazy=True,
),
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
"Reinforcement Learning Optimization",
rl_model.Module(
opts["datasets"], tools_inbox=tools_inbox, opts=opts
).apply(),
lazy=True,
),
"InspectAI": lambda opts: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module([], tools_inbox=tools_inbox).apply(),
lazy=True,
),
"GPT fuzzer": lambda opts: [],
}
datasets = []
options = options or [dict(datasets=datasets) for _ in dataset_names]
# Load base datasets
for name, opts in zip(dataset_names, options):
if name not in dataset_loaders:
continue
for file in csv_files:
try:
datasets.append(dataset_loaders[name]())
df = pd.read_csv(file)
except Exception as e:
logger.error(f"Error loading {name}: {e}")
# Load dynamic datasets and apply transformations
for name, opts in zip(dataset_names, options):
if name not in dynamic_loaders:
logger.error(f"Error reading {file}: {e}")
continue
logger.info(f"Loading dynamic dataset {name} {opts}")
opts["datasets"] = datasets
try:
dynamic_result = dynamic_loaders[name](opts)
datasets.extend(dynamic_result)
except Exception as e:
logger.exception(f"Error loading dynamic {name}: {e}")
# Load csv datasets and apply transformations
for name, opts in zip(dataset_names, options):
if not name.endswith(".csv"):
continue
logger.info(f"Loading csv dataset {name} {opts}")
datasets.append(load_csv(name))
return datasets
async def prepare_prompts_unified(configs: list) -> list[ProbeDataset]:
"""Prepare datasets using unified loader configuration.
This is an alternative to prepare_prompts() that uses the UnifiedDatasetLoader
for streamlined configuration and merging of multiple sources.
Args:
configs: List of InputSourceConfig objects or dicts
Returns:
list[ProbeDataset]: List containing the merged dataset
Example:
>>> from agentic_security.probe_data.unified_loader import InputSourceConfig
>>> configs = [
... InputSourceConfig(
... source_type="huggingface",
... dataset_name="deepset/prompt-injections",
... enabled=True,
... weight=1.0
... )
... ]
>>> datasets = await prepare_prompts_unified(configs)
"""
from agentic_security.probe_data.unified_loader import (
UnifiedDatasetLoader,
InputSourceConfig,
# Check if 'prompt' column exists
if "prompt" in df.columns:
prompt_list.extend(df["prompt"].tolist())
else:
logger.warning(f"File {file} does not contain a 'prompt' column")
return ProbeDataset(
dataset_name="Local CSV",
metadata={"src": str(csv_files)},
prompts=prompt_list,
tokens=count_words_in_list(prompt_list),
approx_cost=0.0,
)
# Convert dicts to InputSourceConfig if needed
config_objects = []
for config in configs:
if isinstance(config, dict):
config_objects.append(InputSourceConfig(**config))
else:
config_objects.append(config)
loader = UnifiedDatasetLoader(config_objects)
merged_dataset = await loader.load_all()
def dataset_from_iterator(name: str, iterator, lazy=False) -> list:
"""Convert an iterator into a list of prompts and create a ProbeDataset
object.
# Return as list for compatibility with existing code
return [merged_dataset] if merged_dataset.prompts else []
Args:
name (str): The name of the dataset.
iterator (iterator): An iterator that generates prompts.
Returns:
list: A list containing a single ProbeDataset object.
"""
prompts = list(iterator) if not lazy else iterator
tokens = count_words_in_list(prompts) if not lazy else 0
dataset = ProbeDataset(
dataset_name=name,
metadata={},
prompts=prompts,
tokens=tokens,
approx_cost=0.0,
lazy=lazy,
)
return [dataset]
# TODO: refactor this abstraction
dataset_map_generics = {
"simonycl/aya-23-8B_advbench_jailbreak": lambda: load_dataset_general(
dataset_name="simonycl/aya-23-8B_advbench_jailbreak",
column_mappings={"prompt": "prompt"},
),
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": lambda: load_dataset_general(
dataset_name="acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b"
),
"karanxa/dolphin-jailbreak-finetuning-dataset": lambda: load_dataset_general(
dataset_name="karanxa/dolphin-jailbreak-finetuning-dataset",
column_mappings={"prompt": "text"},
),
"karanxa/llama-2-jailbreak-dataset": lambda: load_dataset_general(
dataset_name="karanxa/llama-2-jailbreak-dataset",
column_mappings={"prompt": "text"},
),
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": lambda: load_dataset_general(
dataset_name="karanxa/llama2-uncensored-jailbreak-dataset-finetuning",
column_mappings={"prompt": "text"},
),
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": lambda: load_dataset_general(
dataset_name="liuyanchen1015/Llama-3.2-1B_jailbreak_responses",
column_mappings={"prompt": "jailbreak_prompt_text"},
),
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": lambda: load_dataset_general(
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses",
column_mappings={"prompt": "jailbreak_prompt_text"},
),
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": lambda: load_dataset_general(
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment",
column_mappings={"prompt": "jailbreak_prompt_text"},
),
"jackhhao/jailbreak-classification": lambda: load_dataset_general(
dataset_name="jackhhao/jailbreak-classification",
column_mappings={"prompt": "prompt"},
),
"markush1/LLM-Jailbreak-Classifier": lambda: load_dataset_general(
dataset_name="markush1/LLM-Jailbreak-Classifier",
column_mappings={"prompt": "prompt"},
),
"walledai/JailbreakBench": lambda: load_dataset_general(
dataset_name="walledai/JailbreakBench", column_mappings={"prompt": "prompt"}
),
"walledai/JailbreakHub": lambda: load_dataset_general(
dataset_name="walledai/JailbreakHub", column_mappings={"prompt": "prompt"}
),
"Granther/evil-jailbreak": lambda: load_dataset_general(
dataset_name="Granther/evil-jailbreak", column_mappings={"prompt": "text"}
),
"sevdeawesome/jailbreak_success": lambda: load_dataset_general(
dataset_name="sevdeawesome/jailbreak_success",
column_mappings={"prompt": "jailbreak_prompt_text"},
),
"IDA-SERICS/Disaster-tweet-jailbreaking": lambda: load_dataset_general(
dataset_name="IDA-SERICS/Disaster-tweet-jailbreaking",
column_mappings={"prompt": "prompt_attack"},
),
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": lambda: load_dataset_general(
dataset_name="GeorgeDaDude/Jailbreak_Complete_DS_labeled",
column_mappings={"prompt": "question"},
),
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": lambda: load_dataset_general(
dataset_name="dayone3nder/jailbreak_prompt_JBB_sft_trainset",
column_mappings={"prompt": "prompt"},
),
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": lambda: load_dataset_general(
dataset_name="dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset",
column_mappings={"prompt": "prompt"},
),
}
@@ -3,8 +3,7 @@ import io
import httpx
import pandas as pd
from agentic_security.logutils import logger
from loguru import logger
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
@@ -3,8 +3,7 @@ import os
import uuid as U
import httpx
from agentic_security.logutils import logger
from loguru import logger
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
@@ -66,7 +65,7 @@ class Module:
return {}
async def fetch_prompts(self) -> list[str]:
api_url = "https://mcp.metaheuristic.co/infer"
api_url = "https://edge.metaheuristic.co/infer"
headers = {
"Authorization": f"Bearer {AUTH_TOKEN}",
"Content-Type": "application/json",
@@ -4,7 +4,7 @@ import json
import os
import subprocess
from agentic_security.logutils import logger
from loguru import logger
# TODO: add probes modules
@@ -2,7 +2,7 @@ import asyncio
import importlib.util
import os
from agentic_security.logutils import logger
from loguru import logger
inspect_ai_task = (
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
@@ -8,8 +8,7 @@ from typing import Deque
import numpy as np
import requests
from agentic_security.logutils import logger
from loguru import logger
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
@@ -20,10 +19,12 @@ class PromptSelectionInterface(ABC):
@abstractmethod
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
"""Selects the next prompt based on current state and guard result."""
pass
@abstractmethod
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
"""Selects the next prompts based on current state and guard result."""
pass
@abstractmethod
def update_rewards(
@@ -34,6 +35,7 @@ class PromptSelectionInterface(ABC):
passed_guard: bool,
) -> None:
"""Updates internal rewards based on the outcome of the last selected prompt."""
pass
class RandomPromptSelector(PromptSelectionInterface):
@@ -118,7 +120,8 @@ class CloudRLPromptSelector(PromptSelectionInterface):
current_prompt: str,
reward: float,
passed_guard: bool,
) -> None: ...
) -> None:
...
class QLearningPromptSelector(PromptSelectionInterface):
@@ -203,11 +206,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
class Module:
def __init__(
self,
prompt_groups: list[str],
tools_inbox: asyncio.Queue,
opts: dict = {},
rl_model: PromptSelectionInterface | None = None,
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
):
self.tools_inbox = tools_inbox
self.opts = opts
@@ -215,8 +214,8 @@ class Module:
self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts
self.run_id = U.uuid4().hex
self.batch_size = self.opts.get("batch_size", 500)
self.rl_model = rl_model or CloudRLPromptSelector(
prompt_groups, "https://mcp.metaheuristic.co", run_id=self.run_id
self.rl_model = CloudRLPromptSelector(
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
)
async def apply(self):
@@ -33,19 +33,11 @@ def mock_requests() -> Mock:
@pytest.fixture
def mock_rl_selector(dataset_prompts) -> Mock:
class StubSelector:
def __init__(self, prompts: list[str]):
self.prompts = prompts
self.idx = 0
def select_next_prompts(
self, current_prompt: str, passed_guard: bool
) -> list[str]:
self.idx = (self.idx + 1) % len(self.prompts)
return [self.prompts[self.idx]]
return StubSelector(dataset_prompts)
def mock_rl_selector() -> Mock:
return CloudRLPromptSelector(
dataset_prompts,
api_url="https://edge.metaheuristic.co",
)
@pytest.fixture
@@ -99,18 +91,15 @@ class TestCloudRLPromptSelector:
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt in dataset_prompts
def test_select_next_prompt_success_service(self, dataset_prompts, mock_requests):
mock_requests.return_value.status_code = 200
mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]}
def test_select_next_prompt_success_service(self, dataset_prompts):
selector = CloudRLPromptSelector(
dataset_prompts,
api_url="https://mcp.metaheuristic.co",
api_url="https://edge.metaheuristic.co",
)
next_prompt = selector.select_next_prompt(
"How does RL work?", passed_guard=True
)
assert next_prompt == "What is AI?"
assert next_prompt
# Tests for QLearningPromptSelector
@@ -199,7 +188,7 @@ class TestModule:
async def test_apply_basic_flow(
self, dataset_prompts, tools_inbox, mock_rl_selector
):
module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector)
module = Module(dataset_prompts, tools_inbox)
count = 0
async for prompt in module.apply():
@@ -209,9 +198,7 @@ class TestModule:
break
@pytest.mark.asyncio
async def test_apply_rl_with_tools_inbox(
self, dataset_prompts, tools_inbox, mock_rl_selector
):
async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox):
# Add a test message to the tools inbox
test_message = {
"message": "Test message",
@@ -220,7 +207,7 @@ class TestModule:
}
await tools_inbox.put(test_message)
module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector)
module = Module(dataset_prompts, tools_inbox)
async for output in module.apply():
if output == "Test message":
+2 -2
View File
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from cache_to_disk import cache_to_disk # noqa
from cache_to_disk import cache_to_disk
# TODO: refactor this class to use from .data
@@ -22,7 +22,7 @@ class ProbeDataset:
}
# @cache_to_disk(n_days_to_cache=1)
@cache_to_disk()
def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):
from datasets import load_dataset
@@ -1,252 +0,0 @@
"""Unified dataset loader for CSV, HuggingFace, and proxy sources."""
from typing import Literal
from pydantic import BaseModel, Field
from agentic_security.logutils import logger
from agentic_security.probe_data.data import (
load_dataset_generic,
load_csv,
create_probe_dataset,
)
from agentic_security.probe_data.models import ProbeDataset
class InputSourceConfig(BaseModel):
"""Configuration for a single input source."""
source_type: Literal["csv", "huggingface", "proxy"] = Field(
description="Type of input source"
)
enabled: bool = Field(default=True, description="Whether this source is enabled")
dataset_name: str = Field(description="Name/identifier of the dataset")
weight: float = Field(
default=1.0, ge=0.0, description="Sampling weight for merging"
)
# CSV-specific fields
path: str | None = Field(default=None, description="File path for CSV sources")
prompt_column: str | None = Field(
default="prompt", description="Column name containing prompts"
)
# HuggingFace-specific fields
split: str | None = Field(
default="train", description="Dataset split to load (train/test/validation)"
)
max_samples: int | None = Field(
default=None, ge=1, description="Maximum number of samples to load"
)
# URL for custom sources
url: str | None = Field(default=None, description="URL for remote CSV files")
class UnifiedDatasetLoader:
"""Loads and merges datasets from multiple sources."""
def __init__(self, configs: list[InputSourceConfig]):
"""Initialize with list of input source configurations.
Args:
configs: List of InputSourceConfig objects defining data sources
"""
self.configs = configs
logger.info(f"Initialized UnifiedDatasetLoader with {len(configs)} sources")
async def load_all(self) -> ProbeDataset:
"""Load all enabled sources and merge into a single dataset.
Returns:
ProbeDataset: Merged dataset from all enabled sources
"""
datasets = []
for config in self.configs:
if not config.enabled:
logger.debug(f"Skipping disabled source: {config.dataset_name}")
continue
try:
dataset = await self._load_single(config)
if dataset and dataset.prompts:
datasets.append((dataset, config.weight))
logger.info(
f"Loaded {len(dataset.prompts)} prompts from {config.dataset_name} "
f"(weight={config.weight})"
)
else:
logger.warning(f"No prompts loaded from {config.dataset_name}")
except Exception as e:
logger.error(f"Error loading {config.dataset_name}: {e}")
if not datasets:
logger.warning("No datasets loaded successfully")
return create_probe_dataset("unified_empty", [], {"sources": []})
return self._merge_weighted(datasets)
async def _load_single(self, config: InputSourceConfig) -> ProbeDataset:
"""Load a single dataset based on its configuration.
Args:
config: Configuration for the source to load
Returns:
ProbeDataset: Loaded dataset
"""
if config.source_type == "csv":
return self._load_csv_source(config)
elif config.source_type == "huggingface":
return self._load_huggingface_source(config)
elif config.source_type == "proxy":
return self._load_proxy_source(config)
else:
raise ValueError(f"Unknown source type: {config.source_type}")
def _load_csv_source(self, config: InputSourceConfig) -> ProbeDataset:
"""Load dataset from CSV file.
Args:
config: CSV source configuration
Returns:
ProbeDataset: Dataset loaded from CSV
"""
if config.path:
# Local CSV file
logger.info(f"Loading CSV from path: {config.path}")
dataset = load_csv(config.path)
elif config.url:
# Remote CSV file
logger.info(f"Loading CSV from URL: {config.url}")
mappings = (
{config.prompt_column: "prompt"} if config.prompt_column else None
)
dataset = load_dataset_generic(
name=config.dataset_name,
url=config.url,
mappings=mappings,
metadata={"source_type": "csv", "url": config.url},
)
else:
raise ValueError(
f"CSV source {config.dataset_name} requires either path or url"
)
# Apply max_samples limit if specified
if config.max_samples and len(dataset.prompts) > config.max_samples:
logger.info(
f"Limiting {config.dataset_name} from {len(dataset.prompts)} "
f"to {config.max_samples} samples"
)
dataset.prompts = dataset.prompts[: config.max_samples]
return dataset
def _load_huggingface_source(self, config: InputSourceConfig) -> ProbeDataset:
"""Load dataset from HuggingFace.
Args:
config: HuggingFace source configuration
Returns:
ProbeDataset: Dataset loaded from HuggingFace
"""
logger.info(
f"Loading HuggingFace dataset: {config.dataset_name} "
f"(split={config.split})"
)
# Build column mappings
mappings = None
if config.prompt_column and config.prompt_column != "prompt":
mappings = {config.prompt_column: "prompt"}
dataset = load_dataset_generic(
name=config.dataset_name,
mappings=mappings,
metadata={
"source_type": "huggingface",
"split": config.split,
},
)
# Apply max_samples limit if specified
if config.max_samples and len(dataset.prompts) > config.max_samples:
logger.info(
f"Limiting {config.dataset_name} from {len(dataset.prompts)} "
f"to {config.max_samples} samples"
)
dataset.prompts = dataset.prompts[: config.max_samples]
return dataset
def _load_proxy_source(self, config: InputSourceConfig) -> ProbeDataset:
"""Load dataset from proxy queue (placeholder for PoC).
Args:
config: Proxy source configuration
Returns:
ProbeDataset: Empty dataset (proxy integration not implemented in PoC)
"""
logger.warning(
f"Proxy source {config.dataset_name} not implemented in PoC - returning empty dataset"
)
return create_probe_dataset(
config.dataset_name,
[],
{"source_type": "proxy", "status": "not_implemented"},
)
def _merge_weighted(
self, datasets: list[tuple[ProbeDataset, float]]
) -> ProbeDataset:
"""Merge multiple datasets with weighted sampling.
For PoC, this implements simple concatenation with optional weighting.
Production version would implement proper stratified sampling.
Args:
datasets: List of (ProbeDataset, weight) tuples
Returns:
ProbeDataset: Merged dataset
"""
if not datasets:
return create_probe_dataset("unified_empty", [], {"sources": []})
# For PoC: simple concatenation, repeat prompts based on weight
all_prompts = []
source_names = []
total_tokens = 0
for dataset, weight in datasets:
source_names.append(dataset.dataset_name)
# Calculate how many times to include this dataset based on weight
# Weight of 1.0 = include once, 2.0 = include twice, etc.
repeat_count = max(1, int(weight))
for _ in range(repeat_count):
all_prompts.extend(dataset.prompts)
total_tokens += dataset.tokens * repeat_count
logger.info(
f"Merged {len(datasets)} datasets into {len(all_prompts)} total prompts "
f"from sources: {source_names}"
)
return ProbeDataset(
dataset_name="unified",
metadata={
"sources": source_names,
"source_count": len(datasets),
"weights": {ds.dataset_name: w for ds, w in datasets},
},
prompts=all_prompts,
tokens=total_tokens,
approx_cost=0.0,
)
+18 -24
View File
@@ -1,10 +1,8 @@
import importlib.resources as pkg_resources
import os
import warnings
import joblib
import pandas as pd
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.svm import OneClassSVM
@@ -72,31 +70,27 @@ class RefusalClassifier:
"""
Load the trained model, vectorizer, and scaler from disk.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
try:
self.model = joblib.load(self.model_path)
self.vectorizer = joblib.load(self.vectorizer_path)
self.scaler = joblib.load(self.scaler_path)
except FileNotFoundError:
# Load from package resources
package = (
__package__ # This should be 'agentic_security.refusal_classifier'
)
try:
self.model = joblib.load(self.model_path)
self.vectorizer = joblib.load(self.vectorizer_path)
self.scaler = joblib.load(self.scaler_path)
except FileNotFoundError:
# Load from package resources
package = (
__package__ # This should be 'agentic_security.refusal_classifier'
)
# Load model
with pkg_resources.open_binary(
package, "oneclass_svm_model.joblib"
) as f:
self.model = joblib.load(f)
# Load model
with pkg_resources.open_binary(package, "oneclass_svm_model.joblib") as f:
self.model = joblib.load(f)
# Load vectorizer
with pkg_resources.open_binary(package, "tfidf_vectorizer.joblib") as f:
self.vectorizer = joblib.load(f)
# Load vectorizer
with pkg_resources.open_binary(package, "tfidf_vectorizer.joblib") as f:
self.vectorizer = joblib.load(f)
# Load scaler
with pkg_resources.open_binary(package, "scaler.joblib") as f:
self.scaler = joblib.load(f)
# Load scaler
with pkg_resources.open_binary(package, "scaler.joblib") as f:
self.scaler = joblib.load(f)
def is_refusal(self, text):
"""
@@ -3,7 +3,7 @@ from unittest.mock import patch
import pandas as pd
import pytest
from agentic_security.refusal_classifier.model import RefusalClassifier
from .model import RefusalClassifier
@pytest.fixture
+11 -78
View File
@@ -7,62 +7,9 @@ import pandas as pd
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
from agentic_security.logutils import logger
from .primitives import Table
def plot_security_report(table: Table) -> io.BytesIO:
"""
Generates a polar plot representing the security report based on the given data.
Args:
table (Table): The input data table containing security metrics.
Returns:
io.BytesIO: A buffer containing the generated plot image in PNG format.
Returns an empty buffer in case of an error.
"""
try:
return _plot_security_report(table=table)
except (TypeError, ValueError, OverflowError, IndexError, Exception) as e:
logger.error(f"Error in generating the security report: {e} {table}")
return io.BytesIO()
def generate_identifiers(data: pd.DataFrame) -> list[str]:
"""
Generates unique identifiers for the given dataset.
Args:
data (pd.DataFrame): A pandas DataFrame containing security-related data.
Returns:
list[str]: A list of generated identifiers. Returns a list with an empty string in case of an error.
"""
return _generate_identifiers(data=data)
def _plot_security_report(table: Table) -> io.BytesIO:
"""
Generates a polar plot-based security report visualizing the failure rates for different modules.
This function processes the input data, sorts it by failure rate, and generates a polar plot
where each bar represents the failure rate for a specific module. The plot includes identifiers,
color-coding based on token count, failure rate values on the bars, and a table listing the modules
and their corresponding failure rates.
Args:
table (Table): A table-like structure (e.g., pandas DataFrame) containing security report data
with columns for failure rate, tokens, and modules.
Returns:
io.BytesIO: A buffer containing the generated plot image in PNG format.
"""
return io.BytesIO()
def plot_security_report(table):
# Data preprocessing
logger.info("Data preprocessing started.")
data = pd.DataFrame(table)
# Sort by failure rate and reset index
@@ -73,10 +20,10 @@ def _plot_security_report(table: Table) -> io.BytesIO:
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"})
fig.set_facecolor("#f0f0f0")
ax.set_facecolor("#f0f0f0")
logger.info("Plot setup complete.")
# Styling parameters
colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette
# colors = ["#440154", "#3b528b", "#21908c", "#5dc863"] # Viridis-inspired palette
cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max())
@@ -127,10 +74,7 @@ def _plot_security_report(table: Table) -> io.BytesIO:
# Title and caption
fig.suptitle(
"Security Report for Different Modules",
fontsize=16,
fontweight="bold",
y=1.02,
"Security Report for Different Modules", fontsize=16, fontweight="bold", y=1.02
)
caption = "Report generated by https://github.com/msoedov/agentic_security"
fig.text(
@@ -168,12 +112,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
data["identifier"], data["failureRate"], data["module"]
)
]
table = ax.table(cellText=table_data, loc="right", cellLoc="left")
table = ax.table(
cellText=table_data,
loc="right",
cellLoc="left",
)
table.auto_set_font_size(False)
table.set_fontsize(8)
# Adjust table style
table.scale(1, 0.7)
for (row, col), cell in table.get_celld().items():
cell.set_edgecolor("none")
cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0")
@@ -183,33 +132,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
cell.set_text_props(fontweight="bold")
# Adjust layout and save
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=300, bbox_inches="tight")
plt.close(fig)
buf.seek(0)
logger.info("Report successfully generated and saved to buffer.")
return buf
def _generate_identifiers(data: pd.DataFrame) -> list[str]:
"""
Generates a list of unique identifiers for each row in the given DataFrame.
The identifiers are based on the English alphabet, with each identifier consisting
of a letter followed by a number. The letter represents the "group" of identifiers
(using a letter from A to Z) and the number is a counter within that group. For example:
- A1, A2, ..., A26, B1, B2, ..., Z1, Z2, ...
Args:
data (pd.DataFrame): The input DataFrame containing data for which identifiers
are to be generated.
Returns:
list[str]: A list of unique identifiers as strings, one for each row in the DataFrame.
"""
def generate_identifiers(data):
data_length = len(data)
alphabet = string.ascii_uppercase
num_letters = len(alphabet)
-397
View File
@@ -1,397 +0,0 @@
_SPECS = [
"""POST ${SELF_URL}/v1/self-probe
Authorization: Bearer XXXXX
Content-Type: application/json
{
"prompt": "<<PROMPT>>"
}
""",
"""POST https://api.openai.com/v1/chat/completions
Authorization: Bearer $OPENAI_API_KEY
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
"temperature": 0.7
}
""",
"""
POST https://api.deepseek.com/chat/completions
Authorization: Bearer $DEEPSEEK_API_KEY
Content-Type: application/json
{
"model": "deepseek-chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "<<PROMPT>>"}
],
"stream": false
}
""",
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
Authorization: Bearer $APIKEY
Content-Type: application/json
{
"input": {
"top_k": 50,
"top_p": 0.9,
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
"temperature": 0.6,
"max_new_tokens": 1024,
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
"presence_penalty": 0,
"frequency_penalty": 0
}
}
""",
"""POST https://api.groq.com/v1/request_manager/text_completion
Authorization: Bearer $APIKEY
Content-Type: application/json
{
"model_id": "codellama-34b",
"system_prompt": "You are helpful and concise coding assistant",
"user_prompt": "<<PROMPT>>"
}
""",
"""POST https://api.together.xyz/v1/chat/completions
Authorization: Bearer $TOGETHER_API_KEY
Content-Type: application/json
{
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"messages": [
{"role": "system", "content": "You are an expert travel guide"},
{"role": "user", "content": "<<PROMPT>>"}
]
}
""",
"""POST ${SELF_URL}/v1/self-probe-image
Authorization: Bearer XXXXX
Content-Type: application/json
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
},
},
],
}
]
""",
"""POST ${SELF_URL}/v1/self-probe-file
Authorization: Bearer $API_KEY
Content-Type: multipart/form-data
{
"file": "@./sample_audio.m4a",
"model": "whisper-large-v3"
}
""",
"""POST https://api.gemini.com/v1/generate
Authorization: Bearer $GEMINI_API_KEY
Content-Type: application/json
{
"model": "gemini-latest",
"prompt": "<<PROMPT>>",
"temperature": 0.8,
"max_tokens": 150,
"top_p": 1.0,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
"""POST https://api.anthropic.com/v1/complete
Authorization: Bearer $ANTHROPIC_API_KEY
Content-Type: application/json
{
"model": "claude-v1.3",
"prompt": "<<PROMPT>>",
"temperature": 0.7,
"max_tokens_to_sample": 256,
"stop_sequences": ["\n\nHuman:"]
}
""",
"""POST https://api.cohere.ai/generate
Authorization: Bearer $COHERE_API_KEY
Content-Type: application/json
{
"model": "command-xlarge-nightly",
"prompt": "<<PROMPT>>",
"max_tokens": 300,
"temperature": 0.75,
"k": 0,
"p": 0.75
}
""",
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
Authorization: Bearer $AZURE_API_KEY
Content-Type: application/json
{
"prompt": "<<PROMPT>>",
"max_tokens": 150,
"temperature": 0.7,
"top_p": 0.9,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
"""POST https://api.assemblyai.com/v2/transcript
Authorization: Bearer $ASSEMBLY_API_KEY
Content-Type: application/json
{
"audio_url": "<<AUDIO_FILE_URL>>"
}
""",
"""POST https://api.openrouter.ai/v1/chat/completions
Authorization: Bearer $OPENROUTER_API_KEY
Content-Type: application/json
{
"model": "openrouter-latest",
"prompt": "<<PROMPT>>",
"temperature": 0.7,
"max_tokens": 150,
"top_p": 0.9,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
]
LLM_SPECS = [
"""POST ${SELF_URL}/v1/self-probe
Authorization: Bearer XXXXX
Content-Type: application/json
{
"prompt": "<<PROMPT>>"
}
""",
"""POST https://api.openai.com/v1/chat/completions
Authorization: Bearer $OPENAI_API_KEY
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
"temperature": 0.7
}
""",
"""
POST https://api.deepseek.com/chat/completions
Authorization: Bearer $DEEPSEEK_API_KEY
Content-Type: application/json
{
"model": "deepseek-chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "<<PROMPT>>"}
],
"stream": false
}
""",
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
Authorization: Bearer $APIKEY
Content-Type: application/json
{
"input": {
"top_k": 50,
"top_p": 0.9,
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
"temperature": 0.6,
"max_new_tokens": 1024,
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
"presence_penalty": 0,
"frequency_penalty": 0
}
}
""",
"""POST https://api.groq.com/v1/request_manager/text_completion
Authorization: Bearer $APIKEY
Content-Type: application/json
{
"model_id": "codellama-34b",
"system_prompt": "You are helpful and concise coding assistant",
"user_prompt": "<<PROMPT>>"
}
""",
"""POST https://api.together.xyz/v1/chat/completions
Authorization: Bearer $TOGETHER_API_KEY
Content-Type: application/json
{
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"messages": [
{"role": "system", "content": "You are an expert travel guide"},
{"role": "user", "content": "<<PROMPT>>"}
]
}
""",
"""POST ${SELF_URL}/v1/self-probe-image
Authorization: Bearer XXXXX
Content-Type: application/json
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
},
},
],
}
]
""",
"""POST ${SELF_URL}/v1/self-probe-file
Authorization: Bearer $API_KEY
Content-Type: multipart/form-data
{
"file": "@./sample_audio.m4a",
"model": "whisper-large-v3"
}
""",
"""POST https://api.gemini.com/v1/generate
Authorization: Bearer $GEMINI_API_KEY
Content-Type: application/json
{
"model": "gemini-latest",
"prompt": "<<PROMPT>>",
"temperature": 0.8,
"max_tokens": 150,
"top_p": 1.0,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
"""POST https://api.anthropic.com/v1/complete
Authorization: Bearer $ANTHROPIC_API_KEY
Content-Type: application/json
{
"model": "claude-v1.3",
"prompt": "<<PROMPT>>",
"temperature": 0.7,
"max_tokens_to_sample": 256,
"stop_sequences": ["\n\nHuman:"]
}
""",
"""POST https://api.cohere.ai/generate
Authorization: Bearer $COHERE_API_KEY
Content-Type: application/json
{
"model": "command-xlarge-nightly",
"prompt": "<<PROMPT>>",
"max_tokens": 300,
"temperature": 0.75,
"k": 0,
"p": 0.75
}
""",
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
Authorization: Bearer $AZURE_API_KEY
Content-Type: application/json
{
"prompt": "<<PROMPT>>",
"max_tokens": 150,
"temperature": 0.7,
"top_p": 0.9,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
"""POST https://api.assemblyai.com/v2/transcript
Authorization: Bearer $ASSEMBLY_API_KEY
Content-Type: application/json
{
"audio_url": "<<AUDIO_FILE_URL>>"
}
""",
"""POST https://api.openrouter.ai/v1/chat/completions
Authorization: Bearer $OPENROUTER_API_KEY
Content-Type: application/json
{
"model": "openrouter-latest",
"prompt": "<<PROMPT>>",
"temperature": 0.7,
"max_tokens": 150,
"top_p": 0.9,
"frequency_penalty": 0,
"presence_penalty": 0
}
""",
]
LLM_CONFIGS = [
{
"name": "Custom API",
"prompts": 40000,
"customInstructions": "Requires api spec",
"logo": "/icons/myshell.png",
},
{"name": "Open AI", "prompts": 24000, "logo": "/icons/openai.png"},
{"name": "Deepseek v1", "prompts": 24000, "logo": "/icons/deepseek.png"},
{"name": "Replicate", "prompts": 40000, "logo": "/icons/replicate.png"},
{"name": "Groq", "prompts": 40000, "logo": "/icons/groq.png"},
{"name": "Together.ai", "prompts": 40000, "logo": "/icons/together.png"},
{
"name": "Custom API Image",
"prompts": 40000,
"customInstructions": "Requires api spec",
"modality": "Image",
"logo": "/icons/myshell.png",
},
{
"name": "Custom API Files",
"prompts": 40000,
"customInstructions": "Requires api spec",
"modality": "Files",
"logo": "/icons/myshell.png",
},
{"name": "Gemini", "prompts": 40000, "logo": "/icons/gemini.png"},
{"name": "Claude", "prompts": 40000, "logo": "/icons/claude.png"},
{"name": "Cohere", "prompts": 40000, "logo": "/icons/cohere.png"},
{"name": "Azure OpenAI", "prompts": 40000, "logo": "/icons/azureai.png"},
{"name": "assemblyai", "prompts": 40000, "logo": "/icons/myshell.png"},
{"name": "OpenRouter.ai", "prompts": 40000, "logo": "/icons/openrouter.png"},
]
LLM_SPECS = [dict(spec=spec, **d) for spec, d in zip(_SPECS, LLM_CONFIGS)]
-7
View File
@@ -6,7 +6,6 @@ from fastapi.responses import JSONResponse
from ..primitives import FileProbeResponse, Probe
from ..probe_actor.refusal import REFUSAL_MARKS
from ..probe_data import REGISTRY
from ._specs import LLM_SPECS
router = APIRouter()
@@ -74,12 +73,6 @@ async def data_config():
return [m for m in REGISTRY]
@router.get("/v1/llm-specs", response_model=list)
def get_llm_specs():
"""Returns the LLM API specifications."""
return LLM_SPECS
@router.get("/health")
async def health_check():
"""Health check endpoint."""
+1 -2
View File
@@ -2,8 +2,7 @@ import random
from asyncio import Event
from fastapi import APIRouter
from agentic_security.logutils import logger
from loguru import logger
from ..core.app import get_current_run, get_tools_inbox
from ..primitives import CompletionRequest, Settings
+7 -18
View File
@@ -1,6 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from fastapi import (
APIRouter,
@@ -13,11 +11,9 @@ from fastapi import (
)
from fastapi.responses import StreamingResponse
from agentic_security.logutils import logger
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
from ..dependencies import InMemorySecrets, get_in_memory_secrets
from ..http_spec import InvalidHTTPSpecError, LLMSpec
from ..http_spec import LLMSpec
from ..primitives import LLMInfo, Scan
from ..probe_actor import fuzzer
@@ -27,16 +23,9 @@ router = APIRouter()
@router.post("/verify")
async def verify(
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
) -> dict[str, int | str | float]:
):
spec = LLMSpec.from_string(info.spec)
try:
r = await spec.verify()
except InvalidHTTPSpecError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail=str(e))
r = await spec.verify()
if r.status_code >= 400:
raise HTTPException(status_code=r.status_code, detail=r.text)
return dict(
@@ -47,7 +36,7 @@ async def verify(
)
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]:
def streaming_response_generator(scan_parameters: Scan):
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
set_current_run(request_factory)
@@ -68,7 +57,7 @@ async def scan(
scan_parameters: Scan,
background_tasks: BackgroundTasks,
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
) -> StreamingResponse:
):
scan_parameters.with_secrets(secrets)
return StreamingResponse(
streaming_response_generator(scan_parameters), media_type="application/json"
@@ -76,7 +65,7 @@ async def scan(
@router.post("/stop")
async def stop_scan() -> dict[str, str]:
async def stop_scan():
get_stop_event().set()
return {"status": "Scan stopped"}
@@ -90,7 +79,7 @@ async def scan_csv(
maxBudget: int = Query(10_000),
enableMultiStepAttack: bool = Query(False),
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
) -> StreamingResponse:
):
# TODO: content dataset to fuzzer
content = await file.read() # noqa
llm_spec = await llmSpec.read()
+1 -2
View File
@@ -1,8 +1,7 @@
import sentry_sdk
from loguru import logger
from sentry_sdk.integrations.logging import ignore_logger
from agentic_security.logutils import logger
from ..primitives import Settings
@@ -1,12 +1,9 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import agentic_security.test_spec_assets as test_spec_assets
from agentic_security.routes.scan import router
app = FastAPI()
app.include_router(router)
client = TestClient(app)
client = TestClient(router)
def test_upload_csv_and_run():
@@ -1,6 +1,6 @@
from fastapi.testclient import TestClient
from agentic_security.app import app
from ..app import app
def test_health_check():
@@ -1,15 +1,14 @@
import base64
import io
import random
import httpx
import pytest
from fastapi.testclient import TestClient
from agentic_security.app import app
from agentic_security.primitives import Probe
from agentic_security.probe_actor.refusal import REFUSAL_MARKS
from agentic_security.probe_data import REGISTRY
from ..app import app
from ..primitives import Probe
from ..probe_actor.refusal import REFUSAL_MARKS
from ..probe_data import REGISTRY
client = TestClient(app)
@@ -86,9 +85,8 @@ def test_data_config_endpoint():
def test_refusal_rate():
"""Test that refusal rate is approximately 20%"""
random.seed(0)
refusal_count = 0
total_trials = 200
total_trials = 1000
for _ in range(total_trials):
response = client.post("/v1/self-probe", json={"prompt": "test"})
@@ -2,14 +2,11 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from agentic_security.routes.report import router
from .report import router
app = FastAPI()
app.include_router(router)
client = TestClient(app)
client = TestClient(router)
@pytest.fixture
@@ -1,15 +1,13 @@
from pathlib import Path
import pytest
from fastapi import FastAPI, HTTPException
from fastapi import HTTPException
from fastapi.testclient import TestClient
from agentic_security.primitives import Settings
from agentic_security.routes.static import get_static_file, router
from ..primitives import Settings
from .static import get_static_file, router
app = FastAPI()
app.include_router(router)
client = TestClient(app)
client = TestClient(router)
def test_root_route():
+2 -3
View File
@@ -101,7 +101,7 @@ Content-Type: application/json
]
`,
`POST ${SELF_URL}/v1/self-probe-file
Authorization: Bearer $API_KEY
Authorization: Bearer $GROQ_API_KEY
Content-Type: multipart/form-data
{
@@ -204,9 +204,8 @@ let LLM_CONFIGS = [
{ name: 'Claude', prompts: 40000, logo: '/icons/claude.png' },
{ name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' },
{ name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' },
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
];
function has_image(spec) {
return spec.includes('<<BASE64_IMAGE>>');
-26
View File
@@ -63,32 +63,6 @@
</div>
</section>
</main>
<div class="fixed top-6 right-6 z-50 space-y-3">
<transition-group name="toast">
<div
v-for="toast in toasts"
:key="toast.id"
class="flex items-center p-3 rounded-xl shadow-xl text-white max-w-md animate-toast-in border border-opacity-30"
:class="{
'bg-success-toast border-accent-green': toast.type === 'success',
'bg-error-toast border-accent-red': toast.type === 'error',
'bg-info-toast border-accent-orange': toast.type === 'info'
}"
>
<span class="flex-1 font-medium tracking-wide text-sm">{{ toast.message }}</span>
<button
@click="removeToast(toast.id)"
class="ml-3 focus:outline-none hover:opacity-80 transition-opacity"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
</transition-group>
</div>
<main class="max-w-6xl mx-auto space-y-8">
<section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false">
<h2 class="text-2xl font-bold mb-4">Select a Config</h2>
+32 -67
View File
@@ -25,8 +25,6 @@ var app = new Vue({
showModules: false,
showLogs: false,
showConsentModal: true,
toasts: [], // Array to store toast notifications
toastTimeout: 3000, // Duration in milliseconds (3 seconds)
statusDotClass: 'bg-gray-500', // Default status dot class
statusText: 'Verified', // Default status text
statusClass: 'bg-green-500 text-dark-bg', // Default status class
@@ -95,36 +93,21 @@ var app = new Vue({
},
methods: {
showToast(message, type = 'success') {
const id = Date.now(); // Unique ID for each toast
this.toasts.push({ id, message, type });
// Automatically remove toast after timeout
setTimeout(() => {
this.removeToast(id);
}, this.toastTimeout);
},
removeToast(id) {
this.toasts = this.toasts.filter(toast => toast.id !== id);
},
focusTextarea() {
this.isFocused = true;
// Remove 'self' assignment if not used elsewhere
self = this.$refs;
this.$nextTick(() => {
this.$refs.textarea.focus();
this.adjustHeight({ target: this.$refs.textarea });
// Focus the textarea after rendering
self.textarea.focus();
this.adjustHeight({ target: self.textarea });
});
// Correct the event listener to use handleOutsideClick
document.addEventListener("mousedown", this.handleOutsideClick);
document.addEventListener("mousedown", this.handleClickOutside);
},
handleOutsideClick(event) {
if (!this.$refs.textarea) {
return
}
if (!this.$refs.textarea.contains(event.target)) {
if (!this.$refs.container.contains(event.target)) {
this.isFocused = false;
document.removeEventListener("mousedown", this.handleOutsideClick);
document.removeEventListener("mousedown", this.handleClickOutside);
}
},
unfocusTextarea() {
@@ -132,12 +115,7 @@ var app = new Vue({
},
acceptConsent() {
this.showConsentModal = false; // Close the modal
try {
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
} catch (e) {
this.showToast('Failed to save consent', 'error'); // Show error if saving fails
}
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
},
saveStateToLocalStorage() {
@@ -178,7 +156,6 @@ var app = new Vue({
this.integrationVerified = false;
this.showResetConfirmation = false;
this.enableMultiStepAttack = false;
this.showToast('All settings have been reset to default', 'info');
},
confirmResetState() {
this.showResetConfirmation = true;
@@ -217,39 +194,31 @@ var app = new Vue({
spec: this.modelSpec,
};
let startTime = performance.now(); // Capture start time
try {
const response = await fetch(`${SELF_URL}/verify`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload),
});
let r = await response.json();
let endTime = performance.now(); // Capture end time
let latency = ((endTime - startTime) / 1000).toFixed(3); // Calculate latency in milliseconds
this.latency = latency;
if (!response.ok) {
this.updateStatusDot(false);
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
this.showToast('Integration verification failed', 'error');
} else {
this.errorMsg = '';
this.updateStatusDot(true);
this.okMsg = 'Integration verified';
this.showToast('Integration verified successfully', 'success');
this.integrationVerified = true;
}
} catch (error) {
const response = await fetch(`${SELF_URL}/verify`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload),
});
console.log(response);
let r = await response.json();
let endTime = performance.now(); // Capture end time
let latency = endTime - startTime; // Calculate latency in milliseconds
latency = latency.toFixed(3) / 1000; // Round to 2 decimal places
this.latency = latency;
if (!response.ok) {
this.updateStatusDot(false);
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
} else {
this.errorMsg = '';
this.updateStatusDot(true);
this.errorMsg = 'Server unreachable';
this.showToast('Network error', 'error');
}
this.okMsg = 'Integration verified';
this.integrationVerified = true;
// console.log('Integration verified', this.integrationVerified);
// this.$forceUpdate();
}
this.saveStateToLocalStorage();
},
loadConfigs: async function () {
@@ -271,7 +240,6 @@ var app = new Vue({
this.errorMsg = '';
this.okMsg = '';
this.integrationVerified = false;
this.showToast(`Config ${index + 1} selected`, 'info');
},
toggleModules() {
this.showModules = !this.showModules;
@@ -359,7 +327,6 @@ var app = new Vue({
return
}
console.log('New row');
this.showToast('New module', 'success');
let payload = {
table: this.mainTable,
};
@@ -470,8 +437,6 @@ var app = new Vue({
}
});
}
this.scanRunning = false;
this.showToast('Scan finished successfully', 'success');
this.saveStateToLocalStorage();
}
+2 -54
View File
@@ -19,7 +19,7 @@
technopollas: ['Technopollas', 'sans-serif'],
},
colors: {
dark: {
t1: {
bg: '#0D0D0D', // Jet Black
card: '#1A1A1A', // Dark Carbon Fiber
text: '#FFFFFF',
@@ -30,7 +30,7 @@
yellow: '#2E4053', // Jet Black
},
},
dark1: {
dark: {
bg: '#121212',
card: '#1E1E1E',
text: '#FFFFFF',
@@ -86,7 +86,6 @@
}
}
</script>
<style>
.scrollbar-hide::-webkit-scrollbar {
display: none;
@@ -97,55 +96,4 @@
}
</style>
<style>
/* Toast-specific colors */
.bg-success-toast {
background: #1C3F74
}
.bg-error-toast {
background: #85144B
}
.bg-info-toast {
background: #FFC300
}
.border-accent-green, .border-accent-red, .border-accent-orange {
border-color: rgba(255, 255, 255, 0.1); /* Subtle white border for depth */
}
/* Animation */
.animate-toast-in {
animation: tSlideIn 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
}
@keyframes tSlideIn {
from {
transform: translateX(120%) scale(0.95);
opacity: 0;
}
to {
transform: translateX(0) scale(1);
opacity: 1;
}
}
.toast-enter-active,
.toast-leave-active {
transition: all 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
}
.toast-enter-from,
.toast-leave-to {
transform: translateX(120%) scale(0.95);
opacity: 0;
}
/* Hover effect */
[toast-type]:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.6);
}
</style>
</head>
@@ -1,7 +1,6 @@
import importlib
import os
import signal
import socket
import subprocess
import tempfile
import time
@@ -9,7 +8,7 @@ import time
import pytest
import agentic_security.test_spec_assets as test_spec_assets
from agentic_security.lib import SecurityScanner
from agentic_security.lib import AgenticSecurity
def has_module(module_name):
@@ -25,29 +24,12 @@ def test_server(request):
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN),
)
def wait_for_port(host: str, port: int, timeout: float = 5.0) -> bool:
start = time.time()
while time.time() - start < timeout:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(0.2)
try:
sock.connect((host, port))
return True
except OSError:
time.sleep(0.1)
return False
if not wait_for_port("127.0.0.1", 9094):
server.kill()
pytest.skip("Test server failed to start within timeout")
# Give the server time to start
time.sleep(2)
def cleanup():
server.terminate()
try:
server.wait(timeout=3)
except subprocess.TimeoutExpired:
server.kill()
server.wait(timeout=2)
server.wait()
request.addfinalizer(cleanup)
return server
@@ -75,7 +57,7 @@ class TestLibraryLevel:
maxBudget = 1000000
max_th = 0.3
datasets = make_test_registry()
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict)
print(result)
assert len(result) in [0, 1]
@@ -86,7 +68,7 @@ class TestLibraryLevel:
maxBudget = 1000
max_th = 0.3
datasets = make_test_registry()
result = SecurityScanner.scan(
result = AgenticSecurity.scan(
llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True
)
assert isinstance(result, dict)
@@ -111,7 +93,7 @@ class TestLibraryLevel:
"opts": {"port": 9094},
},
]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict)
print(result)
assert len(result) in [0, 1]
@@ -138,12 +120,11 @@ class TestLibraryLevel:
"modality": "text",
},
]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict)
print(result)
assert len(result) in [0, 1]
@pytest.mark.skip
def test_image_modality(self):
llmSpec = test_spec_assets.IMAGE_SPEC
maxBudget = 2
@@ -167,7 +148,7 @@ class TestLibraryLevel:
"modality": "text",
},
]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict)
print(result)
assert len(result) in [0, 1]
@@ -183,10 +164,10 @@ class TestEntrypointCI:
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
# Override default_path to the temporary path
SecurityScanner.default_path = temp_path
AgenticSecurity.default_path = temp_path
# Generate the default configuration
security = SecurityScanner()
security = AgenticSecurity()
security.generate_default_settings()
# Check that the config file was created at the temporary path
@@ -208,17 +189,17 @@ class TestEntrypointCI:
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
# Override default_path to the temporary path
SecurityScanner.default_path = temp_path
AgenticSecurity.default_path = temp_path
# Generate the default configuration
security = SecurityScanner()
security = AgenticSecurity()
security.generate_default_settings()
# Load the generated configuration
SecurityScanner.load_config(temp_path)
AgenticSecurity.load_config(temp_path)
# Validate loaded configuration
config = SecurityScanner.config
config = AgenticSecurity.config
assert (
config["general"]["maxBudget"] == 1000000
), "maxBudget should be 1000000"
@@ -1,10 +1,6 @@
import pytest
from agentic_security.http_spec import (
InvalidHTTPSpecError,
LLMSpec,
parse_http_spec,
)
from agentic_security.http_spec import LLMSpec, parse_http_spec
class TestParseHttpSpec:
@@ -59,19 +55,6 @@ class TestParseHttpSpec:
assert result.headers == {"Content-Type": "application/json"}
assert result.body == ""
def test_parse_http_spec_rejects_malformed_header(self):
http_spec = "GET http://example.com\nHeaderWithoutColon\n\n"
with pytest.raises(InvalidHTTPSpecError, match="Invalid header line"):
parse_http_spec(http_spec)
def test_parse_http_spec_trims_header_whitespace(self):
http_spec = "GET http://example.com\nAuthorization:Bearer token\n\n"
result = parse_http_spec(http_spec)
assert result.headers == {"Authorization": "Bearer token"}
class TestLLMSpec:
def test_validate_raises_error_for_missing_files(self):
@@ -87,3 +70,49 @@ class TestLLMSpec:
)
with pytest.raises(ValueError, match="An image is required for this request."):
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})
@pytest.mark.asyncio
async def test_probe_sends_request(self, httpx_mock):
httpx_mock.add_response(
method="POST", url="http://example.com", status_code=200
)
spec = LLMSpec(
method="POST",
url="http://example.com",
headers={},
body='{"prompt": "<<PROMPT>>"}',
)
response = await spec.probe(prompt="test")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_probe_with_files(self, httpx_mock):
httpx_mock.add_response(
method="POST", url="http://example.com", status_code=200
)
spec = LLMSpec(
method="POST",
url="http://example.com",
headers={"Content-Type": "multipart/form-data"},
body='{"prompt": "<<PROMPT>>"}',
has_files=True,
)
files = {"file": ("filename.txt", "file content")}
response = await spec.probe(prompt="test", files=files)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_probe_with_image(self, httpx_mock):
httpx_mock.add_response(
method="POST", url="http://example.com", status_code=200
)
spec = LLMSpec(
method="POST",
url="http://example.com",
headers={},
body='{"image": "<<BASE64_IMAGE>>"}',
has_image=True,
)
encoded_image = "base64encodedstring"
response = await spec.probe(prompt="test", encoded_image=encoded_image)
assert response.status_code == 200
+2
View File
@@ -0,0 +1,2 @@
from:python-pytest-poetry
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
+17 -13
View File
@@ -33,7 +33,7 @@ The `LLMSpec` class is the core of the HTTP specification. It provides the follo
### Methods
- **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object.
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> null`**: Validates the request parameters based on the specified modality.
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> None`**: Validates the request parameters based on the specified modality.
- **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters.
- **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request.
@@ -52,11 +52,12 @@ Authorization: Bearer sk-xxxxxxxxx
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
"temperature": 0.7
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
"temperature": 0.7
}
"""
spec = LLMSpec.from_string(http_spec)
response = await spec.probe("What is the capital of France?")
```
@@ -70,11 +71,12 @@ Authorization: Bearer sk-xxxxxxxxx
Content-Type: application/json
{
"model": "gpt-4-vision-preview",
"messages": [{"role": "user", "content": "What is in this image? <<BASE64_IMAGE>>"}],
"temperature": 0.7
"model": "gpt-4-vision-preview",
"messages": [{"role": "user", "content": "What is in this image? <<BASE64_IMAGE>>"}],
"temperature": 0.7
}
"""
spec = LLMSpec.from_string(http_spec)
encoded_image = encode_image_base64_by_url("https://example.com/image.jpg")
response = await spec.probe("What is in this image?", encoded_image=encoded_image)
@@ -89,11 +91,12 @@ Authorization: Bearer sk-xxxxxxxxx
Content-Type: application/json
{
"model": "whisper-large-v3",
"messages": [{"role": "user", "content": "Transcribe this audio: <<BASE64_AUDIO>>"}],
"temperature": 0.7
"model": "whisper-large-v3",
"messages": [{"role": "user", "content": "Transcribe this audio: <<BASE64_AUDIO>>"}],
"temperature": 0.7
}
"""
spec = LLMSpec.from_string(http_spec)
encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3")
response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio)
@@ -108,11 +111,12 @@ Authorization: Bearer sk-xxxxxxxxx
Content-Type: multipart/form-data
{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Process this file: <<FILE>>"}],
"temperature": 0.7
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Process this file: <<FILE>>"}],
"temperature": 0.7
}
"""
spec = LLMSpec.from_string(http_spec)
files = {"file": ("document.txt", open("document.txt", "rb"))}
response = await spec.probe("Process this file:", files=files)
Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.0 MiB

+1 -1
View File
@@ -21,4 +21,4 @@ Note: Please be aware that Agentic Security is designed as a safety scanner tool
## UI 🧙
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741192668/final_aa9jhb.gif">
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
+18 -8
View File
@@ -54,15 +54,20 @@ The `probe_data` module is a core component of the Agentic Security project, res
- **Classes:**
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
- Methods:
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> null`: Updates rewards
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
- `RandomPromptSelector`: Basic random selection with history tracking.
- Parameters:
- `prompts: list[str]`: List of available prompts
- `history_size: int = 3`: Size of history to prevent cycles
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
- Parameters:
- `prompts: list[str]`: List of available prompts
- `api_url: str`: URL of RL service
@@ -70,7 +75,9 @@ The `probe_data` module is a core component of the Agentic Security project, res
- `history_size: int = 300`: Size of history
- `timeout: int = 5`: Request timeout
- `run_id: str = ""`: Unique run identifier
- `QLearningPromptSelector`: Local Q-learning implementation.
- Parameters:
- `prompts: list[str]`: List of available prompts
- `learning_rate: float = 0.1`: Learning rate
@@ -79,11 +86,13 @@ The `probe_data` module is a core component of the Agentic Security project, res
- `exploration_decay: float = 0.995`: Exploration decay rate
- `min_exploration: float = 0.01`: Minimum exploration rate
- `history_size: int = 300`: Size of history
- **Module**: Main class that uses CloudRLPromptSelector.
- Parameters:
- `prompt_groups: list[str]`: Groups of prompts
- `tools_inbox: asyncio.Queue`: Queue for tool communication
- `opts: dict = {}`: Configuration options
- `Module`: Main class that uses CloudRLPromptSelector.
- Parameters:
- `prompt_groups: list[str]`: Groups of prompts
- `tools_inbox: asyncio.Queue`: Queue for tool communication
- `opts: dict = {}`: Configuration options
## Usage Examples
@@ -110,9 +119,10 @@ from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
prompts = ["What is AI?", "Explain machine learning"]
selector = QLearningPromptSelector(prompts)
current_prompt = "What is AI?"
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=true)
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=true)
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
```
## Conclusion
+9 -13
View File
@@ -1,16 +1,14 @@
:root {
--md-primary-fg-color: #2E4053;
/* Primary color changed to pinkish */
--md-primary-fg-color--light: #E0A3B6;
--md-primary-fg-color--dark: #1C3F74;
/* Dark variant changed to blue */
--md-primary-fg-color: #e92063;
--md-primary-fg-color--light: #e92063;
--md-primary-fg-color--dark: #e92063;
}
/* Updated slate color scheme with new background */
/* Revert hue value to that of pre mkdocs-material v9.4.0 */
[data-md-color-scheme="slate"] {
--md-hue: 230;
--md-default-bg-color: #1A1A1A;
/* Background changed to dark gray */
--md-default-bg-color: hsla(230, 15%, 21%, 1);
}
.hide {
@@ -26,15 +24,12 @@ img.index-header {
max-width: 500px;
}
/* Updated custom colors */
.pydantic-pink {
color: #E0A3B6;
/* Updated to match new theme */
color: #FF007F;
}
.team-blue {
color: #1C3F74;
/* Updated to match new theme */
color: #0072CE;
}
.secure-green {
@@ -72,6 +67,7 @@ img.index-header {
text-align: center;
}
/* Hide the entire footer */
.md-footer {
display: none;
+1 -1
View File
@@ -89,7 +89,7 @@ theme:
name: Switch to light mode
icon:
repo: fontawesome/brands/github
favicon: https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741195421/favicon_kuz6xr.png
favicon: "https://res.cloudinary.com/dq0w2rtm9/image/upload/v1737555066/r17hrkre246doczwmvbv.png"
extra:
generator: false
Generated
+2250 -2895
View File
File diff suppressed because it is too large Load Diff
+29 -33
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "agentic_security"
version = "0.7.4"
version = "0.5.1"
description = "Agentic LLM vulnerability scanner"
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
@@ -28,49 +28,52 @@ agentic_security = "agentic_security.__main__:main"
[tool.poetry.dependencies]
python = "^3.11"
fastapi = "^0.122.0"
uvicorn = "^0.38.0"
fire = "0.7.1"
fastapi = "^0.115.8"
uvicorn = "^0.34.0"
fire = "0.7.0"
loguru = "^0.7.3"
httpx = "^0.28.1"
cache-to-disk = "^2.0.0"
pandas = ">=1.4,<3.0"
datasets = "^4.4.1"
datasets = "^3.3.0"
tabulate = ">=0.8.9,<0.10.0"
colorama = "^0.4.4"
matplotlib = "^3.10.7"
pydantic = "^2.12.5"
matplotlib = "^3.9.2"
pydantic = "2.10.6"
scikit-optimize = "^0.10.2"
scikit-learn = "^1.7.2"
scikit-learn = "1.6.1"
numpy = ">=1.24.3,<3.0.0"
jinja2 = "^3.1.4"
python-multipart = "^0.0.20"
tomli = "^2.3.0"
rich = "^14.2.0"
tomli = "^2.2.1"
rich = "13.9.4"
gTTS = "^2.5.4"
sentry_sdk = "^2.46.0"
orjson = "^3.11.4"
pyfiglet = "^1.0.4"
termcolor = "^3.2.0"
mcp = "^1.22.0"
sentry_sdk = "^2.22.0"
orjson = "^3.10"
pyfiglet = "^1.0.2"
termcolor = "^2.4.0"
# garak = { version = "*", optional = true }
pytest-xdist = "^3.8.0"
[tool.poetry.group.dev.dependencies]
# Pytest
pytest = "^9.0.1"
pytest-asyncio = "^1.3.0"
inline-snapshot = "^0.31.1"
pytest-mock = "^3.15.1"
pytest = "^8.3.4"
pytest-asyncio = "^0.25.2"
inline-snapshot = ">=0.13.3,<0.21.0"
pytest-httpx = "^0.35.0"
pytest-mock = "^3.14.0"
# Rest
black = ">=24.10,<26.0"
mypy = "^1.19.0"
pre-commit = "^4.5.0"
huggingface-hub = "^1.1.6"
mypy = "^1.12.0"
pre-commit = "^4.0.1"
huggingface-hub = ">=0.25.1,<0.29.0"
# Docs
mkdocs = ">=1.4.2"
mkdocs-material = "^9.7.0"
mkdocstrings = "^1.0.0"
mkdocs-material = "^9.6.4"
mkdocstrings = ">=0.26.1"
mkdocs-jupyter = ">=0.25.1"
@@ -83,14 +86,7 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
# addopts = "--durations=5 -m 'not slow' -n 3"
addopts = "--durations=5 -m 'not slow'"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
markers = "slow: marks tests as slow"
[project]
# MCP requires the following fields to be present in the pyproject.toml file
name = "agentic_security"
version = "1.0.0"
requires-python = ">=3.11"
-35
View File
@@ -1,43 +1,8 @@
import os
import warnings
from pathlib import Path
import pytest
from sklearn.exceptions import InconsistentVersionWarning
from agentic_security.cache_config import ensure_cache_dir
from agentic_security.logutils import logger
CACHE_DIR = ensure_cache_dir(Path(__file__).parent / ".cache_to_disk")
from cache_to_disk import delete_old_disk_caches # noqa: E402 # isort: skip
# Silence noisy third-party warnings that do not impact test behavior
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
try:
from langchain_core._api import LangChainDeprecationWarning
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
except Exception: # pragma: no cover - fallback for older langchain versions
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
module=r"langchain\\.agents",
message=r".*langchain_core.pydantic_v1.*",
)
def pytest_runtest_setup(item):
if "slow" in item.keywords and not os.getenv("RUN_SLOW_TESTS"):
pytest.skip("Skipping slow test")
@pytest.fixture(autouse=True, scope="session")
def setup_delete_old_disk_caches():
logger.info("delete_old_disk_caches at %s", CACHE_DIR)
try:
delete_old_disk_caches()
except PermissionError:
logger.warning("Skipping cache cleanup due to permissions for %s", CACHE_DIR)
except OSError as exc:
logger.warning("Skipping cache cleanup due to OS error: %s", exc)
-1
View File
@@ -1 +0,0 @@
"""Tests for executor package."""
-209
View File
@@ -1,209 +0,0 @@
"""Tests for CircuitBreaker."""
import time
from agentic_security.executor.circuit_breaker import CircuitBreaker
class TestCircuitBreaker:
"""Test CircuitBreaker functionality."""
def test_initialization(self):
"""Test circuit breaker initialization."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
assert breaker.failure_threshold == 0.5
assert breaker.recovery_timeout == 30
assert breaker.state == "closed"
assert breaker.failures == 0
assert breaker.successes == 0
def test_record_success(self):
"""Test recording successful requests."""
breaker = CircuitBreaker()
breaker.record_success()
assert breaker.successes == 1
assert breaker.failures == 0
assert breaker.state == "closed"
def test_record_failure(self):
"""Test recording failed requests."""
breaker = CircuitBreaker()
breaker.record_failure()
assert breaker.failures == 1
assert breaker.successes == 0
assert breaker.last_failure_time is not None
def test_circuit_opens_on_failure_threshold(self):
"""Test that circuit opens when failure threshold is exceeded."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
# Record 10 requests: 6 failures, 4 successes (60% failure rate)
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
# Circuit should be open (60% > 50% threshold)
assert breaker.state == "open"
assert breaker.is_open() is True
def test_circuit_stays_closed_below_threshold(self):
"""Test that circuit stays closed when below threshold."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
# Record 10 requests: 4 failures, 6 successes (40% failure rate)
for _ in range(6):
breaker.record_success()
for _ in range(4):
breaker.record_failure()
# Circuit should stay closed (40% < 50% threshold)
assert breaker.state == "closed"
assert breaker.is_open() is False
def test_minimum_sample_size_required(self):
"""Test that minimum sample size is required before opening."""
breaker = CircuitBreaker(failure_threshold=0.5)
# Only 5 failures (below minimum of 10 total requests)
for _ in range(5):
breaker.record_failure()
# Circuit should stay closed (not enough samples)
assert breaker.state == "closed"
assert breaker.is_open() is False
def test_circuit_recovery_after_timeout(self):
"""Test that circuit enters half-open state after recovery timeout."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
# Open the circuit
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
assert breaker.state == "open"
# Wait for recovery timeout
time.sleep(1.1)
# Check if circuit moves to half-open
is_open = breaker.is_open()
assert is_open is False
assert breaker.state == "half_open"
def test_half_open_to_closed_on_successes(self):
"""Test that circuit closes from half-open after enough successes."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
# Open the circuit
for _ in range(4):
breaker.record_success()
for _ in range(6):
breaker.record_failure()
# Wait for recovery
time.sleep(1.1)
breaker.is_open() # Triggers transition to half-open
assert breaker.state == "half_open"
# Record 3 successes
breaker.record_success()
breaker.record_success()
breaker.record_success()
# Should transition to closed
assert breaker.state == "closed"
def test_get_state(self):
"""Test get_state method."""
breaker = CircuitBreaker()
assert breaker.get_state() == "closed"
# Open the circuit
for _ in range(10):
breaker.record_failure()
assert breaker.get_state() == "open"
def test_get_failure_rate(self):
"""Test get_failure_rate method."""
breaker = CircuitBreaker()
# No requests
assert breaker.get_failure_rate() == 0.0
# 3 failures, 7 successes (30% failure rate)
for _ in range(7):
breaker.record_success()
for _ in range(3):
breaker.record_failure()
assert breaker.get_failure_rate() == 0.3
def test_reset(self):
"""Test reset method."""
breaker = CircuitBreaker()
# Record some activity
breaker.record_success()
breaker.record_failure()
for _ in range(10):
breaker.record_failure()
# Reset
breaker.reset()
# Should be back to initial state
assert breaker.state == "closed"
assert breaker.failures == 0
assert breaker.successes == 0
assert breaker.last_failure_time is None
def test_exact_failure_threshold(self):
"""Test behavior at exact failure threshold."""
breaker = CircuitBreaker(failure_threshold=0.5)
# Exactly 50% failure rate (5 failures, 5 successes)
for _ in range(5):
breaker.record_success()
for _ in range(5):
breaker.record_failure()
# Should be open (>= threshold)
assert breaker.state == "open"
def test_high_failure_threshold(self):
"""Test with high failure threshold."""
breaker = CircuitBreaker(failure_threshold=0.9)
# 80% failure rate (8 failures, 2 successes)
for _ in range(2):
breaker.record_success()
for _ in range(8):
breaker.record_failure()
# Should stay closed (80% < 90%)
assert breaker.state == "closed"
def test_zero_recovery_timeout(self):
"""Test with zero recovery timeout."""
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=0)
# Open the circuit
for _ in range(10):
breaker.record_failure()
assert breaker.state == "open"
# Should immediately allow recovery attempt
time.sleep(0.01)
is_open = breaker.is_open()
assert is_open is False
assert breaker.state == "half_open"
-279
View File
@@ -1,279 +0,0 @@
"""Tests for ConcurrentExecutor."""
import pytest
import asyncio
from unittest.mock import Mock, patch
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
from agentic_security.probe_actor.state import FuzzerState
class TestExecutorMetrics:
"""Test ExecutorMetrics functionality."""
def test_initialization(self):
"""Test metrics initialization."""
metrics = ExecutorMetrics()
assert metrics.successful_requests == 0
assert metrics.failed_requests == 0
assert metrics.total_latency == 0.0
assert len(metrics.latencies) == 0
def test_record_success(self):
"""Test recording successful requests."""
metrics = ExecutorMetrics()
metrics.record_success(0.5)
metrics.record_success(0.3)
assert metrics.successful_requests == 2
assert metrics.total_latency == 0.8
assert len(metrics.latencies) == 2
def test_record_failure(self):
"""Test recording failed requests."""
metrics = ExecutorMetrics()
metrics.record_failure()
metrics.record_failure()
assert metrics.failed_requests == 2
assert metrics.successful_requests == 0
def test_get_stats_no_requests(self):
"""Test get_stats with no requests."""
metrics = ExecutorMetrics()
stats = metrics.get_stats()
assert stats["total_requests"] == 0
assert stats["success_rate"] == 0.0
assert stats["avg_latency_ms"] == 0.0
assert stats["p95_latency_ms"] == 0.0
def test_get_stats_with_requests(self):
"""Test get_stats with recorded requests."""
metrics = ExecutorMetrics()
# Record some requests
metrics.record_success(0.1) # 100ms
metrics.record_success(0.2) # 200ms
metrics.record_success(0.3) # 300ms
metrics.record_failure()
stats = metrics.get_stats()
assert stats["total_requests"] == 4
assert stats["successful_requests"] == 3
assert stats["failed_requests"] == 1
assert stats["success_rate"] == 0.75
assert stats["avg_latency_ms"] == pytest.approx(200.0, rel=0.01)
def test_get_stats_p95_latency(self):
"""Test p95 latency calculation."""
metrics = ExecutorMetrics()
# Add 100 requests with varying latencies
for i in range(100):
metrics.record_success(i * 0.001) # 0ms to 99ms
stats = metrics.get_stats()
# p95 should be around 95ms
assert stats["p95_latency_ms"] >= 90.0
assert stats["p95_latency_ms"] <= 100.0
class TestConcurrentExecutor:
"""Test ConcurrentExecutor functionality."""
def test_initialization(self):
"""Test executor initialization."""
executor = ConcurrentExecutor(
max_concurrent=20,
rate_limit=10,
burst=5,
failure_threshold=0.5,
recovery_timeout=30,
)
assert executor.semaphore._value == 20
assert executor.rate_limiter.rate == 10
assert executor.rate_limiter.burst == 5
assert executor.circuit_breaker.failure_threshold == 0.5
assert executor.circuit_breaker.recovery_timeout == 30
@pytest.mark.asyncio
async def test_execute_batch_success(self):
"""Test successful batch execution."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
# Mock request factory
request_factory = Mock()
# Mock process_prompt to return success
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False) # 10 tokens, not refused
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["prompt1", "prompt2", "prompt3"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
assert tokens == 30 # 3 prompts * 10 tokens
assert failures == 0
@pytest.mark.asyncio
async def test_execute_batch_with_failures(self):
"""Test batch execution with some failures."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Mock process_prompt to alternate success/failure
call_count = [0]
async def mock_process_prompt(rf, prompt, tokens, module, state):
call_count[0] += 1
if call_count[0] % 2 == 0:
return (10, True) # Refused
return (10, False) # Success
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["p1", "p2", "p3", "p4"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
assert tokens == 40 # 4 prompts * 10 tokens
assert failures == 2 # 2 refused
@pytest.mark.asyncio
async def test_execute_batch_respects_concurrency_limit(self):
"""Test that concurrency limit is respected."""
executor = ConcurrentExecutor(max_concurrent=2, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Track concurrent executions
concurrent_count = [0]
max_concurrent = [0]
async def mock_process_prompt(rf, prompt, tokens, module, state):
concurrent_count[0] += 1
max_concurrent[0] = max(max_concurrent[0], concurrent_count[0])
await asyncio.sleep(0.01) # Simulate work
concurrent_count[0] -= 1
return (10, False)
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
prompts = ["p1", "p2", "p3", "p4", "p5"]
await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
# Max concurrent should not exceed limit
assert max_concurrent[0] <= 2
@pytest.mark.asyncio
async def test_circuit_breaker_integration(self):
"""Test that circuit breaker opens on failures."""
executor = ConcurrentExecutor(
max_concurrent=10,
rate_limit=100,
burst=20,
failure_threshold=0.5,
recovery_timeout=1,
)
fuzzer_state = FuzzerState()
request_factory = Mock()
# Mock process_prompt to always fail
async def mock_process_prompt_fail(rf, prompt, tokens, module, state):
raise Exception("Request failed")
# First batch - all failures
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt_fail,
):
prompts = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10"]
tokens, failures = await executor.execute_batch(
request_factory, prompts, "test_module", fuzzer_state
)
# All should have failed
assert failures == 10
# Circuit should be open now
assert executor.circuit_breaker.state == "open"
@pytest.mark.asyncio
async def test_get_metrics(self):
"""Test getting executor metrics."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
fuzzer_state = FuzzerState()
request_factory = Mock()
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False)
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
await executor.execute_batch(
request_factory, ["p1", "p2"], "test_module", fuzzer_state
)
metrics = executor.get_metrics()
assert "total_requests" in metrics
assert "success_rate" in metrics
assert "circuit_breaker_state" in metrics
assert "available_tokens" in metrics
assert metrics["total_requests"] == 2
assert metrics["circuit_breaker_state"] == "closed"
@pytest.mark.asyncio
async def test_rate_limiting_applied(self):
"""Test that rate limiting is applied."""
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=5, burst=2)
fuzzer_state = FuzzerState()
request_factory = Mock()
async def mock_process_prompt(rf, prompt, tokens, module, state):
return (10, False)
import time
with patch(
"agentic_security.probe_actor.fuzzer.process_prompt",
side_effect=mock_process_prompt,
):
start = time.monotonic()
# 5 requests with rate=5/s and burst=2
# First 2 immediate, next 3 should take ~0.6s total
await executor.execute_batch(
request_factory,
["p1", "p2", "p3", "p4", "p5"],
"test_module",
fuzzer_state,
)
elapsed = time.monotonic() - start
# Should take at least 0.5s (3 requests / 5 per second)
assert elapsed >= 0.4
-145
View File
@@ -1,145 +0,0 @@
"""Tests for TokenBucketRateLimiter."""
import asyncio
import pytest
import time
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
class TestTokenBucketRateLimiter:
"""Test TokenBucketRateLimiter functionality."""
@pytest.mark.asyncio
async def test_initialization(self):
"""Test rate limiter initialization."""
limiter = TokenBucketRateLimiter(rate=10, burst=20)
assert limiter.rate == 10
assert limiter.burst == 20
assert limiter.tokens == 20 # Starts full
@pytest.mark.asyncio
async def test_acquire_with_available_tokens(self):
"""Test acquiring tokens when they're available."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
start = time.monotonic()
await limiter.acquire()
elapsed = time.monotonic() - start
# Should return immediately
assert elapsed < 0.1
assert limiter.tokens < 5 # One token consumed
@pytest.mark.asyncio
async def test_acquire_waits_when_no_tokens(self):
"""Test that acquire waits when no tokens available."""
limiter = TokenBucketRateLimiter(rate=10, burst=1)
# Consume the initial token
await limiter.acquire()
# Next acquire should wait
start = time.monotonic()
await limiter.acquire()
elapsed = time.monotonic() - start
# Should wait approximately 1/rate seconds (0.1s for rate=10)
assert elapsed >= 0.08 # Allow some tolerance
@pytest.mark.asyncio
async def test_rate_limiting(self):
"""Test that rate limiting actually limits request rate."""
limiter = TokenBucketRateLimiter(rate=10, burst=2)
# Make 5 requests
start = time.monotonic()
for _ in range(5):
await limiter.acquire()
elapsed = time.monotonic() - start
# With rate=10/s and burst=2:
# - First 2 requests are immediate (burst)
# - Next 3 requests require waiting: 3 * (1/10) = 0.3s
# Total should be around 0.3s
assert elapsed >= 0.25 # Allow some tolerance
assert elapsed < 0.5
@pytest.mark.asyncio
async def test_burst_capacity(self):
"""Test that burst capacity allows immediate requests."""
limiter = TokenBucketRateLimiter(rate=5, burst=10)
# Make burst number of requests immediately
start = time.monotonic()
for _ in range(10):
await limiter.acquire()
elapsed = time.monotonic() - start
# All 10 requests should be nearly immediate (using burst capacity)
assert elapsed < 0.2
@pytest.mark.asyncio
async def test_token_replenishment(self):
"""Test that tokens are replenished over time."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
# Consume all tokens
for _ in range(5):
await limiter.acquire()
assert limiter.tokens < 1
# Wait for tokens to replenish
await asyncio.sleep(0.3) # Should add 3 tokens at rate=10
# Should have tokens again (approximately 3)
available = limiter.get_available_tokens()
assert available >= 2.5
assert available <= 3.5
@pytest.mark.asyncio
async def test_get_available_tokens(self):
"""Test get_available_tokens method."""
limiter = TokenBucketRateLimiter(rate=10, burst=5)
# Initially full
assert limiter.get_available_tokens() == 5
# After consuming one
await limiter.acquire()
assert limiter.get_available_tokens() < 5
@pytest.mark.asyncio
async def test_concurrent_requests(self):
"""Test rate limiter with concurrent requests."""
limiter = TokenBucketRateLimiter(rate=10, burst=3)
async def make_request(limiter):
await limiter.acquire()
return time.monotonic()
# Make 5 concurrent requests
start = time.monotonic()
tasks = [make_request(limiter) for _ in range(5)]
timestamps = await asyncio.gather(*tasks)
total_elapsed = time.monotonic() - start
# First 3 should be immediate (burst=3)
# Next 2 should wait
# Total time should be around 0.2s (2 * 1/10)
assert total_elapsed >= 0.15
assert total_elapsed < 0.4
@pytest.mark.asyncio
async def test_max_burst_capacity(self):
"""Test that tokens don't exceed burst capacity."""
limiter = TokenBucketRateLimiter(rate=100, burst=5)
# Wait longer than needed to fill
await asyncio.sleep(0.2) # Would add 20 tokens, but capped at 5
# Check tokens don't exceed burst
available = limiter.get_available_tokens()
assert available <= 5
assert available >= 4.5 # Close to full
-360
View File
@@ -1,360 +0,0 @@
"""Tests for unified dataset loader."""
import pytest
from unittest.mock import patch
from agentic_security.probe_data.unified_loader import (
InputSourceConfig,
UnifiedDatasetLoader,
)
from agentic_security.probe_data.models import ProbeDataset
class TestInputSourceConfig:
"""Test InputSourceConfig validation."""
def test_csv_source_config(self):
"""Test CSV source configuration."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="./test.csv",
prompt_column="prompt",
weight=1.5,
)
assert config.source_type == "csv"
assert config.dataset_name == "test_csv"
assert config.path == "./test.csv"
assert config.weight == 1.5
def test_huggingface_source_config(self):
"""Test HuggingFace source configuration."""
config = InputSourceConfig(
source_type="huggingface",
dataset_name="test/dataset",
split="train",
max_samples=100,
)
assert config.source_type == "huggingface"
assert config.split == "train"
assert config.max_samples == 100
def test_proxy_source_config(self):
"""Test proxy source configuration."""
config = InputSourceConfig(
source_type="proxy",
dataset_name="proxy_test",
)
assert config.source_type == "proxy"
assert config.enabled is True # Default value
def test_disabled_source(self):
"""Test disabled source configuration."""
config = InputSourceConfig(
source_type="csv",
dataset_name="disabled_test",
enabled=False,
)
assert config.enabled is False
def test_weight_validation(self):
"""Test that weight must be non-negative."""
with pytest.raises(ValueError):
InputSourceConfig(
source_type="csv",
dataset_name="test",
weight=-1.0,
)
class TestUnifiedDatasetLoader:
"""Test UnifiedDatasetLoader functionality."""
@pytest.mark.asyncio
async def test_load_single_csv_source(self):
"""Test loading a single CSV source."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="test.csv",
)
loader = UnifiedDatasetLoader([config])
# Mock the load_csv function
mock_dataset = ProbeDataset(
dataset_name="test_csv",
prompts=["prompt1", "prompt2", "prompt3"],
tokens=10,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
):
result = await loader.load_all()
assert result.dataset_name == "unified"
assert len(result.prompts) == 3
assert result.prompts == ["prompt1", "prompt2", "prompt3"]
@pytest.mark.asyncio
async def test_load_single_huggingface_source(self):
"""Test loading a single HuggingFace source."""
config = InputSourceConfig(
source_type="huggingface",
dataset_name="test/dataset",
split="train",
)
loader = UnifiedDatasetLoader([config])
# Mock the load_dataset_generic function
mock_dataset = ProbeDataset(
dataset_name="test/dataset",
prompts=["hf_prompt1", "hf_prompt2"],
tokens=8,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_dataset_generic",
return_value=mock_dataset,
):
result = await loader.load_all()
assert result.dataset_name == "unified"
assert len(result.prompts) == 2
@pytest.mark.asyncio
async def test_merge_multiple_sources(self):
"""Test merging multiple sources."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="csv1",
path="test1.csv",
weight=1.0,
),
InputSourceConfig(
source_type="csv",
dataset_name="csv2",
path="test2.csv",
weight=2.0,
),
]
loader = UnifiedDatasetLoader(configs)
# Mock datasets
mock_dataset1 = ProbeDataset(
dataset_name="csv1",
prompts=["prompt1"],
tokens=5,
approx_cost=0.0,
metadata={},
)
mock_dataset2 = ProbeDataset(
dataset_name="csv2",
prompts=["prompt2", "prompt3"],
tokens=10,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=[mock_dataset1, mock_dataset2],
):
result = await loader.load_all()
assert result.dataset_name == "unified"
# Weight 1.0 = include once, weight 2.0 = include twice
# csv1: 1 prompt * 1 = 1
# csv2: 2 prompts * 2 = 4
assert len(result.prompts) == 5
assert "csv1" in result.metadata["sources"]
assert "csv2" in result.metadata["sources"]
@pytest.mark.asyncio
async def test_handle_disabled_sources(self):
"""Test that disabled sources are skipped."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="enabled_csv",
path="enabled.csv",
enabled=True,
),
InputSourceConfig(
source_type="csv",
dataset_name="disabled_csv",
path="disabled.csv",
enabled=False,
),
]
loader = UnifiedDatasetLoader(configs)
mock_dataset = ProbeDataset(
dataset_name="enabled_csv",
prompts=["prompt1"],
tokens=5,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
) as mock_load:
result = await loader.load_all()
# Should only be called once (for enabled source)
assert mock_load.call_count == 1
assert len(result.prompts) == 1
@pytest.mark.asyncio
async def test_max_samples_limit(self):
"""Test that max_samples limits the number of prompts."""
config = InputSourceConfig(
source_type="csv",
dataset_name="test_csv",
path="test.csv",
max_samples=2,
)
loader = UnifiedDatasetLoader([config])
# Mock dataset with more prompts than max_samples
mock_dataset = ProbeDataset(
dataset_name="test_csv",
prompts=["prompt1", "prompt2", "prompt3", "prompt4", "prompt5"],
tokens=20,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
return_value=mock_dataset,
):
result = await loader.load_all()
# Should be limited to 2 prompts
assert len(result.prompts) == 2
@pytest.mark.asyncio
async def test_error_handling(self):
"""Test that errors are handled gracefully."""
config = InputSourceConfig(
source_type="csv",
dataset_name="error_csv",
path="nonexistent.csv",
)
loader = UnifiedDatasetLoader([config])
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=Exception("File not found"),
):
result = await loader.load_all()
# Should return empty dataset on error
assert result.dataset_name == "unified_empty"
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_proxy_source_placeholder(self):
"""Test that proxy source returns empty dataset (not implemented in PoC)."""
config = InputSourceConfig(
source_type="proxy",
dataset_name="proxy_test",
)
loader = UnifiedDatasetLoader([config])
result = await loader.load_all()
# Proxy not implemented in PoC, should return empty
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_weighted_sampling(self):
"""Test weighted sampling behavior."""
configs = [
InputSourceConfig(
source_type="csv",
dataset_name="low_weight",
path="low.csv",
weight=1.0,
),
InputSourceConfig(
source_type="csv",
dataset_name="high_weight",
path="high.csv",
weight=3.0,
),
]
loader = UnifiedDatasetLoader(configs)
mock_dataset1 = ProbeDataset(
dataset_name="low_weight",
prompts=["a"],
tokens=1,
approx_cost=0.0,
metadata={},
)
mock_dataset2 = ProbeDataset(
dataset_name="high_weight",
prompts=["b"],
tokens=1,
approx_cost=0.0,
metadata={},
)
with patch(
"agentic_security.probe_data.unified_loader.load_csv",
side_effect=[mock_dataset1, mock_dataset2],
):
result = await loader.load_all()
# Weight 1.0: 1 prompt * 1 = 1
# Weight 3.0: 1 prompt * 3 = 3
# Total: 4 prompts
assert len(result.prompts) == 4
assert result.prompts.count("a") == 1
assert result.prompts.count("b") == 3
@pytest.mark.asyncio
async def test_empty_configs_list(self):
"""Test loading with empty configs list."""
loader = UnifiedDatasetLoader([])
result = await loader.load_all()
assert result.dataset_name == "unified_empty"
assert len(result.prompts) == 0
@pytest.mark.asyncio
async def test_csv_with_url(self):
"""Test CSV loading from URL."""
config = InputSourceConfig(
source_type="csv",
dataset_name="remote_csv",
url="https://example.com/data.csv",
prompt_column="text",
)
loader = UnifiedDatasetLoader([config])
mock_dataset = ProbeDataset(
dataset_name="remote_csv",
prompts=["remote_prompt"],
tokens=5,
approx_cost=0.0,
metadata={"source_type": "csv", "url": "https://example.com/data.csv"},
)
with patch(
"agentic_security.probe_data.unified_loader.load_dataset_generic",
return_value=mock_dataset,
):
result = await loader.load_all()
assert len(result.prompts) == 1
assert result.prompts[0] == "remote_prompt"
View File
+161
View File
@@ -0,0 +1,161 @@
import pytest
import asyncio
from fastapi import FastAPI
from asyncio import Queue, Event
from agentic_security.core.app import create_app, get_tools_inbox, get_stop_event, get_current_run, set_current_run
class TestApp:
"""Test suite for agentic_security.core.app module."""
def test_create_app(self):
"""Test that create_app returns a FastAPI instance."""
app = create_app()
assert isinstance(app, FastAPI)
@pytest.mark.asyncio
async def test_get_tools_inbox(self):
"""Test that get_tools_inbox returns the global Queue instance."""
queue1 = get_tools_inbox()
await queue1.put("test item")
queue2 = get_tools_inbox()
result = queue2.get_nowait()
assert result == "test item"
def test_get_stop_event(self):
"""Test that get_stop_event returns the global Event instance and is not set initially."""
event = get_stop_event()
assert isinstance(event, Event)
assert not event.is_set()
def test_current_run_initial(self):
"""Test that get_current_run returns the global current_run with default values initially."""
run = get_current_run()
# Default values should be empty strings
assert run["spec"] == ""
assert run["id"] == ""
def test_set_current_run(self):
"""Test that set_current_run correctly updates current_run."""
spec = "test run"
result = set_current_run(spec)
expected_id = hash(id(spec))
# Verify that spec is set correctly
assert result["spec"] == spec
assert result["id"] == expected_id
def test_current_run_after_set(self):
"""Test that get_current_run returns the updated current_run after set_current_run is called."""
spec = "another test run"
set_current_run(spec)
current = get_current_run()
assert current["spec"] == spec
assert current["id"] == hash(id(spec))
def test_tools_inbox_same_instance(self):
"""Test that get_tools_inbox returns the same Queue instance by default."""
queue1 = get_tools_inbox()
queue2 = get_tools_inbox()
assert queue1 is queue2
def test_stop_event_set(self):
"""Test that setting the stop event is reflected in subsequent calls."""
event = get_stop_event()
event.set() # set the global event
# Now, subsequent calls should return the same event which is set.
event2 = get_stop_event()
assert event2.is_set()
def test_set_current_run_with_none(self):
"""Test that set_current_run handles None as a valid input and updates current_run accordingly."""
result = set_current_run(None)
expected_id = hash(id(None))
assert result["spec"] is None
assert result["id"] == expected_id
def test_multiple_current_run_assignments(self):
"""Test multiple assignments to current_run to ensure it always updates correctly."""
first_spec = "first run"
result1 = set_current_run(first_spec)
expected_id1 = hash(id(first_spec))
assert result1["spec"] == first_spec
assert result1["id"] == expected_id1
second_spec = "second run"
result2 = set_current_run(second_spec)
expected_id2 = hash(id(second_spec))
assert result2["spec"] == second_spec
assert result2["id"] == expected_id2
current = get_current_run()
# The current_run should reflect the latest assignment.
assert current["spec"] == second_spec
assert current["id"] == expected_id2
@pytest.mark.asyncio
async def test_empty_tools_inbox_exception(self):
"""Test that calling get_nowait on an empty tools_inbox raises QueueEmpty."""
from asyncio import QueueEmpty
queue = get_tools_inbox()
# Clear any existing items in the queue
while True:
try:
queue.get_nowait()
except QueueEmpty:
break
with pytest.raises(QueueEmpty):
queue.get_nowait()
def test_set_current_run_with_dict(self):
"""Test that set_current_run correctly handles a dictionary input as spec."""
spec = {"key": "value"}
result = set_current_run(spec)
expected_id = hash(id(spec))
assert result["spec"] == spec
assert result["id"] == expected_id
@pytest.mark.asyncio
async def test_stop_event_wait(self):
"""Test that waiting on the stop event returns once the event is set."""
event = get_stop_event()
event.clear() # ensure event is not set
async def waiter():
await event.wait()
return True
waiter_task = asyncio.create_task(waiter())
# Wait a moment to ensure the waiter is pending
await asyncio.sleep(0.1)
assert not waiter_task.done()
event.set()
result = await waiter_task
assert result is True
def test_set_current_run_with_int(self):
"""Test that set_current_run handles an integer input as spec."""
spec = 12345
result = set_current_run(spec)
expected_id = hash(id(spec))
assert result["spec"] == spec
assert result["id"] == expected_id
def test_create_app_routes(self):
"""Test that create_app returns a FastAPI instance with default routes available."""
app = create_app()
paths = [route.path for route in app.routes]
# Check that the default OpenAPI route exists
assert "/openapi.json" in paths
@pytest.mark.asyncio
async def test_tools_inbox_async_put_get_order(self):
"""Test that tools_inbox preserves order when items are added and retrieved asynchronously."""
queue = get_tools_inbox()
# Clear any existing items in the queue
from asyncio import QueueEmpty
while True:
try:
queue.get_nowait()
except QueueEmpty:
break
items = ["first", "second", "third"]
for item in items:
await queue.put(item)
result_items = []
for _ in items:
result_items.append(await queue.get())
assert result_items == items
-25
View File
@@ -1,25 +0,0 @@
import os
from pathlib import Path
from agentic_security.cache_config import ensure_cache_dir
def test_ensure_cache_dir_creates_dir_and_sets_env(tmp_path, monkeypatch):
monkeypatch.delenv("DISK_CACHE_DIR", raising=False)
target_dir = tmp_path / "cache_to_disk"
resolved = ensure_cache_dir(target_dir)
assert resolved == target_dir
assert resolved.is_dir()
assert Path(os.environ["DISK_CACHE_DIR"]) == resolved
def test_ensure_cache_dir_respects_existing_env(tmp_path, monkeypatch):
env_dir = tmp_path / "preconfigured"
monkeypatch.setenv("DISK_CACHE_DIR", str(env_dir))
resolved = ensure_cache_dir()
assert resolved == env_dir
assert resolved.exists()
+341
View File
@@ -0,0 +1,341 @@
import pytest
import base64
import httpx
import asyncio
from agentic_security.http_spec import (
LLMSpec,
parse_http_spec,
escape_special_chars_for_json,
encode_image_base64_by_url,
encode_audio_base64_by_url,
InvalidHTTPSpecError,
Modality
)
################################################################################
# Tests for agentic_security/http_spec.py
################################################################################
def test_escape_special_chars_for_json():
"""Test escaping special characters in a prompt for JSON safety."""
prompt = 'Line1\nLine2\t"Quote"\\Backslash'
escaped = escape_special_chars_for_json(prompt)
assert '\\n' in escaped
assert '\\t' in escaped
assert '\\"' in escaped
assert '\\\\' in escaped
def test_parse_http_spec_text():
"""Test parsing a text HTTP spec without image/audio/files requirements."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nThis is a prompt: <<PROMPT>>"
llm_spec = parse_http_spec(spec)
assert llm_spec.method == "POST"
assert llm_spec.url == "http://example.com/api"
assert llm_spec.headers["Content-Type"] == "application/json"
assert "<<PROMPT>>" in llm_spec.body
assert not llm_spec.has_files
assert not llm_spec.has_image
assert not llm_spec.has_audio
def test_parse_http_spec_files():
"""Test parsing a HTTP spec with multipart/form-data header indicating files."""
spec = "PUT http://example.com/upload\nContent-Type: multipart/form-data\n\nFile upload test"
llm_spec = parse_http_spec(spec)
assert llm_spec.has_files
def test_parse_http_spec_image_audio():
"""Test parsing a HTTP spec that requires image and audio via placeholders."""
spec = "GET http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>> and Audio: <<BASE64_AUDIO>>"
llm_spec = parse_http_spec(spec)
assert llm_spec.has_image
assert llm_spec.has_audio
def test_encode_image_base64_by_url(monkeypatch):
"""Test that image encoding returns the correct base64 string with prefix."""
dummy_content = b'test_image'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_image_base64_by_url("http://dummyurl.com/image.jpg")
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
def test_encode_audio_base64_by_url(monkeypatch):
"""Test that audio encoding returns the correct base64 string with prefix."""
dummy_content = b'test_audio'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_audio_base64_by_url("http://dummyurl.com/audio.mp3")
expected = "data:audio/mpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
@pytest.mark.asyncio
async def test_probe_text(monkeypatch):
"""Test the probe function for text modality by replacing <<PROMPT>>."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, content, timeout):
return httpx.Response(200, text="ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Hello")
assert response.status_code == 200
assert "ok" in response.text
@pytest.mark.asyncio
async def test_probe_with_files(monkeypatch):
"""Test that probe correctly branches to _probe_with_files when files are provided."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
llm_spec = parse_http_spec(spec)
files = {"file": ("dummy.txt", b"data")}
async def dummy_request(self, method, url, headers, files, timeout):
return httpx.Response(200, text="file upload ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Unused", files=files)
assert response.status_code == 200
assert "file upload ok" in response.text
@pytest.mark.asyncio
async def test_verify_image(monkeypatch):
"""Test verify method branch for image modality by monkeypatching image encoder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"image\": \"<<BASE64_IMAGE>>\"}"
llm_spec = parse_http_spec(spec)
# Replace the image encoder to return a dummy string
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
async def dummy_request(self, method, url, headers, content, timeout):
# Check that the dummy image is injected in the content
assert "dummy_image" in content
return httpx.Response(200, text="image ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "image ok" in response.text
@pytest.mark.asyncio
async def test_verify_audio(monkeypatch):
"""Test verify method branch for audio modality by monkeypatching audio encoder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"audio\": \"<<BASE64_AUDIO>>\"}"
llm_spec = parse_http_spec(spec)
monkeypatch.setattr("agentic_security.http_spec.encode_audio_base64_by_url", lambda url: "dummy_audio")
async def dummy_request(self, method, url, headers, content, timeout):
# Ensure that the dummy audio string is present in the request content
assert "dummy_audio" in content
return httpx.Response(200, text="audio ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "audio ok" in response.text
@pytest.mark.asyncio
async def test_verify_files(monkeypatch):
"""Test verify method branch for files modality where _probe_with_files is invoked."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, files, timeout):
return httpx.Response(200, text="files ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "files ok" in response.text
def test_llm_spec_modality_property():
"""Test that the modality property reflects the correct modality."""
spec_text = "POST http://example.com/api\nContent-Type: application/json\n\nPrompt: <<PROMPT>>"
llm_spec_text = parse_http_spec(spec_text)
assert llm_spec_text.modality == Modality.TEXT
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec_image = parse_http_spec(spec_image)
assert llm_spec_image.modality == Modality.IMAGE
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec_audio = parse_http_spec(spec_audio)
assert llm_spec_audio.modality == Modality.AUDIO
def test_from_string_invalid():
"""Test that LLMSpec.from_string raises an error for an invalid spec."""
invalid_spec = "INVALID_SPEC"
with pytest.raises(InvalidHTTPSpecError):
LLMSpec.from_string(invalid_spec)
@pytest.mark.asyncio
async def test_validate_missing_files():
"""Test that LLMSpec.validate raises a ValueError when files are required but missing."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload test"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="Files are required"):
llm_spec.validate("test prompt", "", "", {})
@pytest.mark.asyncio
async def test_validate_missing_image():
"""Test that LLMSpec.validate raises a ValueError when an image is required but missing."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="An image is required"):
llm_spec.validate("test prompt", "", "dummy_audio", {})
@pytest.mark.asyncio
async def test_validate_missing_audio():
"""Test that LLMSpec.validate raises a ValueError when audio is required but missing."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="Audio is required"):
llm_spec.validate("test prompt", "dummy_image", "", {})
def test_fn_alias(monkeypatch):
"""Test that LLMSpec.fn is a functional alias for LLMSpec.probe."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
# Instead of overriding the instance method, verify the alias at the class level.
assert LLMSpec.fn is LLMSpec.probe
def test_escape_special_chars_no_special():
"""Test that the escape function returns the original string if no special characters are present."""
prompt = "Simple text without specials"
escaped = escape_special_chars_for_json(prompt)
assert escaped == "Simple text without specials"
@pytest.mark.asyncio
async def test_probe_text_with_special_chars(monkeypatch):
"""Test probe for text modality with special characters in prompt ensuring escaped content."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
captured = {}
async def dummy_request(self, method, url, headers, content, timeout):
captured['content'] = content
return httpx.Response(200, text="ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
test_prompt = 'Hello\nWorld\t"Test"'
response = await llm_spec.probe(test_prompt)
expected_escaped = escape_special_chars_for_json(test_prompt)
assert expected_escaped in captured['content']
assert response.status_code == 200
@pytest.mark.asyncio
async def test_verify_both_image_audio(monkeypatch):
"""Test verify method when both image and audio placeholders are present.
Expect a ValueError because only the image branch is triggered by pattern matching and the missing audio causes validation to fail."""
spec = ("POST http://example.com/api\nContent-Type: application/json\n\n"
"{\"audio\": \"<<BASE64_AUDIO>>\", \"image\":\"<<BASE64_IMAGE>>\"}")
llm_spec = parse_http_spec(spec)
# Monkey patch the image encoder to return a dummy value
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
with pytest.raises(ValueError, match="Audio is required"):
await llm_spec.verify()
def test_parse_http_spec_invalid_header_format():
"""Test that parse_http_spec raises an error when a header line doesn't have the expected 'key: value' format."""
invalid_spec = "GET http://example.com/api\nInvalidHeaderWithoutColon\n\nBody with <<PROMPT>>"
with pytest.raises(ValueError):
parse_http_spec(invalid_spec)
def test_from_string_valid():
"""Test that LLMSpec.from_string returns a valid LLMSpec object when given a proper spec string."""
spec = "GET http://example.com/api\nContent-Type: application/json\n\n{ \"prompt\": \"<<PROMPT>>\" }"
llm_spec = LLMSpec.from_string(spec)
assert llm_spec.method == "GET"
assert llm_spec.url == "http://example.com/api"
@pytest.mark.asyncio
async def test_parse_http_spec_multiline_body():
"""Test parsing an HTTP spec with a multiline body to ensure body concatenation works."""
spec = (
"PATCH http://example.com/api\n"
"Content-Type: application/json\n"
"\n"
"Line one of body\n"
"Line two of body\n"
"Line three"
)
llm_spec = parse_http_spec(spec)
# As implemented, the parser concatenates lines without newline delimiters
expected_body = "Line one of bodyLine two of bodyLine three"
assert llm_spec.body == expected_body
@pytest.mark.asyncio
async def test_encode_image_default_argument(monkeypatch):
"""Test that encode_image_base64_by_url works with its default URL argument."""
dummy_content = b'default_image'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
# check that the default URL (which includes 'fluidicon.png') is used
assert "fluidicon.png" in url
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_image_base64_by_url()
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
@pytest.mark.asyncio
async def test_probe_without_prompt_placeholder(monkeypatch):
"""Test the probe function when the request body does not include the <<PROMPT>> placeholder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"message\": \"No placeholder here\"}"
llm_spec = parse_http_spec(spec)
captured = {}
async def dummy_request(self, method, url, headers, content, timeout):
captured['content'] = content
return httpx.Response(200, text="ok without placeholder")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Ignored prompt")
assert "No placeholder here" in captured['content']
assert response.status_code == 200
def test_validate_success():
"""Test that LLMSpec.validate does not raise an error when all required data is provided."""
# Test case for files: files are provided as required
spec_files = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload"
llm_spec_files = parse_http_spec(spec_files)
llm_spec_files.validate("some prompt", "dummy_image", "dummy_audio", {"file": ("dummy.txt", b"data")})
# Test case for image: image is provided as required
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec_image = parse_http_spec(spec_image)
llm_spec_image.validate("some prompt", "dummy_image", "dummy_audio", {})
# Test case for audio: audio is provided as required
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec_audio = parse_http_spec(spec_audio)
llm_spec_audio.validate("some prompt", "dummy_image", "dummy_audio", {})
@pytest.mark.asyncio
async def test_probe_invalid_url(monkeypatch):
"""Test that probe raises an exception when the HTTP client fails due to an invalid URL."""
spec = "GET http://nonexistent_url/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, content, timeout):
raise httpx.RequestError("Invalid URL")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
with pytest.raises(httpx.RequestError):
await llm_spec.probe("Test")
-12
View File
@@ -1,12 +0,0 @@
import pytest
from agentic_security.mcp.client import run
@pytest.mark.asyncio
async def test_mcp_echo_tool():
"""Test the echo tool functionality"""
prompts, resources, tools = await run()
assert prompts
assert resources
assert tools
+1 -1
View File
@@ -1,7 +1,7 @@
import pytest
from datasets import load_dataset
from agentic_security.probe_data import REGISTRY
from datasets import load_dataset
@pytest.mark.slow
+10 -10
View File
@@ -4266,9 +4266,9 @@
}
},
"node_modules/compression": {
"version": "1.8.1",
"resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz",
"integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==",
"version": "1.8.0",
"resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz",
"integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -4276,7 +4276,7 @@
"compressible": "~2.0.18",
"debug": "2.6.9",
"negotiator": "~0.6.4",
"on-headers": "~1.1.0",
"on-headers": "~1.0.2",
"safe-buffer": "5.2.1",
"vary": "~1.1.2"
},
@@ -6891,9 +6891,9 @@
}
},
"node_modules/http-proxy-middleware": {
"version": "2.0.9",
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz",
"integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==",
"version": "2.0.7",
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.7.tgz",
"integrity": "sha512-fgVY8AV7qU7z/MmXJ/rxwbrtQH4jBQ9m7kp3llF0liB7glmFeVZFBepQb32T3y8n8k2+AEYuMPCpinYW+/CuRA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -8419,9 +8419,9 @@
}
},
"node_modules/on-headers": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz",
"integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==",
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz",
"integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==",
"dev": true,
"license": "MIT",
"engines": {