Compare commits

..

3 Commits

Author SHA1 Message Date
codebeaver-ai[bot] e752ebaeeb Adding codebeaver.yml 2025-03-03 18:41:09 +00:00
codebeaver-ai[bot] 2549194bd1 test: Add coverage improvement test for tests/test_http_spec.py 2025-03-03 18:41:08 +00:00
codebeaver-ai[bot] 4c580ea1b8 test: Add coverage improvement test for tests/test_app.py 2025-03-03 18:41:06 +00:00
56 changed files with 1570 additions and 1343 deletions
-1
View File
@@ -17,4 +17,3 @@ inv/
scripts/ scripts/
docx/ docx/
agentic_security.toml agentic_security.toml
/venv
+5 -15
View File
@@ -83,7 +83,7 @@ agentic_security --port=PORT --host=HOST
## UI 🧙 ## UI 🧙
<img width="100%" alt="booking-screen" src="https://raw.githubusercontent.com/msoedov/agentic_security/refs/heads/main/docs/images/demo.gif"> <img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
## LLM kwargs ## LLM kwargs
@@ -408,15 +408,10 @@ For more detailed information on how to use Agentic Security, including advanced
## Roadmap and Future Goals ## Roadmap and Future Goals
- \[ \] Expand dataset variety
- \[ \] Introduce two new attack vectors
Were just getting started! Heres whats on the horizon: - \[ \] Develop initial attacker LLM
- \[ \] Complete integration of OWASP Top 10 classification
- **RL-Powered Attacks**: An attacker LLM trained with reinforcement learning to dynamically evolve jailbreaks and outsmart defenses.
- **Massive Dataset Expansion**: Scaling to 100,000+ prompts across text, image, and audio modalities—curated for real-world threats.
- **Daily Attack Updates**: Fresh attack vectors delivered daily, keeping your scans ahead of the curve.
- **Community Modules**: A plug-and-play ecosystem where you can share and deploy custom probes, datasets, and integrations.
| Tool | Source | Integrated | | Tool | Source | Integrated |
|-------------------------|-------------------------------------------------------------------------------|------------| |-------------------------|-------------------------------------------------------------------------------|------------|
@@ -444,9 +439,4 @@ Before contributing, please read the contributing guidelines.
Agentic Security is released under the Apache License v2. Agentic Security is released under the Apache License v2.
## 🚫 No Cryptocurrency Affiliation
Agentic Security is focused solely on AI security and has no affiliation with cryptocurrency projects, blockchain technologies, or related initiatives. Our mission is to advance the safety and reliability of AI systems—no tokens, no coins, just code.
## Contact us ## Contact us
+2 -2
View File
@@ -1,3 +1,3 @@
from .lib import SecurityScanner from .lib import AgenticSecurity
__all__ = ["SecurityScanner"] __all__ = ["AgenticSecurity"]
+4 -4
View File
@@ -5,7 +5,7 @@ import fire
import uvicorn import uvicorn
from agentic_security.app import app from agentic_security.app import app
from agentic_security.lib import SecurityScanner from agentic_security.lib import AgenticSecurity
from agentic_security.misc.banner import init_banner from agentic_security.misc.banner import init_banner
@@ -32,14 +32,14 @@ class CLI:
Run Agentic Security in CI mode. Run Agentic Security in CI mode.
""" """
sys.path.append(os.path.dirname(".")) sys.path.append(os.path.dirname("."))
SecurityScanner().entrypoint() AgenticSecurity().entrypoint()
def init(self, host: str = "0.0.0.0", port: int = 8718): def init(self, host: str = "0.0.0.0", port: int = 8718):
""" """
Generate the default CI configuration file. Generate the default CI configuration file.
""" """
sys.path.append(os.path.dirname(".")) sys.path.append(os.path.dirname("."))
SecurityScanner().generate_default_settings(host, port) AgenticSecurity().generate_default_settings(host, port)
i = init i = init
@@ -48,7 +48,7 @@ class CLI:
List all available security checks. List all available security checks.
""" """
sys.path.append(os.path.dirname(".")) sys.path.append(os.path.dirname("."))
SecurityScanner().list_checks() AgenticSecurity().list_checks()
def main(): def main():
+3 -1
View File
@@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import os import os
from typing import Any from typing import Any
@@ -9,11 +10,12 @@ from pydantic import BaseModel, ConfigDict, Field
# Assuming LLMSpec is defined elsewhere; placeholder import # Assuming LLMSpec is defined elsewhere; placeholder import
from agentic_security.http_spec import LLMSpec from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = [] # Populate with LLM spec strings if needed LLM_SPECS = [] # Populate with LLM spec strings if needed
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define AgentSpecification model # Define AgentSpecification model
+5 -1
View File
@@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
from typing import Any from typing import Any
import httpx import httpx
@@ -7,10 +8,13 @@ from pydantic_ai import Agent, RunContext, Tool
# Assuming LLMSpec is defined elsewhere; placeholder import # Assuming LLMSpec is defined elsewhere; placeholder import
from agentic_security.http_spec import LLMSpec from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = [] # Populate this list with LLM spec strings if needed LLM_SPECS = [] # Populate this list with LLM spec strings if needed
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define AgentSpecification model # Define AgentSpecification model
class AgentSpecification(BaseModel): class AgentSpecification(BaseModel):
+1 -2
View File
@@ -1,8 +1,7 @@
from functools import lru_cache from functools import lru_cache
import tomli import tomli
from loguru import logger
from agentic_security.logutils import logger
SETTINGS_VERSION = 1 SETTINGS_VERSION = 1
+6 -10
View File
@@ -4,14 +4,10 @@ from asyncio import Event, Queue
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from agentic_security.http_spec import LLMSpec
tools_inbox: Queue = Queue() tools_inbox: Queue = Queue()
stop_event: Event = Event() stop_event: Event = Event()
current_run: str = {"spec": "", "id": ""} current_run: str = {"spec": "", "id": ""}
_secrets: dict[str, str] = {} _secrets = {}
current_run: dict[str, int | LLMSpec] = {"spec": "", "id": ""}
def create_app() -> FastAPI: def create_app() -> FastAPI:
@@ -30,29 +26,29 @@ def get_stop_event() -> Event:
return stop_event return stop_event
def get_current_run() -> dict[str, int | LLMSpec]: def get_current_run() -> str:
"""Get the current run id.""" """Get the current run id."""
return current_run return current_run
def set_current_run(spec: LLMSpec) -> dict[str, int | LLMSpec]: def set_current_run(spec):
"""Set the current run id.""" """Set the current run id."""
current_run["id"] = hash(id(spec)) current_run["id"] = hash(id(spec))
current_run["spec"] = spec current_run["spec"] = spec
return current_run return current_run
def get_secrets() -> dict[str, str]: def get_secrets():
return _secrets return _secrets
def set_secrets(secrets: dict[str, str]) -> dict[str, str]: def set_secrets(secrets):
_secrets.update(secrets) _secrets.update(secrets)
expand_secrets(_secrets) expand_secrets(_secrets)
return _secrets return _secrets
def expand_secrets(secrets: dict[str, str]) -> None: def expand_secrets(secrets):
for key in secrets: for key in secrets:
val = secrets[key] val = secrets[key]
if val.startswith("$"): if val.startswith("$"):
+23 -2
View File
@@ -1,5 +1,26 @@
from agentic_security.logutils import set_log_level_to_info from logging import config
def setup_logging(): def setup_logging():
return set_log_level_to_info() config.dictConfig(
{
"version": 1,
"disable_existing_loggers": True,
"handlers": {
"console": {
"class": "logging.StreamHandler",
},
},
"root": {
"handlers": ["console"],
"level": "INFO",
},
"loggers": {
"uvicorn.access": {
"level": "ERROR", # Set higher log level to suppress info logs globally
"handlers": ["console"],
"propagate": False,
}
},
}
)
-9
View File
@@ -1,6 +1,5 @@
import base64 import base64
from enum import Enum from enum import Enum
from urllib.parse import urlparse
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@@ -160,14 +159,6 @@ def parse_http_spec(http_spec: str) -> LLMSpec:
# Extract the method and URL from the first line # Extract the method and URL from the first line
method, url = lines[0].split(" ")[0:2] method, url = lines[0].split(" ")[0:2]
# Check url validity
valid_url = urlparse(url)
# if missing the correct formatting ://, urlparse.netloc will be empty
if valid_url.scheme not in ("http", "https") or not valid_url.netloc:
raise InvalidHTTPSpecError(
f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'"
)
# Initialize headers and body # Initialize headers and body
headers = {} headers = {}
body = "" body = ""
+3 -3
View File
@@ -4,12 +4,12 @@ from datetime import datetime
import colorama import colorama
import tqdm.asyncio import tqdm.asyncio
from loguru import logger
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from tabulate import tabulate from tabulate import tabulate
from agentic_security.config import SettingsMixin # Importing the configuration mixin from agentic_security.config import SettingsMixin # Importing the configuration mixin
from agentic_security.logutils import logger
from agentic_security.primitives import Scan from agentic_security.primitives import Scan
from agentic_security.probe_data import REGISTRY from agentic_security.probe_data import REGISTRY
from agentic_security.routes.scan import streaming_response_generator from agentic_security.routes.scan import streaming_response_generator
@@ -23,7 +23,7 @@ YELLOW = colorama.Fore.YELLOW
BLUE = colorama.Fore.BLUE BLUE = colorama.Fore.BLUE
class SecurityScanner(SettingsMixin): class AgenticSecurity(SettingsMixin):
@classmethod @classmethod
async def async_scan( async def async_scan(
cls, cls,
@@ -202,7 +202,7 @@ class SecurityScanner(SettingsMixin):
self.load_config(self.default_path) self.load_config(self.default_path)
logger.info("Configuration loaded successfully.") logger.info("Configuration loaded successfully.")
print(self.config) print(self.config)
datasets = list(self.get_config_value("modules", {}).values()) datasets = list(self.get_config_value("modules").values())
for d in datasets: for d in datasets:
d["selected"] = True d["selected"] = True
self.scan( self.scan(
-85
View File
@@ -1,85 +0,0 @@
# import sys
# from loguru import logger
# # Define custom colors
# BLUE = "#89CFF0"
# BROWN = "#8B4513" # Brown for DEBUG
# # Define custom log level colors
# logger.level("DEBUG", color=f"<fg {BROWN}>")
# logger.level("INFO", color=f"<fg {BLUE}>")
# # Define custom log format with aligned messages and colored levels
# LOG_FORMAT = (
# "<level>{level:<8}</level> " # Properly formatted and colored log level
# "<level>{message:<100}</level> " # Left-aligned message for readability
# "<cyan>{file.name}</cyan>:<cyan>{line}</cyan>" # File name and line number in cyan
# )
# # Remove default handlers and add a new one with custom formatting
# logger.remove()
# logger.add(sys.stdout, format=LOG_FORMAT, level="DEBUG", colorize=True)
import logging
import logging.config
from os import getenv
LOGGER_NAME = None
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {"format": "%(message)s", "datefmt": "[%X]"},
},
"handlers": {
"rich": {
"class": "rich.logging.RichHandler",
"level": "INFO",
"formatter": "rich",
"show_time": False,
"rich_tracebacks": False,
"show_path": lambda: True if getenv("API_RUNTIME") == "dev" else False,
"tracebacks_show_locals": False,
},
},
"loggers": {
"": { # Root logger configuration
"level": "INFO",
"handlers": ["rich"],
"propagate": True,
},
},
}
def configure_logging():
# Apply the dictionary configuration
logging.config.dictConfig(LOGGING_CONFIG)
# Get and return the logger
logger = logging.getLogger(LOGGER_NAME)
return logger
logger: logging.Logger = configure_logging()
def set_log_level_to_debug():
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.DEBUG)
# Update handler level as well
for handler in logger.handlers:
handler.setLevel(logging.DEBUG)
def set_log_level_to_info():
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.INFO)
# Update handler level as well
for handler in logger.handlers:
handler.setLevel(logging.INFO)
# Set initial log level
set_log_level_to_info()
+1 -2
View File
@@ -1,8 +1,7 @@
from fastapi import Request from fastapi import Request
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from agentic_security.logutils import logger
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware): class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
+8 -13
View File
@@ -8,14 +8,14 @@ except ImportError:
def generate_banner( def generate_banner(
title: str = "Agentic Security", title="Agentic Security",
font: str = "slant", font="slant",
version: str = "v2.1.0", version="v2.1.0",
tagline: str = "Proactive Threat Detection & Automated Security Protocols", tagline="Proactive Threat Detection & Automated Security Protocols",
author: str = "Developed by: [Security Team]", author="Developed by: [Security Team]",
website: str = "Website: https://github.com/msoedov/agentic_security", website="Website: https://github.com/msoedov/agentic_security",
warning: str | None = "", # Using Optional for warning since it might be None warning="",
) -> str: ):
"""Generate a visually enhanced banner with dynamic width and borders.""" """Generate a visually enhanced banner with dynamic width and borders."""
# Define the text elements # Define the text elements
@@ -80,13 +80,8 @@ def generate_banner(
def init_banner(): def init_banner():
return
ver = version("agentic_security") ver = version("agentic_security")
try:
print(generate_banner(version=ver)) print(generate_banner(version=ver))
except Exception:
# UnicodeEncodeError with codec on some systems
pass
if __name__ == "__main__": if __name__ == "__main__":
+22 -22
View File
@@ -6,11 +6,11 @@ from json import JSONDecodeError
import httpx import httpx
import pandas as pd import pandas as pd
from loguru import logger
from skopt import Optimizer from skopt import Optimizer
from skopt.space import Real from skopt.space import Real
from agentic_security.http_spec import Modality from agentic_security.http_spec import Modality
from agentic_security.logutils import logger
from agentic_security.primitives import Scan, ScanResult from agentic_security.primitives import Scan, ScanResult
from agentic_security.probe_actor.cost_module import calculate_cost from agentic_security.probe_actor.cost_module import calculate_cost
from agentic_security.probe_actor.refusal import refusal_heuristic from agentic_security.probe_actor.refusal import refusal_heuristic
@@ -107,17 +107,6 @@ async def process_prompt_batch(
return total_tokens, failures return total_tokens, failures
async def with_error_handling(agen):
try:
async for t in agen:
yield t
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
finally:
yield ScanResult.status_msg("Scan completed.")
async def perform_single_shot_scan( async def perform_single_shot_scan(
request_factory, request_factory,
max_budget: int, max_budget: int,
@@ -131,6 +120,7 @@ async def perform_single_shot_scan(
max_budget = max_budget * BUDGET_MULTIPLIER max_budget = max_budget * BUDGET_MULTIPLIER
selected_datasets = [m for m in datasets if m["selected"]] selected_datasets = [m for m in datasets if m["selected"]]
request_factory = multi_modality_spec(request_factory) request_factory = multi_modality_spec(request_factory)
try:
yield ScanResult.status_msg("Loading datasets...") yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts( prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in selected_datasets], dataset_names=[m["dataset_name"] for m in selected_datasets],
@@ -175,8 +165,9 @@ async def perform_single_shot_scan(
processed_prompts += 1 processed_prompts += 1
module_prompts += 1 # Fixed increment syntax module_prompts += 1 # Fixed increment syntax
# Calculate progress based on total processed prompts # Calculate progress based on total processed prompts
progress = 100 * processed_prompts / total_prompts if total_prompts else 0 progress = (
progress = progress % 100 100 * processed_prompts / total_prompts if total_prompts else 0
)
total_tokens -= tokens total_tokens -= tokens
start = time.time() start = time.time()
@@ -243,6 +234,12 @@ async def perform_single_shot_scan(
) )
df.to_csv("failures.csv", index=False) df.to_csv("failures.csv", index=False)
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
finally:
yield ScanResult.status_msg("Scan completed.")
async def perform_many_shot_scan( async def perform_many_shot_scan(
request_factory, request_factory,
@@ -258,6 +255,7 @@ async def perform_many_shot_scan(
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Perform a multi-step security scan with probe injection.""" """Perform a multi-step security scan with probe injection."""
request_factory = multi_modality_spec(request_factory) request_factory = multi_modality_spec(request_factory)
try:
# Load main and probe datasets # Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...") yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts( prompt_modules = prepare_prompts(
@@ -295,8 +293,9 @@ async def perform_many_shot_scan(
return return
tokens = 0 tokens = 0
processed_prompts += 1 processed_prompts += 1
progress = 100 * processed_prompts / total_prompts if total_prompts else 0 progress = (
progress = progress % 100 100 * processed_prompts / total_prompts if total_prompts else 0
)
full_prompt = "" full_prompt = ""
msj_module = random.choice(msj_modules) msj_module = random.choice(msj_modules)
@@ -360,6 +359,11 @@ async def perform_many_shot_scan(
) )
df.to_csv("failures.csv", index=False) df.to_csv("failures.csv", index=False)
except Exception as e:
logger.exception("Scan failed")
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
raise e
def scan_router( def scan_router(
request_factory, request_factory,
@@ -368,8 +372,7 @@ def scan_router(
stop_event: asyncio.Event = None, stop_event: asyncio.Event = None,
): ):
if scan_parameters.enableMultiStepAttack: if scan_parameters.enableMultiStepAttack:
return with_error_handling( return perform_many_shot_scan(
perform_many_shot_scan(
request_factory=request_factory, request_factory=request_factory,
max_budget=scan_parameters.maxBudget, max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets, datasets=scan_parameters.datasets,
@@ -379,10 +382,8 @@ def scan_router(
stop_event=stop_event, stop_event=stop_event,
secrets=scan_parameters.secrets, secrets=scan_parameters.secrets,
) )
)
else: else:
return with_error_handling( return perform_single_shot_scan(
perform_single_shot_scan(
request_factory=request_factory, request_factory=request_factory,
max_budget=scan_parameters.maxBudget, max_budget=scan_parameters.maxBudget,
datasets=scan_parameters.datasets, datasets=scan_parameters.datasets,
@@ -391,4 +392,3 @@ def scan_router(
stop_event=stop_event, stop_event=stop_event,
secrets=scan_parameters.secrets, secrets=scan_parameters.secrets,
) )
)
+5 -1
View File
@@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
from typing import Any from typing import Any
import httpx import httpx
@@ -6,10 +7,13 @@ from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext from pydantic_ai import Agent, RunContext
from agentic_security.http_spec import LLMSpec from agentic_security.http_spec import LLMSpec
from agentic_security.logutils import logger
LLM_SPECS = [] LLM_SPECS = []
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgentSpecification(BaseModel): class AgentSpecification(BaseModel):
name: str | None = Field(None, description="Name of the LLM/agent") name: str | None = Field(None, description="Name of the LLM/agent")
+1 -1
View File
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from agentic_security.refusal_classifier.model import RefusalClassifier from agentic_security.refusal_classifier import RefusalClassifier
classifier = RefusalClassifier() classifier = RefusalClassifier()
classifier.load_model() classifier.load_model()
+5 -29
View File
@@ -1,5 +1,4 @@
import base64 import base64
import logging
import os import os
import platform import platform
import subprocess import subprocess
@@ -8,16 +7,6 @@ import uuid
import httpx import httpx
from cache_to_disk import cache_to_disk from cache_to_disk import cache_to_disk
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioGenerationError(Exception):
"""Custom exception for errors during audio generation."""
pass
def encode(content: bytes) -> str: def encode(content: bytes) -> str:
encoded_content = base64.b64encode(content).decode("utf-8") encoded_content = base64.b64encode(content).decode("utf-8")
@@ -52,25 +41,12 @@ def generate_audio_mac_wav(prompt: str) -> bytes:
# Read the WAV file into memory # Read the WAV file into memory
with open(temp_wav_path, "rb") as f: with open(temp_wav_path, "rb") as f:
audio_bytes = f.read() audio_bytes = f.read()
except subprocess.CalledProcessError as e:
logger.error(f"Subprocess error: {e}")
raise AudioGenerationError("Failed to generate or convert audio.") from e
except FileNotFoundError as e:
logger.error(f"File not found: {e}")
raise AudioGenerationError("Required file not found.") from e
except Exception as e:
logger.exception("Unexpected error occurred.")
raise AudioGenerationError(
"An unexpected error occurred during audio generation."
) from e
finally: finally:
for path in (temp_aiff_path, temp_wav_path): # Clean up the temporary files
try: if os.path.exists(temp_aiff_path):
if os.path.exists(path): os.remove(temp_aiff_path)
os.remove(path) if os.path.exists(temp_wav_path):
except Exception as e: os.remove(temp_wav_path)
logger.warning(f"Failed to delete temporary file {path}: {e}")
# Return the audio bytes # Return the audio bytes
return audio_bytes return audio_bytes
+371 -262
View File
@@ -1,15 +1,14 @@
import io import io
import os import os
import random import random
from collections.abc import Callable, Iterator from functools import lru_cache
from functools import partial
import httpx import httpx
import pandas as pd import pandas as pd
from cache_to_disk import cache_to_disk from cache_to_disk import cache_to_disk
from datasets import load_dataset from datasets import load_dataset
from loguru import logger
from agentic_security.logutils import logger
from agentic_security.probe_data import stenography_fn from agentic_security.probe_data import stenography_fn
from agentic_security.probe_data.models import ProbeDataset from agentic_security.probe_data.models import ProbeDataset
from agentic_security.probe_data.modules import ( from agentic_security.probe_data.modules import (
@@ -20,200 +19,281 @@ from agentic_security.probe_data.modules import (
rl_model, rl_model,
) )
# Type aliases for clarity
FilterFn = Callable[[pd.Series], bool]
ColumnMappings = dict[str, str]
DatasetLoader = Callable[[], ProbeDataset]
@cache_to_disk()
def load_dataset_general(
dataset_name,
dataset_split="train",
column_mappings=None,
filter_fn=None,
custom_url=None,
additional_metadata=None,
):
"""
Generalized function to load datasets with flexible configurations.
# Core data loading utilities :param dataset_name: Name of the dataset or URL for custom CSVs
def fetch_csv_content(url: str) -> str: :param dataset_split: Split to load from the dataset (e.g., "train")
"""Fetch CSV content from a URL.""" :param column_mappings: Dictionary mapping dataset columns to expected keys, e.g., {'prompt': 'query'}
response = httpx.get(url) :param filter_fn: A filtering function that takes a row and returns True/False
return response.content.decode("utf-8") :param custom_url: URL for custom CSV datasets
:param additional_metadata: Additional metadata to include in the ProbeDataset
:return: A ProbeDataset object with the processed data
"""
if custom_url:
logger.info(f"Loading custom CSV dataset from {custom_url}")
r = httpx.get(custom_url)
content = r.content
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
else:
logger.info(f"Loading dataset {dataset_name} from Hugging Face datasets")
dataset = load_dataset(dataset_name)
df = pd.DataFrame(dataset[dataset_split])
# Apply column mappings if provided
if column_mappings:
df.rename(columns=column_mappings, inplace=True)
def load_df_from_source(source: str, is_url: bool = False) -> pd.DataFrame: # Filter rows if filter_fn is provided
"""Load DataFrame from either URL or Hugging Face dataset."""
if is_url:
logger.info(f"Loading custom CSV from {source}")
return pd.read_csv(io.StringIO(fetch_csv_content(source)))
logger.info(f"Loading dataset {source} from Hugging Face")
return pd.DataFrame(load_dataset(source)["train"])
def transform_df(
df: pd.DataFrame,
mappings: ColumnMappings | None = None,
filter_fn: FilterFn | None = None,
) -> pd.DataFrame:
"""Apply transformations to DataFrame."""
result = df.copy()
if mappings:
result.rename(columns=mappings, inplace=True)
if filter_fn: if filter_fn:
result = result[result.apply(filter_fn, axis=1)] df = df[df.apply(filter_fn, axis=1)]
return result
# Extract prompts
prompts = df[column_mappings.get("prompt", "prompt")].tolist()
def create_probe_dataset(
name: str, prompts: list[str], metadata: dict = None
) -> ProbeDataset:
"""Create a ProbeDataset from prompts."""
metadata = metadata or {}
return ProbeDataset( return ProbeDataset(
dataset_name=name, dataset_name=dataset_name,
metadata=metadata, metadata=additional_metadata or {},
prompts=prompts, prompts=prompts,
tokens=sum(len(str(s).split()) for s in prompts), tokens=count_words_in_list(prompts),
approx_cost=0.0, approx_cost=0.0,
) )
# Generalized dataset loader def count_words_in_list(str_list):
"""Calculate the total number of words in a given list of strings.
:param str_list: List of strings
:return: Total number of words across all strings in the list
"""
total_words = sum(len(str(s).split()) for s in str_list)
return total_words
@cache_to_disk() @cache_to_disk()
def load_dataset_generic( def load_dataset_v1():
name: str, return load_dataset_general(
mappings: ColumnMappings | None = None, dataset_name="ShawnMenz/DAN_jailbreak",
filter_fn: FilterFn | None = None, column_mappings={"prompt": "prompt", "jailbreak": "jailbreak"},
url: str | None = None, filter_fn=lambda row: row["jailbreak"],
metadata: dict | None = None,
) -> ProbeDataset:
"""Load and process a dataset with flexible configuration."""
df = load_df_from_source(url or name, is_url=bool(url))
transformed_df = transform_df(df, mappings, filter_fn)
prompt_col = mappings.get("prompt", "prompt") if mappings else "prompt"
prompts = transformed_df[prompt_col].tolist()
return create_probe_dataset(name, prompts, metadata)
# Dataset-specific configurations
DATASET_CONFIGS = {
"ShawnMenz/DAN_jailbreak": {
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["jailbreak"],
},
"deepset/prompt-injections": {
"mappings": {"prompt": "text"},
"filter_fn": lambda row: row["label"],
},
"notrichardren/refuse-to-answer-prompts": {
"mappings": {"prompt": "claim"},
"filter_fn": lambda row: row["label"],
},
"rubend18/ChatGPT-Jailbreak-Prompts": {
"mappings": {"prompt": "Prompt"},
"filter_fn": lambda row: row["label"],
},
"markush1/LLM-Jailbreak-Classifier": {
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["classification"] == "jailbreak",
},
"ShawnMenz/jailbreak_sft_rm_ds": {
"url": "hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: row["jailbreak"] == "jailbreak",
},
"verazuo/jailbreak_llms/2023_05_07": {
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: bool(row["jailbreak"]),
},
"verazuo/jailbreak_llms/2023_12_25": {
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv",
"mappings": {"prompt": "prompt"},
"filter_fn": lambda row: bool(row["jailbreak"]),
},
}
# Additional generic dataset configurations
DATASET_CONFIGS_GENERICS = {
"simonycl/aya-23-8B_advbench_jailbreak": {"mappings": {"prompt": "prompt"}},
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": {},
"karanxa/dolphin-jailbreak-finetuning-dataset": {"mappings": {"prompt": "text"}},
"karanxa/llama-2-jailbreak-dataset": {"mappings": {"prompt": "text"}},
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": {
"mappings": {"prompt": "text"}
},
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": {
"mappings": {"prompt": "jailbreak_prompt_text"}
},
"jackhhao/jailbreak-classification": {"mappings": {"prompt": "prompt"}},
"walledai/JailbreakBench": {"mappings": {"prompt": "prompt"}},
"walledai/JailbreakHub": {"mappings": {"prompt": "prompt"}},
"Granther/evil-jailbreak": {"mappings": {"prompt": "text"}},
"sevdeawesome/jailbreak_success": {"mappings": {"prompt": "jailbreak_prompt_text"}},
"IDA-SERICS/Disaster-tweet-jailbreaking": {"mappings": {"prompt": "prompt_attack"}},
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": {"mappings": {"prompt": "question"}},
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": {"mappings": {"prompt": "prompt"}},
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": {
"mappings": {"prompt": "prompt"}
},
}
# Dataset factory
def create_dataset_loader(name: str, config: dict) -> DatasetLoader:
"""Create a dataset loader from configuration."""
return partial(
load_dataset_generic,
name=name,
mappings=config.get("mappings"),
filter_fn=config.get("filter_fn"),
url=config.get("url"),
) )
# Specialized loaders
@cache_to_disk() @cache_to_disk()
def load_multi_dataset(name: str, sub_datasets: list[str]) -> ProbeDataset: def load_dataset_v2():
"""Load and combine multiple sub-datasets.""" return load_dataset_general(
prompts = [] dataset_name="deepset/prompt-injections",
for sub in sub_datasets: column_mappings={"prompt": "text", "jailbreak": "label"},
dataset = load_dataset(name, sub) filter_fn=lambda row: row["label"],
prompts.extend(dataset["train"]["query"]) )
return create_probe_dataset(f"{name}_combined", prompts)
@cache_to_disk() @cache_to_disk()
def load_jailbreak_v28k() -> ProbeDataset: def load_dataset_v4():
"""Load JailBreakV-28K dataset.""" return load_dataset_general(
df = pd.read_csv("hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv") dataset_name="notrichardren/refuse-to-answer-prompts",
prompts = df["jailbreak_query"].tolist() column_mappings={"prompt": "claim", "jailbreak": "label"},
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", prompts) filter_fn=lambda row: row["label"],
)
@cache_to_disk() @cache_to_disk()
def load_local_csv() -> ProbeDataset: def load_dataset_v3():
"""Load prompts from local CSV files.""" return load_dataset_general(
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")] dataset_name="rubend18/ChatGPT-Jailbreak-Prompts",
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}") column_mappings={"prompt": "Prompt"},
filter_fn=lambda row: row["label"],
)
prompts = []
for file in csv_files: @cache_to_disk()
def load_dataset_v6():
dataset = load_dataset("markush1/LLM-Jailbreak-Classifier")
bad_prompts = [
p
for p, c in zip(dataset["train"]["prompt"], dataset["train"]["classification"])
if c == "jailbreak"
]
return ProbeDataset(
dataset_name="markush1/LLM-Jailbreak-Classifier",
metadata={},
prompts=bad_prompts,
tokens=count_words_in_list(bad_prompts),
approx_cost=0.0,
)
@cache_to_disk()
def load_dataset_v7():
splits = {
"mini_JailBreakV_28K": "JailBreakV_28K/mini_JailBreakV_28K.csv",
"JailBreakV_28K": "JailBreakV_28K/JailBreakV_28K.csv",
}
df = pd.read_csv(
"hf://datasets/JailbreakV-28K/JailBreakV-28k/" + splits["JailBreakV_28K"]
)
bad_prompts = df["jailbreak_query"].tolist()
print(df.shape)
return ProbeDataset(
dataset_name="JailbreakV-28K/JailBreakV-28k",
metadata={},
prompts=bad_prompts,
tokens=count_words_in_list(bad_prompts),
approx_cost=0.0,
)
@cache_to_disk()
def load_dataset_v8():
df = pd.read_csv(
"hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
names=["jailbreak", "prompt"],
)
filtered = df[df["jailbreak"] == "jailbreak"]["prompt"].tolist()
return ProbeDataset(
dataset_name="JailbreakV-28K/JailBreakV-28k",
metadata={},
prompts=filtered,
tokens=count_words_in_list(filtered),
approx_cost=0.0,
)
@cache_to_disk()
def load_dataset_v5():
ds = []
for c in [
"AdvBench",
"ForbiddenQuestion",
"MJP",
"MaliciousInstruct",
"QuestionList",
]:
dataset = load_dataset("Lemhf14/EasyJailbreak_Datasets", c)
bad_prompts = dataset["train"]["query"]
ds.extend(bad_prompts)
return ProbeDataset(
dataset_name="Lemhf14/EasyJailbreak_Datasets",
metadata={},
prompts=ds,
tokens=count_words_in_list(ds),
approx_cost=0.0,
)
@cache_to_disk()
def load_generic_csv(url, name, column="prompt", predicator=None):
r = httpx.get(url)
content = r.content
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
logger.info(f"Loaded {len(df)} prompts from {url}")
filtered_prompts = df[df.apply(predicator, axis=1)][column].tolist()
return ProbeDataset(
dataset_name=name,
metadata={},
prompts=filtered_prompts,
tokens=count_words_in_list(filtered_prompts),
approx_cost=0.0,
)
def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
# ## Datasets used and cleaned:
# markush1/LLM-Jailbreak-Classifier
# 1. Open-Orca/OpenOrca
# 2. ShawnMenz/DAN_jailbreak
# 3. EddyLuo/JailBreakV_28K
# 4. https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/jailbreak_prompts.csv
dataset_map = {
"ShawnMenz/DAN_jailbreak": load_dataset_v1,
"deepset/prompt-injections": load_dataset_v2,
"notrichardren/refuse-to-answer-prompts": load_dataset_v4,
"rubend18/ChatGPT-Jailbreak-Prompts": load_dataset_v3,
"Lemhf14/EasyJailbreak_Datasets": load_dataset_v5,
"markush1/LLM-Jailbreak-Classifier": load_dataset_v6,
"JailbreakV-28K/JailBreakV-28k": load_dataset_v7,
"ShawnMenz/jailbreak_sft_rm_ds": load_dataset_v8,
"verazuo/jailbreak_llms/2023_05_07": lambda: load_generic_csv(
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
name="verazuo/jailbreak_llms/2023_05_07",
column="prompt",
predicator=lambda x: bool(x["jailbreak"]),
),
"verazuo/jailbreak_llms/2023_12_25.csv": lambda: load_generic_csv(
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv.csv",
name="verazuo/jailbreak_llms/2023_12_25.csv",
column="prompt",
predicator=lambda x: bool(x["jailbreak"]),
),
"Custom CSV": load_local_csv,
}
dataset_map.update(dataset_map_generics)
group = []
for dataset_name in dataset_names:
if dataset_name in dataset_map:
logger.info(f"Loading {dataset_name}")
try: try:
df = pd.read_csv(file) group.append(dataset_map[dataset_name]())
if "prompt" in df.columns:
prompts.extend(df["prompt"].tolist())
else:
logger.warning(f"File {file} lacks 'prompt' column")
except Exception as e: except Exception as e:
logger.error(f"Error reading {file}: {e}") logger.error(f"Error loading {dataset_name}: {e}")
return create_probe_dataset("Local CSV", prompts, {"src": str(csv_files)}) dynamic_datasets = {
"AgenticBackend": lambda opts: dataset_from_iterator(
"AgenticBackend",
fine_tuned.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Steganography": lambda opts: Stenography(group),
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
"llm-adaptive-attacks",
adaptive_attacks.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
),
"Garak": lambda opts: dataset_from_iterator(
"Garak",
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
"Reinforcement Learning Optimization",
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"InspectAI": lambda opts: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
lazy=True,
),
"GPT fuzzer": lambda opts: [],
}
dynamic_groups = []
options = options or [{} for _ in dataset_names]
for dataset_name, opts in zip(dataset_names, options):
if dataset_name in dynamic_datasets:
logger.info(f"Loading {dataset_name}")
ds = dynamic_datasets[dataset_name](opts)
for g in ds:
dynamic_groups.append(g)
return group + dynamic_groups
# Stenography transformer class Stenography:
class StenographyTransformer: fn_library = {
"""Apply stenography transformations to datasets."""
TRANSFORMATIONS = {
"rot5": stenography_fn.rot5, "rot5": stenography_fn.rot5,
"rot13": stenography_fn.rot13, "rot13": stenography_fn.rot13,
"base64": stenography_fn.base64_encode, "base64": stenography_fn.base64_encode,
@@ -226,32 +306,71 @@ class StenographyTransformer:
"zigzag_obfuscation": stenography_fn.zigzag_obfuscation, "zigzag_obfuscation": stenography_fn.zigzag_obfuscation,
} }
def __init__(self, datasets: list[ProbeDataset]): def __init__(self, prompt_groups: [ProbeDataset]):
self.datasets = datasets self.prompt_groups = prompt_groups
def transform(self) -> Iterator[ProbeDataset]: def __iter__(self):
"""Generate transformed datasets.""" return self.apply()
for dataset in self.datasets:
size = len(dataset.prompts)
sample_size = int(size * 0.25)
for name, fn in self.TRANSFORMATIONS.items(): def apply(self):
logger.info(f"Applying {name} to {dataset.dataset_name}") for prompt_group in self.prompt_groups:
sampled = random.sample(dataset.prompts, sample_size) size = len(prompt_group.prompts)
transformed = [ for name, fn in self.fn_library.items():
f"Follow the instructions in {name}: {fn(p)}" for p in sampled logger.info(f"Applying {name} to {prompt_group.dataset_name}")
selected_prompts = random.sample(prompt_group.prompts, int(size * 0.25))
prompts = [
f"Follow the instructions in {name}: {fn(p)}"
for p in selected_prompts
] ]
yield create_probe_dataset( yield ProbeDataset(
f"stenography.{name}({dataset.dataset_name})", transformed dataset_name=f"stenography.{name}({prompt_group.dataset_name})",
metadata={},
prompts=prompts,
tokens=count_words_in_list(prompts),
approx_cost=0.0,
) )
def dataset_from_iterator( @lru_cache
name: str, iterator, lazy: bool = False def load_local_csv() -> ProbeDataset:
) -> list[ProbeDataset]: csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
"""Convert an iterator into a list of ProbeDataset objects.""" logger.info(f"Found {len(csv_files)} CSV files")
logger.info(f"CSV files: {csv_files}")
prompt_list = []
for file in csv_files:
try:
df = pd.read_csv(file)
except Exception as e:
logger.error(f"Error reading {file}: {e}")
continue
# Check if 'prompt' column exists
if "prompt" in df.columns:
prompt_list.extend(df["prompt"].tolist())
else:
logger.warning(f"File {file} does not contain a 'prompt' column")
return ProbeDataset(
dataset_name="Local CSV",
metadata={"src": str(csv_files)},
prompts=prompt_list,
tokens=count_words_in_list(prompt_list),
approx_cost=0.0,
)
def dataset_from_iterator(name: str, iterator, lazy=False) -> list:
"""Convert an iterator into a list of prompts and create a ProbeDataset
object.
Args:
name (str): The name of the dataset.
iterator (iterator): An iterator that generates prompts.
Returns:
list: A list containing a single ProbeDataset object.
"""
prompts = list(iterator) if not lazy else iterator prompts = list(iterator) if not lazy else iterator
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0 tokens = count_words_in_list(prompts) if not lazy else 0
dataset = ProbeDataset( dataset = ProbeDataset(
dataset_name=name, dataset_name=name,
metadata={}, metadata={},
@@ -263,85 +382,75 @@ def dataset_from_iterator(
return [dataset] return [dataset]
# Main dataset preparation # TODO: refactor this abstraction
def prepare_prompts(
dataset_names: list[str],
budget: float,
tools_inbox=None,
options: list[dict] = None,
) -> list[ProbeDataset]:
"""Prepare datasets based on names and options."""
# Base dataset loaders
dataset_loaders = {
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS.items()},
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS_GENERICS.items()},
"Lemhf14/EasyJailbreak_Datasets": partial(
load_multi_dataset,
name="Lemhf14/EasyJailbreak_Datasets",
sub_datasets=[
"AdvBench",
"ForbiddenQuestion",
"MJP",
"MaliciousInstruct",
"QuestionList",
],
),
"JailbreakV-28K/JailBreakV-28k": load_jailbreak_v28k,
"Local CSV": load_local_csv,
}
# Dynamic dataset loaders dataset_map_generics = {
dynamic_loaders = { "simonycl/aya-23-8B_advbench_jailbreak": lambda: load_dataset_general(
"AgenticBackend": lambda opts: dataset_from_iterator( dataset_name="simonycl/aya-23-8B_advbench_jailbreak",
"AgenticBackend", column_mappings={"prompt": "prompt"},
fine_tuned.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
), ),
"Steganography": lambda opts: list(StenographyTransformer([]).transform()), "acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": lambda: load_dataset_general(
"llm-adaptive-attacks": lambda opts: dataset_from_iterator( dataset_name="acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b"
"llm-adaptive-attacks",
adaptive_attacks.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
), ),
"Garak": lambda opts: dataset_from_iterator( "karanxa/dolphin-jailbreak-finetuning-dataset": lambda: load_dataset_general(
"Garak", dataset_name="karanxa/dolphin-jailbreak-finetuning-dataset",
garak_tool.Module([], tools_inbox=tools_inbox, opts=opts).apply(), column_mappings={"prompt": "text"},
lazy=True,
), ),
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator( "karanxa/llama-2-jailbreak-dataset": lambda: load_dataset_general(
"Reinforcement Learning Optimization", dataset_name="karanxa/llama-2-jailbreak-dataset",
rl_model.Module([], tools_inbox=tools_inbox, opts=opts).apply(), column_mappings={"prompt": "text"},
lazy=True,
), ),
"InspectAI": lambda opts: dataset_from_iterator( "karanxa/llama2-uncensored-jailbreak-dataset-finetuning": lambda: load_dataset_general(
"InspectAI", dataset_name="karanxa/llama2-uncensored-jailbreak-dataset-finetuning",
inspect_ai_tool.Module([], tools_inbox=tools_inbox).apply(), column_mappings={"prompt": "text"},
lazy=True,
), ),
"GPT fuzzer": lambda opts: [], "liuyanchen1015/Llama-3.2-1B_jailbreak_responses": lambda: load_dataset_general(
} dataset_name="liuyanchen1015/Llama-3.2-1B_jailbreak_responses",
column_mappings={"prompt": "jailbreak_prompt_text"},
options = options or [{} for _ in dataset_names] ),
datasets = [] "liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": lambda: load_dataset_general(
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses",
# Load base datasets column_mappings={"prompt": "jailbreak_prompt_text"},
for name, opts in zip(dataset_names, options): ),
if name in dataset_loaders: "liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": lambda: load_dataset_general(
logger.info(f"Loading base dataset {name}") dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment",
try: column_mappings={"prompt": "jailbreak_prompt_text"},
datasets.append(dataset_loaders[name]()) ),
except Exception as e: "jackhhao/jailbreak-classification": lambda: load_dataset_general(
logger.error(f"Error loading {name}: {e}") dataset_name="jackhhao/jailbreak-classification",
column_mappings={"prompt": "prompt"},
# Load dynamic datasets and apply transformations ),
for name, opts in zip(dataset_names, options): "markush1/LLM-Jailbreak-Classifier": lambda: load_dataset_general(
if name in dynamic_loaders: dataset_name="markush1/LLM-Jailbreak-Classifier",
logger.info(f"Loading dynamic dataset {name}") column_mappings={"prompt": "prompt"},
try: ),
dynamic_result = dynamic_loaders[name](opts) "walledai/JailbreakBench": lambda: load_dataset_general(
datasets.extend(dynamic_result) dataset_name="walledai/JailbreakBench", column_mappings={"prompt": "prompt"}
except Exception as e: ),
logger.error(f"Error loading dynamic {name}: {e}") "walledai/JailbreakHub": lambda: load_dataset_general(
elif name == "Steganography": dataset_name="walledai/JailbreakHub", column_mappings={"prompt": "prompt"}
datasets.extend(list(StenographyTransformer(datasets).transform())) ),
"Granther/evil-jailbreak": lambda: load_dataset_general(
return datasets dataset_name="Granther/evil-jailbreak", column_mappings={"prompt": "text"}
),
"sevdeawesome/jailbreak_success": lambda: load_dataset_general(
dataset_name="sevdeawesome/jailbreak_success",
column_mappings={"prompt": "jailbreak_prompt_text"},
),
"IDA-SERICS/Disaster-tweet-jailbreaking": lambda: load_dataset_general(
dataset_name="IDA-SERICS/Disaster-tweet-jailbreaking",
column_mappings={"prompt": "prompt_attack"},
),
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": lambda: load_dataset_general(
dataset_name="GeorgeDaDude/Jailbreak_Complete_DS_labeled",
column_mappings={"prompt": "question"},
),
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": lambda: load_dataset_general(
dataset_name="dayone3nder/jailbreak_prompt_JBB_sft_trainset",
column_mappings={"prompt": "prompt"},
),
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": lambda: load_dataset_general(
dataset_name="dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset",
column_mappings={"prompt": "prompt"},
),
}
@@ -3,8 +3,7 @@ import io
import httpx import httpx
import pandas as pd import pandas as pd
from loguru import logger
from agentic_security.logutils import logger
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv" url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
@@ -3,8 +3,7 @@ import os
import uuid as U import uuid as U
import httpx import httpx
from loguru import logger
from agentic_security.logutils import logger
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b") AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
@@ -66,7 +65,7 @@ class Module:
return {} return {}
async def fetch_prompts(self) -> list[str]: async def fetch_prompts(self) -> list[str]:
api_url = "https://mcp.metaheuristic.co/infer" api_url = "https://edge.metaheuristic.co/infer"
headers = { headers = {
"Authorization": f"Bearer {AUTH_TOKEN}", "Authorization": f"Bearer {AUTH_TOKEN}",
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -4,7 +4,7 @@ import json
import os import os
import subprocess import subprocess
from agentic_security.logutils import logger from loguru import logger
# TODO: add probes modules # TODO: add probes modules
@@ -2,7 +2,7 @@ import asyncio
import importlib.util import importlib.util
import os import os
from agentic_security.logutils import logger from loguru import logger
inspect_ai_task = ( inspect_ai_task = (
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py") __file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
@@ -8,8 +8,7 @@ from typing import Deque
import numpy as np import numpy as np
import requests import requests
from loguru import logger
from agentic_security.logutils import logger
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b") AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
@@ -216,7 +215,7 @@ class Module:
self.run_id = U.uuid4().hex self.run_id = U.uuid4().hex
self.batch_size = self.opts.get("batch_size", 500) self.batch_size = self.opts.get("batch_size", 500)
self.rl_model = CloudRLPromptSelector( self.rl_model = CloudRLPromptSelector(
prompt_groups, "https://mcp.metaheuristic.co", run_id=self.run_id prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
) )
async def apply(self): async def apply(self):
@@ -36,7 +36,7 @@ def mock_requests() -> Mock:
def mock_rl_selector() -> Mock: def mock_rl_selector() -> Mock:
return CloudRLPromptSelector( return CloudRLPromptSelector(
dataset_prompts, dataset_prompts,
api_url="https://mcp.metaheuristic.co", api_url="https://edge.metaheuristic.co",
) )
@@ -94,7 +94,7 @@ class TestCloudRLPromptSelector:
def test_select_next_prompt_success_service(self, dataset_prompts): def test_select_next_prompt_success_service(self, dataset_prompts):
selector = CloudRLPromptSelector( selector = CloudRLPromptSelector(
dataset_prompts, dataset_prompts,
api_url="https://mcp.metaheuristic.co", api_url="https://edge.metaheuristic.co",
) )
next_prompt = selector.select_next_prompt( next_prompt = selector.select_next_prompt(
"How does RL work?", passed_guard=True "How does RL work?", passed_guard=True
@@ -3,7 +3,7 @@ from unittest.mock import patch
import pandas as pd import pandas as pd
import pytest import pytest
from agentic_security.refusal_classifier.model import RefusalClassifier from .model import RefusalClassifier
@pytest.fixture @pytest.fixture
+11 -81
View File
@@ -7,65 +7,9 @@ import pandas as pd
from matplotlib.cm import ScalarMappable from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize from matplotlib.colors import LinearSegmentedColormap, Normalize
from agentic_security.logutils import logger
from .primitives import Table def plot_security_report(table):
def plot_security_report(table: Table) -> io.BytesIO:
"""
Generates a polar plot representing the security report based on the given data.
Args:
table (Table): The input data table containing security metrics.
Returns:
io.BytesIO: A buffer containing the generated plot image in PNG format.
Returns an empty buffer in case of an error.
"""
try:
return _plot_security_report(table=table)
except (TypeError, ValueError, OverflowError, IndexError, Exception) as e:
logger.error(f"Error in generating the security report: {e}")
return io.BytesIO()
def generate_identifiers(data: pd.DataFrame) -> list[str]:
"""
Generates unique identifiers for the given dataset.
Args:
data (pd.DataFrame): A pandas DataFrame containing security-related data.
Returns:
list[str]: A list of generated identifiers. Returns a list with an empty string in case of an error.
"""
try:
_generate_identifiers(data=data)
except (TypeError, ValueError, Exception) as e:
logger.error(f"Error in generate_identifiers: {e}")
return [""]
def _plot_security_report(table: Table) -> io.BytesIO:
"""
Generates a polar plot-based security report visualizing the failure rates for different modules.
This function processes the input data, sorts it by failure rate, and generates a polar plot
where each bar represents the failure rate for a specific module. The plot includes identifiers,
color-coding based on token count, failure rate values on the bars, and a table listing the modules
and their corresponding failure rates.
Args:
table (Table): A table-like structure (e.g., pandas DataFrame) containing security report data
with columns for failure rate, tokens, and modules.
Returns:
io.BytesIO: A buffer containing the generated plot image in PNG format.
"""
# Data preprocessing # Data preprocessing
logger.info("Data preprocessing started.")
data = pd.DataFrame(table) data = pd.DataFrame(table)
# Sort by failure rate and reset index # Sort by failure rate and reset index
@@ -76,10 +20,10 @@ def _plot_security_report(table: Table) -> io.BytesIO:
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"}) fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"})
fig.set_facecolor("#f0f0f0") fig.set_facecolor("#f0f0f0")
ax.set_facecolor("#f0f0f0") ax.set_facecolor("#f0f0f0")
logger.info("Plot setup complete.")
# Styling parameters # Styling parameters
colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette
# colors = ["#440154", "#3b528b", "#21908c", "#5dc863"] # Viridis-inspired palette
cmap = LinearSegmentedColormap.from_list("custom", colors, N=256) cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max()) norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max())
@@ -130,10 +74,7 @@ def _plot_security_report(table: Table) -> io.BytesIO:
# Title and caption # Title and caption
fig.suptitle( fig.suptitle(
"Security Report for Different Modules", "Security Report for Different Modules", fontsize=16, fontweight="bold", y=1.02
fontsize=16,
fontweight="bold",
y=1.02,
) )
caption = "Report generated by https://github.com/msoedov/agentic_security" caption = "Report generated by https://github.com/msoedov/agentic_security"
fig.text( fig.text(
@@ -171,12 +112,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
data["identifier"], data["failureRate"], data["module"] data["identifier"], data["failureRate"], data["module"]
) )
] ]
table = ax.table(cellText=table_data, loc="right", cellLoc="left") table = ax.table(
cellText=table_data,
loc="right",
cellLoc="left",
)
table.auto_set_font_size(False) table.auto_set_font_size(False)
table.set_fontsize(8) table.set_fontsize(8)
# Adjust table style # Adjust table style
table.scale(1, 0.7) table.scale(1, 0.7)
for (row, col), cell in table.get_celld().items(): for (row, col), cell in table.get_celld().items():
cell.set_edgecolor("none") cell.set_edgecolor("none")
cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0") cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0")
@@ -186,33 +132,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
cell.set_text_props(fontweight="bold") cell.set_text_props(fontweight="bold")
# Adjust layout and save # Adjust layout and save
plt.tight_layout() plt.tight_layout()
buf = io.BytesIO() buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=300, bbox_inches="tight") plt.savefig(buf, format="png", dpi=300, bbox_inches="tight")
plt.close(fig) plt.close(fig)
buf.seek(0) buf.seek(0)
logger.info("Report successfully generated and saved to buffer.")
return buf return buf
def _generate_identifiers(data: pd.DataFrame) -> list[str]: def generate_identifiers(data):
"""
Generates a list of unique identifiers for each row in the given DataFrame.
The identifiers are based on the English alphabet, with each identifier consisting
of a letter followed by a number. The letter represents the "group" of identifiers
(using a letter from A to Z) and the number is a counter within that group. For example:
- A1, A2, ..., A26, B1, B2, ..., Z1, Z2, ...
Args:
data (pd.DataFrame): The input DataFrame containing data for which identifiers
are to be generated.
Returns:
list[str]: A list of unique identifiers as strings, one for each row in the DataFrame.
"""
data_length = len(data) data_length = len(data)
alphabet = string.ascii_uppercase alphabet = string.ascii_uppercase
num_letters = len(alphabet) num_letters = len(alphabet)
+1 -2
View File
@@ -2,8 +2,7 @@ import random
from asyncio import Event from asyncio import Event
from fastapi import APIRouter from fastapi import APIRouter
from loguru import logger
from agentic_security.logutils import logger
from ..core.app import get_current_run, get_tools_inbox from ..core.app import get_current_run, get_tools_inbox
from ..primitives import CompletionRequest, Settings from ..primitives import CompletionRequest, Settings
+5 -14
View File
@@ -1,6 +1,4 @@
from collections.abc import Generator
from datetime import datetime from datetime import datetime
from typing import Any
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@@ -13,8 +11,6 @@ from fastapi import (
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from agentic_security.logutils import logger
from ..core.app import get_stop_event, get_tools_inbox, set_current_run from ..core.app import get_stop_event, get_tools_inbox, set_current_run
from ..dependencies import InMemorySecrets, get_in_memory_secrets from ..dependencies import InMemorySecrets, get_in_memory_secrets
from ..http_spec import LLMSpec from ..http_spec import LLMSpec
@@ -27,14 +23,9 @@ router = APIRouter()
@router.post("/verify") @router.post("/verify")
async def verify( async def verify(
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets) info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
) -> dict[str, int | str | float]: ):
spec = LLMSpec.from_string(info.spec) spec = LLMSpec.from_string(info.spec)
try:
r = await spec.verify() r = await spec.verify()
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail=str(e))
if r.status_code >= 400: if r.status_code >= 400:
raise HTTPException(status_code=r.status_code, detail=r.text) raise HTTPException(status_code=r.status_code, detail=r.text)
return dict( return dict(
@@ -45,7 +36,7 @@ async def verify(
) )
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]: def streaming_response_generator(scan_parameters: Scan):
request_factory = LLMSpec.from_string(scan_parameters.llmSpec) request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
set_current_run(request_factory) set_current_run(request_factory)
@@ -66,7 +57,7 @@ async def scan(
scan_parameters: Scan, scan_parameters: Scan,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
secrets: InMemorySecrets = Depends(get_in_memory_secrets), secrets: InMemorySecrets = Depends(get_in_memory_secrets),
) -> StreamingResponse: ):
scan_parameters.with_secrets(secrets) scan_parameters.with_secrets(secrets)
return StreamingResponse( return StreamingResponse(
streaming_response_generator(scan_parameters), media_type="application/json" streaming_response_generator(scan_parameters), media_type="application/json"
@@ -74,7 +65,7 @@ async def scan(
@router.post("/stop") @router.post("/stop")
async def stop_scan() -> dict[str, str]: async def stop_scan():
get_stop_event().set() get_stop_event().set()
return {"status": "Scan stopped"} return {"status": "Scan stopped"}
@@ -88,7 +79,7 @@ async def scan_csv(
maxBudget: int = Query(10_000), maxBudget: int = Query(10_000),
enableMultiStepAttack: bool = Query(False), enableMultiStepAttack: bool = Query(False),
secrets: InMemorySecrets = Depends(get_in_memory_secrets), secrets: InMemorySecrets = Depends(get_in_memory_secrets),
) -> StreamingResponse: ):
# TODO: content dataset to fuzzer # TODO: content dataset to fuzzer
content = await file.read() # noqa content = await file.read() # noqa
llm_spec = await llmSpec.read() llm_spec = await llmSpec.read()
+1 -2
View File
@@ -1,8 +1,7 @@
import sentry_sdk import sentry_sdk
from loguru import logger
from sentry_sdk.integrations.logging import ignore_logger from sentry_sdk.integrations.logging import ignore_logger
from agentic_security.logutils import logger
from ..primitives import Settings from ..primitives import Settings
@@ -1,6 +1,6 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from agentic_security.app import app from ..app import app
def test_health_check(): def test_health_check():
@@ -5,10 +5,10 @@ import httpx
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from agentic_security.app import app from ..app import app
from agentic_security.primitives import Probe from ..primitives import Probe
from agentic_security.probe_actor.refusal import REFUSAL_MARKS from ..probe_actor.refusal import REFUSAL_MARKS
from agentic_security.probe_data import REGISTRY from ..probe_data import REGISTRY
client = TestClient(app) client = TestClient(app)
@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from agentic_security.routes.report import router from .report import router
client = TestClient(router) client = TestClient(router)
@@ -4,8 +4,8 @@ import pytest
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from agentic_security.primitives import Settings from ..primitives import Settings
from agentic_security.routes.static import get_static_file, router from .static import get_static_file, router
client = TestClient(router) client = TestClient(router)
+2 -3
View File
@@ -101,7 +101,7 @@ Content-Type: application/json
] ]
`, `,
`POST ${SELF_URL}/v1/self-probe-file `POST ${SELF_URL}/v1/self-probe-file
Authorization: Bearer $API_KEY Authorization: Bearer $GROQ_API_KEY
Content-Type: multipart/form-data Content-Type: multipart/form-data
{ {
@@ -204,9 +204,8 @@ let LLM_CONFIGS = [
{ name: 'Claude', prompts: 40000, logo: '/icons/claude.png' }, { name: 'Claude', prompts: 40000, logo: '/icons/claude.png' },
{ name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' }, { name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' },
{ name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' }, { name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' },
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' }, { name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
]; ];
function has_image(spec) { function has_image(spec) {
return spec.includes('<<BASE64_IMAGE>>'); return spec.includes('<<BASE64_IMAGE>>');
-26
View File
@@ -63,32 +63,6 @@
</div> </div>
</section> </section>
</main> </main>
<div class="fixed top-6 right-6 z-50 space-y-3">
<transition-group name="toast">
<div
v-for="toast in toasts"
:key="toast.id"
class="flex items-center p-3 rounded-xl shadow-xl text-white max-w-md animate-toast-in border border-opacity-30"
:class="{
'bg-success-toast border-accent-green': toast.type === 'success',
'bg-error-toast border-accent-red': toast.type === 'error',
'bg-info-toast border-accent-orange': toast.type === 'info'
}"
>
<span class="flex-1 font-medium tracking-wide text-sm">{{ toast.message }}</span>
<button
@click="removeToast(toast.id)"
class="ml-3 focus:outline-none hover:opacity-80 transition-opacity"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
</transition-group>
</div>
<main class="max-w-6xl mx-auto space-y-8"> <main class="max-w-6xl mx-auto space-y-8">
<section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false"> <section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false">
<h2 class="text-2xl font-bold mb-4">Select a Config</h2> <h2 class="text-2xl font-bold mb-4">Select a Config</h2>
-17
View File
@@ -25,8 +25,6 @@ var app = new Vue({
showModules: false, showModules: false,
showLogs: false, showLogs: false,
showConsentModal: true, showConsentModal: true,
toasts: [], // Array to store toast notifications
toastTimeout: 3000, // Duration in milliseconds (3 seconds)
statusDotClass: 'bg-gray-500', // Default status dot class statusDotClass: 'bg-gray-500', // Default status dot class
statusText: 'Verified', // Default status text statusText: 'Verified', // Default status text
statusClass: 'bg-green-500 text-dark-bg', // Default status class statusClass: 'bg-green-500 text-dark-bg', // Default status class
@@ -95,19 +93,6 @@ var app = new Vue({
}, },
methods: { methods: {
showToast(message, type = 'success') {
const id = Date.now(); // Unique ID for each toast
this.toasts.push({ id, message, type });
// Automatically remove toast after timeout
setTimeout(() => {
this.removeToast(id);
}, this.toastTimeout);
},
removeToast(id) {
this.toasts = this.toasts.filter(toast => toast.id !== id);
},
focusTextarea() { focusTextarea() {
this.isFocused = true; this.isFocused = true;
self = this.$refs; self = this.$refs;
@@ -225,12 +210,10 @@ var app = new Vue({
if (!response.ok) { if (!response.ok) {
this.updateStatusDot(false); this.updateStatusDot(false);
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r); this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
this.showToast('Integration verification failed', 'error');
} else { } else {
this.errorMsg = ''; this.errorMsg = '';
this.updateStatusDot(true); this.updateStatusDot(true);
this.okMsg = 'Integration verified'; this.okMsg = 'Integration verified';
this.showToast('Integration verified successfully', 'success');
this.integrationVerified = true; this.integrationVerified = true;
// console.log('Integration verified', this.integrationVerified); // console.log('Integration verified', this.integrationVerified);
// this.$forceUpdate(); // this.$forceUpdate();
+2 -54
View File
@@ -19,7 +19,7 @@
technopollas: ['Technopollas', 'sans-serif'], technopollas: ['Technopollas', 'sans-serif'],
}, },
colors: { colors: {
dark: { t1: {
bg: '#0D0D0D', // Jet Black bg: '#0D0D0D', // Jet Black
card: '#1A1A1A', // Dark Carbon Fiber card: '#1A1A1A', // Dark Carbon Fiber
text: '#FFFFFF', text: '#FFFFFF',
@@ -30,7 +30,7 @@
yellow: '#2E4053', // Jet Black yellow: '#2E4053', // Jet Black
}, },
}, },
dark1: { dark: {
bg: '#121212', bg: '#121212',
card: '#1E1E1E', card: '#1E1E1E',
text: '#FFFFFF', text: '#FFFFFF',
@@ -86,7 +86,6 @@
} }
} }
</script> </script>
<style> <style>
.scrollbar-hide::-webkit-scrollbar { .scrollbar-hide::-webkit-scrollbar {
display: none; display: none;
@@ -97,55 +96,4 @@
} }
</style> </style>
<style>
/* Toast-specific colors */
.bg-success-toast {
background: #1C3F74
}
.bg-error-toast {
background: #85144B
}
.bg-info-toast {
background: #FFC300
}
.border-accent-green, .border-accent-red, .border-accent-orange {
border-color: rgba(255, 255, 255, 0.1); /* Subtle white border for depth */
}
/* Animation */
.animate-toast-in {
animation: tSlideIn 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
}
@keyframes tSlideIn {
from {
transform: translateX(120%) scale(0.95);
opacity: 0;
}
to {
transform: translateX(0) scale(1);
opacity: 1;
}
}
.toast-enter-active,
.toast-leave-active {
transition: all 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
}
.toast-enter-from,
.toast-leave-to {
transform: translateX(120%) scale(0.95);
opacity: 0;
}
/* Hover effect */
[toast-type]:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.6);
}
</style>
</head> </head>
@@ -8,7 +8,7 @@ import time
import pytest import pytest
import agentic_security.test_spec_assets as test_spec_assets import agentic_security.test_spec_assets as test_spec_assets
from agentic_security.lib import SecurityScanner from agentic_security.lib import AgenticSecurity
def has_module(module_name): def has_module(module_name):
@@ -57,7 +57,7 @@ class TestLibraryLevel:
maxBudget = 1000000 maxBudget = 1000000
max_th = 0.3 max_th = 0.3
datasets = make_test_registry() datasets = make_test_registry()
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th) result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict) assert isinstance(result, dict)
print(result) print(result)
assert len(result) in [0, 1] assert len(result) in [0, 1]
@@ -68,7 +68,7 @@ class TestLibraryLevel:
maxBudget = 1000 maxBudget = 1000
max_th = 0.3 max_th = 0.3
datasets = make_test_registry() datasets = make_test_registry()
result = SecurityScanner.scan( result = AgenticSecurity.scan(
llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True
) )
assert isinstance(result, dict) assert isinstance(result, dict)
@@ -93,7 +93,7 @@ class TestLibraryLevel:
"opts": {"port": 9094}, "opts": {"port": 9094},
}, },
] ]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th) result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict) assert isinstance(result, dict)
print(result) print(result)
assert len(result) in [0, 1] assert len(result) in [0, 1]
@@ -120,7 +120,7 @@ class TestLibraryLevel:
"modality": "text", "modality": "text",
}, },
] ]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th) result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict) assert isinstance(result, dict)
print(result) print(result)
assert len(result) in [0, 1] assert len(result) in [0, 1]
@@ -148,7 +148,7 @@ class TestLibraryLevel:
"modality": "text", "modality": "text",
}, },
] ]
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th) result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
assert isinstance(result, dict) assert isinstance(result, dict)
print(result) print(result)
assert len(result) in [0, 1] assert len(result) in [0, 1]
@@ -164,10 +164,10 @@ class TestEntrypointCI:
temp_path = os.path.join(tmpdir, "custom_agesec.toml") temp_path = os.path.join(tmpdir, "custom_agesec.toml")
# Override default_path to the temporary path # Override default_path to the temporary path
SecurityScanner.default_path = temp_path AgenticSecurity.default_path = temp_path
# Generate the default configuration # Generate the default configuration
security = SecurityScanner() security = AgenticSecurity()
security.generate_default_settings() security.generate_default_settings()
# Check that the config file was created at the temporary path # Check that the config file was created at the temporary path
@@ -189,17 +189,17 @@ class TestEntrypointCI:
temp_path = os.path.join(tmpdir, "custom_agesec.toml") temp_path = os.path.join(tmpdir, "custom_agesec.toml")
# Override default_path to the temporary path # Override default_path to the temporary path
SecurityScanner.default_path = temp_path AgenticSecurity.default_path = temp_path
# Generate the default configuration # Generate the default configuration
security = SecurityScanner() security = AgenticSecurity()
security.generate_default_settings() security.generate_default_settings()
# Load the generated configuration # Load the generated configuration
SecurityScanner.load_config(temp_path) AgenticSecurity.load_config(temp_path)
# Validate loaded configuration # Validate loaded configuration
config = SecurityScanner.config config = AgenticSecurity.config
assert ( assert (
config["general"]["maxBudget"] == 1000000 config["general"]["maxBudget"] == 1000000
), "maxBudget should be 1000000" ), "maxBudget should be 1000000"
+2
View File
@@ -0,0 +1,2 @@
from:python-pytest-poetry
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
+5 -1
View File
@@ -33,7 +33,7 @@ The `LLMSpec` class is the core of the HTTP specification. It provides the follo
### Methods ### Methods
- **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object. - **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object.
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> null`**: Validates the request parameters based on the specified modality. - **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> None`**: Validates the request parameters based on the specified modality.
- **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters. - **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters.
- **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request. - **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request.
@@ -57,6 +57,7 @@ Content-Type: application/json
"temperature": 0.7 "temperature": 0.7
} }
""" """
spec = LLMSpec.from_string(http_spec) spec = LLMSpec.from_string(http_spec)
response = await spec.probe("What is the capital of France?") response = await spec.probe("What is the capital of France?")
``` ```
@@ -75,6 +76,7 @@ Content-Type: application/json
"temperature": 0.7 "temperature": 0.7
} }
""" """
spec = LLMSpec.from_string(http_spec) spec = LLMSpec.from_string(http_spec)
encoded_image = encode_image_base64_by_url("https://example.com/image.jpg") encoded_image = encode_image_base64_by_url("https://example.com/image.jpg")
response = await spec.probe("What is in this image?", encoded_image=encoded_image) response = await spec.probe("What is in this image?", encoded_image=encoded_image)
@@ -94,6 +96,7 @@ Content-Type: application/json
"temperature": 0.7 "temperature": 0.7
} }
""" """
spec = LLMSpec.from_string(http_spec) spec = LLMSpec.from_string(http_spec)
encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3") encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3")
response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio) response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio)
@@ -113,6 +116,7 @@ Content-Type: multipart/form-data
"temperature": 0.7 "temperature": 0.7
} }
""" """
spec = LLMSpec.from_string(http_spec) spec = LLMSpec.from_string(http_spec)
files = {"file": ("document.txt", open("document.txt", "rb"))} files = {"file": ("document.txt", open("document.txt", "rb"))}
response = await spec.probe("Process this file:", files=files) response = await spec.probe("Process this file:", files=files)
Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.0 MiB

+1 -1
View File
@@ -21,4 +21,4 @@ Note: Please be aware that Agentic Security is designed as a safety scanner tool
## UI 🧙 ## UI 🧙
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741192668/final_aa9jhb.gif"> <img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
+14 -4
View File
@@ -54,15 +54,20 @@ The `probe_data` module is a core component of the Agentic Security project, res
- **Classes:** - **Classes:**
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies. - `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
- Methods: - Methods:
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt - `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts - `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> null`: Updates rewards - `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
- `RandomPromptSelector`: Basic random selection with history tracking. - `RandomPromptSelector`: Basic random selection with history tracking.
- Parameters: - Parameters:
- `prompts: list[str]`: List of available prompts - `prompts: list[str]`: List of available prompts
- `history_size: int = 3`: Size of history to prevent cycles - `history_size: int = 3`: Size of history to prevent cycles
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback. - `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
- Parameters: - Parameters:
- `prompts: list[str]`: List of available prompts - `prompts: list[str]`: List of available prompts
- `api_url: str`: URL of RL service - `api_url: str`: URL of RL service
@@ -70,7 +75,9 @@ The `probe_data` module is a core component of the Agentic Security project, res
- `history_size: int = 300`: Size of history - `history_size: int = 300`: Size of history
- `timeout: int = 5`: Request timeout - `timeout: int = 5`: Request timeout
- `run_id: str = ""`: Unique run identifier - `run_id: str = ""`: Unique run identifier
- `QLearningPromptSelector`: Local Q-learning implementation. - `QLearningPromptSelector`: Local Q-learning implementation.
- Parameters: - Parameters:
- `prompts: list[str]`: List of available prompts - `prompts: list[str]`: List of available prompts
- `learning_rate: float = 0.1`: Learning rate - `learning_rate: float = 0.1`: Learning rate
@@ -79,7 +86,9 @@ The `probe_data` module is a core component of the Agentic Security project, res
- `exploration_decay: float = 0.995`: Exploration decay rate - `exploration_decay: float = 0.995`: Exploration decay rate
- `min_exploration: float = 0.01`: Minimum exploration rate - `min_exploration: float = 0.01`: Minimum exploration rate
- `history_size: int = 300`: Size of history - `history_size: int = 300`: Size of history
- **Module**: Main class that uses CloudRLPromptSelector.
- `Module`: Main class that uses CloudRLPromptSelector.
- Parameters: - Parameters:
- `prompt_groups: list[str]`: Groups of prompts - `prompt_groups: list[str]`: Groups of prompts
- `tools_inbox: asyncio.Queue`: Queue for tool communication - `tools_inbox: asyncio.Queue`: Queue for tool communication
@@ -110,9 +119,10 @@ from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
prompts = ["What is AI?", "Explain machine learning"] prompts = ["What is AI?", "Explain machine learning"]
selector = QLearningPromptSelector(prompts) selector = QLearningPromptSelector(prompts)
current_prompt = "What is AI?" current_prompt = "What is AI?"
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=true) next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=true) selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
``` ```
## Conclusion ## Conclusion
+9 -13
View File
@@ -1,16 +1,14 @@
:root { :root {
--md-primary-fg-color: #2E4053; --md-primary-fg-color: #e92063;
/* Primary color changed to pinkish */ --md-primary-fg-color--light: #e92063;
--md-primary-fg-color--light: #E0A3B6; --md-primary-fg-color--dark: #e92063;
--md-primary-fg-color--dark: #1C3F74;
/* Dark variant changed to blue */
} }
/* Updated slate color scheme with new background */
/* Revert hue value to that of pre mkdocs-material v9.4.0 */
[data-md-color-scheme="slate"] { [data-md-color-scheme="slate"] {
--md-hue: 230; --md-hue: 230;
--md-default-bg-color: #1A1A1A; --md-default-bg-color: hsla(230, 15%, 21%, 1);
/* Background changed to dark gray */
} }
.hide { .hide {
@@ -26,15 +24,12 @@ img.index-header {
max-width: 500px; max-width: 500px;
} }
/* Updated custom colors */
.pydantic-pink { .pydantic-pink {
color: #E0A3B6; color: #FF007F;
/* Updated to match new theme */
} }
.team-blue { .team-blue {
color: #1C3F74; color: #0072CE;
/* Updated to match new theme */
} }
.secure-green { .secure-green {
@@ -72,6 +67,7 @@ img.index-header {
text-align: center; text-align: center;
} }
/* Hide the entire footer */ /* Hide the entire footer */
.md-footer { .md-footer {
display: none; display: none;
+1 -1
View File
@@ -89,7 +89,7 @@ theme:
name: Switch to light mode name: Switch to light mode
icon: icon:
repo: fontawesome/brands/github repo: fontawesome/brands/github
favicon: https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741195421/favicon_kuz6xr.png favicon: "https://res.cloudinary.com/dq0w2rtm9/image/upload/v1737555066/r17hrkre246doczwmvbv.png"
extra: extra:
generator: false generator: false
Generated
+296 -402
View File
File diff suppressed because it is too large Load Diff
+2 -3
View File
@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "agentic_security" name = "agentic_security"
version = "0.7.0" version = "0.5.1"
description = "Agentic LLM vulnerability scanner" description = "Agentic LLM vulnerability scanner"
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"] authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"] maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
@@ -54,7 +54,6 @@ pyfiglet = "^1.0.2"
termcolor = "^2.4.0" termcolor = "^2.4.0"
# garak = { version = "*", optional = true } # garak = { version = "*", optional = true }
pytest-xdist = "3.6.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
@@ -87,7 +86,7 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--durations=5 -m 'not slow' -n 3" addopts = "--durations=5 -m 'not slow'"
asyncio_mode = "auto" asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
markers = "slow: marks tests as slow" markers = "slow: marks tests as slow"
View File
+161
View File
@@ -0,0 +1,161 @@
import pytest
import asyncio
from fastapi import FastAPI
from asyncio import Queue, Event
from agentic_security.core.app import create_app, get_tools_inbox, get_stop_event, get_current_run, set_current_run
class TestApp:
"""Test suite for agentic_security.core.app module."""
def test_create_app(self):
"""Test that create_app returns a FastAPI instance."""
app = create_app()
assert isinstance(app, FastAPI)
@pytest.mark.asyncio
async def test_get_tools_inbox(self):
"""Test that get_tools_inbox returns the global Queue instance."""
queue1 = get_tools_inbox()
await queue1.put("test item")
queue2 = get_tools_inbox()
result = queue2.get_nowait()
assert result == "test item"
def test_get_stop_event(self):
"""Test that get_stop_event returns the global Event instance and is not set initially."""
event = get_stop_event()
assert isinstance(event, Event)
assert not event.is_set()
def test_current_run_initial(self):
"""Test that get_current_run returns the global current_run with default values initially."""
run = get_current_run()
# Default values should be empty strings
assert run["spec"] == ""
assert run["id"] == ""
def test_set_current_run(self):
"""Test that set_current_run correctly updates current_run."""
spec = "test run"
result = set_current_run(spec)
expected_id = hash(id(spec))
# Verify that spec is set correctly
assert result["spec"] == spec
assert result["id"] == expected_id
def test_current_run_after_set(self):
"""Test that get_current_run returns the updated current_run after set_current_run is called."""
spec = "another test run"
set_current_run(spec)
current = get_current_run()
assert current["spec"] == spec
assert current["id"] == hash(id(spec))
def test_tools_inbox_same_instance(self):
"""Test that get_tools_inbox returns the same Queue instance by default."""
queue1 = get_tools_inbox()
queue2 = get_tools_inbox()
assert queue1 is queue2
def test_stop_event_set(self):
"""Test that setting the stop event is reflected in subsequent calls."""
event = get_stop_event()
event.set() # set the global event
# Now, subsequent calls should return the same event which is set.
event2 = get_stop_event()
assert event2.is_set()
def test_set_current_run_with_none(self):
"""Test that set_current_run handles None as a valid input and updates current_run accordingly."""
result = set_current_run(None)
expected_id = hash(id(None))
assert result["spec"] is None
assert result["id"] == expected_id
def test_multiple_current_run_assignments(self):
"""Test multiple assignments to current_run to ensure it always updates correctly."""
first_spec = "first run"
result1 = set_current_run(first_spec)
expected_id1 = hash(id(first_spec))
assert result1["spec"] == first_spec
assert result1["id"] == expected_id1
second_spec = "second run"
result2 = set_current_run(second_spec)
expected_id2 = hash(id(second_spec))
assert result2["spec"] == second_spec
assert result2["id"] == expected_id2
current = get_current_run()
# The current_run should reflect the latest assignment.
assert current["spec"] == second_spec
assert current["id"] == expected_id2
@pytest.mark.asyncio
async def test_empty_tools_inbox_exception(self):
"""Test that calling get_nowait on an empty tools_inbox raises QueueEmpty."""
from asyncio import QueueEmpty
queue = get_tools_inbox()
# Clear any existing items in the queue
while True:
try:
queue.get_nowait()
except QueueEmpty:
break
with pytest.raises(QueueEmpty):
queue.get_nowait()
def test_set_current_run_with_dict(self):
"""Test that set_current_run correctly handles a dictionary input as spec."""
spec = {"key": "value"}
result = set_current_run(spec)
expected_id = hash(id(spec))
assert result["spec"] == spec
assert result["id"] == expected_id
@pytest.mark.asyncio
async def test_stop_event_wait(self):
"""Test that waiting on the stop event returns once the event is set."""
event = get_stop_event()
event.clear() # ensure event is not set
async def waiter():
await event.wait()
return True
waiter_task = asyncio.create_task(waiter())
# Wait a moment to ensure the waiter is pending
await asyncio.sleep(0.1)
assert not waiter_task.done()
event.set()
result = await waiter_task
assert result is True
def test_set_current_run_with_int(self):
"""Test that set_current_run handles an integer input as spec."""
spec = 12345
result = set_current_run(spec)
expected_id = hash(id(spec))
assert result["spec"] == spec
assert result["id"] == expected_id
def test_create_app_routes(self):
"""Test that create_app returns a FastAPI instance with default routes available."""
app = create_app()
paths = [route.path for route in app.routes]
# Check that the default OpenAPI route exists
assert "/openapi.json" in paths
@pytest.mark.asyncio
async def test_tools_inbox_async_put_get_order(self):
"""Test that tools_inbox preserves order when items are added and retrieved asynchronously."""
queue = get_tools_inbox()
# Clear any existing items in the queue
from asyncio import QueueEmpty
while True:
try:
queue.get_nowait()
except QueueEmpty:
break
items = ["first", "second", "third"]
for item in items:
await queue.put(item)
result_items = []
for _ in items:
result_items.append(await queue.get())
assert result_items == items
+341
View File
@@ -0,0 +1,341 @@
import pytest
import base64
import httpx
import asyncio
from agentic_security.http_spec import (
LLMSpec,
parse_http_spec,
escape_special_chars_for_json,
encode_image_base64_by_url,
encode_audio_base64_by_url,
InvalidHTTPSpecError,
Modality
)
################################################################################
# Tests for agentic_security/http_spec.py
################################################################################
def test_escape_special_chars_for_json():
"""Test escaping special characters in a prompt for JSON safety."""
prompt = 'Line1\nLine2\t"Quote"\\Backslash'
escaped = escape_special_chars_for_json(prompt)
assert '\\n' in escaped
assert '\\t' in escaped
assert '\\"' in escaped
assert '\\\\' in escaped
def test_parse_http_spec_text():
"""Test parsing a text HTTP spec without image/audio/files requirements."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nThis is a prompt: <<PROMPT>>"
llm_spec = parse_http_spec(spec)
assert llm_spec.method == "POST"
assert llm_spec.url == "http://example.com/api"
assert llm_spec.headers["Content-Type"] == "application/json"
assert "<<PROMPT>>" in llm_spec.body
assert not llm_spec.has_files
assert not llm_spec.has_image
assert not llm_spec.has_audio
def test_parse_http_spec_files():
"""Test parsing a HTTP spec with multipart/form-data header indicating files."""
spec = "PUT http://example.com/upload\nContent-Type: multipart/form-data\n\nFile upload test"
llm_spec = parse_http_spec(spec)
assert llm_spec.has_files
def test_parse_http_spec_image_audio():
"""Test parsing a HTTP spec that requires image and audio via placeholders."""
spec = "GET http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>> and Audio: <<BASE64_AUDIO>>"
llm_spec = parse_http_spec(spec)
assert llm_spec.has_image
assert llm_spec.has_audio
def test_encode_image_base64_by_url(monkeypatch):
"""Test that image encoding returns the correct base64 string with prefix."""
dummy_content = b'test_image'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_image_base64_by_url("http://dummyurl.com/image.jpg")
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
def test_encode_audio_base64_by_url(monkeypatch):
"""Test that audio encoding returns the correct base64 string with prefix."""
dummy_content = b'test_audio'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_audio_base64_by_url("http://dummyurl.com/audio.mp3")
expected = "data:audio/mpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
@pytest.mark.asyncio
async def test_probe_text(monkeypatch):
"""Test the probe function for text modality by replacing <<PROMPT>>."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, content, timeout):
return httpx.Response(200, text="ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Hello")
assert response.status_code == 200
assert "ok" in response.text
@pytest.mark.asyncio
async def test_probe_with_files(monkeypatch):
"""Test that probe correctly branches to _probe_with_files when files are provided."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
llm_spec = parse_http_spec(spec)
files = {"file": ("dummy.txt", b"data")}
async def dummy_request(self, method, url, headers, files, timeout):
return httpx.Response(200, text="file upload ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Unused", files=files)
assert response.status_code == 200
assert "file upload ok" in response.text
@pytest.mark.asyncio
async def test_verify_image(monkeypatch):
"""Test verify method branch for image modality by monkeypatching image encoder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"image\": \"<<BASE64_IMAGE>>\"}"
llm_spec = parse_http_spec(spec)
# Replace the image encoder to return a dummy string
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
async def dummy_request(self, method, url, headers, content, timeout):
# Check that the dummy image is injected in the content
assert "dummy_image" in content
return httpx.Response(200, text="image ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "image ok" in response.text
@pytest.mark.asyncio
async def test_verify_audio(monkeypatch):
"""Test verify method branch for audio modality by monkeypatching audio encoder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"audio\": \"<<BASE64_AUDIO>>\"}"
llm_spec = parse_http_spec(spec)
monkeypatch.setattr("agentic_security.http_spec.encode_audio_base64_by_url", lambda url: "dummy_audio")
async def dummy_request(self, method, url, headers, content, timeout):
# Ensure that the dummy audio string is present in the request content
assert "dummy_audio" in content
return httpx.Response(200, text="audio ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "audio ok" in response.text
@pytest.mark.asyncio
async def test_verify_files(monkeypatch):
"""Test verify method branch for files modality where _probe_with_files is invoked."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, files, timeout):
return httpx.Response(200, text="files ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.verify()
assert response.status_code == 200
assert "files ok" in response.text
def test_llm_spec_modality_property():
"""Test that the modality property reflects the correct modality."""
spec_text = "POST http://example.com/api\nContent-Type: application/json\n\nPrompt: <<PROMPT>>"
llm_spec_text = parse_http_spec(spec_text)
assert llm_spec_text.modality == Modality.TEXT
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec_image = parse_http_spec(spec_image)
assert llm_spec_image.modality == Modality.IMAGE
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec_audio = parse_http_spec(spec_audio)
assert llm_spec_audio.modality == Modality.AUDIO
def test_from_string_invalid():
"""Test that LLMSpec.from_string raises an error for an invalid spec."""
invalid_spec = "INVALID_SPEC"
with pytest.raises(InvalidHTTPSpecError):
LLMSpec.from_string(invalid_spec)
@pytest.mark.asyncio
async def test_validate_missing_files():
"""Test that LLMSpec.validate raises a ValueError when files are required but missing."""
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload test"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="Files are required"):
llm_spec.validate("test prompt", "", "", {})
@pytest.mark.asyncio
async def test_validate_missing_image():
"""Test that LLMSpec.validate raises a ValueError when an image is required but missing."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="An image is required"):
llm_spec.validate("test prompt", "", "dummy_audio", {})
@pytest.mark.asyncio
async def test_validate_missing_audio():
"""Test that LLMSpec.validate raises a ValueError when audio is required but missing."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec = parse_http_spec(spec)
with pytest.raises(ValueError, match="Audio is required"):
llm_spec.validate("test prompt", "dummy_image", "", {})
def test_fn_alias(monkeypatch):
"""Test that LLMSpec.fn is a functional alias for LLMSpec.probe."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
# Instead of overriding the instance method, verify the alias at the class level.
assert LLMSpec.fn is LLMSpec.probe
def test_escape_special_chars_no_special():
"""Test that the escape function returns the original string if no special characters are present."""
prompt = "Simple text without specials"
escaped = escape_special_chars_for_json(prompt)
assert escaped == "Simple text without specials"
@pytest.mark.asyncio
async def test_probe_text_with_special_chars(monkeypatch):
"""Test probe for text modality with special characters in prompt ensuring escaped content."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
captured = {}
async def dummy_request(self, method, url, headers, content, timeout):
captured['content'] = content
return httpx.Response(200, text="ok")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
test_prompt = 'Hello\nWorld\t"Test"'
response = await llm_spec.probe(test_prompt)
expected_escaped = escape_special_chars_for_json(test_prompt)
assert expected_escaped in captured['content']
assert response.status_code == 200
@pytest.mark.asyncio
async def test_verify_both_image_audio(monkeypatch):
"""Test verify method when both image and audio placeholders are present.
Expect a ValueError because only the image branch is triggered by pattern matching and the missing audio causes validation to fail."""
spec = ("POST http://example.com/api\nContent-Type: application/json\n\n"
"{\"audio\": \"<<BASE64_AUDIO>>\", \"image\":\"<<BASE64_IMAGE>>\"}")
llm_spec = parse_http_spec(spec)
# Monkey patch the image encoder to return a dummy value
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
with pytest.raises(ValueError, match="Audio is required"):
await llm_spec.verify()
def test_parse_http_spec_invalid_header_format():
"""Test that parse_http_spec raises an error when a header line doesn't have the expected 'key: value' format."""
invalid_spec = "GET http://example.com/api\nInvalidHeaderWithoutColon\n\nBody with <<PROMPT>>"
with pytest.raises(ValueError):
parse_http_spec(invalid_spec)
def test_from_string_valid():
"""Test that LLMSpec.from_string returns a valid LLMSpec object when given a proper spec string."""
spec = "GET http://example.com/api\nContent-Type: application/json\n\n{ \"prompt\": \"<<PROMPT>>\" }"
llm_spec = LLMSpec.from_string(spec)
assert llm_spec.method == "GET"
assert llm_spec.url == "http://example.com/api"
@pytest.mark.asyncio
async def test_parse_http_spec_multiline_body():
"""Test parsing an HTTP spec with a multiline body to ensure body concatenation works."""
spec = (
"PATCH http://example.com/api\n"
"Content-Type: application/json\n"
"\n"
"Line one of body\n"
"Line two of body\n"
"Line three"
)
llm_spec = parse_http_spec(spec)
# As implemented, the parser concatenates lines without newline delimiters
expected_body = "Line one of bodyLine two of bodyLine three"
assert llm_spec.body == expected_body
@pytest.mark.asyncio
async def test_encode_image_default_argument(monkeypatch):
"""Test that encode_image_base64_by_url works with its default URL argument."""
dummy_content = b'default_image'
class DummyResponse:
def __init__(self, content):
self.content = content
def dummy_get(url):
# check that the default URL (which includes 'fluidicon.png') is used
assert "fluidicon.png" in url
return DummyResponse(dummy_content)
monkeypatch.setattr(httpx, "get", dummy_get)
result = encode_image_base64_by_url()
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
assert result == expected
@pytest.mark.asyncio
async def test_probe_without_prompt_placeholder(monkeypatch):
"""Test the probe function when the request body does not include the <<PROMPT>> placeholder."""
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"message\": \"No placeholder here\"}"
llm_spec = parse_http_spec(spec)
captured = {}
async def dummy_request(self, method, url, headers, content, timeout):
captured['content'] = content
return httpx.Response(200, text="ok without placeholder")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
response = await llm_spec.probe("Ignored prompt")
assert "No placeholder here" in captured['content']
assert response.status_code == 200
def test_validate_success():
"""Test that LLMSpec.validate does not raise an error when all required data is provided."""
# Test case for files: files are provided as required
spec_files = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload"
llm_spec_files = parse_http_spec(spec_files)
llm_spec_files.validate("some prompt", "dummy_image", "dummy_audio", {"file": ("dummy.txt", b"data")})
# Test case for image: image is provided as required
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
llm_spec_image = parse_http_spec(spec_image)
llm_spec_image.validate("some prompt", "dummy_image", "dummy_audio", {})
# Test case for audio: audio is provided as required
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
llm_spec_audio = parse_http_spec(spec_audio)
llm_spec_audio.validate("some prompt", "dummy_image", "dummy_audio", {})
@pytest.mark.asyncio
async def test_probe_invalid_url(monkeypatch):
"""Test that probe raises an exception when the HTTP client fails due to an invalid URL."""
spec = "GET http://nonexistent_url/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
llm_spec = parse_http_spec(spec)
async def dummy_request(self, method, url, headers, content, timeout):
raise httpx.RequestError("Invalid URL")
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
with pytest.raises(httpx.RequestError):
await llm_spec.probe("Test")