mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-25 14:49:57 +02:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e752ebaeeb | |||
| 2549194bd1 | |||
| 4c580ea1b8 |
@@ -9,7 +9,7 @@ on:
|
|||||||
- 0.*
|
- 0.*
|
||||||
|
|
||||||
env:
|
env:
|
||||||
POETRY_VERSION: "1.8.5"
|
POETRY_VERSION: "1.7.1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
if_release:
|
if_release:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
env:
|
env:
|
||||||
POETRY_VERSION: "1.8.5"
|
POETRY_VERSION: "1.7.1"
|
||||||
OPENAI_API_KEY: "sk-fake"
|
OPENAI_API_KEY: "sk-fake"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
@@ -17,5 +17,3 @@ inv/
|
|||||||
scripts/
|
scripts/
|
||||||
docx/
|
docx/
|
||||||
agentic_security.toml
|
agentic_security.toml
|
||||||
/venv
|
|
||||||
*.csv
|
|
||||||
|
|||||||
+1
-8
@@ -1,5 +1,5 @@
|
|||||||
# Build stage
|
# Build stage
|
||||||
FROM python:3.11-slim AS builder
|
FROM python:3.11-slim as builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -14,15 +14,8 @@ RUN poetry self add "poetry-plugin-export"
|
|||||||
# Copy only dependency files to leverage Docker layer caching
|
# Copy only dependency files to leverage Docker layer caching
|
||||||
COPY pyproject.toml poetry.lock ./
|
COPY pyproject.toml poetry.lock ./
|
||||||
|
|
||||||
# update lock file to avoid failure
|
|
||||||
RUN poetry lock
|
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN poetry export -f requirements.txt --without-hashes -o requirements.txt
|
RUN poetry export -f requirements.txt --without-hashes -o requirements.txt
|
||||||
|
|
||||||
# Install wheel (required to build packages like fire)
|
|
||||||
RUN pip install --upgrade pip setuptools wheel
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Runtime stage
|
# Runtime stage
|
||||||
|
|||||||
@@ -21,7 +21,9 @@
|
|||||||
<a href="https://pypi.org/project/agentic-security/">
|
<a href="https://pypi.org/project/agentic-security/">
|
||||||
<img alt="PyPI Version" src="https://img.shields.io/pypi/v/agentic-security?style=for-the-badge&logo=pypi&labelColor=000000&color=00CCFF" />
|
<img alt="PyPI Version" src="https://img.shields.io/pypi/v/agentic-security?style=for-the-badge&logo=pypi&labelColor=000000&color=00CCFF" />
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://discord.gg/stw3DfZQ">
|
||||||
|
<img alt="Join Discord" src="https://img.shields.io/badge/Discord-Join%20Us-black?style=for-the-badge&logo=discord&labelColor=000000&color=DD55FF" />
|
||||||
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
@@ -81,7 +83,7 @@ agentic_security --port=PORT --host=HOST
|
|||||||
|
|
||||||
## UI 🧙
|
## UI 🧙
|
||||||
|
|
||||||
<img width="100%" alt="booking-screen" src="https://raw.githubusercontent.com/msoedov/agentic_security/refs/heads/main/docs/images/demo.gif">
|
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
|
||||||
|
|
||||||
## LLM kwargs
|
## LLM kwargs
|
||||||
|
|
||||||
@@ -400,31 +402,16 @@ This setup ensures a continuous integration approach towards maintaining securit
|
|||||||
|
|
||||||
The `Module` class is designed to manage prompt processing and interaction with external AI models and tools. It supports fetching, processing, and posting prompts asynchronously for model vulnerabilities. Check out [module.md](https://github.com/msoedov/agentic_security/blob/main/docs/module.md) for details.
|
The `Module` class is designed to manage prompt processing and interaction with external AI models and tools. It supports fetching, processing, and posting prompts asynchronously for model vulnerabilities. Check out [module.md](https://github.com/msoedov/agentic_security/blob/main/docs/module.md) for details.
|
||||||
|
|
||||||
|
|
||||||
## MCP server
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install -U mcp
|
|
||||||
|
|
||||||
# From cloned directory
|
|
||||||
mcp install agentic_security/mcp/main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
For more detailed information on how to use Agentic Security, including advanced features and customization options, please refer to the official documentation.
|
For more detailed information on how to use Agentic Security, including advanced features and customization options, please refer to the official documentation.
|
||||||
|
|
||||||
## Roadmap and Future Goals
|
## Roadmap and Future Goals
|
||||||
|
|
||||||
|
- \[ \] Expand dataset variety
|
||||||
|
- \[ \] Introduce two new attack vectors
|
||||||
We’re just getting started! Here’s what’s on the horizon:
|
- \[ \] Develop initial attacker LLM
|
||||||
|
- \[ \] Complete integration of OWASP Top 10 classification
|
||||||
- **RL-Powered Attacks**: An attacker LLM trained with reinforcement learning to dynamically evolve jailbreaks and outsmart defenses.
|
|
||||||
- **Massive Dataset Expansion**: Scaling to 100,000+ prompts across text, image, and audio modalities—curated for real-world threats.
|
|
||||||
- **Daily Attack Updates**: Fresh attack vectors delivered daily, keeping your scans ahead of the curve.
|
|
||||||
- **Community Modules**: A plug-and-play ecosystem where you can share and deploy custom probes, datasets, and integrations.
|
|
||||||
|
|
||||||
|
|
||||||
| Tool | Source | Integrated |
|
| Tool | Source | Integrated |
|
||||||
|-------------------------|-------------------------------------------------------------------------------|------------|
|
|-------------------------|-------------------------------------------------------------------------------|------------|
|
||||||
@@ -436,7 +423,6 @@ We’re just getting started! Here’s what’s on the horizon:
|
|||||||
|
|
||||||
Note: All dates are tentative and subject to change based on project progress and priorities.
|
Note: All dates are tentative and subject to change based on project progress and priorities.
|
||||||
|
|
||||||
|
|
||||||
## 👋 Contributing
|
## 👋 Contributing
|
||||||
|
|
||||||
Contributions to Agentic Security are welcome! If you'd like to contribute, please follow these steps:
|
Contributions to Agentic Security are welcome! If you'd like to contribute, please follow these steps:
|
||||||
@@ -453,9 +439,4 @@ Before contributing, please read the contributing guidelines.
|
|||||||
|
|
||||||
Agentic Security is released under the Apache License v2.
|
Agentic Security is released under the Apache License v2.
|
||||||
|
|
||||||
|
|
||||||
## 🚫 No Cryptocurrency Affiliation
|
|
||||||
|
|
||||||
Agentic Security is focused solely on AI security and has no affiliation with cryptocurrency projects, blockchain technologies, or related initiatives. Our mission is to advance the safety and reliability of AI systems—no tokens, no coins, just code.
|
|
||||||
|
|
||||||
## Contact us
|
## Contact us
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .lib import SecurityScanner
|
from .lib import AgenticSecurity
|
||||||
|
|
||||||
__all__ = ["SecurityScanner"]
|
__all__ = ["AgenticSecurity"]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import fire
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from agentic_security.app import app
|
from agentic_security.app import app
|
||||||
from agentic_security.lib import SecurityScanner
|
from agentic_security.lib import AgenticSecurity
|
||||||
from agentic_security.misc.banner import init_banner
|
from agentic_security.misc.banner import init_banner
|
||||||
|
|
||||||
|
|
||||||
@@ -32,14 +32,14 @@ class CLI:
|
|||||||
Run Agentic Security in CI mode.
|
Run Agentic Security in CI mode.
|
||||||
"""
|
"""
|
||||||
sys.path.append(os.path.dirname("."))
|
sys.path.append(os.path.dirname("."))
|
||||||
SecurityScanner().entrypoint()
|
AgenticSecurity().entrypoint()
|
||||||
|
|
||||||
def init(self, host: str = "0.0.0.0", port: int = 8718):
|
def init(self, host: str = "0.0.0.0", port: int = 8718):
|
||||||
"""
|
"""
|
||||||
Generate the default CI configuration file.
|
Generate the default CI configuration file.
|
||||||
"""
|
"""
|
||||||
sys.path.append(os.path.dirname("."))
|
sys.path.append(os.path.dirname("."))
|
||||||
SecurityScanner().generate_default_settings(host, port)
|
AgenticSecurity().generate_default_settings(host, port)
|
||||||
|
|
||||||
i = init
|
i = init
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class CLI:
|
|||||||
List all available security checks.
|
List all available security checks.
|
||||||
"""
|
"""
|
||||||
sys.path.append(os.path.dirname("."))
|
sys.path.append(os.path.dirname("."))
|
||||||
SecurityScanner().list_checks()
|
AgenticSecurity().list_checks()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -9,11 +10,12 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
|
|
||||||
# Assuming LLMSpec is defined elsewhere; placeholder import
|
# Assuming LLMSpec is defined elsewhere; placeholder import
|
||||||
from agentic_security.http_spec import LLMSpec
|
from agentic_security.http_spec import LLMSpec
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
LLM_SPECS = [] # Populate with LLM spec strings if needed
|
LLM_SPECS = [] # Populate with LLM spec strings if needed
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Define AgentSpecification model
|
# Define AgentSpecification model
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -7,10 +8,13 @@ from pydantic_ai import Agent, RunContext, Tool
|
|||||||
|
|
||||||
# Assuming LLMSpec is defined elsewhere; placeholder import
|
# Assuming LLMSpec is defined elsewhere; placeholder import
|
||||||
from agentic_security.http_spec import LLMSpec
|
from agentic_security.http_spec import LLMSpec
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
LLM_SPECS = [] # Populate this list with LLM spec strings if needed
|
LLM_SPECS = [] # Populate this list with LLM spec strings if needed
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Define AgentSpecification model
|
# Define AgentSpecification model
|
||||||
class AgentSpecification(BaseModel):
|
class AgentSpecification(BaseModel):
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import tomli
|
import tomli
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
SETTINGS_VERSION = 1
|
||||||
|
|
||||||
SETTINGS_VERSION = 2
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
@@ -143,13 +142,6 @@ use_disk_cache = false
|
|||||||
retry = 3
|
retry = 3
|
||||||
timeout_connect = 30
|
timeout_connect = 30
|
||||||
timeout_response = 90
|
timeout_response = 90
|
||||||
|
|
||||||
[fuzzer]
|
|
||||||
max_prompt_lenght = 2048
|
|
||||||
budget_multiplier = 100000000
|
|
||||||
initial_optimizer_points = 25
|
|
||||||
min_failure_samples = 5
|
|
||||||
failure_rate_threshold = 0.5
|
|
||||||
""".replace(
|
""".replace(
|
||||||
"$HOST", host
|
"$HOST", host
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,14 +4,10 @@ from asyncio import Event, Queue
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
|
|
||||||
from agentic_security.http_spec import LLMSpec
|
|
||||||
|
|
||||||
tools_inbox: Queue = Queue()
|
tools_inbox: Queue = Queue()
|
||||||
stop_event: Event = Event()
|
stop_event: Event = Event()
|
||||||
current_run: str = {"spec": "", "id": ""}
|
current_run: str = {"spec": "", "id": ""}
|
||||||
_secrets: dict[str, str] = {}
|
_secrets = {}
|
||||||
|
|
||||||
current_run: dict[str, int | LLMSpec] = {"spec": "", "id": ""}
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
@@ -30,29 +26,29 @@ def get_stop_event() -> Event:
|
|||||||
return stop_event
|
return stop_event
|
||||||
|
|
||||||
|
|
||||||
def get_current_run() -> dict[str, int | LLMSpec]:
|
def get_current_run() -> str:
|
||||||
"""Get the current run id."""
|
"""Get the current run id."""
|
||||||
return current_run
|
return current_run
|
||||||
|
|
||||||
|
|
||||||
def set_current_run(spec: LLMSpec) -> dict[str, int | LLMSpec]:
|
def set_current_run(spec):
|
||||||
"""Set the current run id."""
|
"""Set the current run id."""
|
||||||
current_run["id"] = hash(id(spec))
|
current_run["id"] = hash(id(spec))
|
||||||
current_run["spec"] = spec
|
current_run["spec"] = spec
|
||||||
return current_run
|
return current_run
|
||||||
|
|
||||||
|
|
||||||
def get_secrets() -> dict[str, str]:
|
def get_secrets():
|
||||||
return _secrets
|
return _secrets
|
||||||
|
|
||||||
|
|
||||||
def set_secrets(secrets: dict[str, str]) -> dict[str, str]:
|
def set_secrets(secrets):
|
||||||
_secrets.update(secrets)
|
_secrets.update(secrets)
|
||||||
expand_secrets(_secrets)
|
expand_secrets(_secrets)
|
||||||
return _secrets
|
return _secrets
|
||||||
|
|
||||||
|
|
||||||
def expand_secrets(secrets: dict[str, str]) -> None:
|
def expand_secrets(secrets):
|
||||||
for key in secrets:
|
for key in secrets:
|
||||||
val = secrets[key]
|
val = secrets[key]
|
||||||
if val.startswith("$"):
|
if val.startswith("$"):
|
||||||
|
|||||||
@@ -1,5 +1,26 @@
|
|||||||
from agentic_security.logutils import set_log_level_to_info
|
from logging import config
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
return set_log_level_to_info()
|
config.dictConfig(
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": True,
|
||||||
|
"handlers": {
|
||||||
|
"console": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"handlers": ["console"],
|
||||||
|
"level": "INFO",
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"uvicorn.access": {
|
||||||
|
"level": "ERROR", # Set higher log level to suppress info logs globally
|
||||||
|
"handlers": ["console"],
|
||||||
|
"propagate": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -160,14 +159,6 @@ def parse_http_spec(http_spec: str) -> LLMSpec:
|
|||||||
# Extract the method and URL from the first line
|
# Extract the method and URL from the first line
|
||||||
method, url = lines[0].split(" ")[0:2]
|
method, url = lines[0].split(" ")[0:2]
|
||||||
|
|
||||||
# Check url validity
|
|
||||||
valid_url = urlparse(url)
|
|
||||||
# if missing the correct formatting ://, urlparse.netloc will be empty
|
|
||||||
if valid_url.scheme not in ("http", "https") or not valid_url.netloc:
|
|
||||||
raise InvalidHTTPSpecError(
|
|
||||||
f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize headers and body
|
# Initialize headers and body
|
||||||
headers = {}
|
headers = {}
|
||||||
body = ""
|
body = ""
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from datetime import datetime
|
|||||||
|
|
||||||
import colorama
|
import colorama
|
||||||
import tqdm.asyncio
|
import tqdm.asyncio
|
||||||
|
from loguru import logger
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from agentic_security.config import SettingsMixin # Importing the configuration mixin
|
from agentic_security.config import SettingsMixin # Importing the configuration mixin
|
||||||
from agentic_security.logutils import logger
|
|
||||||
from agentic_security.primitives import Scan
|
from agentic_security.primitives import Scan
|
||||||
from agentic_security.probe_data import REGISTRY
|
from agentic_security.probe_data import REGISTRY
|
||||||
from agentic_security.routes.scan import streaming_response_generator
|
from agentic_security.routes.scan import streaming_response_generator
|
||||||
@@ -23,7 +23,7 @@ YELLOW = colorama.Fore.YELLOW
|
|||||||
BLUE = colorama.Fore.BLUE
|
BLUE = colorama.Fore.BLUE
|
||||||
|
|
||||||
|
|
||||||
class SecurityScanner(SettingsMixin):
|
class AgenticSecurity(SettingsMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def async_scan(
|
async def async_scan(
|
||||||
cls,
|
cls,
|
||||||
@@ -202,7 +202,7 @@ class SecurityScanner(SettingsMixin):
|
|||||||
self.load_config(self.default_path)
|
self.load_config(self.default_path)
|
||||||
logger.info("Configuration loaded successfully.")
|
logger.info("Configuration loaded successfully.")
|
||||||
print(self.config)
|
print(self.config)
|
||||||
datasets = list(self.get_config_value("modules", {}).values())
|
datasets = list(self.get_config_value("modules").values())
|
||||||
for d in datasets:
|
for d in datasets:
|
||||||
d["selected"] = True
|
d["selected"] = True
|
||||||
self.scan(
|
self.scan(
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
# import sys
|
|
||||||
|
|
||||||
# from loguru import logger
|
|
||||||
|
|
||||||
# # Define custom colors
|
|
||||||
# BLUE = "#89CFF0"
|
|
||||||
# BROWN = "#8B4513" # Brown for DEBUG
|
|
||||||
|
|
||||||
# # Define custom log level colors
|
|
||||||
# logger.level("DEBUG", color=f"<fg {BROWN}>")
|
|
||||||
# logger.level("INFO", color=f"<fg {BLUE}>")
|
|
||||||
|
|
||||||
# # Define custom log format with aligned messages and colored levels
|
|
||||||
# LOG_FORMAT = (
|
|
||||||
# "<level>{level:<8}</level> " # Properly formatted and colored log level
|
|
||||||
# "<level>{message:<100}</level> " # Left-aligned message for readability
|
|
||||||
# "<cyan>{file.name}</cyan>:<cyan>{line}</cyan>" # File name and line number in cyan
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Remove default handlers and add a new one with custom formatting
|
|
||||||
# logger.remove()
|
|
||||||
# logger.add(sys.stdout, format=LOG_FORMAT, level="DEBUG", colorize=True)
|
|
||||||
import logging
|
|
||||||
import logging.config
|
|
||||||
import time
|
|
||||||
from collections.abc import Callable, Coroutine
|
|
||||||
from functools import wraps
|
|
||||||
from os import getenv
|
|
||||||
from typing import Any, ParamSpec, TypeVar
|
|
||||||
|
|
||||||
LOGGER_NAME = None
|
|
||||||
|
|
||||||
LOGGING_CONFIG = {
|
|
||||||
"version": 1,
|
|
||||||
"disable_existing_loggers": False,
|
|
||||||
"formatters": {
|
|
||||||
"rich": {"format": "%(message)s", "datefmt": "[%X]"},
|
|
||||||
},
|
|
||||||
"handlers": {
|
|
||||||
"rich": {
|
|
||||||
"class": "rich.logging.RichHandler",
|
|
||||||
"level": "INFO",
|
|
||||||
"formatter": "rich",
|
|
||||||
"show_time": False,
|
|
||||||
"rich_tracebacks": False,
|
|
||||||
"show_path": lambda: True if getenv("API_RUNTIME") == "dev" else False,
|
|
||||||
"tracebacks_show_locals": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"loggers": {
|
|
||||||
"": { # Root logger configuration
|
|
||||||
"level": "INFO",
|
|
||||||
"handlers": ["rich"],
|
|
||||||
"propagate": True,
|
|
||||||
},
|
|
||||||
"httpx": { # Disable httpx logging
|
|
||||||
"level": "WARNING", # Suppress DEBUG and INFO messages from httpx
|
|
||||||
"handlers": [],
|
|
||||||
"propagate": False,
|
|
||||||
},
|
|
||||||
"uvicorn.access": { # Disable uvicorn.access logging
|
|
||||||
"level": "WARNING", # Suppress DEBUG and INFO messages from uvicorn.access
|
|
||||||
"handlers": [],
|
|
||||||
"propagate": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def configure_logging():
|
|
||||||
# Apply the dictionary configuration
|
|
||||||
logging.config.dictConfig(LOGGING_CONFIG)
|
|
||||||
|
|
||||||
# Get and return the logger
|
|
||||||
logger = logging.getLogger(LOGGER_NAME)
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = configure_logging()
|
|
||||||
|
|
||||||
|
|
||||||
def set_log_level_to_debug():
|
|
||||||
logger = logging.getLogger(LOGGER_NAME)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
# Update handler level as well
|
|
||||||
for handler in logger.handlers:
|
|
||||||
handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
|
|
||||||
def set_log_level_to_info():
|
|
||||||
logger = logging.getLogger(LOGGER_NAME)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
# Update handler level as well
|
|
||||||
for handler in logger.handlers:
|
|
||||||
handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
# Set initial log level
|
|
||||||
set_log_level_to_info()
|
|
||||||
|
|
||||||
|
|
||||||
# Define generic type variables for return type and parameters
|
|
||||||
R = TypeVar("R")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
def time_execution_sync(
|
|
||||||
additional_text: str = "",
|
|
||||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
||||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
start_time = time.time()
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
execution_time = time.time() - start_time
|
|
||||||
logger.debug(
|
|
||||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def time_execution_async(
|
|
||||||
additional_text: str = "",
|
|
||||||
) -> Callable[
|
|
||||||
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
|
|
||||||
]:
|
|
||||||
def decorator(
|
|
||||||
func: Callable[P, Coroutine[Any, Any, R]]
|
|
||||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
start_time = time.time()
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
execution_time = time.time() - start_time
|
|
||||||
logger.debug(
|
|
||||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from mcp import ClientSession, StdioServerParameters
|
|
||||||
from mcp.client.stdio import stdio_client
|
|
||||||
|
|
||||||
# Create server parameters for stdio connection
|
|
||||||
server_params = StdioServerParameters(
|
|
||||||
command="python", # Executable
|
|
||||||
args=["agentic_security/mcp/main.py"], # Your server script
|
|
||||||
env=None, # Optional environment variables
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run() -> None:
|
|
||||||
async with stdio_client(server_params) as (read, write):
|
|
||||||
async with ClientSession(read, write) as session:
|
|
||||||
# Initialize the connection --> connection does not work
|
|
||||||
await session.initialize()
|
|
||||||
|
|
||||||
# List available prompts, resources, and tools --> no avalialbe tools
|
|
||||||
prompts = await session.list_prompts()
|
|
||||||
print(f"Available prompts: {prompts}")
|
|
||||||
|
|
||||||
resources = await session.list_resources()
|
|
||||||
print(f"Available resources: {resources}")
|
|
||||||
|
|
||||||
tools = await session.list_tools()
|
|
||||||
print(f"Available tools: {tools}")
|
|
||||||
|
|
||||||
# Call the echo tool --> echo tool iisue
|
|
||||||
echo_result = await session.call_tool(
|
|
||||||
"echo_tool", arguments={"message": "Hello from client!"}
|
|
||||||
)
|
|
||||||
print(f"Tool result: {echo_result}")
|
|
||||||
|
|
||||||
# # Read the echo resource
|
|
||||||
# echo_content, mime_type = await session.read_resource(
|
|
||||||
# "echo://Hello_resource"
|
|
||||||
# )
|
|
||||||
# print(f"Resource content: {echo_content}")
|
|
||||||
# print(f"Resource MIME type: {mime_type}")
|
|
||||||
|
|
||||||
# # Get and use the echo prompt
|
|
||||||
# prompt_result = await session.get_prompt(
|
|
||||||
# "echo_prompt", arguments={"message": "Hello prompt!"}
|
|
||||||
# )
|
|
||||||
# print(f"Prompt result: {prompt_result}")
|
|
||||||
|
|
||||||
# You can perform additional operations here as needed
|
|
||||||
return prompts, resources, tools
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(run())
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
import httpx
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
|
||||||
|
|
||||||
# Initialize MCP server
|
|
||||||
mcp = FastMCP(
|
|
||||||
name="Agentic Security MCP Server",
|
|
||||||
description="MCP server to interact with LLM scanning test",
|
|
||||||
dependencies=["httpx"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# FastAPI Server Configuration
|
|
||||||
AGENTIC_SECURITY = "http://0.0.0.0:8718"
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def verify_llm(spec: str) -> dict:
|
|
||||||
"""
|
|
||||||
Verify an LLM model specification using the FastAPI server
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: containing the verification result form the FastAPI server
|
|
||||||
|
|
||||||
Args: spect(str): The specification of the LLM model to verify.
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = f"{AGENTIC_SECURITY}/verify"
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(url, json={"spec": spec})
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def start_scan(
|
|
||||||
llmSpec: str,
|
|
||||||
maxBudget: int,
|
|
||||||
optimize: bool = False,
|
|
||||||
enableMultiStepAttack: bool = False,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Start an LLM security scan via the FastAPI server.
|
|
||||||
Returns:
|
|
||||||
dict: The scan initiation result from the FastAPI server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llmSpec (str): The specification of the LLM model.
|
|
||||||
maxBudget (int): The maximum budget for the scan.
|
|
||||||
optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False.
|
|
||||||
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = f"{AGENTIC_SECURITY}/scan"
|
|
||||||
payload = {
|
|
||||||
"llmSpec": llmSpec,
|
|
||||||
"maxBudget": maxBudget,
|
|
||||||
"datasets": [],
|
|
||||||
"optimize": optimize,
|
|
||||||
"enableMultiStepAttack": enableMultiStepAttack,
|
|
||||||
"probe_datasets": [],
|
|
||||||
"secrets": {},
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(url, json=payload)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def stop_scan() -> dict:
|
|
||||||
"""Stop an ongoing scan via the FastAPI server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The confirmation from the FastAPI server that the scan has been stopped.
|
|
||||||
"""
|
|
||||||
url = f"{AGENTIC_SECURITY}/stop"
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(url)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_data_config() -> list:
|
|
||||||
"""
|
|
||||||
Retrieve data configuration from the FastAPI server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The response from the FastAPI server, confirming the scan has been stopped.
|
|
||||||
"""
|
|
||||||
url = f"{AGENTIC_SECURITY}/v1/data-config"
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(url)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_spec_templates() -> list:
|
|
||||||
"""
|
|
||||||
Retrieve data configuration from the FastAPI server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The LLM specification templates from the FastAPI server.
|
|
||||||
"""
|
|
||||||
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(url)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
# Run the MCP server
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mcp.run()
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from loguru import logger
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
|
|
||||||
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
|
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def generate_banner(
|
def generate_banner(
|
||||||
title: str = "Agentic Security",
|
title="Agentic Security",
|
||||||
font: str = "slant",
|
font="slant",
|
||||||
version: str = "v2.1.0",
|
version="v2.1.0",
|
||||||
tagline: str = "Proactive Threat Detection & Automated Security Protocols",
|
tagline="Proactive Threat Detection & Automated Security Protocols",
|
||||||
author: str = "Developed by: [Security Team]",
|
author="Developed by: [Security Team]",
|
||||||
website: str = "Website: https://github.com/msoedov/agentic_security",
|
website="Website: https://github.com/msoedov/agentic_security",
|
||||||
warning: str | None = "", # Using Optional for warning since it might be None
|
warning="",
|
||||||
) -> str:
|
):
|
||||||
"""Generate a visually enhanced banner with dynamic width and borders."""
|
"""Generate a visually enhanced banner with dynamic width and borders."""
|
||||||
# Define the text elements
|
# Define the text elements
|
||||||
|
|
||||||
@@ -80,13 +80,8 @@ def generate_banner(
|
|||||||
|
|
||||||
|
|
||||||
def init_banner():
|
def init_banner():
|
||||||
return
|
|
||||||
ver = version("agentic_security")
|
ver = version("agentic_security")
|
||||||
try:
|
|
||||||
print(generate_banner(version=ver))
|
print(generate_banner(version=ver))
|
||||||
except Exception:
|
|
||||||
# UnicodeEncodeError with codec on some systems
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# noqa
|
|
||||||
from agentic_security.primitives.models import CompletionRequest # noqa
|
|
||||||
from agentic_security.primitives.models import ( # noqa
|
from agentic_security.primitives.models import ( # noqa
|
||||||
|
CompletionRequest,
|
||||||
FileProbeResponse,
|
FileProbeResponse,
|
||||||
LLMInfo,
|
LLMInfo,
|
||||||
Message,
|
Message,
|
||||||
|
|||||||
@@ -3,44 +3,32 @@ import random
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import pandas as pd
|
||||||
|
from loguru import logger
|
||||||
from skopt import Optimizer
|
from skopt import Optimizer
|
||||||
from skopt.space import Real
|
from skopt.space import Real
|
||||||
|
|
||||||
from agentic_security.config import settings_var
|
|
||||||
from agentic_security.http_spec import Modality
|
from agentic_security.http_spec import Modality
|
||||||
from agentic_security.logutils import logger
|
|
||||||
from agentic_security.primitives import Scan, ScanResult
|
from agentic_security.primitives import Scan, ScanResult
|
||||||
from agentic_security.probe_actor.cost_module import calculate_cost
|
from agentic_security.probe_actor.cost_module import calculate_cost
|
||||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||||
from agentic_security.probe_actor.state import FuzzerState
|
|
||||||
from agentic_security.probe_data import audio_generator, image_generator, msj_data
|
from agentic_security.probe_data import audio_generator, image_generator, msj_data
|
||||||
from agentic_security.probe_data.data import prepare_prompts
|
from agentic_security.probe_data.data import prepare_prompts
|
||||||
|
|
||||||
MAX_PROMPT_LENGTH = settings_var("fuzzer.max_prompt_lenght", 2048)
|
# TODO: full log file
|
||||||
BUDGET_MULTIPLIER = settings_var("fuzzer.budget_multiplier", 100000000)
|
|
||||||
INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25)
|
MAX_PROMPT_LENGTH = 2048
|
||||||
MIN_FAILURE_SAMPLES = settings_var("min_failure_samples", 5)
|
BUDGET_MULTIPLIER = 100_000_000
|
||||||
FAILURE_RATE_THRESHOLD = settings_var("failure_rate_threshold", 0.5)
|
INITIAL_OPTIMIZER_POINTS = 25
|
||||||
|
MIN_FAILURE_SAMPLES = 5
|
||||||
|
FAILURE_RATE_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
|
||||||
async def generate_prompts(
|
async def generate_prompts(
|
||||||
prompts: list[str] | AsyncGenerator,
|
prompts: list[str] | AsyncGenerator,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
|
||||||
Asynchronously generates and yields individual prompts.
|
|
||||||
|
|
||||||
If the input is a list of strings, the function sequentially yields each string.
|
|
||||||
If the input is an asynchronous generator, it forwards each generated prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: An individual prompt from the list or the asynchronous generator.
|
|
||||||
"""
|
|
||||||
if isinstance(prompts, list):
|
if isinstance(prompts, list):
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
yield prompt
|
yield prompt
|
||||||
@@ -49,21 +37,7 @@ async def generate_prompts(
|
|||||||
yield prompt
|
yield prompt
|
||||||
|
|
||||||
|
|
||||||
def get_modality_adapter(llm_spec):
|
def multi_modality_spec(llm_spec):
|
||||||
"""
|
|
||||||
Returns the appropriate request adapter based on the modality of the LLM specification.
|
|
||||||
|
|
||||||
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
|
|
||||||
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
|
|
||||||
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_spec: An object containing modality information for the LLM.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RequestAdapter | llm_spec: An instance of the appropriate request adapter
|
|
||||||
or the original `llm_spec` if no adaptation is needed.
|
|
||||||
"""
|
|
||||||
match llm_spec.modality:
|
match llm_spec.modality:
|
||||||
case Modality.IMAGE:
|
case Modality.IMAGE:
|
||||||
return image_generator.RequestAdapter(llm_spec)
|
return image_generator.RequestAdapter(llm_spec)
|
||||||
@@ -76,71 +50,40 @@ def get_modality_adapter(llm_spec):
|
|||||||
|
|
||||||
|
|
||||||
async def process_prompt(
|
async def process_prompt(
|
||||||
request_factory,
|
request_factory, prompt, tokens, module_name, refusals, errors, outputs
|
||||||
prompt: str,
|
|
||||||
tokens: int,
|
|
||||||
module_name: str,
|
|
||||||
fuzzer_state: FuzzerState,
|
|
||||||
) -> tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Processes a single prompt using the provided request factory and updates tracking lists.
|
Process a single prompt and update the token count and failure status.
|
||||||
|
|
||||||
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
|
|
||||||
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
|
|
||||||
the response indicates a refusal, the function records the issue and returns the updated token count
|
|
||||||
along with a boolean indicating whether the prompt was refused.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_factory: An object with a `fn` method used to send the prompt.
|
|
||||||
prompt (str): The input prompt to be processed.
|
|
||||||
tokens (int): The current token count, which will be updated.
|
|
||||||
module_name (str): The name of the module handling the request.
|
|
||||||
fuzzer_state: State tracking object for the fuzzer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await request_factory.fn(prompt=prompt)
|
response = await request_factory.fn(prompt=prompt)
|
||||||
|
|
||||||
# Handle HTTP errors
|
|
||||||
if response.status_code == 422:
|
if response.status_code == 422:
|
||||||
logger.error(f"Invalid prompt: {prompt}, error=422")
|
logger.error(f"Invalid prompt: {prompt}, error=422")
|
||||||
fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt")
|
errors.append((module_name, prompt, 422, "Invalid prompt"))
|
||||||
return tokens, True
|
return tokens, True
|
||||||
|
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 400:
|
||||||
logger.error(f"HTTP {response.status_code} {response.content=}")
|
logger.error(f"HTTP {response.status_code} {response.content=}")
|
||||||
fuzzer_state.add_error(
|
errors.append((module_name, prompt, response.status_code, response.text))
|
||||||
module_name, prompt, response.status_code, response.text
|
|
||||||
)
|
|
||||||
return tokens, True
|
return tokens, True
|
||||||
|
|
||||||
# Process successful response
|
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
tokens += len(response_text.split())
|
tokens += len(response_text.split())
|
||||||
|
|
||||||
# Check if the response indicates a refusal
|
|
||||||
refused = refusal_heuristic(response.json())
|
refused = refusal_heuristic(response.json())
|
||||||
if refused:
|
if refused:
|
||||||
fuzzer_state.add_refusal(
|
refusals.append((module_name, prompt, response.status_code, response_text))
|
||||||
module_name, prompt, response.status_code, response_text
|
|
||||||
)
|
|
||||||
|
|
||||||
fuzzer_state.add_output(module_name, prompt, response_text, refused)
|
outputs.append((module_name, prompt, response_text, refused))
|
||||||
return tokens, refused
|
return tokens, refused
|
||||||
|
|
||||||
except httpx.RequestError as exc:
|
except httpx.RequestError as exc:
|
||||||
logger.error(f"Request error: {exc}")
|
logger.error(f"Request error: {exc}")
|
||||||
fuzzer_state.add_error(module_name, prompt, "?", str(exc))
|
errors.append((module_name, prompt, "?", str(exc)))
|
||||||
return tokens, True
|
return tokens, True
|
||||||
except JSONDecodeError as json_decode_error:
|
except JSONDecodeError as json_decode_error:
|
||||||
logger.error(f"JSON error: {json_decode_error}")
|
logger.error(f"Jason error: {json_decode_error}")
|
||||||
fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error))
|
errors.append((module_name, prompt, "?", str(json_decode_error)))
|
||||||
return tokens, True
|
return tokens, True
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Unexpected error: {e}")
|
|
||||||
return tokens, False
|
|
||||||
|
|
||||||
|
|
||||||
async def process_prompt_batch(
|
async def process_prompt_batch(
|
||||||
@@ -148,29 +91,14 @@ async def process_prompt_batch(
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
tokens: int,
|
tokens: int,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
fuzzer_state: FuzzerState,
|
refusals,
|
||||||
|
errors,
|
||||||
|
outputs,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
|
||||||
Processes a batch of prompts asynchronously and aggregates the results.
|
|
||||||
|
|
||||||
This function sends multiple prompts concurrently using `process_prompt`,
|
|
||||||
collects the token count and failure status for each prompt, and returns
|
|
||||||
the total number of tokens processed and the number of failed prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_factory: An object with a `fn` method used to send the prompts.
|
|
||||||
prompts (list[str]): A list of input prompts to be processed.
|
|
||||||
tokens (int): The initial token count, which will be updated.
|
|
||||||
module_name (str): The name of the module handling the request.
|
|
||||||
fuzzer_state: State tracking object for the fuzzer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[int, int]:
|
|
||||||
- Total number of tokens processed.
|
|
||||||
- Number of failed prompts.
|
|
||||||
"""
|
|
||||||
tasks = [
|
tasks = [
|
||||||
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
|
process_prompt(
|
||||||
|
request_factory, p, tokens, module_name, refusals, errors, outputs
|
||||||
|
)
|
||||||
for p in prompts
|
for p in prompts
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
@@ -179,61 +107,53 @@ async def process_prompt_batch(
|
|||||||
return total_tokens, failures
|
return total_tokens, failures
|
||||||
|
|
||||||
|
|
||||||
async def scan_module(
|
async def perform_single_shot_scan(
|
||||||
request_factory,
|
request_factory,
|
||||||
module,
|
max_budget: int,
|
||||||
fuzzer_state: FuzzerState,
|
datasets: list[dict[str, str]] = [],
|
||||||
processed_prompts: int = 0,
|
tools_inbox=None,
|
||||||
total_prompts: int = 0,
|
optimize=False,
|
||||||
max_budget: int = 0,
|
stop_event: asyncio.Event = None,
|
||||||
total_tokens: int = 0,
|
secrets: dict[str, str] = {},
|
||||||
optimize: bool = False,
|
) -> AsyncGenerator[str, None]:
|
||||||
stop_event: asyncio.Event | None = None,
|
"""Perform a standard security scan."""
|
||||||
) -> AsyncGenerator[dict[str, Any], None]:
|
max_budget = max_budget * BUDGET_MULTIPLIER
|
||||||
"""
|
selected_datasets = [m for m in datasets if m["selected"]]
|
||||||
Scan a single module.
|
request_factory = multi_modality_spec(request_factory)
|
||||||
|
try:
|
||||||
Args:
|
yield ScanResult.status_msg("Loading datasets...")
|
||||||
request_factory: The factory for creating requests
|
prompt_modules = prepare_prompts(
|
||||||
module: The prompt module to scan
|
dataset_names=[m["dataset_name"] for m in selected_datasets],
|
||||||
fuzzer_state: State tracking object for the fuzzer
|
budget=max_budget,
|
||||||
processed_prompts: Number of prompts processed so far
|
tools_inbox=tools_inbox,
|
||||||
total_prompts: Total number of prompts to process
|
options=[m.get("opts", {}) for m in selected_datasets],
|
||||||
max_budget: Maximum token budget
|
|
||||||
total_tokens: Current token count
|
|
||||||
optimize: Whether to use optimization
|
|
||||||
stop_event: Event to stop scanning
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
ScanResult objects as the scan progresses
|
|
||||||
"""
|
|
||||||
tokens = 0
|
|
||||||
module_failures = 0
|
|
||||||
module_prompts = 0
|
|
||||||
failure_rates = []
|
|
||||||
should_stop = False
|
|
||||||
|
|
||||||
# Initialize optimizer if optimization is enabled
|
|
||||||
optimizer = (
|
|
||||||
Optimizer(
|
|
||||||
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
|
|
||||||
)
|
)
|
||||||
|
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
refusals = []
|
||||||
|
outputs = []
|
||||||
|
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||||
|
processed_prompts = 0
|
||||||
|
|
||||||
|
optimizer = (
|
||||||
|
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||||
if optimize
|
if optimize
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
failure_rates = []
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
tokens = 0
|
||||||
|
should_stop = False
|
||||||
|
for module in prompt_modules:
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
tokens = 0
|
||||||
|
module_failures = 0
|
||||||
module_size = 0 if module.lazy else len(module.prompts)
|
module_size = 0 if module.lazy else len(module.prompts)
|
||||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||||
yield ScanResult(
|
module_prompts = 0 # Reset for each module
|
||||||
module=module.dataset_name,
|
|
||||||
tokens=0,
|
|
||||||
cost=0,
|
|
||||||
progress=0,
|
|
||||||
failureRate=0,
|
|
||||||
prompt="",
|
|
||||||
latency=0,
|
|
||||||
model="",
|
|
||||||
).model_dump_json()
|
|
||||||
|
|
||||||
async for prompt in generate_prompts(module.prompts):
|
async for prompt in generate_prompts(module.prompts):
|
||||||
if stop_event and stop_event.is_set():
|
if stop_event and stop_event.is_set():
|
||||||
@@ -243,34 +163,37 @@ async def scan_module(
|
|||||||
return
|
return
|
||||||
|
|
||||||
processed_prompts += 1
|
processed_prompts += 1
|
||||||
module_prompts += 1
|
module_prompts += 1 # Fixed increment syntax
|
||||||
|
|
||||||
# Calculate progress based on total processed prompts
|
# Calculate progress based on total processed prompts
|
||||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
progress = (
|
||||||
progress = progress % 100
|
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||||
|
)
|
||||||
|
|
||||||
total_tokens -= tokens
|
total_tokens -= tokens
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
tokens, failed = await process_prompt(
|
tokens, failed = await process_prompt(
|
||||||
request_factory,
|
request_factory,
|
||||||
prompt,
|
prompt,
|
||||||
tokens,
|
tokens,
|
||||||
module.dataset_name,
|
module.dataset_name,
|
||||||
fuzzer_state=fuzzer_state,
|
refusals,
|
||||||
|
errors,
|
||||||
|
outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
total_tokens += tokens
|
total_tokens += tokens
|
||||||
|
|
||||||
if failed:
|
if failed:
|
||||||
module_failures += 1
|
module_failures += 1
|
||||||
|
|
||||||
failure_rate = module_failures / max(module_prompts, 1)
|
failure_rate = module_failures / max(module_prompts, 1)
|
||||||
failure_rates.append(failure_rate)
|
failure_rates.append(failure_rate)
|
||||||
cost = calculate_cost(tokens)
|
cost = calculate_cost(tokens)
|
||||||
|
|
||||||
response_text = fuzzer_state.get_last_output(prompt) or ""
|
last_output = outputs[-1] if outputs else None
|
||||||
|
if last_output and last_output[1] == prompt:
|
||||||
|
response_text = last_output[2]
|
||||||
|
else:
|
||||||
|
response_text = ""
|
||||||
|
|
||||||
yield ScanResult(
|
yield ScanResult(
|
||||||
module=module.dataset_name,
|
module=module.dataset_name,
|
||||||
@@ -283,19 +206,16 @@ async def scan_module(
|
|||||||
model=response_text,
|
model=response_text,
|
||||||
).model_dump_json()
|
).model_dump_json()
|
||||||
|
|
||||||
# Optimization logic
|
if optimize and len(failure_rates) >= 5:
|
||||||
if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
|
||||||
next_point = optimizer.ask()
|
next_point = optimizer.ask()
|
||||||
optimizer.tell(next_point, -failure_rate)
|
optimizer.tell(next_point, -failure_rate)
|
||||||
best_failure_rate = -optimizer.get_result().fun
|
best_failure_rate = -optimizer.get_result().fun
|
||||||
if best_failure_rate > FAILURE_RATE_THRESHOLD:
|
if best_failure_rate > 0.5:
|
||||||
yield ScanResult.status_msg(
|
yield ScanResult.status_msg(
|
||||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||||
)
|
)
|
||||||
should_stop = True
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# Budget check
|
|
||||||
if total_tokens > max_budget:
|
if total_tokens > max_budget:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||||
@@ -306,30 +226,14 @@ async def scan_module(
|
|||||||
should_stop = True
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if should_stop:
|
yield ScanResult.status_msg("Scan completed.")
|
||||||
break
|
|
||||||
|
|
||||||
return
|
failure_data = errors + refusals
|
||||||
|
df = pd.DataFrame(
|
||||||
|
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||||
|
)
|
||||||
|
df.to_csv("failures.csv", index=False)
|
||||||
|
|
||||||
|
|
||||||
async def with_error_handling(agen):
|
|
||||||
"""
|
|
||||||
Wraps an asynchronous generator with error handling.
|
|
||||||
|
|
||||||
This function iterates over an asynchronous generator, yielding its values.
|
|
||||||
If an exception occurs, it logs the error and yields a failure message.
|
|
||||||
Finally, it ensures that a completion message is always yielded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agen: An asynchronous generator that produces scan results.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
ScanResult: Either a successful result, an error message if an
|
|
||||||
exception occurs, or a completion message at the end.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async for t in agen:
|
|
||||||
yield t
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Scan failed")
|
logger.exception("Scan failed")
|
||||||
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
||||||
@@ -337,121 +241,21 @@ async def with_error_handling(agen):
|
|||||||
yield ScanResult.status_msg("Scan completed.")
|
yield ScanResult.status_msg("Scan completed.")
|
||||||
|
|
||||||
|
|
||||||
async def perform_single_shot_scan(
|
|
||||||
request_factory,
|
|
||||||
max_budget: int,
|
|
||||||
datasets: list[dict[str, str]] = [],
|
|
||||||
tools_inbox=None,
|
|
||||||
optimize: bool = False,
|
|
||||||
stop_event: asyncio.Event | None = None,
|
|
||||||
secrets: dict[str, str] = {},
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""
|
|
||||||
Perform a standard security scan using a given request factory.
|
|
||||||
|
|
||||||
This function processes security scan prompts from selected datasets while
|
|
||||||
respecting a predefined token budget. It supports optimization, failure tracking,
|
|
||||||
and early stopping based on budget constraints or user intervention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_factory: A factory function that generates requests for processing prompts.
|
|
||||||
max_budget (int): The maximum token budget for the scan.
|
|
||||||
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
|
|
||||||
tools_inbox: Optional additional tools for processing (default: None).
|
|
||||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
|
||||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
|
||||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: JSON-encoded scan results or status messages.
|
|
||||||
|
|
||||||
The function iterates over prompts, processes them asynchronously, and updates
|
|
||||||
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
|
|
||||||
it stops execution. Results are saved to a CSV file upon completion.
|
|
||||||
"""
|
|
||||||
max_budget = max_budget * BUDGET_MULTIPLIER
|
|
||||||
selected_datasets = [m for m in datasets if m["selected"]]
|
|
||||||
request_factory = get_modality_adapter(request_factory)
|
|
||||||
|
|
||||||
yield ScanResult.status_msg("Loading datasets...")
|
|
||||||
prompt_modules = prepare_prompts(
|
|
||||||
dataset_names=[m["dataset_name"] for m in selected_datasets],
|
|
||||||
budget=max_budget,
|
|
||||||
tools_inbox=tools_inbox,
|
|
||||||
options=[m.get("opts", {}) for m in selected_datasets],
|
|
||||||
)
|
|
||||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
|
||||||
|
|
||||||
fuzzer_state = FuzzerState()
|
|
||||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
|
||||||
processed_prompts = 0
|
|
||||||
|
|
||||||
total_tokens = 0
|
|
||||||
for module in prompt_modules:
|
|
||||||
module_gen = scan_module(
|
|
||||||
request_factory=request_factory,
|
|
||||||
module=module,
|
|
||||||
fuzzer_state=fuzzer_state,
|
|
||||||
processed_prompts=processed_prompts,
|
|
||||||
total_prompts=total_prompts,
|
|
||||||
max_budget=max_budget,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
optimize=optimize,
|
|
||||||
stop_event=stop_event,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
async for result in module_gen:
|
|
||||||
yield result
|
|
||||||
except Exception:
|
|
||||||
logger.error("Module exception")
|
|
||||||
continue
|
|
||||||
# Update processed_prompts count
|
|
||||||
module_size = 0 if module.lazy else len(module.prompts)
|
|
||||||
processed_prompts += module_size
|
|
||||||
|
|
||||||
yield ScanResult.status_msg("Scan completed.")
|
|
||||||
fuzzer_state.export_failures("failures.csv")
|
|
||||||
|
|
||||||
|
|
||||||
async def perform_many_shot_scan(
|
async def perform_many_shot_scan(
|
||||||
request_factory,
|
request_factory,
|
||||||
max_budget: int,
|
max_budget: int,
|
||||||
datasets: list[dict[str, str]] = [],
|
datasets: list[dict[str, str]] = [],
|
||||||
probe_datasets: list[dict[str, str]] = [],
|
probe_datasets: list[dict[str, str]] = [],
|
||||||
tools_inbox=None,
|
tools_inbox=None,
|
||||||
optimize: bool = False,
|
optimize=False,
|
||||||
stop_event: asyncio.Event | None = None,
|
stop_event: asyncio.Event = None,
|
||||||
probe_frequency: float = 0.2,
|
probe_frequency: float = 0.2,
|
||||||
max_ctx_length: int = 10_000,
|
max_ctx_length: int = 10_000,
|
||||||
secrets: dict[str, str] = {},
|
secrets: dict[str, str] = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""Perform a multi-step security scan with probe injection."""
|
||||||
Perform a multi-step security scan with probe injection.
|
request_factory = multi_modality_spec(request_factory)
|
||||||
|
try:
|
||||||
This function executes a security scan while periodically injecting probe datasets
|
|
||||||
to test system robustness. It tracks failures, optimizes scan efficiency,
|
|
||||||
and ensures adherence to a predefined token budget.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_factory: A factory function that generates requests for processing prompts.
|
|
||||||
max_budget (int): The maximum token budget for the scan.
|
|
||||||
datasets (list[dict[str, str]], optional): The main datasets for scanning.
|
|
||||||
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
|
|
||||||
tools_inbox: Optional tools for additional processing (default: None).
|
|
||||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
|
||||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
|
||||||
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
|
|
||||||
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
|
|
||||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
str: JSON-encoded scan results or status messages.
|
|
||||||
|
|
||||||
This function iterates over prompts, injects probe prompts at random intervals,
|
|
||||||
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
|
|
||||||
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
|
|
||||||
"""
|
|
||||||
request_factory = get_modality_adapter(request_factory)
|
|
||||||
# Load main and probe datasets
|
# Load main and probe datasets
|
||||||
yield ScanResult.status_msg("Loading datasets...")
|
yield ScanResult.status_msg("Loading datasets...")
|
||||||
prompt_modules = prepare_prompts(
|
prompt_modules = prepare_prompts(
|
||||||
@@ -463,10 +267,17 @@ async def perform_many_shot_scan(
|
|||||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||||
|
|
||||||
fuzzer_state = FuzzerState()
|
errors = []
|
||||||
|
refusals = []
|
||||||
|
outputs = []
|
||||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||||
processed_prompts = 0
|
processed_prompts = 0
|
||||||
|
|
||||||
|
optimizer = (
|
||||||
|
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||||
|
if optimize
|
||||||
|
else None
|
||||||
|
)
|
||||||
failure_rates = []
|
failure_rates = []
|
||||||
|
|
||||||
for module in prompt_modules:
|
for module in prompt_modules:
|
||||||
@@ -480,11 +291,11 @@ async def perform_many_shot_scan(
|
|||||||
logger.info("Scan stopped by user.")
|
logger.info("Scan stopped by user.")
|
||||||
yield ScanResult.status_msg("Scan stopped by user.")
|
yield ScanResult.status_msg("Scan stopped by user.")
|
||||||
return
|
return
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
processed_prompts += 1
|
processed_prompts += 1
|
||||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
progress = (
|
||||||
progress = progress % 100
|
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||||
|
)
|
||||||
|
|
||||||
full_prompt = ""
|
full_prompt = ""
|
||||||
msj_module = random.choice(msj_modules)
|
msj_module = random.choice(msj_modules)
|
||||||
@@ -508,7 +319,9 @@ async def perform_many_shot_scan(
|
|||||||
full_prompt,
|
full_prompt,
|
||||||
tokens,
|
tokens,
|
||||||
module.dataset_name,
|
module.dataset_name,
|
||||||
fuzzer_state=fuzzer_state,
|
refusals,
|
||||||
|
errors,
|
||||||
|
outputs,
|
||||||
)
|
)
|
||||||
if failed:
|
if failed:
|
||||||
module_failures += 1
|
module_failures += 1
|
||||||
@@ -529,51 +342,37 @@ async def perform_many_shot_scan(
|
|||||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||||
).model_dump_json()
|
).model_dump_json()
|
||||||
|
|
||||||
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
if optimize and len(failure_rates) >= 5:
|
||||||
|
next_point = optimizer.ask()
|
||||||
|
optimizer.tell(next_point, -failure_rate)
|
||||||
|
best_failure_rate = -optimizer.get_result().fun
|
||||||
|
if best_failure_rate > 0.5:
|
||||||
yield ScanResult.status_msg(
|
yield ScanResult.status_msg(
|
||||||
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
|
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
yield ScanResult.status_msg("Scan completed.")
|
yield ScanResult.status_msg("Scan completed.")
|
||||||
fuzzer_state.export_failures("failures.csv")
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
errors + refusals, columns=["module", "prompt", "status_code", "content"]
|
||||||
|
)
|
||||||
|
df.to_csv("failures.csv", index=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Scan failed")
|
||||||
|
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def scan_router(
|
def scan_router(
|
||||||
request_factory,
|
request_factory,
|
||||||
scan_parameters: Scan,
|
scan_parameters: Scan,
|
||||||
tools_inbox=None,
|
tools_inbox=None,
|
||||||
stop_event: asyncio.Event | None = None,
|
stop_event: asyncio.Event = None,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Route scan requests to the appropriate scanning function.
|
|
||||||
|
|
||||||
This function determines whether to perform a multi-step or single-shot
|
|
||||||
security scan based on the provided scan parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_factory: A factory function to generate requests for processing prompts.
|
|
||||||
scan_parameters (Scan): An object containing the parameters for the scan, including:
|
|
||||||
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
|
|
||||||
- maxBudget (int): The maximum token budget for the scan.
|
|
||||||
- datasets (list[dict[str, str]]): The datasets to scan.
|
|
||||||
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
|
|
||||||
- optimize (bool): Whether to enable optimization.
|
|
||||||
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
|
|
||||||
tools_inbox: Optional tools for additional processing (default: None).
|
|
||||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A function wrapped with `with_error_handling`, which executes either:
|
|
||||||
- `perform_many_shot_scan` for multi-step scanning.
|
|
||||||
- `perform_single_shot_scan` for single-shot scanning.
|
|
||||||
|
|
||||||
The function ensures that the appropriate scanning method is chosen based on
|
|
||||||
the `enableMultiStepAttack` flag in `scan_parameters`.
|
|
||||||
"""
|
|
||||||
if scan_parameters.enableMultiStepAttack:
|
if scan_parameters.enableMultiStepAttack:
|
||||||
return with_error_handling(
|
return perform_many_shot_scan(
|
||||||
perform_many_shot_scan(
|
|
||||||
request_factory=request_factory,
|
request_factory=request_factory,
|
||||||
max_budget=scan_parameters.maxBudget,
|
max_budget=scan_parameters.maxBudget,
|
||||||
datasets=scan_parameters.datasets,
|
datasets=scan_parameters.datasets,
|
||||||
@@ -583,10 +382,8 @@ def scan_router(
|
|||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
secrets=scan_parameters.secrets,
|
secrets=scan_parameters.secrets,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return with_error_handling(
|
return perform_single_shot_scan(
|
||||||
perform_single_shot_scan(
|
|
||||||
request_factory=request_factory,
|
request_factory=request_factory,
|
||||||
max_budget=scan_parameters.maxBudget,
|
max_budget=scan_parameters.maxBudget,
|
||||||
datasets=scan_parameters.datasets,
|
datasets=scan_parameters.datasets,
|
||||||
@@ -595,4 +392,3 @@ def scan_router(
|
|||||||
stop_event=stop_event,
|
stop_event=stop_event,
|
||||||
secrets=scan_parameters.secrets,
|
secrets=scan_parameters.secrets,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -6,10 +7,13 @@ from pydantic import BaseModel, Field
|
|||||||
from pydantic_ai import Agent, RunContext
|
from pydantic_ai import Agent, RunContext
|
||||||
|
|
||||||
from agentic_security.http_spec import LLMSpec
|
from agentic_security.http_spec import LLMSpec
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
LLM_SPECS = []
|
LLM_SPECS = []
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentSpecification(BaseModel):
|
class AgentSpecification(BaseModel):
|
||||||
name: str | None = Field(None, description="Name of the LLM/agent")
|
name: str | None = Field(None, description="Name of the LLM/agent")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
from agentic_security.refusal_classifier import RefusalClassifier
|
||||||
|
|
||||||
classifier = RefusalClassifier()
|
classifier = RefusalClassifier()
|
||||||
classifier.load_model()
|
classifier.load_model()
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
class FuzzerState:
|
|
||||||
"""Container for tracking scan results"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.errors = []
|
|
||||||
self.refusals = []
|
|
||||||
self.outputs = []
|
|
||||||
|
|
||||||
def add_error(
|
|
||||||
self,
|
|
||||||
module_name: str,
|
|
||||||
prompt: str,
|
|
||||||
status_code: int | str,
|
|
||||||
error_msg: str,
|
|
||||||
):
|
|
||||||
"""Add an error to the state"""
|
|
||||||
self.errors.append((module_name, prompt, status_code, error_msg))
|
|
||||||
|
|
||||||
def add_refusal(
|
|
||||||
self, module_name: str, prompt: str, status_code: int, response_text: str
|
|
||||||
):
|
|
||||||
"""Add a refusal to the state"""
|
|
||||||
self.refusals.append((module_name, prompt, status_code, response_text))
|
|
||||||
|
|
||||||
def add_output(
|
|
||||||
self, module_name: str, prompt: str, response_text: str, refused: bool
|
|
||||||
):
|
|
||||||
"""Add an output to the state"""
|
|
||||||
self.outputs.append((module_name, prompt, response_text, refused))
|
|
||||||
|
|
||||||
def get_last_output(self, prompt: str) -> str | None:
|
|
||||||
"""Get the last output for a given prompt"""
|
|
||||||
for output in reversed(self.outputs):
|
|
||||||
if output[1] == prompt:
|
|
||||||
return output[2]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def export_failures(self, filename: str = "failures.csv"):
|
|
||||||
"""Export failures to a CSV file"""
|
|
||||||
failure_data = self.errors + self.refusals
|
|
||||||
df = pd.DataFrame(
|
|
||||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
|
||||||
)
|
|
||||||
df.to_csv(filename, index=False)
|
|
||||||
@@ -7,7 +7,6 @@ import pytest
|
|||||||
|
|
||||||
from agentic_security.primitives import Scan
|
from agentic_security.primitives import Scan
|
||||||
from agentic_security.probe_actor.fuzzer import (
|
from agentic_security.probe_actor.fuzzer import (
|
||||||
FuzzerState,
|
|
||||||
generate_prompts,
|
generate_prompts,
|
||||||
perform_many_shot_scan,
|
perform_many_shot_scan,
|
||||||
perform_single_shot_scan,
|
perform_single_shot_scan,
|
||||||
@@ -208,7 +207,9 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
|||||||
prompt="test prompt",
|
prompt="test prompt",
|
||||||
tokens=0,
|
tokens=0,
|
||||||
module_name="module_a",
|
module_name="module_a",
|
||||||
fuzzer_state=FuzzerState(),
|
refusals=[],
|
||||||
|
errors=[],
|
||||||
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||||
@@ -225,17 +226,20 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
fuzzer_state = FuzzerState()
|
refusals = []
|
||||||
|
outputs = []
|
||||||
tokens, refusal = await process_prompt(
|
tokens, refusal = await process_prompt(
|
||||||
request_factory=mock_request_factory,
|
request_factory=mock_request_factory,
|
||||||
prompt="test prompt",
|
prompt="test prompt",
|
||||||
tokens=0,
|
tokens=0,
|
||||||
module_name="module_a",
|
module_name="module_a",
|
||||||
fuzzer_state=fuzzer_state,
|
refusals=refusals,
|
||||||
|
errors=[],
|
||||||
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||||
# self.assertFalse(fuzzer_state.refusals)
|
self.assertFalse(refusal)
|
||||||
|
|
||||||
async def test_http_error_response(self):
|
async def test_http_error_response(self):
|
||||||
mock_request_factory = Mock()
|
mock_request_factory = Mock()
|
||||||
@@ -248,13 +252,15 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
fuzzer_state = FuzzerState()
|
refusals = []
|
||||||
await process_prompt(
|
await process_prompt(
|
||||||
request_factory=mock_request_factory,
|
request_factory=mock_request_factory,
|
||||||
prompt="test prompt",
|
prompt="test prompt",
|
||||||
tokens=0,
|
tokens=0,
|
||||||
module_name="module_a",
|
module_name="module_a",
|
||||||
fuzzer_state=fuzzer_state,
|
refusals=refusals,
|
||||||
|
errors=[],
|
||||||
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_request_error(self):
|
async def test_request_error(self):
|
||||||
@@ -263,14 +269,18 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
|||||||
side_effect=httpx.RequestError("Connection error")
|
side_effect=httpx.RequestError("Connection error")
|
||||||
)
|
)
|
||||||
|
|
||||||
fuzzer_state = FuzzerState()
|
errors = []
|
||||||
tokens, refusal = await process_prompt(
|
tokens, refusal = await process_prompt(
|
||||||
request_factory=mock_request_factory,
|
request_factory=mock_request_factory,
|
||||||
prompt="test prompt",
|
prompt="test prompt",
|
||||||
tokens=0,
|
tokens=0,
|
||||||
module_name="module_a",
|
module_name="module_a",
|
||||||
fuzzer_state=fuzzer_state,
|
refusals=[],
|
||||||
|
errors=errors,
|
||||||
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(tokens, 0)
|
self.assertEqual(tokens, 0)
|
||||||
self.assertTrue(refusal)
|
self.assertTrue(refusal)
|
||||||
|
self.assertEqual(len(errors), 1)
|
||||||
|
self.assertIn("Connection error", errors[0][3])
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from .data import load_local_csv, load_local_csv_files
|
from .data import load_local_csv
|
||||||
|
|
||||||
REGISTRY_V0 = [
|
REGISTRY_V0 = [
|
||||||
{
|
{
|
||||||
@@ -484,18 +484,3 @@ REGISTRY = REGISTRY_V0 + [
|
|||||||
"modality": "text",
|
"modality": "text",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for ds in load_local_csv_files():
|
|
||||||
REGISTRY.append(
|
|
||||||
{
|
|
||||||
"dataset_name": ds.dataset_name,
|
|
||||||
"num_prompts": len(ds.prompts),
|
|
||||||
"tokens": ds.prompts,
|
|
||||||
"approx_cost": 0.0,
|
|
||||||
"is_active": True,
|
|
||||||
"source": f"Local file dataset: {ds.metadata['src']}",
|
|
||||||
"selected": False,
|
|
||||||
"url": "",
|
|
||||||
"modality": "text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -8,16 +7,6 @@ import uuid
|
|||||||
import httpx
|
import httpx
|
||||||
from cache_to_disk import cache_to_disk
|
from cache_to_disk import cache_to_disk
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AudioGenerationError(Exception):
|
|
||||||
"""Custom exception for errors during audio generation."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def encode(content: bytes) -> str:
|
def encode(content: bytes) -> str:
|
||||||
encoded_content = base64.b64encode(content).decode("utf-8")
|
encoded_content = base64.b64encode(content).decode("utf-8")
|
||||||
@@ -52,25 +41,12 @@ def generate_audio_mac_wav(prompt: str) -> bytes:
|
|||||||
# Read the WAV file into memory
|
# Read the WAV file into memory
|
||||||
with open(temp_wav_path, "rb") as f:
|
with open(temp_wav_path, "rb") as f:
|
||||||
audio_bytes = f.read()
|
audio_bytes = f.read()
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
logger.error(f"Subprocess error: {e}")
|
|
||||||
raise AudioGenerationError("Failed to generate or convert audio.") from e
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.error(f"File not found: {e}")
|
|
||||||
raise AudioGenerationError("Required file not found.") from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Unexpected error occurred.")
|
|
||||||
raise AudioGenerationError(
|
|
||||||
"An unexpected error occurred during audio generation."
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
for path in (temp_aiff_path, temp_wav_path):
|
# Clean up the temporary files
|
||||||
try:
|
if os.path.exists(temp_aiff_path):
|
||||||
if os.path.exists(path):
|
os.remove(temp_aiff_path)
|
||||||
os.remove(path)
|
if os.path.exists(temp_wav_path):
|
||||||
except Exception as e:
|
os.remove(temp_wav_path)
|
||||||
logger.warning(f"Failed to delete temporary file {path}: {e}")
|
|
||||||
|
|
||||||
# Return the audio bytes
|
# Return the audio bytes
|
||||||
return audio_bytes
|
return audio_bytes
|
||||||
|
|||||||
+374
-395
@@ -1,16 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections.abc import Callable, Iterator
|
from functools import lru_cache
|
||||||
from functools import partial
|
|
||||||
from typing import Any, TypeVar
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from cache_to_disk import cache_to_disk
|
from cache_to_disk import cache_to_disk
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
from agentic_security.probe_data import stenography_fn
|
from agentic_security.probe_data import stenography_fn
|
||||||
from agentic_security.probe_data.models import ProbeDataset
|
from agentic_security.probe_data.models import ProbeDataset
|
||||||
from agentic_security.probe_data.modules import (
|
from agentic_security.probe_data.modules import (
|
||||||
@@ -21,279 +19,281 @@ from agentic_security.probe_data.modules import (
|
|||||||
rl_model,
|
rl_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Type aliases for clarity
|
|
||||||
T = TypeVar("T")
|
|
||||||
FilterFn = Callable[[pd.Series], bool]
|
|
||||||
ColumnMappings = dict[str, str]
|
|
||||||
DatasetLoader = Callable[[], ProbeDataset]
|
|
||||||
TransformFn = Callable[[str], str]
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_dataset_general(
|
||||||
|
dataset_name,
|
||||||
|
dataset_split="train",
|
||||||
|
column_mappings=None,
|
||||||
|
filter_fn=None,
|
||||||
|
custom_url=None,
|
||||||
|
additional_metadata=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generalized function to load datasets with flexible configurations.
|
||||||
|
|
||||||
# Core data loading utilities
|
:param dataset_name: Name of the dataset or URL for custom CSVs
|
||||||
def fetch_csv_content(url: str) -> str:
|
:param dataset_split: Split to load from the dataset (e.g., "train")
|
||||||
"""Fetch CSV content from a URL."""
|
:param column_mappings: Dictionary mapping dataset columns to expected keys, e.g., {'prompt': 'query'}
|
||||||
response = httpx.get(url)
|
:param filter_fn: A filtering function that takes a row and returns True/False
|
||||||
response.raise_for_status() # Raise exception for bad responses
|
:param custom_url: URL for custom CSV datasets
|
||||||
return response.content.decode("utf-8")
|
:param additional_metadata: Additional metadata to include in the ProbeDataset
|
||||||
|
:return: A ProbeDataset object with the processed data
|
||||||
|
"""
|
||||||
|
if custom_url:
|
||||||
|
logger.info(f"Loading custom CSV dataset from {custom_url}")
|
||||||
|
r = httpx.get(custom_url)
|
||||||
|
content = r.content
|
||||||
|
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading dataset {dataset_name} from Hugging Face datasets")
|
||||||
|
dataset = load_dataset(dataset_name)
|
||||||
|
df = pd.DataFrame(dataset[dataset_split])
|
||||||
|
|
||||||
|
# Apply column mappings if provided
|
||||||
|
if column_mappings:
|
||||||
|
df.rename(columns=column_mappings, inplace=True)
|
||||||
|
|
||||||
def load_df_from_source(source: str, is_url: bool = False) -> pd.DataFrame:
|
# Filter rows if filter_fn is provided
|
||||||
"""Load DataFrame from either URL or Hugging Face dataset."""
|
|
||||||
if is_url:
|
|
||||||
logger.info(f"Loading custom CSV from {source}")
|
|
||||||
return pd.read_csv(io.StringIO(fetch_csv_content(source)))
|
|
||||||
logger.info(f"Loading dataset {source} from Hugging Face")
|
|
||||||
return pd.DataFrame(load_dataset(source)["train"])
|
|
||||||
|
|
||||||
|
|
||||||
def transform_df(
|
|
||||||
df: pd.DataFrame,
|
|
||||||
mappings: ColumnMappings | None = None,
|
|
||||||
filter_fn: FilterFn | None = None,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Apply transformations to DataFrame."""
|
|
||||||
result = df.copy()
|
|
||||||
if mappings:
|
|
||||||
result.rename(columns=mappings, inplace=True)
|
|
||||||
if filter_fn:
|
if filter_fn:
|
||||||
result = result[result.apply(filter_fn, axis=1)]
|
df = df[df.apply(filter_fn, axis=1)]
|
||||||
return result
|
|
||||||
|
|
||||||
|
# Extract prompts
|
||||||
|
prompts = df[column_mappings.get("prompt", "prompt")].tolist()
|
||||||
|
|
||||||
def create_probe_dataset(
|
|
||||||
name: str, prompts: list[str], metadata: dict[str, Any] | None = None
|
|
||||||
) -> ProbeDataset:
|
|
||||||
"""Create a ProbeDataset from prompts."""
|
|
||||||
metadata = metadata or {}
|
|
||||||
return ProbeDataset(
|
return ProbeDataset(
|
||||||
dataset_name=name,
|
dataset_name=dataset_name,
|
||||||
metadata=metadata,
|
metadata=additional_metadata or {},
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
tokens=sum(len(str(s).split()) for s in prompts),
|
tokens=count_words_in_list(prompts),
|
||||||
approx_cost=0.0,
|
approx_cost=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Generalized dataset loader
|
def count_words_in_list(str_list):
|
||||||
|
"""Calculate the total number of words in a given list of strings.
|
||||||
|
|
||||||
|
:param str_list: List of strings
|
||||||
|
:return: Total number of words across all strings in the list
|
||||||
|
"""
|
||||||
|
total_words = sum(len(str(s).split()) for s in str_list)
|
||||||
|
return total_words
|
||||||
|
|
||||||
|
|
||||||
@cache_to_disk()
|
@cache_to_disk()
|
||||||
def load_dataset_generic(
|
def load_dataset_v1():
|
||||||
name: str,
|
return load_dataset_general(
|
||||||
mappings: ColumnMappings | None = None,
|
dataset_name="ShawnMenz/DAN_jailbreak",
|
||||||
filter_fn: FilterFn | None = None,
|
column_mappings={"prompt": "prompt", "jailbreak": "jailbreak"},
|
||||||
url: str | None = None,
|
filter_fn=lambda row: row["jailbreak"],
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> ProbeDataset:
|
|
||||||
"""Load and process a dataset with flexible configuration."""
|
|
||||||
try:
|
|
||||||
df = load_df_from_source(url or name, is_url=bool(url))
|
|
||||||
transformed_df = transform_df(df, mappings, filter_fn)
|
|
||||||
|
|
||||||
# Determine which column to use as the prompt source
|
|
||||||
prompt_col = None
|
|
||||||
if mappings and "prompt" in mappings:
|
|
||||||
prompt_col = mappings["prompt"]
|
|
||||||
elif "prompt" in transformed_df.columns:
|
|
||||||
prompt_col = "prompt"
|
|
||||||
else:
|
|
||||||
# Try to find a suitable text column
|
|
||||||
text_columns = [
|
|
||||||
col
|
|
||||||
for col in transformed_df.columns
|
|
||||||
if any(
|
|
||||||
keyword in col.lower()
|
|
||||||
for keyword in ["prompt", "text", "query", "question"]
|
|
||||||
)
|
)
|
||||||
]
|
|
||||||
if text_columns:
|
|
||||||
prompt_col = text_columns[0]
|
|
||||||
logger.info(f"Using column '{prompt_col}' as prompt source")
|
|
||||||
else:
|
|
||||||
logger.error(f"No suitable prompt column found in dataset {name}")
|
|
||||||
return create_probe_dataset(name, [], metadata)
|
|
||||||
|
|
||||||
# Extract prompts and filter out empty ones
|
|
||||||
prompts = [
|
@cache_to_disk()
|
||||||
|
def load_dataset_v2():
|
||||||
|
return load_dataset_general(
|
||||||
|
dataset_name="deepset/prompt-injections",
|
||||||
|
column_mappings={"prompt": "text", "jailbreak": "label"},
|
||||||
|
filter_fn=lambda row: row["label"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_dataset_v4():
|
||||||
|
return load_dataset_general(
|
||||||
|
dataset_name="notrichardren/refuse-to-answer-prompts",
|
||||||
|
column_mappings={"prompt": "claim", "jailbreak": "label"},
|
||||||
|
filter_fn=lambda row: row["label"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_dataset_v3():
|
||||||
|
return load_dataset_general(
|
||||||
|
dataset_name="rubend18/ChatGPT-Jailbreak-Prompts",
|
||||||
|
column_mappings={"prompt": "Prompt"},
|
||||||
|
filter_fn=lambda row: row["label"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_dataset_v6():
|
||||||
|
dataset = load_dataset("markush1/LLM-Jailbreak-Classifier")
|
||||||
|
bad_prompts = [
|
||||||
p
|
p
|
||||||
for p in transformed_df[prompt_col].tolist()
|
for p, c in zip(dataset["train"]["prompt"], dataset["train"]["classification"])
|
||||||
if p and isinstance(p, (str, int, float))
|
if c == "jailbreak"
|
||||||
]
|
]
|
||||||
return create_probe_dataset(name, prompts, metadata)
|
return ProbeDataset(
|
||||||
except Exception as e:
|
dataset_name="markush1/LLM-Jailbreak-Classifier",
|
||||||
logger.error(f"Error loading dataset {name}: {e}")
|
metadata={},
|
||||||
return create_probe_dataset(name, [], {"error": str(e)})
|
prompts=bad_prompts,
|
||||||
|
tokens=count_words_in_list(bad_prompts),
|
||||||
|
approx_cost=0.0,
|
||||||
# Dataset-specific configurations
|
|
||||||
DATASET_CONFIGS = {
|
|
||||||
"ShawnMenz/DAN_jailbreak": {
|
|
||||||
"mappings": {"prompt": "prompt"},
|
|
||||||
"filter_fn": lambda row: row["jailbreak"],
|
|
||||||
},
|
|
||||||
"deepset/prompt-injections": {
|
|
||||||
"mappings": {"prompt": "text"},
|
|
||||||
"filter_fn": lambda row: row["label"],
|
|
||||||
},
|
|
||||||
"notrichardren/refuse-to-answer-prompts": {
|
|
||||||
"mappings": {"prompt": "claim"},
|
|
||||||
"filter_fn": lambda row: row["label"],
|
|
||||||
},
|
|
||||||
"rubend18/ChatGPT-Jailbreak-Prompts": {
|
|
||||||
"mappings": {"prompt": "Prompt"},
|
|
||||||
"filter_fn": lambda row: row["label"],
|
|
||||||
},
|
|
||||||
"markush1/LLM-Jailbreak-Classifier": {
|
|
||||||
"mappings": {"prompt": "prompt"},
|
|
||||||
"filter_fn": lambda row: row["classification"] == "jailbreak",
|
|
||||||
},
|
|
||||||
"ShawnMenz/jailbreak_sft_rm_ds": {
|
|
||||||
"url": "hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
|
|
||||||
"mappings": {"prompt": "prompt"},
|
|
||||||
"filter_fn": lambda row: row["jailbreak"] == "jailbreak",
|
|
||||||
},
|
|
||||||
"verazuo/jailbreak_llms/2023_05_07": {
|
|
||||||
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
|
|
||||||
"mappings": {"prompt": "prompt"},
|
|
||||||
"filter_fn": lambda row: bool(row["jailbreak"]),
|
|
||||||
},
|
|
||||||
"verazuo/jailbreak_llms/2023_12_25": {
|
|
||||||
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv",
|
|
||||||
"mappings": {"prompt": "prompt"},
|
|
||||||
"filter_fn": lambda row: bool(row["jailbreak"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Additional generic dataset configurations
|
|
||||||
DATASET_CONFIGS_GENERICS = {
|
|
||||||
"simonycl/aya-23-8B_advbench_jailbreak": {"mappings": {"prompt": "prompt"}},
|
|
||||||
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": {},
|
|
||||||
"karanxa/dolphin-jailbreak-finetuning-dataset": {"mappings": {"prompt": "text"}},
|
|
||||||
"karanxa/llama-2-jailbreak-dataset": {"mappings": {"prompt": "text"}},
|
|
||||||
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": {
|
|
||||||
"mappings": {"prompt": "text"}
|
|
||||||
},
|
|
||||||
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": {
|
|
||||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
|
||||||
},
|
|
||||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": {
|
|
||||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
|
||||||
},
|
|
||||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": {
|
|
||||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
|
||||||
},
|
|
||||||
"jackhhao/jailbreak-classification": {"mappings": {"prompt": "prompt"}},
|
|
||||||
"walledai/JailbreakBench": {"mappings": {"prompt": "prompt"}},
|
|
||||||
"walledai/JailbreakHub": {"mappings": {"prompt": "prompt"}},
|
|
||||||
"Granther/evil-jailbreak": {"mappings": {"prompt": "text"}},
|
|
||||||
"sevdeawesome/jailbreak_success": {"mappings": {"prompt": "jailbreak_prompt_text"}},
|
|
||||||
"IDA-SERICS/Disaster-tweet-jailbreaking": {"mappings": {"prompt": "prompt_attack"}},
|
|
||||||
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": {"mappings": {"prompt": "question"}},
|
|
||||||
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": {"mappings": {"prompt": "prompt"}},
|
|
||||||
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": {
|
|
||||||
"mappings": {"prompt": "prompt"}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Dataset factory
|
|
||||||
def create_dataset_loader(name: str, config: dict[str, Any]) -> DatasetLoader:
|
|
||||||
"""Create a dataset loader from configuration."""
|
|
||||||
return partial(
|
|
||||||
load_dataset_generic,
|
|
||||||
name=name,
|
|
||||||
mappings=config.get("mappings"),
|
|
||||||
filter_fn=config.get("filter_fn"),
|
|
||||||
url=config.get("url"),
|
|
||||||
metadata={"source": name, "config": str(config)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Specialized loaders
|
|
||||||
@cache_to_disk()
|
|
||||||
def load_multi_dataset(name: str, sub_datasets: list[str]) -> ProbeDataset:
|
|
||||||
"""Load and combine multiple sub-datasets."""
|
|
||||||
prompts = []
|
|
||||||
for sub in sub_datasets:
|
|
||||||
try:
|
|
||||||
dataset = load_dataset(name, sub)
|
|
||||||
if "query" in dataset["train"].features:
|
|
||||||
prompts.extend(dataset["train"]["query"])
|
|
||||||
else:
|
|
||||||
logger.warning(f"No 'query' column in {name}/{sub}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading {name}/{sub}: {e}")
|
|
||||||
|
|
||||||
return create_probe_dataset(
|
|
||||||
f"{name}_combined", prompts, {"source": name, "sub_datasets": sub_datasets}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@cache_to_disk()
|
@cache_to_disk()
|
||||||
def load_jailbreak_v28k() -> ProbeDataset:
|
def load_dataset_v7():
|
||||||
"""Load JailBreakV-28K dataset."""
|
splits = {
|
||||||
try:
|
"mini_JailBreakV_28K": "JailBreakV_28K/mini_JailBreakV_28K.csv",
|
||||||
|
"JailBreakV_28K": "JailBreakV_28K/JailBreakV_28K.csv",
|
||||||
|
}
|
||||||
df = pd.read_csv(
|
df = pd.read_csv(
|
||||||
"hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv"
|
"hf://datasets/JailbreakV-28K/JailBreakV-28k/" + splits["JailBreakV_28K"]
|
||||||
)
|
)
|
||||||
prompts = df["jailbreak_query"].tolist()
|
bad_prompts = df["jailbreak_query"].tolist()
|
||||||
return create_probe_dataset(
|
print(df.shape)
|
||||||
"JailbreakV-28K/JailBreakV-28k",
|
return ProbeDataset(
|
||||||
prompts,
|
dataset_name="JailbreakV-28K/JailBreakV-28k",
|
||||||
{"source": "JailbreakV-28K/JailBreakV-28k"},
|
metadata={},
|
||||||
|
prompts=bad_prompts,
|
||||||
|
tokens=count_words_in_list(bad_prompts),
|
||||||
|
approx_cost=0.0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading JailbreakV-28K: {e}")
|
|
||||||
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", [])
|
|
||||||
|
|
||||||
|
|
||||||
@cache_to_disk(1)
|
@cache_to_disk()
|
||||||
def file_dataset(file) -> list[str]:
|
def load_dataset_v8():
|
||||||
prompts = []
|
df = pd.read_csv(
|
||||||
|
"hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
|
||||||
|
names=["jailbreak", "prompt"],
|
||||||
|
)
|
||||||
|
filtered = df[df["jailbreak"] == "jailbreak"]["prompt"].tolist()
|
||||||
|
return ProbeDataset(
|
||||||
|
dataset_name="JailbreakV-28K/JailBreakV-28k",
|
||||||
|
metadata={},
|
||||||
|
prompts=filtered,
|
||||||
|
tokens=count_words_in_list(filtered),
|
||||||
|
approx_cost=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_dataset_v5():
|
||||||
|
ds = []
|
||||||
|
for c in [
|
||||||
|
"AdvBench",
|
||||||
|
"ForbiddenQuestion",
|
||||||
|
"MJP",
|
||||||
|
"MaliciousInstruct",
|
||||||
|
"QuestionList",
|
||||||
|
]:
|
||||||
|
dataset = load_dataset("Lemhf14/EasyJailbreak_Datasets", c)
|
||||||
|
bad_prompts = dataset["train"]["query"]
|
||||||
|
ds.extend(bad_prompts)
|
||||||
|
|
||||||
|
return ProbeDataset(
|
||||||
|
dataset_name="Lemhf14/EasyJailbreak_Datasets",
|
||||||
|
metadata={},
|
||||||
|
prompts=ds,
|
||||||
|
tokens=count_words_in_list(ds),
|
||||||
|
approx_cost=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk()
|
||||||
|
def load_generic_csv(url, name, column="prompt", predicator=None):
|
||||||
|
r = httpx.get(url)
|
||||||
|
content = r.content
|
||||||
|
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||||
|
logger.info(f"Loaded {len(df)} prompts from {url}")
|
||||||
|
filtered_prompts = df[df.apply(predicator, axis=1)][column].tolist()
|
||||||
|
return ProbeDataset(
|
||||||
|
dataset_name=name,
|
||||||
|
metadata={},
|
||||||
|
prompts=filtered_prompts,
|
||||||
|
tokens=count_words_in_list(filtered_prompts),
|
||||||
|
approx_cost=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
|
||||||
|
# ## Datasets used and cleaned:
|
||||||
|
# markush1/LLM-Jailbreak-Classifier
|
||||||
|
# 1. Open-Orca/OpenOrca
|
||||||
|
# 2. ShawnMenz/DAN_jailbreak
|
||||||
|
# 3. EddyLuo/JailBreakV_28K
|
||||||
|
# 4. https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/jailbreak_prompts.csv
|
||||||
|
|
||||||
|
dataset_map = {
|
||||||
|
"ShawnMenz/DAN_jailbreak": load_dataset_v1,
|
||||||
|
"deepset/prompt-injections": load_dataset_v2,
|
||||||
|
"notrichardren/refuse-to-answer-prompts": load_dataset_v4,
|
||||||
|
"rubend18/ChatGPT-Jailbreak-Prompts": load_dataset_v3,
|
||||||
|
"Lemhf14/EasyJailbreak_Datasets": load_dataset_v5,
|
||||||
|
"markush1/LLM-Jailbreak-Classifier": load_dataset_v6,
|
||||||
|
"JailbreakV-28K/JailBreakV-28k": load_dataset_v7,
|
||||||
|
"ShawnMenz/jailbreak_sft_rm_ds": load_dataset_v8,
|
||||||
|
"verazuo/jailbreak_llms/2023_05_07": lambda: load_generic_csv(
|
||||||
|
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
|
||||||
|
name="verazuo/jailbreak_llms/2023_05_07",
|
||||||
|
column="prompt",
|
||||||
|
predicator=lambda x: bool(x["jailbreak"]),
|
||||||
|
),
|
||||||
|
"verazuo/jailbreak_llms/2023_12_25.csv": lambda: load_generic_csv(
|
||||||
|
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv.csv",
|
||||||
|
name="verazuo/jailbreak_llms/2023_12_25.csv",
|
||||||
|
column="prompt",
|
||||||
|
predicator=lambda x: bool(x["jailbreak"]),
|
||||||
|
),
|
||||||
|
"Custom CSV": load_local_csv,
|
||||||
|
}
|
||||||
|
dataset_map.update(dataset_map_generics)
|
||||||
|
|
||||||
|
group = []
|
||||||
|
for dataset_name in dataset_names:
|
||||||
|
if dataset_name in dataset_map:
|
||||||
|
logger.info(f"Loading {dataset_name}")
|
||||||
try:
|
try:
|
||||||
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
|
group.append(dataset_map[dataset_name]())
|
||||||
if "prompt" in df.columns:
|
|
||||||
prompts = df["prompt"].tolist()
|
|
||||||
else:
|
|
||||||
logger.warning(f"File {file} lacks a suitable prompt column")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error reading {file}: {e}")
|
logger.error(f"Error loading {dataset_name}: {e}")
|
||||||
return prompts
|
|
||||||
|
dynamic_datasets = {
|
||||||
|
"AgenticBackend": lambda opts: dataset_from_iterator(
|
||||||
|
"AgenticBackend",
|
||||||
|
fine_tuned.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||||
|
lazy=True,
|
||||||
|
),
|
||||||
|
"Steganography": lambda opts: Stenography(group),
|
||||||
|
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
|
||||||
|
"llm-adaptive-attacks",
|
||||||
|
adaptive_attacks.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||||
|
),
|
||||||
|
"Garak": lambda opts: dataset_from_iterator(
|
||||||
|
"Garak",
|
||||||
|
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||||
|
lazy=True,
|
||||||
|
),
|
||||||
|
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
||||||
|
"Reinforcement Learning Optimization",
|
||||||
|
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||||
|
lazy=True,
|
||||||
|
),
|
||||||
|
"InspectAI": lambda opts: dataset_from_iterator(
|
||||||
|
"InspectAI",
|
||||||
|
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
|
||||||
|
lazy=True,
|
||||||
|
),
|
||||||
|
"GPT fuzzer": lambda opts: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
dynamic_groups = []
|
||||||
|
options = options or [{} for _ in dataset_names]
|
||||||
|
for dataset_name, opts in zip(dataset_names, options):
|
||||||
|
if dataset_name in dynamic_datasets:
|
||||||
|
logger.info(f"Loading {dataset_name}")
|
||||||
|
|
||||||
|
ds = dynamic_datasets[dataset_name](opts)
|
||||||
|
|
||||||
|
for g in ds:
|
||||||
|
dynamic_groups.append(g)
|
||||||
|
return group + dynamic_groups
|
||||||
|
|
||||||
|
|
||||||
def load_local_csv() -> ProbeDataset:
|
class Stenography:
|
||||||
"""Load prompts from local CSV files."""
|
fn_library = {
|
||||||
os.makedirs("./datasets", exist_ok=True)
|
|
||||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
|
||||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
|
||||||
|
|
||||||
prompts = []
|
|
||||||
for file in csv_files:
|
|
||||||
prompts.extend(file_dataset(file))
|
|
||||||
return create_probe_dataset("Local CSV", prompts, {"src": str(csv_files)})
|
|
||||||
|
|
||||||
|
|
||||||
def load_csv(file: str) -> ProbeDataset:
|
|
||||||
"""Load prompts from local CSV files."""
|
|
||||||
prompts = file_dataset(file)
|
|
||||||
return create_probe_dataset(f"fs://{file}", prompts, {"src": str(file)})
|
|
||||||
|
|
||||||
|
|
||||||
def load_local_csv_files() -> list[ProbeDataset]:
|
|
||||||
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
|
|
||||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
|
||||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
|
|
||||||
for file in csv_files:
|
|
||||||
datasets.append(create_probe_dataset(file, file_dataset(file), {"src": file}))
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# Stenography transformer
|
|
||||||
class StenographyTransformer:
|
|
||||||
"""Apply stenography transformations to datasets."""
|
|
||||||
|
|
||||||
TRANSFORMATIONS = {
|
|
||||||
"rot5": stenography_fn.rot5,
|
"rot5": stenography_fn.rot5,
|
||||||
"rot13": stenography_fn.rot13,
|
"rot13": stenography_fn.rot13,
|
||||||
"base64": stenography_fn.base64_encode,
|
"base64": stenography_fn.base64_encode,
|
||||||
@@ -306,172 +306,151 @@ class StenographyTransformer:
|
|||||||
"zigzag_obfuscation": stenography_fn.zigzag_obfuscation,
|
"zigzag_obfuscation": stenography_fn.zigzag_obfuscation,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, datasets: list[ProbeDataset]):
|
def __init__(self, prompt_groups: [ProbeDataset]):
|
||||||
self.datasets = datasets
|
self.prompt_groups = prompt_groups
|
||||||
|
|
||||||
def transform(self) -> Iterator[ProbeDataset]:
|
def __iter__(self):
|
||||||
"""Generate transformed datasets."""
|
return self.apply()
|
||||||
if not self.datasets:
|
|
||||||
logger.warning("No datasets provided to StenographyTransformer")
|
|
||||||
return
|
|
||||||
|
|
||||||
for dataset in self.datasets:
|
def apply(self):
|
||||||
if not dataset.prompts:
|
for prompt_group in self.prompt_groups:
|
||||||
logger.warning(
|
size = len(prompt_group.prompts)
|
||||||
f"Dataset {dataset.dataset_name} has no prompts to transform"
|
for name, fn in self.fn_library.items():
|
||||||
)
|
logger.info(f"Applying {name} to {prompt_group.dataset_name}")
|
||||||
continue
|
selected_prompts = random.sample(prompt_group.prompts, int(size * 0.25))
|
||||||
|
prompts = [
|
||||||
# Convert async generator (if any) to a list
|
f"Follow the instructions in {name}: {fn(p)}"
|
||||||
prompts = (
|
for p in selected_prompts
|
||||||
list(dataset.prompts) if hasattr(dataset.prompts, "__iter__") else []
|
|
||||||
)
|
|
||||||
size = len(prompts)
|
|
||||||
|
|
||||||
if size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
sample_size = max(1, int(size * 0.25)) # At least 1 sample
|
|
||||||
|
|
||||||
for name, fn in self.TRANSFORMATIONS.items():
|
|
||||||
logger.info(f"Applying {name} to {dataset.dataset_name}")
|
|
||||||
sampled = random.sample(prompts, min(sample_size, size))
|
|
||||||
try:
|
|
||||||
transformed = [
|
|
||||||
f"Follow the instructions in {name}: {fn(str(p))}"
|
|
||||||
for p in sampled
|
|
||||||
]
|
]
|
||||||
transformed_dataset = create_probe_dataset(
|
yield ProbeDataset(
|
||||||
f"stenography.{name}({dataset.dataset_name})",
|
dataset_name=f"stenography.{name}({prompt_group.dataset_name})",
|
||||||
transformed,
|
metadata={},
|
||||||
{"source": dataset.dataset_name, "transformation": name},
|
prompts=prompts,
|
||||||
)
|
tokens=count_words_in_list(prompts),
|
||||||
yield transformed_dataset
|
approx_cost=0.0,
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error applying {name} to {dataset.dataset_name}: {e}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dataset_from_iterator(
|
@lru_cache
|
||||||
name: str, iterator: Iterator[str], lazy: bool = False
|
def load_local_csv() -> ProbeDataset:
|
||||||
) -> list[ProbeDataset]:
|
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
|
||||||
"""Convert an iterator into a list of ProbeDataset objects."""
|
logger.info(f"Found {len(csv_files)} CSV files")
|
||||||
|
logger.info(f"CSV files: {csv_files}")
|
||||||
|
prompt_list = []
|
||||||
|
|
||||||
|
for file in csv_files:
|
||||||
try:
|
try:
|
||||||
|
df = pd.read_csv(file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading {file}: {e}")
|
||||||
|
continue
|
||||||
|
# Check if 'prompt' column exists
|
||||||
|
if "prompt" in df.columns:
|
||||||
|
prompt_list.extend(df["prompt"].tolist())
|
||||||
|
else:
|
||||||
|
logger.warning(f"File {file} does not contain a 'prompt' column")
|
||||||
|
return ProbeDataset(
|
||||||
|
dataset_name="Local CSV",
|
||||||
|
metadata={"src": str(csv_files)},
|
||||||
|
prompts=prompt_list,
|
||||||
|
tokens=count_words_in_list(prompt_list),
|
||||||
|
approx_cost=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_from_iterator(name: str, iterator, lazy=False) -> list:
|
||||||
|
"""Convert an iterator into a list of prompts and create a ProbeDataset
|
||||||
|
object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the dataset.
|
||||||
|
iterator (iterator): An iterator that generates prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list containing a single ProbeDataset object.
|
||||||
|
"""
|
||||||
prompts = list(iterator) if not lazy else iterator
|
prompts = list(iterator) if not lazy else iterator
|
||||||
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0
|
tokens = count_words_in_list(prompts) if not lazy else 0
|
||||||
dataset = ProbeDataset(
|
dataset = ProbeDataset(
|
||||||
dataset_name=name,
|
dataset_name=name,
|
||||||
metadata={"source": name, "lazy": lazy},
|
metadata={},
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
approx_cost=0.0,
|
approx_cost=0.0,
|
||||||
lazy=lazy,
|
lazy=lazy,
|
||||||
)
|
)
|
||||||
return [dataset]
|
return [dataset]
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating dataset from iterator {name}: {e}")
|
|
||||||
return [create_probe_dataset(name, [], {"error": str(e)})]
|
|
||||||
|
|
||||||
|
|
||||||
# Main dataset preparation
|
# TODO: refactor this abstraction
|
||||||
def prepare_prompts(
|
|
||||||
dataset_names: list[str],
|
dataset_map_generics = {
|
||||||
budget: float,
|
"simonycl/aya-23-8B_advbench_jailbreak": lambda: load_dataset_general(
|
||||||
tools_inbox=None,
|
dataset_name="simonycl/aya-23-8B_advbench_jailbreak",
|
||||||
options: list[dict] = None,
|
column_mappings={"prompt": "prompt"},
|
||||||
) -> list[ProbeDataset]:
|
),
|
||||||
"""Prepare datasets based on names and options."""
|
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": lambda: load_dataset_general(
|
||||||
# Base dataset loaders
|
dataset_name="acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b"
|
||||||
logger.info(f"Preparing datasets: {dataset_names}")
|
),
|
||||||
dataset_loaders = {
|
"karanxa/dolphin-jailbreak-finetuning-dataset": lambda: load_dataset_general(
|
||||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS.items()},
|
dataset_name="karanxa/dolphin-jailbreak-finetuning-dataset",
|
||||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS_GENERICS.items()},
|
column_mappings={"prompt": "text"},
|
||||||
"Lemhf14/EasyJailbreak_Datasets": partial(
|
),
|
||||||
load_multi_dataset,
|
"karanxa/llama-2-jailbreak-dataset": lambda: load_dataset_general(
|
||||||
name="Lemhf14/EasyJailbreak_Datasets",
|
dataset_name="karanxa/llama-2-jailbreak-dataset",
|
||||||
sub_datasets=[
|
column_mappings={"prompt": "text"},
|
||||||
"AdvBench",
|
),
|
||||||
"ForbiddenQuestion",
|
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": lambda: load_dataset_general(
|
||||||
"MJP",
|
dataset_name="karanxa/llama2-uncensored-jailbreak-dataset-finetuning",
|
||||||
"MaliciousInstruct",
|
column_mappings={"prompt": "text"},
|
||||||
"QuestionList",
|
),
|
||||||
],
|
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": lambda: load_dataset_general(
|
||||||
|
dataset_name="liuyanchen1015/Llama-3.2-1B_jailbreak_responses",
|
||||||
|
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||||
|
),
|
||||||
|
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": lambda: load_dataset_general(
|
||||||
|
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses",
|
||||||
|
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||||
|
),
|
||||||
|
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": lambda: load_dataset_general(
|
||||||
|
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment",
|
||||||
|
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||||
|
),
|
||||||
|
"jackhhao/jailbreak-classification": lambda: load_dataset_general(
|
||||||
|
dataset_name="jackhhao/jailbreak-classification",
|
||||||
|
column_mappings={"prompt": "prompt"},
|
||||||
|
),
|
||||||
|
"markush1/LLM-Jailbreak-Classifier": lambda: load_dataset_general(
|
||||||
|
dataset_name="markush1/LLM-Jailbreak-Classifier",
|
||||||
|
column_mappings={"prompt": "prompt"},
|
||||||
|
),
|
||||||
|
"walledai/JailbreakBench": lambda: load_dataset_general(
|
||||||
|
dataset_name="walledai/JailbreakBench", column_mappings={"prompt": "prompt"}
|
||||||
|
),
|
||||||
|
"walledai/JailbreakHub": lambda: load_dataset_general(
|
||||||
|
dataset_name="walledai/JailbreakHub", column_mappings={"prompt": "prompt"}
|
||||||
|
),
|
||||||
|
"Granther/evil-jailbreak": lambda: load_dataset_general(
|
||||||
|
dataset_name="Granther/evil-jailbreak", column_mappings={"prompt": "text"}
|
||||||
|
),
|
||||||
|
"sevdeawesome/jailbreak_success": lambda: load_dataset_general(
|
||||||
|
dataset_name="sevdeawesome/jailbreak_success",
|
||||||
|
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||||
|
),
|
||||||
|
"IDA-SERICS/Disaster-tweet-jailbreaking": lambda: load_dataset_general(
|
||||||
|
dataset_name="IDA-SERICS/Disaster-tweet-jailbreaking",
|
||||||
|
column_mappings={"prompt": "prompt_attack"},
|
||||||
|
),
|
||||||
|
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": lambda: load_dataset_general(
|
||||||
|
dataset_name="GeorgeDaDude/Jailbreak_Complete_DS_labeled",
|
||||||
|
column_mappings={"prompt": "question"},
|
||||||
|
),
|
||||||
|
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": lambda: load_dataset_general(
|
||||||
|
dataset_name="dayone3nder/jailbreak_prompt_JBB_sft_trainset",
|
||||||
|
column_mappings={"prompt": "prompt"},
|
||||||
|
),
|
||||||
|
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": lambda: load_dataset_general(
|
||||||
|
dataset_name="dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset",
|
||||||
|
column_mappings={"prompt": "prompt"},
|
||||||
),
|
),
|
||||||
"JailbreakV-28K/JailBreakV-28k": load_jailbreak_v28k,
|
|
||||||
"Local CSV": load_local_csv,
|
|
||||||
"Custom CSV": load_local_csv,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dynamic dataset loaders
|
|
||||||
dynamic_loaders = {
|
|
||||||
"AgenticBackend": lambda opts: dataset_from_iterator(
|
|
||||||
"AgenticBackend",
|
|
||||||
fine_tuned.Module(
|
|
||||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
|
||||||
).apply(),
|
|
||||||
lazy=True,
|
|
||||||
),
|
|
||||||
"Steganography": lambda opts: list(
|
|
||||||
StenographyTransformer(opts["datasets"]).transform()
|
|
||||||
),
|
|
||||||
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
|
|
||||||
"llm-adaptive-attacks",
|
|
||||||
adaptive_attacks.Module(
|
|
||||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
|
||||||
).apply(),
|
|
||||||
),
|
|
||||||
"Garak": lambda opts: dataset_from_iterator(
|
|
||||||
"Garak",
|
|
||||||
garak_tool.Module(
|
|
||||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
|
||||||
).apply(),
|
|
||||||
lazy=True,
|
|
||||||
),
|
|
||||||
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
|
||||||
"Reinforcement Learning Optimization",
|
|
||||||
rl_model.Module(
|
|
||||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
|
||||||
).apply(),
|
|
||||||
lazy=True,
|
|
||||||
),
|
|
||||||
"InspectAI": lambda opts: dataset_from_iterator(
|
|
||||||
"InspectAI",
|
|
||||||
inspect_ai_tool.Module([], tools_inbox=tools_inbox).apply(),
|
|
||||||
lazy=True,
|
|
||||||
),
|
|
||||||
"GPT fuzzer": lambda opts: [],
|
|
||||||
}
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
options = options or [dict(datasets=datasets) for _ in dataset_names]
|
|
||||||
|
|
||||||
# Load base datasets
|
|
||||||
for name, opts in zip(dataset_names, options):
|
|
||||||
if name not in dataset_loaders:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
datasets.append(dataset_loaders[name]())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading {name}: {e}")
|
|
||||||
|
|
||||||
# Load dynamic datasets and apply transformations
|
|
||||||
for name, opts in zip(dataset_names, options):
|
|
||||||
if name not in dynamic_loaders:
|
|
||||||
continue
|
|
||||||
logger.info(f"Loading dynamic dataset {name} {opts}")
|
|
||||||
opts["datasets"] = datasets
|
|
||||||
try:
|
|
||||||
dynamic_result = dynamic_loaders[name](opts)
|
|
||||||
datasets.extend(dynamic_result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error loading dynamic {name}: {e}")
|
|
||||||
|
|
||||||
# Load csv datasets and apply transformations
|
|
||||||
for name, opts in zip(dataset_names, options):
|
|
||||||
if not name.endswith(".csv"):
|
|
||||||
continue
|
|
||||||
logger.info(f"Loading csv dataset {name} {opts}")
|
|
||||||
datasets.append(load_csv(name))
|
|
||||||
|
|
||||||
return datasets
|
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ import io
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from loguru import logger
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
|
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ import os
|
|||||||
import uuid as U
|
import uuid as U
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
||||||
|
|
||||||
@@ -66,7 +65,7 @@ class Module:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def fetch_prompts(self) -> list[str]:
|
async def fetch_prompts(self) -> list[str]:
|
||||||
api_url = "https://mcp.metaheuristic.co/infer"
|
api_url = "https://edge.metaheuristic.co/infer"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {AUTH_TOKEN}",
|
"Authorization": f"Bearer {AUTH_TOKEN}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
from loguru import logger
|
||||||
|
|
||||||
# TODO: add probes modules
|
# TODO: add probes modules
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
from loguru import logger
|
||||||
|
|
||||||
inspect_ai_task = (
|
inspect_ai_task = (
|
||||||
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
|
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ from typing import Deque
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
from loguru import logger
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
||||||
|
|
||||||
@@ -216,7 +215,7 @@ class Module:
|
|||||||
self.run_id = U.uuid4().hex
|
self.run_id = U.uuid4().hex
|
||||||
self.batch_size = self.opts.get("batch_size", 500)
|
self.batch_size = self.opts.get("batch_size", 500)
|
||||||
self.rl_model = CloudRLPromptSelector(
|
self.rl_model = CloudRLPromptSelector(
|
||||||
prompt_groups, "https://mcp.metaheuristic.co", run_id=self.run_id
|
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def apply(self):
|
async def apply(self):
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def mock_requests() -> Mock:
|
|||||||
def mock_rl_selector() -> Mock:
|
def mock_rl_selector() -> Mock:
|
||||||
return CloudRLPromptSelector(
|
return CloudRLPromptSelector(
|
||||||
dataset_prompts,
|
dataset_prompts,
|
||||||
api_url="https://mcp.metaheuristic.co",
|
api_url="https://edge.metaheuristic.co",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ class TestCloudRLPromptSelector:
|
|||||||
def test_select_next_prompt_success_service(self, dataset_prompts):
|
def test_select_next_prompt_success_service(self, dataset_prompts):
|
||||||
selector = CloudRLPromptSelector(
|
selector = CloudRLPromptSelector(
|
||||||
dataset_prompts,
|
dataset_prompts,
|
||||||
api_url="https://mcp.metaheuristic.co",
|
api_url="https://edge.metaheuristic.co",
|
||||||
)
|
)
|
||||||
next_prompt = selector.select_next_prompt(
|
next_prompt = selector.select_next_prompt(
|
||||||
"How does RL work?", passed_guard=True
|
"How does RL work?", passed_guard=True
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from cache_to_disk import cache_to_disk # noqa
|
from cache_to_disk import cache_to_disk
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor this class to use from .data
|
# TODO: refactor this class to use from .data
|
||||||
@@ -22,7 +22,7 @@ class ProbeDataset:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# @cache_to_disk(n_days_to_cache=1)
|
@cache_to_disk()
|
||||||
def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):
|
def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
from .model import RefusalClassifier
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -7,61 +7,9 @@ import pandas as pd
|
|||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
from .primitives import Table
|
def plot_security_report(table):
|
||||||
|
|
||||||
|
|
||||||
def plot_security_report(table: Table) -> io.BytesIO:
|
|
||||||
"""
|
|
||||||
Generates a polar plot representing the security report based on the given data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table (Table): The input data table containing security metrics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
io.BytesIO: A buffer containing the generated plot image in PNG format.
|
|
||||||
Returns an empty buffer in case of an error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return _plot_security_report(table=table)
|
|
||||||
except (TypeError, ValueError, OverflowError, IndexError, Exception) as e:
|
|
||||||
logger.error(f"Error in generating the security report: {e} {table}")
|
|
||||||
return io.BytesIO()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_identifiers(data: pd.DataFrame) -> list[str]:
|
|
||||||
"""
|
|
||||||
Generates unique identifiers for the given dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data (pd.DataFrame): A pandas DataFrame containing security-related data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: A list of generated identifiers. Returns a list with an empty string in case of an error.
|
|
||||||
"""
|
|
||||||
return _generate_identifiers(data=data)
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_security_report(table: Table) -> io.BytesIO:
|
|
||||||
"""
|
|
||||||
Generates a polar plot-based security report visualizing the failure rates for different modules.
|
|
||||||
|
|
||||||
This function processes the input data, sorts it by failure rate, and generates a polar plot
|
|
||||||
where each bar represents the failure rate for a specific module. The plot includes identifiers,
|
|
||||||
color-coding based on token count, failure rate values on the bars, and a table listing the modules
|
|
||||||
and their corresponding failure rates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table (Table): A table-like structure (e.g., pandas DataFrame) containing security report data
|
|
||||||
with columns for failure rate, tokens, and modules.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
io.BytesIO: A buffer containing the generated plot image in PNG format.
|
|
||||||
"""
|
|
||||||
# Data preprocessing
|
# Data preprocessing
|
||||||
logger.info("Data preprocessing started.")
|
|
||||||
|
|
||||||
data = pd.DataFrame(table)
|
data = pd.DataFrame(table)
|
||||||
|
|
||||||
# Sort by failure rate and reset index
|
# Sort by failure rate and reset index
|
||||||
@@ -72,10 +20,10 @@ def _plot_security_report(table: Table) -> io.BytesIO:
|
|||||||
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"})
|
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"})
|
||||||
fig.set_facecolor("#f0f0f0")
|
fig.set_facecolor("#f0f0f0")
|
||||||
ax.set_facecolor("#f0f0f0")
|
ax.set_facecolor("#f0f0f0")
|
||||||
logger.info("Plot setup complete.")
|
|
||||||
|
|
||||||
# Styling parameters
|
# Styling parameters
|
||||||
colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette
|
colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette
|
||||||
|
# colors = ["#440154", "#3b528b", "#21908c", "#5dc863"] # Viridis-inspired palette
|
||||||
cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
|
cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
|
||||||
norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max())
|
norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max())
|
||||||
|
|
||||||
@@ -126,10 +74,7 @@ def _plot_security_report(table: Table) -> io.BytesIO:
|
|||||||
|
|
||||||
# Title and caption
|
# Title and caption
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
"Security Report for Different Modules",
|
"Security Report for Different Modules", fontsize=16, fontweight="bold", y=1.02
|
||||||
fontsize=16,
|
|
||||||
fontweight="bold",
|
|
||||||
y=1.02,
|
|
||||||
)
|
)
|
||||||
caption = "Report generated by https://github.com/msoedov/agentic_security"
|
caption = "Report generated by https://github.com/msoedov/agentic_security"
|
||||||
fig.text(
|
fig.text(
|
||||||
@@ -167,12 +112,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
|
|||||||
data["identifier"], data["failureRate"], data["module"]
|
data["identifier"], data["failureRate"], data["module"]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
table = ax.table(cellText=table_data, loc="right", cellLoc="left")
|
table = ax.table(
|
||||||
|
cellText=table_data,
|
||||||
|
loc="right",
|
||||||
|
cellLoc="left",
|
||||||
|
)
|
||||||
table.auto_set_font_size(False)
|
table.auto_set_font_size(False)
|
||||||
table.set_fontsize(8)
|
table.set_fontsize(8)
|
||||||
|
|
||||||
# Adjust table style
|
# Adjust table style
|
||||||
table.scale(1, 0.7)
|
table.scale(1, 0.7)
|
||||||
|
|
||||||
for (row, col), cell in table.get_celld().items():
|
for (row, col), cell in table.get_celld().items():
|
||||||
cell.set_edgecolor("none")
|
cell.set_edgecolor("none")
|
||||||
cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0")
|
cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0")
|
||||||
@@ -182,33 +132,17 @@ def _plot_security_report(table: Table) -> io.BytesIO:
|
|||||||
cell.set_text_props(fontweight="bold")
|
cell.set_text_props(fontweight="bold")
|
||||||
|
|
||||||
# Adjust layout and save
|
# Adjust layout and save
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
plt.savefig(buf, format="png", dpi=300, bbox_inches="tight")
|
plt.savefig(buf, format="png", dpi=300, bbox_inches="tight")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
logger.info("Report successfully generated and saved to buffer.")
|
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
def _generate_identifiers(data: pd.DataFrame) -> list[str]:
|
def generate_identifiers(data):
|
||||||
"""
|
|
||||||
Generates a list of unique identifiers for each row in the given DataFrame.
|
|
||||||
|
|
||||||
The identifiers are based on the English alphabet, with each identifier consisting
|
|
||||||
of a letter followed by a number. The letter represents the "group" of identifiers
|
|
||||||
(using a letter from A to Z) and the number is a counter within that group. For example:
|
|
||||||
- A1, A2, ..., A26, B1, B2, ..., Z1, Z2, ...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data (pd.DataFrame): The input DataFrame containing data for which identifiers
|
|
||||||
are to be generated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: A list of unique identifiers as strings, one for each row in the DataFrame.
|
|
||||||
"""
|
|
||||||
data_length = len(data)
|
data_length = len(data)
|
||||||
|
|
||||||
alphabet = string.ascii_uppercase
|
alphabet = string.ascii_uppercase
|
||||||
num_letters = len(alphabet)
|
num_letters = len(alphabet)
|
||||||
|
|
||||||
|
|||||||
@@ -1,397 +0,0 @@
|
|||||||
_SPECS = [
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe
|
|
||||||
Authorization: Bearer XXXXX
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"prompt": "<<PROMPT>>"
|
|
||||||
}
|
|
||||||
|
|
||||||
""",
|
|
||||||
"""POST https://api.openai.com/v1/chat/completions
|
|
||||||
Authorization: Bearer $OPENAI_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
POST https://api.deepseek.com/chat/completions
|
|
||||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "deepseek-chat",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "<<PROMPT>>"}
|
|
||||||
],
|
|
||||||
"stream": false
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
|
||||||
Authorization: Bearer $APIKEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"top_k": 50,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
|
||||||
"temperature": 0.6,
|
|
||||||
"max_new_tokens": 1024,
|
|
||||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
|
||||||
Authorization: Bearer $APIKEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model_id": "codellama-34b",
|
|
||||||
"system_prompt": "You are helpful and concise coding assistant",
|
|
||||||
"user_prompt": "<<PROMPT>>"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.together.xyz/v1/chat/completions
|
|
||||||
Authorization: Bearer $TOGETHER_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are an expert travel guide"},
|
|
||||||
{"role": "user", "content": "<<PROMPT>>"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe-image
|
|
||||||
Authorization: Bearer XXXXX
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What is in this image?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
""",
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe-file
|
|
||||||
Authorization: Bearer $API_KEY
|
|
||||||
Content-Type: multipart/form-data
|
|
||||||
|
|
||||||
{
|
|
||||||
"file": "@./sample_audio.m4a",
|
|
||||||
"model": "whisper-large-v3"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.gemini.com/v1/generate
|
|
||||||
Authorization: Bearer $GEMINI_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "gemini-latest",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.8,
|
|
||||||
"max_tokens": 150,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.anthropic.com/v1/complete
|
|
||||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "claude-v1.3",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens_to_sample": 256,
|
|
||||||
"stop_sequences": ["\n\nHuman:"]
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.cohere.ai/generate
|
|
||||||
Authorization: Bearer $COHERE_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "command-xlarge-nightly",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"max_tokens": 300,
|
|
||||||
"temperature": 0.75,
|
|
||||||
"k": 0,
|
|
||||||
"p": 0.75
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
|
||||||
Authorization: Bearer $AZURE_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"max_tokens": 150,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.assemblyai.com/v2/transcript
|
|
||||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
|
||||||
Authorization: Bearer $OPENROUTER_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "openrouter-latest",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 150,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
LLM_SPECS = [
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe
|
|
||||||
Authorization: Bearer XXXXX
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"prompt": "<<PROMPT>>"
|
|
||||||
}
|
|
||||||
|
|
||||||
""",
|
|
||||||
"""POST https://api.openai.com/v1/chat/completions
|
|
||||||
Authorization: Bearer $OPENAI_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
POST https://api.deepseek.com/chat/completions
|
|
||||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "deepseek-chat",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "<<PROMPT>>"}
|
|
||||||
],
|
|
||||||
"stream": false
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
|
||||||
Authorization: Bearer $APIKEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"top_k": 50,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
|
||||||
"temperature": 0.6,
|
|
||||||
"max_new_tokens": 1024,
|
|
||||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
|
||||||
Authorization: Bearer $APIKEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model_id": "codellama-34b",
|
|
||||||
"system_prompt": "You are helpful and concise coding assistant",
|
|
||||||
"user_prompt": "<<PROMPT>>"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.together.xyz/v1/chat/completions
|
|
||||||
Authorization: Bearer $TOGETHER_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are an expert travel guide"},
|
|
||||||
{"role": "user", "content": "<<PROMPT>>"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe-image
|
|
||||||
Authorization: Bearer XXXXX
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What is in this image?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
""",
|
|
||||||
"""POST ${SELF_URL}/v1/self-probe-file
|
|
||||||
Authorization: Bearer $API_KEY
|
|
||||||
Content-Type: multipart/form-data
|
|
||||||
|
|
||||||
{
|
|
||||||
"file": "@./sample_audio.m4a",
|
|
||||||
"model": "whisper-large-v3"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.gemini.com/v1/generate
|
|
||||||
Authorization: Bearer $GEMINI_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "gemini-latest",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.8,
|
|
||||||
"max_tokens": 150,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.anthropic.com/v1/complete
|
|
||||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "claude-v1.3",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens_to_sample": 256,
|
|
||||||
"stop_sequences": ["\n\nHuman:"]
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.cohere.ai/generate
|
|
||||||
Authorization: Bearer $COHERE_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "command-xlarge-nightly",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"max_tokens": 300,
|
|
||||||
"temperature": 0.75,
|
|
||||||
"k": 0,
|
|
||||||
"p": 0.75
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
|
||||||
Authorization: Bearer $AZURE_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"max_tokens": 150,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.assemblyai.com/v2/transcript
|
|
||||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
|
||||||
Authorization: Bearer $OPENROUTER_API_KEY
|
|
||||||
Content-Type: application/json
|
|
||||||
|
|
||||||
{
|
|
||||||
"model": "openrouter-latest",
|
|
||||||
"prompt": "<<PROMPT>>",
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 150,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
LLM_CONFIGS = [
|
|
||||||
{
|
|
||||||
"name": "Custom API",
|
|
||||||
"prompts": 40000,
|
|
||||||
"customInstructions": "Requires api spec",
|
|
||||||
"logo": "/icons/myshell.png",
|
|
||||||
},
|
|
||||||
{"name": "Open AI", "prompts": 24000, "logo": "/icons/openai.png"},
|
|
||||||
{"name": "Deepseek v1", "prompts": 24000, "logo": "/icons/deepseek.png"},
|
|
||||||
{"name": "Replicate", "prompts": 40000, "logo": "/icons/replicate.png"},
|
|
||||||
{"name": "Groq", "prompts": 40000, "logo": "/icons/groq.png"},
|
|
||||||
{"name": "Together.ai", "prompts": 40000, "logo": "/icons/together.png"},
|
|
||||||
{
|
|
||||||
"name": "Custom API Image",
|
|
||||||
"prompts": 40000,
|
|
||||||
"customInstructions": "Requires api spec",
|
|
||||||
"modality": "Image",
|
|
||||||
"logo": "/icons/myshell.png",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Custom API Files",
|
|
||||||
"prompts": 40000,
|
|
||||||
"customInstructions": "Requires api spec",
|
|
||||||
"modality": "Files",
|
|
||||||
"logo": "/icons/myshell.png",
|
|
||||||
},
|
|
||||||
{"name": "Gemini", "prompts": 40000, "logo": "/icons/gemini.png"},
|
|
||||||
{"name": "Claude", "prompts": 40000, "logo": "/icons/claude.png"},
|
|
||||||
{"name": "Cohere", "prompts": 40000, "logo": "/icons/cohere.png"},
|
|
||||||
{"name": "Azure OpenAI", "prompts": 40000, "logo": "/icons/azureai.png"},
|
|
||||||
{"name": "assemblyai", "prompts": 40000, "logo": "/icons/myshell.png"},
|
|
||||||
{"name": "OpenRouter.ai", "prompts": 40000, "logo": "/icons/openrouter.png"},
|
|
||||||
]
|
|
||||||
|
|
||||||
LLM_SPECS = [dict(spec=spec, **d) for spec, d in zip(_SPECS, LLM_CONFIGS)]
|
|
||||||
@@ -6,7 +6,6 @@ from fastapi.responses import JSONResponse
|
|||||||
from ..primitives import FileProbeResponse, Probe
|
from ..primitives import FileProbeResponse, Probe
|
||||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||||
from ..probe_data import REGISTRY
|
from ..probe_data import REGISTRY
|
||||||
from ._specs import LLM_SPECS
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -74,12 +73,6 @@ async def data_config():
|
|||||||
return [m for m in REGISTRY]
|
return [m for m in REGISTRY]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/llm-specs", response_model=list)
|
|
||||||
def get_llm_specs():
|
|
||||||
"""Returns the LLM API specifications."""
|
|
||||||
return LLM_SPECS
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ import random
|
|||||||
from asyncio import Event
|
from asyncio import Event
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from loguru import logger
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
from ..core.app import get_current_run, get_tools_inbox
|
from ..core.app import get_current_run, get_tools_inbox
|
||||||
from ..primitives import CompletionRequest, Settings
|
from ..primitives import CompletionRequest, Settings
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
@@ -13,11 +11,9 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
|
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
|
||||||
from ..dependencies import InMemorySecrets, get_in_memory_secrets
|
from ..dependencies import InMemorySecrets, get_in_memory_secrets
|
||||||
from ..http_spec import InvalidHTTPSpecError, LLMSpec
|
from ..http_spec import LLMSpec
|
||||||
from ..primitives import LLMInfo, Scan
|
from ..primitives import LLMInfo, Scan
|
||||||
from ..probe_actor import fuzzer
|
from ..probe_actor import fuzzer
|
||||||
|
|
||||||
@@ -27,16 +23,9 @@ router = APIRouter()
|
|||||||
@router.post("/verify")
|
@router.post("/verify")
|
||||||
async def verify(
|
async def verify(
|
||||||
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
|
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
|
||||||
) -> dict[str, int | str | float]:
|
):
|
||||||
spec = LLMSpec.from_string(info.spec)
|
spec = LLMSpec.from_string(info.spec)
|
||||||
try:
|
|
||||||
r = await spec.verify()
|
r = await spec.verify()
|
||||||
except InvalidHTTPSpecError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
if r.status_code >= 400:
|
if r.status_code >= 400:
|
||||||
raise HTTPException(status_code=r.status_code, detail=r.text)
|
raise HTTPException(status_code=r.status_code, detail=r.text)
|
||||||
return dict(
|
return dict(
|
||||||
@@ -47,7 +36,7 @@ async def verify(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]:
|
def streaming_response_generator(scan_parameters: Scan):
|
||||||
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
|
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
|
||||||
set_current_run(request_factory)
|
set_current_run(request_factory)
|
||||||
|
|
||||||
@@ -68,7 +57,7 @@ async def scan(
|
|||||||
scan_parameters: Scan,
|
scan_parameters: Scan,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
||||||
) -> StreamingResponse:
|
):
|
||||||
scan_parameters.with_secrets(secrets)
|
scan_parameters.with_secrets(secrets)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
streaming_response_generator(scan_parameters), media_type="application/json"
|
streaming_response_generator(scan_parameters), media_type="application/json"
|
||||||
@@ -76,7 +65,7 @@ async def scan(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/stop")
|
@router.post("/stop")
|
||||||
async def stop_scan() -> dict[str, str]:
|
async def stop_scan():
|
||||||
get_stop_event().set()
|
get_stop_event().set()
|
||||||
return {"status": "Scan stopped"}
|
return {"status": "Scan stopped"}
|
||||||
|
|
||||||
@@ -90,7 +79,7 @@ async def scan_csv(
|
|||||||
maxBudget: int = Query(10_000),
|
maxBudget: int = Query(10_000),
|
||||||
enableMultiStepAttack: bool = Query(False),
|
enableMultiStepAttack: bool = Query(False),
|
||||||
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
||||||
) -> StreamingResponse:
|
):
|
||||||
# TODO: content dataset to fuzzer
|
# TODO: content dataset to fuzzer
|
||||||
content = await file.read() # noqa
|
content = await file.read() # noqa
|
||||||
llm_spec = await llmSpec.read()
|
llm_spec = await llmSpec.read()
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
from loguru import logger
|
||||||
from sentry_sdk.integrations.logging import ignore_logger
|
from sentry_sdk.integrations.logging import ignore_logger
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
from ..primitives import Settings
|
from ..primitives import Settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from agentic_security.app import app
|
from ..app import app
|
||||||
|
|
||||||
|
|
||||||
def test_health_check():
|
def test_health_check():
|
||||||
@@ -5,10 +5,10 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from agentic_security.app import app
|
from ..app import app
|
||||||
from agentic_security.primitives import Probe
|
from ..primitives import Probe
|
||||||
from agentic_security.probe_actor.refusal import REFUSAL_MARKS
|
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||||
from agentic_security.probe_data import REGISTRY
|
from ..probe_data import REGISTRY
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from agentic_security.routes.report import router
|
from .report import router
|
||||||
|
|
||||||
client = TestClient(router)
|
client = TestClient(router)
|
||||||
|
|
||||||
@@ -4,8 +4,8 @@ import pytest
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from agentic_security.primitives import Settings
|
from ..primitives import Settings
|
||||||
from agentic_security.routes.static import get_static_file, router
|
from .static import get_static_file, router
|
||||||
|
|
||||||
client = TestClient(router)
|
client = TestClient(router)
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ Content-Type: application/json
|
|||||||
]
|
]
|
||||||
`,
|
`,
|
||||||
`POST ${SELF_URL}/v1/self-probe-file
|
`POST ${SELF_URL}/v1/self-probe-file
|
||||||
Authorization: Bearer $API_KEY
|
Authorization: Bearer $GROQ_API_KEY
|
||||||
Content-Type: multipart/form-data
|
Content-Type: multipart/form-data
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -204,9 +204,8 @@ let LLM_CONFIGS = [
|
|||||||
{ name: 'Claude', prompts: 40000, logo: '/icons/claude.png' },
|
{ name: 'Claude', prompts: 40000, logo: '/icons/claude.png' },
|
||||||
{ name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' },
|
{ name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' },
|
||||||
{ name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' },
|
{ name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' },
|
||||||
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
|
|
||||||
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
|
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
|
||||||
|
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
|
||||||
];
|
];
|
||||||
function has_image(spec) {
|
function has_image(spec) {
|
||||||
return spec.includes('<<BASE64_IMAGE>>');
|
return spec.includes('<<BASE64_IMAGE>>');
|
||||||
|
|||||||
@@ -63,32 +63,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
</main>
|
</main>
|
||||||
<div class="fixed top-6 right-6 z-50 space-y-3">
|
|
||||||
<transition-group name="toast">
|
|
||||||
<div
|
|
||||||
v-for="toast in toasts"
|
|
||||||
:key="toast.id"
|
|
||||||
class="flex items-center p-3 rounded-xl shadow-xl text-white max-w-md animate-toast-in border border-opacity-30"
|
|
||||||
:class="{
|
|
||||||
'bg-success-toast border-accent-green': toast.type === 'success',
|
|
||||||
'bg-error-toast border-accent-red': toast.type === 'error',
|
|
||||||
'bg-info-toast border-accent-orange': toast.type === 'info'
|
|
||||||
}"
|
|
||||||
>
|
|
||||||
<span class="flex-1 font-medium tracking-wide text-sm">{{ toast.message }}</span>
|
|
||||||
<button
|
|
||||||
@click="removeToast(toast.id)"
|
|
||||||
class="ml-3 focus:outline-none hover:opacity-80 transition-opacity"
|
|
||||||
>
|
|
||||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
|
|
||||||
</svg>
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</transition-group>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
<main class="max-w-6xl mx-auto space-y-8">
|
<main class="max-w-6xl mx-auto space-y-8">
|
||||||
<section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false">
|
<section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false">
|
||||||
<h2 class="text-2xl font-bold mb-4">Select a Config</h2>
|
<h2 class="text-2xl font-bold mb-4">Select a Config</h2>
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ var app = new Vue({
|
|||||||
showModules: false,
|
showModules: false,
|
||||||
showLogs: false,
|
showLogs: false,
|
||||||
showConsentModal: true,
|
showConsentModal: true,
|
||||||
toasts: [], // Array to store toast notifications
|
|
||||||
toastTimeout: 3000, // Duration in milliseconds (3 seconds)
|
|
||||||
statusDotClass: 'bg-gray-500', // Default status dot class
|
statusDotClass: 'bg-gray-500', // Default status dot class
|
||||||
statusText: 'Verified', // Default status text
|
statusText: 'Verified', // Default status text
|
||||||
statusClass: 'bg-green-500 text-dark-bg', // Default status class
|
statusClass: 'bg-green-500 text-dark-bg', // Default status class
|
||||||
@@ -95,36 +93,21 @@ var app = new Vue({
|
|||||||
|
|
||||||
},
|
},
|
||||||
methods: {
|
methods: {
|
||||||
showToast(message, type = 'success') {
|
|
||||||
const id = Date.now(); // Unique ID for each toast
|
|
||||||
this.toasts.push({ id, message, type });
|
|
||||||
|
|
||||||
// Automatically remove toast after timeout
|
|
||||||
setTimeout(() => {
|
|
||||||
this.removeToast(id);
|
|
||||||
}, this.toastTimeout);
|
|
||||||
},
|
|
||||||
|
|
||||||
removeToast(id) {
|
|
||||||
this.toasts = this.toasts.filter(toast => toast.id !== id);
|
|
||||||
},
|
|
||||||
focusTextarea() {
|
focusTextarea() {
|
||||||
this.isFocused = true;
|
this.isFocused = true;
|
||||||
// Remove 'self' assignment if not used elsewhere
|
self = this.$refs;
|
||||||
this.$nextTick(() => {
|
this.$nextTick(() => {
|
||||||
this.$refs.textarea.focus();
|
// Focus the textarea after rendering
|
||||||
this.adjustHeight({ target: this.$refs.textarea });
|
self.textarea.focus();
|
||||||
|
this.adjustHeight({ target: self.textarea });
|
||||||
});
|
});
|
||||||
// Correct the event listener to use handleOutsideClick
|
document.addEventListener("mousedown", this.handleClickOutside);
|
||||||
document.addEventListener("mousedown", this.handleOutsideClick);
|
|
||||||
},
|
},
|
||||||
handleOutsideClick(event) {
|
handleOutsideClick(event) {
|
||||||
if (!this.$refs.textarea) {
|
if (!this.$refs.container.contains(event.target)) {
|
||||||
return
|
|
||||||
}
|
|
||||||
if (!this.$refs.textarea.contains(event.target)) {
|
|
||||||
this.isFocused = false;
|
this.isFocused = false;
|
||||||
document.removeEventListener("mousedown", this.handleOutsideClick);
|
document.removeEventListener("mousedown", this.handleClickOutside);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
unfocusTextarea() {
|
unfocusTextarea() {
|
||||||
@@ -132,12 +115,7 @@ var app = new Vue({
|
|||||||
},
|
},
|
||||||
acceptConsent() {
|
acceptConsent() {
|
||||||
this.showConsentModal = false; // Close the modal
|
this.showConsentModal = false; // Close the modal
|
||||||
|
|
||||||
try {
|
|
||||||
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
||||||
} catch (e) {
|
|
||||||
this.showToast('Failed to save consent', 'error'); // Show error if saving fails
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
saveStateToLocalStorage() {
|
saveStateToLocalStorage() {
|
||||||
@@ -178,7 +156,6 @@ var app = new Vue({
|
|||||||
this.integrationVerified = false;
|
this.integrationVerified = false;
|
||||||
this.showResetConfirmation = false;
|
this.showResetConfirmation = false;
|
||||||
this.enableMultiStepAttack = false;
|
this.enableMultiStepAttack = false;
|
||||||
this.showToast('All settings have been reset to default', 'info');
|
|
||||||
},
|
},
|
||||||
confirmResetState() {
|
confirmResetState() {
|
||||||
this.showResetConfirmation = true;
|
this.showResetConfirmation = true;
|
||||||
@@ -217,8 +194,6 @@ var app = new Vue({
|
|||||||
spec: this.modelSpec,
|
spec: this.modelSpec,
|
||||||
};
|
};
|
||||||
let startTime = performance.now(); // Capture start time
|
let startTime = performance.now(); // Capture start time
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${SELF_URL}/verify`, {
|
const response = await fetch(`${SELF_URL}/verify`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@@ -226,30 +201,24 @@ var app = new Vue({
|
|||||||
},
|
},
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
});
|
});
|
||||||
|
console.log(response);
|
||||||
let r = await response.json();
|
let r = await response.json();
|
||||||
|
|
||||||
let endTime = performance.now(); // Capture end time
|
let endTime = performance.now(); // Capture end time
|
||||||
let latency = ((endTime - startTime) / 1000).toFixed(3); // Calculate latency in milliseconds
|
let latency = endTime - startTime; // Calculate latency in milliseconds
|
||||||
|
latency = latency.toFixed(3) / 1000; // Round to 2 decimal places
|
||||||
this.latency = latency;
|
this.latency = latency;
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
this.updateStatusDot(false);
|
this.updateStatusDot(false);
|
||||||
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
||||||
this.showToast('Integration verification failed', 'error');
|
|
||||||
} else {
|
} else {
|
||||||
this.errorMsg = '';
|
this.errorMsg = '';
|
||||||
this.updateStatusDot(true);
|
this.updateStatusDot(true);
|
||||||
this.okMsg = 'Integration verified';
|
this.okMsg = 'Integration verified';
|
||||||
this.showToast('Integration verified successfully', 'success');
|
|
||||||
this.integrationVerified = true;
|
this.integrationVerified = true;
|
||||||
}
|
// console.log('Integration verified', this.integrationVerified);
|
||||||
} catch (error) {
|
// this.$forceUpdate();
|
||||||
this.updateStatusDot(true);
|
|
||||||
this.errorMsg = 'Server unreachable';
|
|
||||||
this.showToast('Network error', 'error');
|
|
||||||
}
|
|
||||||
|
|
||||||
|
}
|
||||||
this.saveStateToLocalStorage();
|
this.saveStateToLocalStorage();
|
||||||
},
|
},
|
||||||
loadConfigs: async function () {
|
loadConfigs: async function () {
|
||||||
@@ -271,7 +240,6 @@ var app = new Vue({
|
|||||||
this.errorMsg = '';
|
this.errorMsg = '';
|
||||||
this.okMsg = '';
|
this.okMsg = '';
|
||||||
this.integrationVerified = false;
|
this.integrationVerified = false;
|
||||||
this.showToast(`Config ${index + 1} selected`, 'info');
|
|
||||||
},
|
},
|
||||||
toggleModules() {
|
toggleModules() {
|
||||||
this.showModules = !this.showModules;
|
this.showModules = !this.showModules;
|
||||||
@@ -359,7 +327,6 @@ var app = new Vue({
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
console.log('New row');
|
console.log('New row');
|
||||||
this.showToast('New module', 'success');
|
|
||||||
let payload = {
|
let payload = {
|
||||||
table: this.mainTable,
|
table: this.mainTable,
|
||||||
};
|
};
|
||||||
@@ -470,8 +437,6 @@ var app = new Vue({
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
this.scanRunning = false;
|
|
||||||
this.showToast('Scan finished successfully', 'success');
|
|
||||||
this.saveStateToLocalStorage();
|
this.saveStateToLocalStorage();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
technopollas: ['Technopollas', 'sans-serif'],
|
technopollas: ['Technopollas', 'sans-serif'],
|
||||||
},
|
},
|
||||||
colors: {
|
colors: {
|
||||||
dark: {
|
t1: {
|
||||||
bg: '#0D0D0D', // Jet Black
|
bg: '#0D0D0D', // Jet Black
|
||||||
card: '#1A1A1A', // Dark Carbon Fiber
|
card: '#1A1A1A', // Dark Carbon Fiber
|
||||||
text: '#FFFFFF',
|
text: '#FFFFFF',
|
||||||
@@ -30,7 +30,7 @@
|
|||||||
yellow: '#2E4053', // Jet Black
|
yellow: '#2E4053', // Jet Black
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
dark1: {
|
dark: {
|
||||||
bg: '#121212',
|
bg: '#121212',
|
||||||
card: '#1E1E1E',
|
card: '#1E1E1E',
|
||||||
text: '#FFFFFF',
|
text: '#FFFFFF',
|
||||||
@@ -86,7 +86,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
.scrollbar-hide::-webkit-scrollbar {
|
.scrollbar-hide::-webkit-scrollbar {
|
||||||
display: none;
|
display: none;
|
||||||
@@ -97,55 +96,4 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
</style>
|
</style>
|
||||||
<style>
|
|
||||||
/* Toast-specific colors */
|
|
||||||
.bg-success-toast {
|
|
||||||
background: #1C3F74
|
|
||||||
}
|
|
||||||
|
|
||||||
.bg-error-toast {
|
|
||||||
background: #85144B
|
|
||||||
}
|
|
||||||
|
|
||||||
.bg-info-toast {
|
|
||||||
background: #FFC300
|
|
||||||
}
|
|
||||||
|
|
||||||
.border-accent-green, .border-accent-red, .border-accent-orange {
|
|
||||||
border-color: rgba(255, 255, 255, 0.1); /* Subtle white border for depth */
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Animation */
|
|
||||||
.animate-toast-in {
|
|
||||||
animation: tSlideIn 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes tSlideIn {
|
|
||||||
from {
|
|
||||||
transform: translateX(120%) scale(0.95);
|
|
||||||
opacity: 0;
|
|
||||||
}
|
|
||||||
to {
|
|
||||||
transform: translateX(0) scale(1);
|
|
||||||
opacity: 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.toast-enter-active,
|
|
||||||
.toast-leave-active {
|
|
||||||
transition: all 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
.toast-enter-from,
|
|
||||||
.toast-leave-to {
|
|
||||||
transform: translateX(120%) scale(0.95);
|
|
||||||
opacity: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Hover effect */
|
|
||||||
[toast-type]:hover {
|
|
||||||
transform: translateY(-2px);
|
|
||||||
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.6);
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
</head>
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import time
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import agentic_security.test_spec_assets as test_spec_assets
|
import agentic_security.test_spec_assets as test_spec_assets
|
||||||
from agentic_security.lib import SecurityScanner
|
from agentic_security.lib import AgenticSecurity
|
||||||
|
|
||||||
|
|
||||||
def has_module(module_name):
|
def has_module(module_name):
|
||||||
@@ -57,7 +57,7 @@ class TestLibraryLevel:
|
|||||||
maxBudget = 1000000
|
maxBudget = 1000000
|
||||||
max_th = 0.3
|
max_th = 0.3
|
||||||
datasets = make_test_registry()
|
datasets = make_test_registry()
|
||||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result) in [0, 1]
|
assert len(result) in [0, 1]
|
||||||
@@ -68,7 +68,7 @@ class TestLibraryLevel:
|
|||||||
maxBudget = 1000
|
maxBudget = 1000
|
||||||
max_th = 0.3
|
max_th = 0.3
|
||||||
datasets = make_test_registry()
|
datasets = make_test_registry()
|
||||||
result = SecurityScanner.scan(
|
result = AgenticSecurity.scan(
|
||||||
llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True
|
llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True
|
||||||
)
|
)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
@@ -93,7 +93,7 @@ class TestLibraryLevel:
|
|||||||
"opts": {"port": 9094},
|
"opts": {"port": 9094},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result) in [0, 1]
|
assert len(result) in [0, 1]
|
||||||
@@ -120,7 +120,7 @@ class TestLibraryLevel:
|
|||||||
"modality": "text",
|
"modality": "text",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result) in [0, 1]
|
assert len(result) in [0, 1]
|
||||||
@@ -148,7 +148,7 @@ class TestLibraryLevel:
|
|||||||
"modality": "text",
|
"modality": "text",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result) in [0, 1]
|
assert len(result) in [0, 1]
|
||||||
@@ -164,10 +164,10 @@ class TestEntrypointCI:
|
|||||||
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
||||||
|
|
||||||
# Override default_path to the temporary path
|
# Override default_path to the temporary path
|
||||||
SecurityScanner.default_path = temp_path
|
AgenticSecurity.default_path = temp_path
|
||||||
|
|
||||||
# Generate the default configuration
|
# Generate the default configuration
|
||||||
security = SecurityScanner()
|
security = AgenticSecurity()
|
||||||
security.generate_default_settings()
|
security.generate_default_settings()
|
||||||
|
|
||||||
# Check that the config file was created at the temporary path
|
# Check that the config file was created at the temporary path
|
||||||
@@ -189,17 +189,17 @@ class TestEntrypointCI:
|
|||||||
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
||||||
|
|
||||||
# Override default_path to the temporary path
|
# Override default_path to the temporary path
|
||||||
SecurityScanner.default_path = temp_path
|
AgenticSecurity.default_path = temp_path
|
||||||
|
|
||||||
# Generate the default configuration
|
# Generate the default configuration
|
||||||
security = SecurityScanner()
|
security = AgenticSecurity()
|
||||||
security.generate_default_settings()
|
security.generate_default_settings()
|
||||||
|
|
||||||
# Load the generated configuration
|
# Load the generated configuration
|
||||||
SecurityScanner.load_config(temp_path)
|
AgenticSecurity.load_config(temp_path)
|
||||||
|
|
||||||
# Validate loaded configuration
|
# Validate loaded configuration
|
||||||
config = SecurityScanner.config
|
config = AgenticSecurity.config
|
||||||
assert (
|
assert (
|
||||||
config["general"]["maxBudget"] == 1000000
|
config["general"]["maxBudget"] == 1000000
|
||||||
), "maxBudget should be 1000000"
|
), "maxBudget should be 1000000"
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from:python-pytest-poetry
|
||||||
|
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
|
||||||
+5
-36
@@ -1,54 +1,23 @@
|
|||||||
# Getting Started
|
# Getting Started
|
||||||
|
|
||||||
Welcome to Agentic Security! This guide will help you orient yourself within the project and start using the tool quickly.
|
Welcome to Agentic Security! This guide will help you get started with using the tool.
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
Agentic Security is an open-source vulnerability scanner for Large Language Models (LLMs). It provides both a command line interface and a FastAPI server so you can probe models for weaknesses such as jailbreaks or refusal patterns. The tool supports multimodal attacks, multi-step scans and reinforcement-learning based probes.
|
|
||||||
|
|
||||||
## Repository Layout
|
|
||||||
|
|
||||||
```
|
|
||||||
agentic_security/
|
|
||||||
├── __main__.py - CLI entry point
|
|
||||||
├── app.py - FastAPI app assembly
|
|
||||||
├── lib.py - SecurityScanner and utilities
|
|
||||||
├── config.py - Configuration handling
|
|
||||||
├── core/ - app state and logging helpers
|
|
||||||
├── probe_actor/ - scanning logic and RL modules
|
|
||||||
├── probe_data/ - dataset registry and loaders
|
|
||||||
├── routes/ - API endpoints
|
|
||||||
└── ui/ - Web UI assets (Vue)
|
|
||||||
```
|
|
||||||
|
|
||||||
`tests/` contains unit tests, and `docs/` houses the project documentation.
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
1. Ensure you have completed the [installation](installation.md) steps.
|
1. Ensure you have completed the [installation](installation.md) steps.
|
||||||
2. Run the following command to start the application:
|
1. Run the following command to start the application:
|
||||||
```bash
|
```bash
|
||||||
agentic_security
|
agentic_security
|
||||||
```
|
```
|
||||||
The server will start on `http://localhost:8718`.
|
1. Access the application at `http://localhost:8718`.
|
||||||
3. Explore available commands with:
|
|
||||||
```bash
|
|
||||||
agentic_security --help
|
|
||||||
```
|
|
||||||
|
|
||||||
## Basic Usage
|
## Basic Usage
|
||||||
|
|
||||||
- To view available commands, run:
|
- To view available commands, use:
|
||||||
```bash
|
```bash
|
||||||
agentic_security --help
|
agentic_security --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- Review the [Quickstart Guide](quickstart.md) for a fast setup walkthrough.
|
Explore the [Configuration](configuration.md) section to customize your setup.
|
||||||
- Check [http_spec.md](http_spec.md) to learn how LLM endpoints are described.
|
|
||||||
- Browse the `probe_actor` and `probe_data` modules to understand how scanning works and how datasets are loaded.
|
|
||||||
- Explore the [Configuration](configuration.md) section to customize your setup.
|
|
||||||
- Run the tests in `tests/` to verify your environment once dependencies are installed.
|
|
||||||
|
|
||||||
This guide should give you a solid foundation for exploring and extending Agentic Security. For more details, see the rest of the documentation.
|
|
||||||
|
|||||||
+5
-1
@@ -33,7 +33,7 @@ The `LLMSpec` class is the core of the HTTP specification. It provides the follo
|
|||||||
### Methods
|
### Methods
|
||||||
|
|
||||||
- **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object.
|
- **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object.
|
||||||
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> null`**: Validates the request parameters based on the specified modality.
|
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> None`**: Validates the request parameters based on the specified modality.
|
||||||
- **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters.
|
- **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters.
|
||||||
- **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request.
|
- **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request.
|
||||||
|
|
||||||
@@ -57,6 +57,7 @@ Content-Type: application/json
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
spec = LLMSpec.from_string(http_spec)
|
spec = LLMSpec.from_string(http_spec)
|
||||||
response = await spec.probe("What is the capital of France?")
|
response = await spec.probe("What is the capital of France?")
|
||||||
```
|
```
|
||||||
@@ -75,6 +76,7 @@ Content-Type: application/json
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
spec = LLMSpec.from_string(http_spec)
|
spec = LLMSpec.from_string(http_spec)
|
||||||
encoded_image = encode_image_base64_by_url("https://example.com/image.jpg")
|
encoded_image = encode_image_base64_by_url("https://example.com/image.jpg")
|
||||||
response = await spec.probe("What is in this image?", encoded_image=encoded_image)
|
response = await spec.probe("What is in this image?", encoded_image=encoded_image)
|
||||||
@@ -94,6 +96,7 @@ Content-Type: application/json
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
spec = LLMSpec.from_string(http_spec)
|
spec = LLMSpec.from_string(http_spec)
|
||||||
encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3")
|
encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3")
|
||||||
response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio)
|
response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio)
|
||||||
@@ -113,6 +116,7 @@ Content-Type: multipart/form-data
|
|||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
spec = LLMSpec.from_string(http_spec)
|
spec = LLMSpec.from_string(http_spec)
|
||||||
files = {"file": ("document.txt", open("document.txt", "rb"))}
|
files = {"file": ("document.txt", open("document.txt", "rb"))}
|
||||||
response = await spec.probe("Process this file:", files=files)
|
response = await spec.probe("Process this file:", files=files)
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 9.0 MiB |
+1
-1
@@ -21,4 +21,4 @@ Note: Please be aware that Agentic Security is designed as a safety scanner tool
|
|||||||
|
|
||||||
## UI 🧙
|
## UI 🧙
|
||||||
|
|
||||||
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741192668/final_aa9jhb.gif">
|
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
|
||||||
|
|||||||
+14
-4
@@ -54,15 +54,20 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
|||||||
|
|
||||||
- **Classes:**
|
- **Classes:**
|
||||||
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
|
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
|
||||||
|
|
||||||
- Methods:
|
- Methods:
|
||||||
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
|
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
|
||||||
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
|
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
|
||||||
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> null`: Updates rewards
|
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
|
||||||
|
|
||||||
- `RandomPromptSelector`: Basic random selection with history tracking.
|
- `RandomPromptSelector`: Basic random selection with history tracking.
|
||||||
|
|
||||||
- Parameters:
|
- Parameters:
|
||||||
- `prompts: list[str]`: List of available prompts
|
- `prompts: list[str]`: List of available prompts
|
||||||
- `history_size: int = 3`: Size of history to prevent cycles
|
- `history_size: int = 3`: Size of history to prevent cycles
|
||||||
|
|
||||||
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
|
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
|
||||||
|
|
||||||
- Parameters:
|
- Parameters:
|
||||||
- `prompts: list[str]`: List of available prompts
|
- `prompts: list[str]`: List of available prompts
|
||||||
- `api_url: str`: URL of RL service
|
- `api_url: str`: URL of RL service
|
||||||
@@ -70,7 +75,9 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
|||||||
- `history_size: int = 300`: Size of history
|
- `history_size: int = 300`: Size of history
|
||||||
- `timeout: int = 5`: Request timeout
|
- `timeout: int = 5`: Request timeout
|
||||||
- `run_id: str = ""`: Unique run identifier
|
- `run_id: str = ""`: Unique run identifier
|
||||||
|
|
||||||
- `QLearningPromptSelector`: Local Q-learning implementation.
|
- `QLearningPromptSelector`: Local Q-learning implementation.
|
||||||
|
|
||||||
- Parameters:
|
- Parameters:
|
||||||
- `prompts: list[str]`: List of available prompts
|
- `prompts: list[str]`: List of available prompts
|
||||||
- `learning_rate: float = 0.1`: Learning rate
|
- `learning_rate: float = 0.1`: Learning rate
|
||||||
@@ -79,7 +86,9 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
|||||||
- `exploration_decay: float = 0.995`: Exploration decay rate
|
- `exploration_decay: float = 0.995`: Exploration decay rate
|
||||||
- `min_exploration: float = 0.01`: Minimum exploration rate
|
- `min_exploration: float = 0.01`: Minimum exploration rate
|
||||||
- `history_size: int = 300`: Size of history
|
- `history_size: int = 300`: Size of history
|
||||||
- **Module**: Main class that uses CloudRLPromptSelector.
|
|
||||||
|
- `Module`: Main class that uses CloudRLPromptSelector.
|
||||||
|
|
||||||
- Parameters:
|
- Parameters:
|
||||||
- `prompt_groups: list[str]`: Groups of prompts
|
- `prompt_groups: list[str]`: Groups of prompts
|
||||||
- `tools_inbox: asyncio.Queue`: Queue for tool communication
|
- `tools_inbox: asyncio.Queue`: Queue for tool communication
|
||||||
@@ -110,9 +119,10 @@ from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
|
|||||||
|
|
||||||
prompts = ["What is AI?", "Explain machine learning"]
|
prompts = ["What is AI?", "Explain machine learning"]
|
||||||
selector = QLearningPromptSelector(prompts)
|
selector = QLearningPromptSelector(prompts)
|
||||||
|
|
||||||
current_prompt = "What is AI?"
|
current_prompt = "What is AI?"
|
||||||
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=true)
|
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
|
||||||
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=true)
|
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Conclusion
|
## Conclusion
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
:root {
|
:root {
|
||||||
--md-primary-fg-color: #2E4053;
|
--md-primary-fg-color: #e92063;
|
||||||
/* Primary color changed to pinkish */
|
--md-primary-fg-color--light: #e92063;
|
||||||
--md-primary-fg-color--light: #E0A3B6;
|
--md-primary-fg-color--dark: #e92063;
|
||||||
--md-primary-fg-color--dark: #1C3F74;
|
|
||||||
/* Dark variant changed to blue */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Updated slate color scheme with new background */
|
|
||||||
|
/* Revert hue value to that of pre mkdocs-material v9.4.0 */
|
||||||
[data-md-color-scheme="slate"] {
|
[data-md-color-scheme="slate"] {
|
||||||
--md-hue: 230;
|
--md-hue: 230;
|
||||||
--md-default-bg-color: #1A1A1A;
|
--md-default-bg-color: hsla(230, 15%, 21%, 1);
|
||||||
/* Background changed to dark gray */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.hide {
|
.hide {
|
||||||
@@ -26,15 +24,12 @@ img.index-header {
|
|||||||
max-width: 500px;
|
max-width: 500px;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Updated custom colors */
|
|
||||||
.pydantic-pink {
|
.pydantic-pink {
|
||||||
color: #E0A3B6;
|
color: #FF007F;
|
||||||
/* Updated to match new theme */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.team-blue {
|
.team-blue {
|
||||||
color: #1C3F74;
|
color: #0072CE;
|
||||||
/* Updated to match new theme */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.secure-green {
|
.secure-green {
|
||||||
@@ -72,6 +67,7 @@ img.index-header {
|
|||||||
text-align: center;
|
text-align: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Hide the entire footer */
|
/* Hide the entire footer */
|
||||||
.md-footer {
|
.md-footer {
|
||||||
display: none;
|
display: none;
|
||||||
|
|||||||
+1
-1
@@ -89,7 +89,7 @@ theme:
|
|||||||
name: Switch to light mode
|
name: Switch to light mode
|
||||||
icon:
|
icon:
|
||||||
repo: fontawesome/brands/github
|
repo: fontawesome/brands/github
|
||||||
favicon: https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741195421/favicon_kuz6xr.png
|
favicon: "https://res.cloudinary.com/dq0w2rtm9/image/upload/v1737555066/r17hrkre246doczwmvbv.png"
|
||||||
|
|
||||||
extra:
|
extra:
|
||||||
generator: false
|
generator: false
|
||||||
|
|||||||
Generated
+396
-587
File diff suppressed because it is too large
Load Diff
+3
-11
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "agentic_security"
|
name = "agentic_security"
|
||||||
version = "0.7.4"
|
version = "0.5.1"
|
||||||
description = "Agentic LLM vulnerability scanner"
|
description = "Agentic LLM vulnerability scanner"
|
||||||
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||||
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||||
@@ -52,10 +52,8 @@ sentry_sdk = "^2.22.0"
|
|||||||
orjson = "^3.10"
|
orjson = "^3.10"
|
||||||
pyfiglet = "^1.0.2"
|
pyfiglet = "^1.0.2"
|
||||||
termcolor = "^2.4.0"
|
termcolor = "^2.4.0"
|
||||||
mcp = "^1.4.1"
|
|
||||||
|
|
||||||
# garak = { version = "*", optional = true }
|
# garak = { version = "*", optional = true }
|
||||||
pytest-xdist = "3.6.1"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
@@ -70,7 +68,7 @@ pytest-mock = "^3.14.0"
|
|||||||
black = ">=24.10,<26.0"
|
black = ">=24.10,<26.0"
|
||||||
mypy = "^1.12.0"
|
mypy = "^1.12.0"
|
||||||
pre-commit = "^4.0.1"
|
pre-commit = "^4.0.1"
|
||||||
huggingface-hub = ">=0.25.1,<0.30.0"
|
huggingface-hub = ">=0.25.1,<0.29.0"
|
||||||
|
|
||||||
# Docs
|
# Docs
|
||||||
mkdocs = ">=1.4.2"
|
mkdocs = ">=1.4.2"
|
||||||
@@ -88,13 +86,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--durations=5 -m 'not slow' -n 3"
|
addopts = "--durations=5 -m 'not slow'"
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
markers = "slow: marks tests as slow"
|
markers = "slow: marks tests as slow"
|
||||||
|
|
||||||
[project]
|
|
||||||
# MCP requires the following fields to be present in the pyproject.toml file
|
|
||||||
name = "agentic_security"
|
|
||||||
version = "1.0.0"
|
|
||||||
requires-python = ">=3.11"
|
|
||||||
|
|||||||
@@ -1,17 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from cache_to_disk import delete_old_disk_caches
|
|
||||||
|
|
||||||
from agentic_security.logutils import logger
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item):
|
||||||
if "slow" in item.keywords and not os.getenv("RUN_SLOW_TESTS"):
|
if "slow" in item.keywords and not os.getenv("RUN_SLOW_TESTS"):
|
||||||
pytest.skip("Skipping slow test")
|
pytest.skip("Skipping slow test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="session")
|
|
||||||
def setup_delete_old_disk_caches():
|
|
||||||
logger.info("delete_old_disk_caches")
|
|
||||||
delete_old_disk_caches()
|
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from asyncio import Queue, Event
|
||||||
|
from agentic_security.core.app import create_app, get_tools_inbox, get_stop_event, get_current_run, set_current_run
|
||||||
|
|
||||||
|
class TestApp:
|
||||||
|
"""Test suite for agentic_security.core.app module."""
|
||||||
|
|
||||||
|
def test_create_app(self):
|
||||||
|
"""Test that create_app returns a FastAPI instance."""
|
||||||
|
app = create_app()
|
||||||
|
assert isinstance(app, FastAPI)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tools_inbox(self):
|
||||||
|
"""Test that get_tools_inbox returns the global Queue instance."""
|
||||||
|
queue1 = get_tools_inbox()
|
||||||
|
await queue1.put("test item")
|
||||||
|
queue2 = get_tools_inbox()
|
||||||
|
result = queue2.get_nowait()
|
||||||
|
assert result == "test item"
|
||||||
|
|
||||||
|
def test_get_stop_event(self):
|
||||||
|
"""Test that get_stop_event returns the global Event instance and is not set initially."""
|
||||||
|
event = get_stop_event()
|
||||||
|
assert isinstance(event, Event)
|
||||||
|
assert not event.is_set()
|
||||||
|
|
||||||
|
def test_current_run_initial(self):
|
||||||
|
"""Test that get_current_run returns the global current_run with default values initially."""
|
||||||
|
run = get_current_run()
|
||||||
|
# Default values should be empty strings
|
||||||
|
assert run["spec"] == ""
|
||||||
|
assert run["id"] == ""
|
||||||
|
|
||||||
|
def test_set_current_run(self):
|
||||||
|
"""Test that set_current_run correctly updates current_run."""
|
||||||
|
spec = "test run"
|
||||||
|
result = set_current_run(spec)
|
||||||
|
expected_id = hash(id(spec))
|
||||||
|
# Verify that spec is set correctly
|
||||||
|
assert result["spec"] == spec
|
||||||
|
assert result["id"] == expected_id
|
||||||
|
|
||||||
|
def test_current_run_after_set(self):
|
||||||
|
"""Test that get_current_run returns the updated current_run after set_current_run is called."""
|
||||||
|
spec = "another test run"
|
||||||
|
set_current_run(spec)
|
||||||
|
current = get_current_run()
|
||||||
|
assert current["spec"] == spec
|
||||||
|
assert current["id"] == hash(id(spec))
|
||||||
|
def test_tools_inbox_same_instance(self):
|
||||||
|
"""Test that get_tools_inbox returns the same Queue instance by default."""
|
||||||
|
queue1 = get_tools_inbox()
|
||||||
|
queue2 = get_tools_inbox()
|
||||||
|
assert queue1 is queue2
|
||||||
|
|
||||||
|
def test_stop_event_set(self):
|
||||||
|
"""Test that setting the stop event is reflected in subsequent calls."""
|
||||||
|
event = get_stop_event()
|
||||||
|
event.set() # set the global event
|
||||||
|
# Now, subsequent calls should return the same event which is set.
|
||||||
|
event2 = get_stop_event()
|
||||||
|
assert event2.is_set()
|
||||||
|
|
||||||
|
def test_set_current_run_with_none(self):
|
||||||
|
"""Test that set_current_run handles None as a valid input and updates current_run accordingly."""
|
||||||
|
result = set_current_run(None)
|
||||||
|
expected_id = hash(id(None))
|
||||||
|
assert result["spec"] is None
|
||||||
|
assert result["id"] == expected_id
|
||||||
|
|
||||||
|
def test_multiple_current_run_assignments(self):
|
||||||
|
"""Test multiple assignments to current_run to ensure it always updates correctly."""
|
||||||
|
first_spec = "first run"
|
||||||
|
result1 = set_current_run(first_spec)
|
||||||
|
expected_id1 = hash(id(first_spec))
|
||||||
|
assert result1["spec"] == first_spec
|
||||||
|
assert result1["id"] == expected_id1
|
||||||
|
|
||||||
|
second_spec = "second run"
|
||||||
|
result2 = set_current_run(second_spec)
|
||||||
|
expected_id2 = hash(id(second_spec))
|
||||||
|
assert result2["spec"] == second_spec
|
||||||
|
assert result2["id"] == expected_id2
|
||||||
|
|
||||||
|
current = get_current_run()
|
||||||
|
# The current_run should reflect the latest assignment.
|
||||||
|
assert current["spec"] == second_spec
|
||||||
|
assert current["id"] == expected_id2
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_tools_inbox_exception(self):
|
||||||
|
"""Test that calling get_nowait on an empty tools_inbox raises QueueEmpty."""
|
||||||
|
from asyncio import QueueEmpty
|
||||||
|
queue = get_tools_inbox()
|
||||||
|
# Clear any existing items in the queue
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
break
|
||||||
|
with pytest.raises(QueueEmpty):
|
||||||
|
queue.get_nowait()
|
||||||
|
|
||||||
|
def test_set_current_run_with_dict(self):
|
||||||
|
"""Test that set_current_run correctly handles a dictionary input as spec."""
|
||||||
|
spec = {"key": "value"}
|
||||||
|
result = set_current_run(spec)
|
||||||
|
expected_id = hash(id(spec))
|
||||||
|
assert result["spec"] == spec
|
||||||
|
assert result["id"] == expected_id
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_event_wait(self):
|
||||||
|
"""Test that waiting on the stop event returns once the event is set."""
|
||||||
|
event = get_stop_event()
|
||||||
|
event.clear() # ensure event is not set
|
||||||
|
async def waiter():
|
||||||
|
await event.wait()
|
||||||
|
return True
|
||||||
|
waiter_task = asyncio.create_task(waiter())
|
||||||
|
# Wait a moment to ensure the waiter is pending
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert not waiter_task.done()
|
||||||
|
event.set()
|
||||||
|
result = await waiter_task
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_set_current_run_with_int(self):
|
||||||
|
"""Test that set_current_run handles an integer input as spec."""
|
||||||
|
spec = 12345
|
||||||
|
result = set_current_run(spec)
|
||||||
|
expected_id = hash(id(spec))
|
||||||
|
assert result["spec"] == spec
|
||||||
|
assert result["id"] == expected_id
|
||||||
|
|
||||||
|
def test_create_app_routes(self):
|
||||||
|
"""Test that create_app returns a FastAPI instance with default routes available."""
|
||||||
|
app = create_app()
|
||||||
|
paths = [route.path for route in app.routes]
|
||||||
|
# Check that the default OpenAPI route exists
|
||||||
|
assert "/openapi.json" in paths
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tools_inbox_async_put_get_order(self):
|
||||||
|
"""Test that tools_inbox preserves order when items are added and retrieved asynchronously."""
|
||||||
|
queue = get_tools_inbox()
|
||||||
|
# Clear any existing items in the queue
|
||||||
|
from asyncio import QueueEmpty
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
break
|
||||||
|
items = ["first", "second", "third"]
|
||||||
|
for item in items:
|
||||||
|
await queue.put(item)
|
||||||
|
result_items = []
|
||||||
|
for _ in items:
|
||||||
|
result_items.append(await queue.get())
|
||||||
|
assert result_items == items
|
||||||
@@ -0,0 +1,341 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
from agentic_security.http_spec import (
|
||||||
|
LLMSpec,
|
||||||
|
parse_http_spec,
|
||||||
|
escape_special_chars_for_json,
|
||||||
|
encode_image_base64_by_url,
|
||||||
|
encode_audio_base64_by_url,
|
||||||
|
InvalidHTTPSpecError,
|
||||||
|
Modality
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Tests for agentic_security/http_spec.py
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
def test_escape_special_chars_for_json():
|
||||||
|
"""Test escaping special characters in a prompt for JSON safety."""
|
||||||
|
prompt = 'Line1\nLine2\t"Quote"\\Backslash'
|
||||||
|
escaped = escape_special_chars_for_json(prompt)
|
||||||
|
assert '\\n' in escaped
|
||||||
|
assert '\\t' in escaped
|
||||||
|
assert '\\"' in escaped
|
||||||
|
assert '\\\\' in escaped
|
||||||
|
|
||||||
|
def test_parse_http_spec_text():
|
||||||
|
"""Test parsing a text HTTP spec without image/audio/files requirements."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\nThis is a prompt: <<PROMPT>>"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
assert llm_spec.method == "POST"
|
||||||
|
assert llm_spec.url == "http://example.com/api"
|
||||||
|
assert llm_spec.headers["Content-Type"] == "application/json"
|
||||||
|
assert "<<PROMPT>>" in llm_spec.body
|
||||||
|
assert not llm_spec.has_files
|
||||||
|
assert not llm_spec.has_image
|
||||||
|
assert not llm_spec.has_audio
|
||||||
|
|
||||||
|
def test_parse_http_spec_files():
|
||||||
|
"""Test parsing a HTTP spec with multipart/form-data header indicating files."""
|
||||||
|
spec = "PUT http://example.com/upload\nContent-Type: multipart/form-data\n\nFile upload test"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
assert llm_spec.has_files
|
||||||
|
|
||||||
|
def test_parse_http_spec_image_audio():
|
||||||
|
"""Test parsing a HTTP spec that requires image and audio via placeholders."""
|
||||||
|
spec = "GET http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>> and Audio: <<BASE64_AUDIO>>"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
assert llm_spec.has_image
|
||||||
|
assert llm_spec.has_audio
|
||||||
|
|
||||||
|
def test_encode_image_base64_by_url(monkeypatch):
|
||||||
|
"""Test that image encoding returns the correct base64 string with prefix."""
|
||||||
|
dummy_content = b'test_image'
|
||||||
|
class DummyResponse:
|
||||||
|
def __init__(self, content):
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def dummy_get(url):
|
||||||
|
return DummyResponse(dummy_content)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx, "get", dummy_get)
|
||||||
|
result = encode_image_base64_by_url("http://dummyurl.com/image.jpg")
|
||||||
|
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
def test_encode_audio_base64_by_url(monkeypatch):
|
||||||
|
"""Test that audio encoding returns the correct base64 string with prefix."""
|
||||||
|
dummy_content = b'test_audio'
|
||||||
|
class DummyResponse:
|
||||||
|
def __init__(self, content):
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def dummy_get(url):
|
||||||
|
return DummyResponse(dummy_content)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx, "get", dummy_get)
|
||||||
|
result = encode_audio_base64_by_url("http://dummyurl.com/audio.mp3")
|
||||||
|
expected = "data:audio/mpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_text(monkeypatch):
|
||||||
|
"""Test the probe function for text modality by replacing <<PROMPT>>."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
return httpx.Response(200, text="ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.probe("Hello")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "ok" in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_with_files(monkeypatch):
|
||||||
|
"""Test that probe correctly branches to _probe_with_files when files are provided."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
files = {"file": ("dummy.txt", b"data")}
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, files, timeout):
|
||||||
|
return httpx.Response(200, text="file upload ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.probe("Unused", files=files)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "file upload ok" in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_image(monkeypatch):
|
||||||
|
"""Test verify method branch for image modality by monkeypatching image encoder."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"image\": \"<<BASE64_IMAGE>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
# Replace the image encoder to return a dummy string
|
||||||
|
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
# Check that the dummy image is injected in the content
|
||||||
|
assert "dummy_image" in content
|
||||||
|
return httpx.Response(200, text="image ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.verify()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "image ok" in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_audio(monkeypatch):
|
||||||
|
"""Test verify method branch for audio modality by monkeypatching audio encoder."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"audio\": \"<<BASE64_AUDIO>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
monkeypatch.setattr("agentic_security.http_spec.encode_audio_base64_by_url", lambda url: "dummy_audio")
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
# Ensure that the dummy audio string is present in the request content
|
||||||
|
assert "dummy_audio" in content
|
||||||
|
return httpx.Response(200, text="audio ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.verify()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "audio ok" in response.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_files(monkeypatch):
|
||||||
|
"""Test verify method branch for files modality where _probe_with_files is invoked."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile data"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, files, timeout):
|
||||||
|
return httpx.Response(200, text="files ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.verify()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "files ok" in response.text
|
||||||
|
|
||||||
|
def test_llm_spec_modality_property():
|
||||||
|
"""Test that the modality property reflects the correct modality."""
|
||||||
|
spec_text = "POST http://example.com/api\nContent-Type: application/json\n\nPrompt: <<PROMPT>>"
|
||||||
|
llm_spec_text = parse_http_spec(spec_text)
|
||||||
|
assert llm_spec_text.modality == Modality.TEXT
|
||||||
|
|
||||||
|
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
|
||||||
|
llm_spec_image = parse_http_spec(spec_image)
|
||||||
|
assert llm_spec_image.modality == Modality.IMAGE
|
||||||
|
|
||||||
|
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
|
||||||
|
llm_spec_audio = parse_http_spec(spec_audio)
|
||||||
|
assert llm_spec_audio.modality == Modality.AUDIO
|
||||||
|
|
||||||
|
def test_from_string_invalid():
|
||||||
|
"""Test that LLMSpec.from_string raises an error for an invalid spec."""
|
||||||
|
invalid_spec = "INVALID_SPEC"
|
||||||
|
with pytest.raises(InvalidHTTPSpecError):
|
||||||
|
LLMSpec.from_string(invalid_spec)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_missing_files():
|
||||||
|
"""Test that LLMSpec.validate raises a ValueError when files are required but missing."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload test"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
with pytest.raises(ValueError, match="Files are required"):
|
||||||
|
llm_spec.validate("test prompt", "", "", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_missing_image():
|
||||||
|
"""Test that LLMSpec.validate raises a ValueError when an image is required but missing."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
with pytest.raises(ValueError, match="An image is required"):
|
||||||
|
llm_spec.validate("test prompt", "", "dummy_audio", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_missing_audio():
|
||||||
|
"""Test that LLMSpec.validate raises a ValueError when audio is required but missing."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
with pytest.raises(ValueError, match="Audio is required"):
|
||||||
|
llm_spec.validate("test prompt", "dummy_image", "", {})
|
||||||
|
|
||||||
|
def test_fn_alias(monkeypatch):
|
||||||
|
"""Test that LLMSpec.fn is a functional alias for LLMSpec.probe."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
# Instead of overriding the instance method, verify the alias at the class level.
|
||||||
|
assert LLMSpec.fn is LLMSpec.probe
|
||||||
|
|
||||||
|
def test_escape_special_chars_no_special():
|
||||||
|
"""Test that the escape function returns the original string if no special characters are present."""
|
||||||
|
prompt = "Simple text without specials"
|
||||||
|
escaped = escape_special_chars_for_json(prompt)
|
||||||
|
assert escaped == "Simple text without specials"
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_text_with_special_chars(monkeypatch):
|
||||||
|
"""Test probe for text modality with special characters in prompt ensuring escaped content."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
captured['content'] = content
|
||||||
|
return httpx.Response(200, text="ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
test_prompt = 'Hello\nWorld\t"Test"'
|
||||||
|
response = await llm_spec.probe(test_prompt)
|
||||||
|
expected_escaped = escape_special_chars_for_json(test_prompt)
|
||||||
|
assert expected_escaped in captured['content']
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_both_image_audio(monkeypatch):
|
||||||
|
"""Test verify method when both image and audio placeholders are present.
|
||||||
|
Expect a ValueError because only the image branch is triggered by pattern matching and the missing audio causes validation to fail."""
|
||||||
|
spec = ("POST http://example.com/api\nContent-Type: application/json\n\n"
|
||||||
|
"{\"audio\": \"<<BASE64_AUDIO>>\", \"image\":\"<<BASE64_IMAGE>>\"}")
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
# Monkey patch the image encoder to return a dummy value
|
||||||
|
monkeypatch.setattr("agentic_security.http_spec.encode_image_base64_by_url", lambda url="": "dummy_image")
|
||||||
|
with pytest.raises(ValueError, match="Audio is required"):
|
||||||
|
await llm_spec.verify()
|
||||||
|
|
||||||
|
def test_parse_http_spec_invalid_header_format():
|
||||||
|
"""Test that parse_http_spec raises an error when a header line doesn't have the expected 'key: value' format."""
|
||||||
|
invalid_spec = "GET http://example.com/api\nInvalidHeaderWithoutColon\n\nBody with <<PROMPT>>"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_http_spec(invalid_spec)
|
||||||
|
|
||||||
|
def test_from_string_valid():
|
||||||
|
"""Test that LLMSpec.from_string returns a valid LLMSpec object when given a proper spec string."""
|
||||||
|
spec = "GET http://example.com/api\nContent-Type: application/json\n\n{ \"prompt\": \"<<PROMPT>>\" }"
|
||||||
|
llm_spec = LLMSpec.from_string(spec)
|
||||||
|
assert llm_spec.method == "GET"
|
||||||
|
assert llm_spec.url == "http://example.com/api"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_http_spec_multiline_body():
|
||||||
|
"""Test parsing an HTTP spec with a multiline body to ensure body concatenation works."""
|
||||||
|
spec = (
|
||||||
|
"PATCH http://example.com/api\n"
|
||||||
|
"Content-Type: application/json\n"
|
||||||
|
"\n"
|
||||||
|
"Line one of body\n"
|
||||||
|
"Line two of body\n"
|
||||||
|
"Line three"
|
||||||
|
)
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
# As implemented, the parser concatenates lines without newline delimiters
|
||||||
|
expected_body = "Line one of bodyLine two of bodyLine three"
|
||||||
|
assert llm_spec.body == expected_body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_encode_image_default_argument(monkeypatch):
|
||||||
|
"""Test that encode_image_base64_by_url works with its default URL argument."""
|
||||||
|
dummy_content = b'default_image'
|
||||||
|
class DummyResponse:
|
||||||
|
def __init__(self, content):
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def dummy_get(url):
|
||||||
|
# check that the default URL (which includes 'fluidicon.png') is used
|
||||||
|
assert "fluidicon.png" in url
|
||||||
|
return DummyResponse(dummy_content)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx, "get", dummy_get)
|
||||||
|
result = encode_image_base64_by_url()
|
||||||
|
expected = "data:image/jpeg;base64," + base64.b64encode(dummy_content).decode("utf-8")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_without_prompt_placeholder(monkeypatch):
|
||||||
|
"""Test the probe function when the request body does not include the <<PROMPT>> placeholder."""
|
||||||
|
spec = "POST http://example.com/api\nContent-Type: application/json\n\n{\"message\": \"No placeholder here\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
captured['content'] = content
|
||||||
|
return httpx.Response(200, text="ok without placeholder")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
response = await llm_spec.probe("Ignored prompt")
|
||||||
|
assert "No placeholder here" in captured['content']
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_validate_success():
|
||||||
|
"""Test that LLMSpec.validate does not raise an error when all required data is provided."""
|
||||||
|
# Test case for files: files are provided as required
|
||||||
|
spec_files = "POST http://example.com/api\nContent-Type: multipart/form-data\n\nFile upload"
|
||||||
|
llm_spec_files = parse_http_spec(spec_files)
|
||||||
|
llm_spec_files.validate("some prompt", "dummy_image", "dummy_audio", {"file": ("dummy.txt", b"data")})
|
||||||
|
|
||||||
|
# Test case for image: image is provided as required
|
||||||
|
spec_image = "POST http://example.com/api\nContent-Type: application/json\n\nImage: <<BASE64_IMAGE>>"
|
||||||
|
llm_spec_image = parse_http_spec(spec_image)
|
||||||
|
llm_spec_image.validate("some prompt", "dummy_image", "dummy_audio", {})
|
||||||
|
|
||||||
|
# Test case for audio: audio is provided as required
|
||||||
|
spec_audio = "POST http://example.com/api\nContent-Type: application/json\n\nAudio: <<BASE64_AUDIO>>"
|
||||||
|
llm_spec_audio = parse_http_spec(spec_audio)
|
||||||
|
llm_spec_audio.validate("some prompt", "dummy_image", "dummy_audio", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_invalid_url(monkeypatch):
|
||||||
|
"""Test that probe raises an exception when the HTTP client fails due to an invalid URL."""
|
||||||
|
spec = "GET http://nonexistent_url/api\nContent-Type: application/json\n\n{\"prompt\": \"<<PROMPT>>\"}"
|
||||||
|
llm_spec = parse_http_spec(spec)
|
||||||
|
|
||||||
|
async def dummy_request(self, method, url, headers, content, timeout):
|
||||||
|
raise httpx.RequestError("Invalid URL")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "request", dummy_request)
|
||||||
|
with pytest.raises(httpx.RequestError):
|
||||||
|
await llm_spec.probe("Test")
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from agentic_security.mcp.client import run
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mcp_echo_tool():
|
|
||||||
"""Test the echo tool functionality"""
|
|
||||||
prompts, resources, tools = await run()
|
|
||||||
assert prompts
|
|
||||||
assert resources
|
|
||||||
assert tools
|
|
||||||
Generated
+3
-3
@@ -6891,9 +6891,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/http-proxy-middleware": {
|
"node_modules/http-proxy-middleware": {
|
||||||
"version": "2.0.9",
|
"version": "2.0.7",
|
||||||
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz",
|
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.7.tgz",
|
||||||
"integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==",
|
"integrity": "sha512-fgVY8AV7qU7z/MmXJ/rxwbrtQH4jBQ9m7kp3llF0liB7glmFeVZFBepQb32T3y8n8k2+AEYuMPCpinYW+/CuRA==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|||||||
Reference in New Issue
Block a user