mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 14:19:55 +02:00
Compare commits
208 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5238d67846 | |||
| a9adb22458 | |||
| 2dc41af98d | |||
| 48125bd106 | |||
| 5285fdd0a0 | |||
| bf628db5c4 | |||
| d56b406e1a | |||
| b9dc5de708 | |||
| 9a4fb05491 | |||
| 3e2df49976 | |||
| 14eefb7a67 | |||
| 7a9c884333 | |||
| a8b5876883 | |||
| fbe9885c0b | |||
| 583eec1a67 | |||
| f19664f95c | |||
| b3ae0026fb | |||
| 8ddfec303f | |||
| c45778f196 | |||
| a5bdbe54a2 | |||
| 61da912f18 | |||
| a02aed2c2b | |||
| 40ff7f9dfb | |||
| c09ce32def | |||
| c5406e8a0e | |||
| b260672b1a | |||
| 0a07fc54d6 | |||
| 2f1151d44d | |||
| d0353e3ab9 | |||
| 926c583a17 | |||
| 17e34356e1 | |||
| 312fa756a5 | |||
| 145e7f81e1 | |||
| 04af7d24a1 | |||
| c5c5ae2e4b | |||
| 2bc0605a1d | |||
| 335787d40e | |||
| 1b211b5d76 | |||
| 444f908009 | |||
| f81dc508f9 | |||
| 4a55b99d70 | |||
| 5c2f9eba71 | |||
| aa2fe4d1ad | |||
| cf7c017621 | |||
| 73184e3454 | |||
| 3720ece2af | |||
| 0dc738a11e | |||
| 47ca656d59 | |||
| 4fa166298d | |||
| 77557ade85 | |||
| 5cdbf933de | |||
| 54d159a737 | |||
| 35fd373cb2 | |||
| f2b95a0040 | |||
| a8e80e85e1 | |||
| f97c3367b4 | |||
| c065818053 | |||
| 1139577eaa | |||
| 5d6a65350f | |||
| c277cca045 | |||
| fcbb832968 | |||
| a0e523758d | |||
| 5ebf428de6 | |||
| d5fe89f298 | |||
| 98b7d7f691 | |||
| c5ddcb2d75 | |||
| da63270142 | |||
| bf5f7a7dff | |||
| d3ccea76b6 | |||
| b7fef85750 | |||
| a1249cae12 | |||
| 8549aee952 | |||
| 414ee62467 | |||
| 7f68224716 | |||
| 3910bab28e | |||
| 8a4dcfd43e | |||
| 17234a846b | |||
| a51a3aa497 | |||
| 0b3424e9fd | |||
| f81b32d9b4 | |||
| a9f8090614 | |||
| 8770726f63 | |||
| ffc4f94a0a | |||
| 5edd4f0959 | |||
| e495f9626f | |||
| b45006c0d1 | |||
| d60d87f142 | |||
| 68f01622fc | |||
| 29787ae5fc | |||
| 1d0e88b001 | |||
| 8e5a53eaa3 | |||
| dcaba04dd6 | |||
| f4271ef2a1 | |||
| feb1becb3e | |||
| 7b44a2f510 | |||
| e3c3119790 | |||
| e171f0216e | |||
| 5d712ebce4 | |||
| 37a6e7a5bc | |||
| 85216ad106 | |||
| bb2e0e7517 | |||
| 8689efbe59 | |||
| 0b41fe0e3f | |||
| c3776df5c1 | |||
| 143ea4f8c1 | |||
| dd2eb1472f | |||
| 4332e4affd | |||
| e871443e76 | |||
| e9ae785625 | |||
| b1e2dc8cef | |||
| b9802fd268 | |||
| ac3f2f803c | |||
| bd6d2f3db1 | |||
| dda8d13b72 | |||
| 839c1af9d7 | |||
| e261fe55c5 | |||
| b4857a5f36 | |||
| 85d5c67047 | |||
| b4e0aef33f | |||
| 3a3ac4421c | |||
| d6ccd8f114 | |||
| 41ecc3c8e2 | |||
| a5f8da4139 | |||
| 9b47b72bb8 | |||
| 91ff3aa11f | |||
| 87264b793b | |||
| 2bc093cfc4 | |||
| 182b67de57 | |||
| 35540d0299 | |||
| 50ce80b149 | |||
| 07772e9779 | |||
| c7a6d16e00 | |||
| fe71ab0b70 | |||
| b6643d3289 | |||
| 3d6018302e | |||
| a54b11318a | |||
| 91c99e642f | |||
| 1226e2059d | |||
| c94aa54e41 | |||
| 431bf84956 | |||
| 9213b64de0 | |||
| 9d2885f7a6 | |||
| 4c0a57dbae | |||
| 6e6abae680 | |||
| fa7bd104a0 | |||
| 321c3dafc0 | |||
| eb14fe0f61 | |||
| b5985bf810 | |||
| 29362aed30 | |||
| c2671fd6f1 | |||
| 56984c7078 | |||
| f9dce3f915 | |||
| 6e99dfd54f | |||
| 67f7381ace | |||
| e2cb909329 | |||
| dbef9fcc03 | |||
| 4a07411fbd | |||
| 11951c0ca0 | |||
| 02222b75f8 | |||
| aa27817f94 | |||
| 5a4b5e11b2 | |||
| 8bd76b9f05 | |||
| 6f3c522d59 | |||
| 71787c6ec9 | |||
| f911912c0f | |||
| 896ca95ae2 | |||
| f85c77d622 | |||
| 684ba0b70d | |||
| 21b43b18e7 | |||
| 9340fa7223 | |||
| d20c1a3d0d | |||
| 812f9e7619 | |||
| ebac62e21a | |||
| ce2a791663 | |||
| db994fd483 | |||
| 7fe93e8a14 | |||
| a934ee243e | |||
| d646ecd61b | |||
| 21180b53e5 | |||
| a8808b3165 | |||
| 87c26ca3cc | |||
| e06c6932de | |||
| 51fcc38885 | |||
| 06a7bbfd87 | |||
| 50f3e52445 | |||
| 2bd62c21be | |||
| d5d5dd48aa | |||
| bb2437197a | |||
| 51bb79aa6d | |||
| 94f034fa9f | |||
| f69de8720b | |||
| 7c9d83b1a7 | |||
| a9d4d671ba | |||
| 554a219535 | |||
| 32e99006bf | |||
| 8c09d65687 | |||
| a2842755fa | |||
| b923f7fea5 | |||
| 7f30a8ff7a | |||
| 909cbd69b4 | |||
| 4f0ebf180e | |||
| 6be9673aa7 | |||
| bd9ed97d85 | |||
| 3c88a4d6ba | |||
| 2001eeb125 | |||
| a26b5dd448 | |||
| 716a0f67f3 | |||
| c1bbf6b422 |
@@ -1,5 +1,9 @@
|
||||
name: Pre-Commit Checks
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.8.5"
|
||||
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
@@ -15,7 +19,9 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Install poetry
|
||||
run: pipx install poetry==$POETRY_VERSION
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit
|
||||
run: poetry install
|
||||
- name: Run pre-commit
|
||||
run: pre-commit run --all-files
|
||||
run: poetry run pre-commit run --all-files
|
||||
|
||||
@@ -9,7 +9,7 @@ on:
|
||||
- 0.*
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
POETRY_VERSION: "1.8.5"
|
||||
|
||||
jobs:
|
||||
if_release:
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
name: Security Scan
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
schedule:
|
||||
- cron: '0 0 * * 1' # Run weekly on Mondays
|
||||
workflow_dispatch: # Allow manual trigger
|
||||
|
||||
jobs:
|
||||
security_scan:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
API_KEY: PLACEHOLDER
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install agentic-security colorama tabulate tqdm python-multipart
|
||||
|
||||
- name: Run security scan
|
||||
id: scan
|
||||
run: |
|
||||
agentic_security init
|
||||
# agentic_security ci
|
||||
@@ -1,14 +0,0 @@
|
||||
name: PyCharm Python Security Scanner
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
security_checks:
|
||||
runs-on: ubuntu-latest
|
||||
name: Execute the pycharm-security action
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- name: PyCharm Python Security Scanner
|
||||
uses: tonybaloney/pycharm-security@1.19.0
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
POETRY_VERSION: "1.8.5"
|
||||
OPENAI_API_KEY: "sk-fake"
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -17,3 +17,9 @@ inv/
|
||||
scripts/
|
||||
docx/
|
||||
agentic_security.toml
|
||||
/venv
|
||||
*.csv
|
||||
agentic_security/agents/operator_agno.py
|
||||
.claude/
|
||||
plan.md
|
||||
auto_loop.sh
|
||||
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
args: [--py311-plus]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.11.0
|
||||
rev: 25.11.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
@@ -20,12 +20,13 @@ repos:
|
||||
- id: flake8
|
||||
language_version: python3.11
|
||||
additional_dependencies: [flake8-docstrings]
|
||||
exclude: '^(tests)/'
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [--profile, black]
|
||||
# - repo: https://github.com/PyCQA/isort
|
||||
# rev: 7.0.0
|
||||
# hooks:
|
||||
# - id: isort
|
||||
# args: [--profile, black]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
|
||||
+8
-1
@@ -1,5 +1,5 @@
|
||||
# Build stage
|
||||
FROM python:3.11-slim as builder
|
||||
FROM python:3.11-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -14,8 +14,15 @@ RUN poetry self add "poetry-plugin-export"
|
||||
# Copy only dependency files to leverage Docker layer caching
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
# update lock file to avoid failure
|
||||
RUN poetry lock
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry export -f requirements.txt --without-hashes -o requirements.txt
|
||||
|
||||
# Install wheel (required to build packages like fire)
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Runtime stage
|
||||
|
||||
@@ -21,9 +21,7 @@
|
||||
<a href="https://pypi.org/project/agentic-security/">
|
||||
<img alt="PyPI Version" src="https://img.shields.io/pypi/v/agentic-security?style=for-the-badge&logo=pypi&labelColor=000000&color=00CCFF" />
|
||||
</a>
|
||||
<a href="https://discord.gg/stw3DfZQ">
|
||||
<img alt="Join Discord" src="https://img.shields.io/badge/Discord-Join%20Us-black?style=for-the-badge&logo=discord&labelColor=000000&color=DD55FF" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
@@ -83,7 +81,7 @@ agentic_security --port=PORT --host=HOST
|
||||
|
||||
## UI 🧙
|
||||
|
||||
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
|
||||
<img width="100%" alt="booking-screen" src="https://raw.githubusercontent.com/msoedov/agentic_security/refs/heads/main/docs/images/demo.gif">
|
||||
|
||||
## LLM kwargs
|
||||
|
||||
@@ -402,16 +400,31 @@ This setup ensures a continuous integration approach towards maintaining securit
|
||||
|
||||
The `Module` class is designed to manage prompt processing and interaction with external AI models and tools. It supports fetching, processing, and posting prompts asynchronously for model vulnerabilities. Check out [module.md](https://github.com/msoedov/agentic_security/blob/main/docs/module.md) for details.
|
||||
|
||||
|
||||
## MCP server
|
||||
|
||||
```shell
|
||||
pip install -U mcp
|
||||
|
||||
# From cloned directory
|
||||
mcp install agentic_security/mcp/main.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For more detailed information on how to use Agentic Security, including advanced features and customization options, please refer to the official documentation.
|
||||
|
||||
## Roadmap and Future Goals
|
||||
|
||||
- \[ \] Expand dataset variety
|
||||
- \[ \] Introduce two new attack vectors
|
||||
- \[ \] Develop initial attacker LLM
|
||||
- \[ \] Complete integration of OWASP Top 10 classification
|
||||
|
||||
|
||||
We’re just getting started! Here’s what’s on the horizon:
|
||||
|
||||
- **RL-Powered Attacks**: An attacker LLM trained with reinforcement learning to dynamically evolve jailbreaks and outsmart defenses.
|
||||
- **Massive Dataset Expansion**: Scaling to 100,000+ prompts across text, image, and audio modalities—curated for real-world threats.
|
||||
- **Daily Attack Updates**: Fresh attack vectors delivered daily, keeping your scans ahead of the curve.
|
||||
- **Community Modules**: A plug-and-play ecosystem where you can share and deploy custom probes, datasets, and integrations.
|
||||
|
||||
|
||||
| Tool | Source | Integrated |
|
||||
|-------------------------|-------------------------------------------------------------------------------|------------|
|
||||
@@ -423,6 +436,7 @@ For more detailed information on how to use Agentic Security, including advanced
|
||||
|
||||
Note: All dates are tentative and subject to change based on project progress and priorities.
|
||||
|
||||
|
||||
## 👋 Contributing
|
||||
|
||||
Contributions to Agentic Security are welcome! If you'd like to contribute, please follow these steps:
|
||||
@@ -439,4 +453,9 @@ Before contributing, please read the contributing guidelines.
|
||||
|
||||
Agentic Security is released under the Apache License v2.
|
||||
|
||||
|
||||
## 🚫 No Cryptocurrency Affiliation
|
||||
|
||||
Agentic Security is focused solely on AI security and has no affiliation with cryptocurrency projects, blockchain technologies, or related initiatives. Our mission is to advance the safety and reliability of AI systems—no tokens, no coins, just code.
|
||||
|
||||
## Contact us
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from .lib import AgenticSecurity
|
||||
from agentic_security.cache_config import ensure_cache_dir
|
||||
|
||||
__all__ = ["AgenticSecurity"]
|
||||
ensure_cache_dir()
|
||||
|
||||
from .lib import SecurityScanner # noqa: E402
|
||||
|
||||
__all__ = ["SecurityScanner", "ensure_cache_dir"]
|
||||
|
||||
@@ -5,7 +5,7 @@ import fire
|
||||
import uvicorn
|
||||
|
||||
from agentic_security.app import app
|
||||
from agentic_security.lib import AgenticSecurity
|
||||
from agentic_security.lib import SecurityScanner
|
||||
from agentic_security.misc.banner import init_banner
|
||||
|
||||
|
||||
@@ -32,14 +32,14 @@ class CLI:
|
||||
Run Agentic Security in CI mode.
|
||||
"""
|
||||
sys.path.append(os.path.dirname("."))
|
||||
AgenticSecurity().entrypoint()
|
||||
SecurityScanner().entrypoint()
|
||||
|
||||
def init(self, host: str = "0.0.0.0", port: int = 8718):
|
||||
"""
|
||||
Generate the default CI configuration file.
|
||||
"""
|
||||
sys.path.append(os.path.dirname("."))
|
||||
AgenticSecurity().generate_default_settings(host, port)
|
||||
SecurityScanner().generate_default_settings(host, port)
|
||||
|
||||
i = init
|
||||
|
||||
@@ -48,7 +48,7 @@ class CLI:
|
||||
List all available security checks.
|
||||
"""
|
||||
sys.path.append(os.path.dirname("."))
|
||||
AgenticSecurity().list_checks()
|
||||
SecurityScanner().list_checks()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@@ -10,12 +9,11 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# Assuming LLMSpec is defined elsewhere; placeholder import
|
||||
from agentic_security.http_spec import LLMSpec
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
LLM_SPECS = [] # Populate with LLM spec strings if needed
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define AgentSpecification model
|
||||
@@ -248,9 +246,9 @@ async def run_crew():
|
||||
os.environ["OPENAI_API_KEY"] = os.environ.get(
|
||||
"DEEPSEEK_API_KEY", ""
|
||||
) # CrewAI uses OPENAI_API_KEY
|
||||
os.environ[
|
||||
"OPENAI_MODEL_NAME"
|
||||
] = "deepseek:chat" # Specify DeepSeek model (adjust if needed)
|
||||
os.environ["OPENAI_MODEL_NAME"] = (
|
||||
"deepseek:chat" # Specify DeepSeek model (adjust if needed)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_crew())
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -8,13 +7,10 @@ from pydantic_ai import Agent, RunContext, Tool
|
||||
|
||||
# Assuming LLMSpec is defined elsewhere; placeholder import
|
||||
from agentic_security.http_spec import LLMSpec
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
LLM_SPECS = [] # Populate this list with LLM spec strings if needed
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define AgentSpecification model
|
||||
class AgentSpecification(BaseModel):
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Utilities to keep cache-to-disk storage in a writable, predictable location."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ensure_cache_dir(base_dir: Path | None = None) -> Path:
|
||||
"""Ensure ``DISK_CACHE_DIR`` points to a writable directory and create it if needed."""
|
||||
env_var = "DISK_CACHE_DIR"
|
||||
configured_path = os.environ.get(env_var) or os.environ.get(
|
||||
"AGENTIC_SECURITY_CACHE_DIR"
|
||||
)
|
||||
cache_dir = Path(
|
||||
configured_path or base_dir or Path.cwd() / ".cache" / "agentic_security"
|
||||
).expanduser()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ[env_var] = str(cache_dir)
|
||||
return cache_dir
|
||||
|
||||
|
||||
__all__ = ["ensure_cache_dir"]
|
||||
@@ -1,9 +1,10 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import tomli
|
||||
from loguru import logger
|
||||
|
||||
SETTINGS_VERSION = 1
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
SETTINGS_VERSION = 2
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -142,6 +143,13 @@ use_disk_cache = false
|
||||
retry = 3
|
||||
timeout_connect = 30
|
||||
timeout_response = 90
|
||||
|
||||
[fuzzer]
|
||||
max_prompt_lenght = 2048
|
||||
budget_multiplier = 100000000
|
||||
initial_optimizer_points = 25
|
||||
min_failure_samples = 5
|
||||
failure_rate_threshold = 0.5
|
||||
""".replace(
|
||||
"$HOST", host
|
||||
)
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
import os
|
||||
from asyncio import Event, Queue
|
||||
from typing import TypedDict
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from agentic_security.http_spec import LLMSpec
|
||||
|
||||
|
||||
class CurrentRun(TypedDict):
|
||||
id: int | None
|
||||
spec: LLMSpec | None
|
||||
|
||||
|
||||
tools_inbox: Queue = Queue()
|
||||
stop_event: Event = Event()
|
||||
current_run: str = {"spec": "", "id": ""}
|
||||
_secrets = {}
|
||||
current_run: CurrentRun = {"spec": None, "id": None}
|
||||
_secrets: dict[str, str] = {}
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -26,30 +35,34 @@ def get_stop_event() -> Event:
|
||||
return stop_event
|
||||
|
||||
|
||||
def get_current_run() -> str:
|
||||
def get_current_run() -> CurrentRun:
|
||||
"""Get the current run id."""
|
||||
return current_run
|
||||
|
||||
|
||||
def set_current_run(spec):
|
||||
"""Set the current run id."""
|
||||
def set_current_run(spec: LLMSpec) -> CurrentRun:
|
||||
"""Set the current run metadata based on a spec instance."""
|
||||
current_run["id"] = hash(id(spec))
|
||||
current_run["spec"] = spec
|
||||
return current_run
|
||||
|
||||
|
||||
def get_secrets():
|
||||
def get_secrets() -> dict[str, str]:
|
||||
return _secrets
|
||||
|
||||
|
||||
def set_secrets(secrets):
|
||||
def set_secrets(secrets: dict[str, str]) -> dict[str, str]:
|
||||
_secrets.update(secrets)
|
||||
expand_secrets(_secrets)
|
||||
return _secrets
|
||||
|
||||
|
||||
def expand_secrets(secrets):
|
||||
def expand_secrets(secrets: dict[str, str]) -> None:
|
||||
for key in secrets:
|
||||
val = secrets[key]
|
||||
if val.startswith("$"):
|
||||
secrets[key] = os.getenv(val.strip("$"))
|
||||
env_value = os.getenv(val.strip("$"))
|
||||
if env_value is not None:
|
||||
secrets[key] = env_value
|
||||
else:
|
||||
secrets[key] = None
|
||||
|
||||
@@ -1,26 +1,5 @@
|
||||
from logging import config
|
||||
from agentic_security.logutils import set_log_level_to_info
|
||||
|
||||
|
||||
def setup_logging():
|
||||
config.dictConfig(
|
||||
{
|
||||
"version": 1,
|
||||
"disable_existing_loggers": True,
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn.access": {
|
||||
"level": "ERROR", # Set higher log level to suppress info logs globally
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return set_log_level_to_info()
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Advanced concurrent execution package for security scanning."""
|
||||
|
||||
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
|
||||
from agentic_security.executor.circuit_breaker import CircuitBreaker
|
||||
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
|
||||
|
||||
__all__ = [
|
||||
"TokenBucketRateLimiter",
|
||||
"CircuitBreaker",
|
||||
"ConcurrentExecutor",
|
||||
"ExecutorMetrics",
|
||||
]
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Circuit breaker pattern for fault tolerance."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
|
||||
CircuitState = Literal["closed", "open", "half_open"]
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker to prevent cascading failures.
|
||||
|
||||
Implements the circuit breaker pattern with three states:
|
||||
- closed: Normal operation, requests pass through
|
||||
- open: Failure threshold exceeded, requests fail fast
|
||||
- half_open: Recovery attempt, limited requests allowed
|
||||
|
||||
Example:
|
||||
>>> breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
>>> if breaker.is_open():
|
||||
... raise Exception("Circuit breaker is open")
|
||||
>>> try:
|
||||
... result = make_request()
|
||||
... breaker.record_success()
|
||||
>>> except Exception:
|
||||
... breaker.record_failure()
|
||||
"""
|
||||
|
||||
def __init__(self, failure_threshold: float = 0.5, recovery_timeout: int = 30):
|
||||
"""Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Failure rate (0.0-1.0) that triggers open state
|
||||
recovery_timeout: Seconds to wait before attempting recovery
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.failures = 0
|
||||
self.successes = 0
|
||||
self.state: CircuitState = "closed"
|
||||
self.last_failure_time: float | None = None
|
||||
|
||||
def record_success(self):
|
||||
"""Record a successful request."""
|
||||
self.successes += 1
|
||||
|
||||
# If in half_open state and we have enough successes, close the circuit
|
||||
if self.state == "half_open" and self.successes >= 3:
|
||||
self.state = "closed"
|
||||
self.failures = 0
|
||||
self.successes = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""Record a failed request."""
|
||||
self.failures += 1
|
||||
self.last_failure_time = time.monotonic()
|
||||
|
||||
total = self.failures + self.successes
|
||||
|
||||
# Need minimum sample size before opening circuit
|
||||
if total >= 10:
|
||||
failure_rate = self.failures / total
|
||||
if failure_rate >= self.failure_threshold:
|
||||
self.state = "open"
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit breaker is open.
|
||||
|
||||
Returns:
|
||||
bool: True if circuit is open and requests should be blocked
|
||||
"""
|
||||
if self.state == "open":
|
||||
# Check if we should attempt recovery
|
||||
if self.last_failure_time is not None:
|
||||
if time.monotonic() - self.last_failure_time > self.recovery_timeout:
|
||||
self.state = "half_open"
|
||||
# Reset counters for half-open state
|
||||
self.failures = 0
|
||||
self.successes = 0
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_state(self) -> CircuitState:
|
||||
"""Get current circuit breaker state.
|
||||
|
||||
Returns:
|
||||
CircuitState: Current state (closed, open, or half_open)
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_failure_rate(self) -> float:
|
||||
"""Get current failure rate.
|
||||
|
||||
Returns:
|
||||
float: Failure rate (0.0-1.0), or 0.0 if no requests recorded
|
||||
"""
|
||||
total = self.failures + self.successes
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.failures / total
|
||||
|
||||
def reset(self):
|
||||
"""Reset circuit breaker to initial state."""
|
||||
self.failures = 0
|
||||
self.successes = 0
|
||||
self.state = "closed"
|
||||
self.last_failure_time = None
|
||||
@@ -0,0 +1,236 @@
|
||||
"""Concurrent executor with rate limiting and circuit breaking."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
|
||||
from agentic_security.executor.circuit_breaker import CircuitBreaker
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.probe_actor.state import FuzzerState
|
||||
|
||||
|
||||
class ExecutorMetrics:
|
||||
"""Track executor performance metrics."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize metrics tracking."""
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_latency = 0.0
|
||||
self.latencies: list[float] = []
|
||||
|
||||
def record_success(self, latency: float):
|
||||
"""Record a successful request.
|
||||
|
||||
Args:
|
||||
latency: Request latency in seconds
|
||||
"""
|
||||
self.successful_requests += 1
|
||||
self.total_latency += latency
|
||||
self.latencies.append(latency)
|
||||
|
||||
def record_failure(self):
|
||||
"""Record a failed request."""
|
||||
self.failed_requests += 1
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get current statistics.
|
||||
|
||||
Returns:
|
||||
dict: Statistics including total requests, success rate, latency metrics
|
||||
"""
|
||||
total_requests = self.successful_requests + self.failed_requests
|
||||
|
||||
if total_requests == 0:
|
||||
return {
|
||||
"total_requests": 0,
|
||||
"success_rate": 0.0,
|
||||
"avg_latency_ms": 0.0,
|
||||
"p95_latency_ms": 0.0,
|
||||
}
|
||||
|
||||
success_rate = self.successful_requests / total_requests
|
||||
avg_latency_ms = (
|
||||
(self.total_latency / self.successful_requests * 1000)
|
||||
if self.successful_requests > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# Calculate p95 latency
|
||||
if self.latencies:
|
||||
sorted_latencies = sorted(self.latencies)
|
||||
p95_index = int(len(sorted_latencies) * 0.95)
|
||||
p95_latency_ms = (
|
||||
sorted_latencies[p95_index] * 1000
|
||||
if p95_index < len(sorted_latencies)
|
||||
else 0.0
|
||||
)
|
||||
else:
|
||||
p95_latency_ms = 0.0
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": self.successful_requests,
|
||||
"failed_requests": self.failed_requests,
|
||||
"success_rate": success_rate,
|
||||
"avg_latency_ms": avg_latency_ms,
|
||||
"p95_latency_ms": p95_latency_ms,
|
||||
}
|
||||
|
||||
|
||||
class ConcurrentExecutor:
|
||||
"""Enhanced concurrent executor with rate limiting and circuit breaking.
|
||||
|
||||
Provides advanced concurrency control for security scanning with:
|
||||
- Token bucket rate limiting
|
||||
- Circuit breaker for fault tolerance
|
||||
- Metrics collection
|
||||
- Semaphore-based concurrency limits
|
||||
|
||||
Example:
|
||||
>>> executor = ConcurrentExecutor(max_concurrent=20, rate_limit=10, burst=5)
|
||||
>>> tokens, failures = await executor.execute_batch(
|
||||
... request_factory, prompts, "module_name", fuzzer_state
|
||||
... )
|
||||
>>> print(executor.metrics.get_stats())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_concurrent: int = 50,
|
||||
rate_limit: float = 100,
|
||||
burst: int = 20,
|
||||
failure_threshold: float = 0.5,
|
||||
recovery_timeout: int = 30,
|
||||
):
|
||||
"""Initialize concurrent executor.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent requests
|
||||
rate_limit: Requests per second limit
|
||||
burst: Maximum burst size for rate limiter
|
||||
failure_threshold: Failure rate that triggers circuit breaker
|
||||
recovery_timeout: Seconds before attempting circuit recovery
|
||||
"""
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.rate_limiter = TokenBucketRateLimiter(rate_limit, burst)
|
||||
self.circuit_breaker = CircuitBreaker(failure_threshold, recovery_timeout)
|
||||
self.metrics = ExecutorMetrics()
|
||||
|
||||
logger.info(
|
||||
f"ConcurrentExecutor initialized: max_concurrent={max_concurrent}, "
|
||||
f"rate_limit={rate_limit}/s, burst={burst}"
|
||||
)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
request_factory,
|
||||
prompts: list[str],
|
||||
module_name: str,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, int]:
|
||||
"""Execute a batch of prompts with rate limiting and circuit breaking.
|
||||
|
||||
This is compatible with the existing process_prompt_batch signature.
|
||||
|
||||
Args:
|
||||
request_factory: Request factory with fn() method
|
||||
prompts: List of prompts to process
|
||||
module_name: Name of the module being scanned
|
||||
fuzzer_state: State tracking object
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (total_tokens, failures)
|
||||
"""
|
||||
tasks = [
|
||||
self._execute_single(request_factory, prompt, module_name, fuzzer_state)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results
|
||||
total_tokens = 0
|
||||
failures = 0
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
failures += 1
|
||||
logger.error(f"Task failed with exception: {result}")
|
||||
else:
|
||||
tokens, refused = result
|
||||
total_tokens += tokens
|
||||
if refused:
|
||||
failures += 1
|
||||
|
||||
return total_tokens, failures
|
||||
|
||||
async def _execute_single(
|
||||
self,
|
||||
request_factory,
|
||||
prompt: str,
|
||||
module_name: str,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, bool]:
|
||||
"""Execute a single prompt with rate limiting and circuit breaking.
|
||||
|
||||
Args:
|
||||
request_factory: Request factory with fn() method
|
||||
prompt: Prompt to process
|
||||
module_name: Name of the module being scanned
|
||||
fuzzer_state: State tracking object
|
||||
|
||||
Returns:
|
||||
tuple[int, bool]: (tokens, refused)
|
||||
|
||||
Raises:
|
||||
Exception: If circuit breaker is open
|
||||
"""
|
||||
# Rate limiting
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
# Circuit breaker check
|
||||
if self.circuit_breaker.is_open():
|
||||
self.metrics.record_failure()
|
||||
raise Exception("Circuit breaker is open - too many failures")
|
||||
|
||||
# Concurrency control
|
||||
async with self.semaphore:
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from agentic_security.probe_actor.fuzzer import process_prompt
|
||||
|
||||
tokens = 0 # Initial token count for this prompt
|
||||
result = await process_prompt(
|
||||
request_factory, prompt, tokens, module_name, fuzzer_state
|
||||
)
|
||||
|
||||
# Record success
|
||||
self.circuit_breaker.record_success()
|
||||
latency = time.monotonic() - start_time
|
||||
self.metrics.record_success(latency)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record failure
|
||||
self.circuit_breaker.record_failure()
|
||||
self.metrics.record_failure()
|
||||
logger.error(f"Error executing prompt: {e}")
|
||||
raise
|
||||
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""Get current executor metrics.
|
||||
|
||||
Returns:
|
||||
dict: Metrics including request stats, latency, and circuit breaker state
|
||||
"""
|
||||
stats = self.metrics.get_stats()
|
||||
stats["circuit_breaker_state"] = self.circuit_breaker.get_state()
|
||||
stats["circuit_breaker_failure_rate"] = self.circuit_breaker.get_failure_rate()
|
||||
stats["available_tokens"] = self.rate_limiter.get_available_tokens()
|
||||
|
||||
return stats
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Token bucket rate limiter for controlling request rate."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
"""Token bucket rate limiter with configurable rate and burst capacity.
|
||||
|
||||
This implements the token bucket algorithm where tokens are added at a fixed
|
||||
rate and consumed for each request. Supports bursting up to the bucket capacity.
|
||||
|
||||
Example:
|
||||
>>> limiter = TokenBucketRateLimiter(rate=10, burst=20)
|
||||
>>> await limiter.acquire() # Will wait if no tokens available
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float, burst: int):
|
||||
"""Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
rate: Tokens added per second (requests/sec)
|
||||
burst: Maximum bucket capacity (max concurrent burst)
|
||||
"""
|
||||
self.rate = rate
|
||||
self.burst = burst
|
||||
self.tokens = float(burst)
|
||||
self.last_update = time.monotonic()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token, waiting if necessary.
|
||||
|
||||
This method will block until a token is available.
|
||||
"""
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_update
|
||||
|
||||
# Add tokens based on elapsed time
|
||||
self.tokens = min(self.burst, self.tokens + elapsed * self.rate)
|
||||
self.last_update = now
|
||||
|
||||
if self.tokens >= 1:
|
||||
# Token available, consume it
|
||||
self.tokens -= 1
|
||||
return
|
||||
|
||||
# Need to wait for next token
|
||||
wait_time = (1 - self.tokens) / self.rate
|
||||
await asyncio.sleep(wait_time)
|
||||
self.tokens = 0
|
||||
self.last_update = time.monotonic()
|
||||
|
||||
def get_available_tokens(self) -> float:
|
||||
"""Get current number of available tokens (non-blocking).
|
||||
|
||||
Returns:
|
||||
float: Number of tokens currently available
|
||||
"""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_update
|
||||
return min(self.burst, self.tokens + elapsed * self.rate)
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
@@ -68,7 +69,9 @@ class LLMSpec(BaseModel):
|
||||
|
||||
return response
|
||||
|
||||
def validate(self, prompt, encoded_image, encoded_audio, files) -> None:
|
||||
def validate(
|
||||
self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None
|
||||
) -> None:
|
||||
if self.has_files and not files:
|
||||
raise ValueError("Files are required for this request.")
|
||||
|
||||
@@ -79,7 +82,11 @@ class LLMSpec(BaseModel):
|
||||
raise ValueError("Audio is required for this request.")
|
||||
|
||||
async def probe(
|
||||
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
|
||||
self,
|
||||
prompt: str,
|
||||
encoded_image: str = "",
|
||||
encoded_audio: str = "",
|
||||
files: dict | None = None,
|
||||
) -> httpx.Response:
|
||||
"""Sends an HTTP request using the `httpx` library.
|
||||
|
||||
@@ -154,27 +161,52 @@ def parse_http_spec(http_spec: str) -> LLMSpec:
|
||||
secrets = get_secrets()
|
||||
|
||||
# Split the spec by lines
|
||||
lines = http_spec.strip().split("\n")
|
||||
lines = http_spec.strip("\n").splitlines()
|
||||
if not lines:
|
||||
raise InvalidHTTPSpecError("HTTP spec is empty.")
|
||||
|
||||
# Extract the method and URL from the first line
|
||||
method, url = lines[0].split(" ")[0:2]
|
||||
request_line_parts = lines[0].split()
|
||||
if len(request_line_parts) < 2:
|
||||
raise InvalidHTTPSpecError(
|
||||
"First line of HTTP spec must include the method and URL."
|
||||
)
|
||||
method, url = request_line_parts[0], request_line_parts[1]
|
||||
|
||||
# Check url validity
|
||||
valid_url = urlparse(url)
|
||||
# if missing the correct formatting ://, urlparse.netloc will be empty
|
||||
if valid_url.scheme not in ("http", "https") or not valid_url.netloc:
|
||||
raise InvalidHTTPSpecError(
|
||||
f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'"
|
||||
)
|
||||
|
||||
# Initialize headers and body
|
||||
headers = {}
|
||||
body = ""
|
||||
body_lines: list[str] = []
|
||||
|
||||
# Iterate over the remaining lines
|
||||
reading_headers = True
|
||||
for line in lines[1:]:
|
||||
if line == "":
|
||||
reading_headers = False
|
||||
if line.strip() == "":
|
||||
if reading_headers:
|
||||
reading_headers = False
|
||||
continue
|
||||
body_lines.append("")
|
||||
continue
|
||||
|
||||
if reading_headers:
|
||||
key, value = line.split(": ")
|
||||
if ":" not in line:
|
||||
raise InvalidHTTPSpecError(f"Invalid header line: '{line}'")
|
||||
key, value = line.split(":", maxsplit=1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key:
|
||||
raise InvalidHTTPSpecError("Header name cannot be empty.")
|
||||
headers[key] = value
|
||||
else:
|
||||
body += line
|
||||
body_lines.append(line)
|
||||
body = "\n".join(body_lines)
|
||||
has_files = "multipart/form-data" in headers.get("Content-Type", "")
|
||||
has_image = "<<BASE64_IMAGE>>" in body
|
||||
has_audio = "<<BASE64_AUDIO>>" in body
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import Protocol
|
||||
class IntegrationProto(Protocol):
|
||||
def __init__(
|
||||
self, prompt_groups: list, tools_inbox: asyncio.Queue, opts: dict = {}
|
||||
):
|
||||
...
|
||||
): ...
|
||||
|
||||
async def apply(self) -> list:
|
||||
...
|
||||
async def apply(self) -> list: ...
|
||||
|
||||
+16
-10
@@ -1,15 +1,16 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import colorama
|
||||
import tqdm.asyncio
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from tabulate import tabulate
|
||||
|
||||
from agentic_security.config import SettingsMixin # Importing the configuration mixin
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.primitives import Scan
|
||||
from agentic_security.probe_data import REGISTRY
|
||||
from agentic_security.routes.scan import streaming_response_generator
|
||||
@@ -23,18 +24,20 @@ YELLOW = colorama.Fore.YELLOW
|
||||
BLUE = colorama.Fore.BLUE
|
||||
|
||||
|
||||
class AgenticSecurity(SettingsMixin):
|
||||
class SecurityScanner(SettingsMixin):
|
||||
@classmethod
|
||||
async def async_scan(
|
||||
cls,
|
||||
llmSpec: str,
|
||||
maxBudget: int,
|
||||
datasets: list[dict],
|
||||
datasets: list[dict] | None,
|
||||
max_th: float,
|
||||
optimize: bool = False,
|
||||
enableMultiStepAttack: bool = False,
|
||||
probe_datasets: list[dict] = [],
|
||||
probe_datasets: list[dict] | None = None,
|
||||
):
|
||||
datasets = copy.deepcopy(datasets) if datasets is not None else []
|
||||
probe_datasets = copy.deepcopy(probe_datasets or [])
|
||||
start_time = datetime.now()
|
||||
total_modules = len(datasets)
|
||||
completed_modules = 0
|
||||
@@ -170,15 +173,18 @@ class AgenticSecurity(SettingsMixin):
|
||||
cls,
|
||||
llmSpec: str,
|
||||
maxBudget: int = 1_000_000,
|
||||
datasets: list[dict] = REGISTRY,
|
||||
datasets: list[dict] | None = None,
|
||||
max_th: float = 0.3,
|
||||
optimize: bool = False,
|
||||
enableMultiStepAttack: bool = False,
|
||||
probe_datasets: list[dict] = [],
|
||||
only: list[str] = [],
|
||||
probe_datasets: list[dict] | None = None,
|
||||
only: list[str] | None = None,
|
||||
):
|
||||
if only:
|
||||
datasets = [d for d in datasets if d["dataset_name"] in only]
|
||||
datasets = copy.deepcopy(datasets or REGISTRY)
|
||||
probe_datasets = copy.deepcopy(probe_datasets or [])
|
||||
only_set = set(only) if only else None
|
||||
if only_set is not None:
|
||||
datasets = [d for d in datasets if d.get("dataset_name") in only_set]
|
||||
for d in datasets:
|
||||
d["selected"] = True
|
||||
return asyncio.run(
|
||||
@@ -202,7 +208,7 @@ class AgenticSecurity(SettingsMixin):
|
||||
self.load_config(self.default_path)
|
||||
logger.info("Configuration loaded successfully.")
|
||||
print(self.config)
|
||||
datasets = list(self.get_config_value("modules").values())
|
||||
datasets = list(self.get_config_value("modules", {}).values())
|
||||
for d in datasets:
|
||||
d["selected"] = True
|
||||
self.scan(
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
# import sys
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# # Define custom colors
|
||||
# BLUE = "#89CFF0"
|
||||
# BROWN = "#8B4513" # Brown for DEBUG
|
||||
|
||||
# # Define custom log level colors
|
||||
# logger.level("DEBUG", color=f"<fg {BROWN}>")
|
||||
# logger.level("INFO", color=f"<fg {BLUE}>")
|
||||
|
||||
# # Define custom log format with aligned messages and colored levels
|
||||
# LOG_FORMAT = (
|
||||
# "<level>{level:<8}</level> " # Properly formatted and colored log level
|
||||
# "<level>{message:<100}</level> " # Left-aligned message for readability
|
||||
# "<cyan>{file.name}</cyan>:<cyan>{line}</cyan>" # File name and line number in cyan
|
||||
# )
|
||||
|
||||
# # Remove default handlers and add a new one with custom formatting
|
||||
# logger.remove()
|
||||
# logger.add(sys.stdout, format=LOG_FORMAT, level="DEBUG", colorize=True)
|
||||
import logging
|
||||
import logging.config
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from os import getenv
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
LOGGER_NAME = None
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {"format": "%(message)s", "datefmt": "[%X]"},
|
||||
},
|
||||
"handlers": {
|
||||
"rich": {
|
||||
"class": "rich.logging.RichHandler",
|
||||
"level": "INFO",
|
||||
"formatter": "rich",
|
||||
"show_time": False,
|
||||
"rich_tracebacks": False,
|
||||
"show_path": lambda: True if getenv("API_RUNTIME") == "dev" else False,
|
||||
"tracebacks_show_locals": False,
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"": { # Root logger configuration
|
||||
"level": "INFO",
|
||||
"handlers": ["rich"],
|
||||
"propagate": True,
|
||||
},
|
||||
"httpx": { # Disable httpx logging
|
||||
"level": "WARNING", # Suppress DEBUG and INFO messages from httpx
|
||||
"handlers": [],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": { # Disable uvicorn.access logging
|
||||
"level": "WARNING", # Suppress DEBUG and INFO messages from uvicorn.access
|
||||
"handlers": [],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging():
|
||||
# Apply the dictionary configuration
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
# Get and return the logger
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
return logger
|
||||
|
||||
|
||||
logger: logging.Logger = configure_logging()
|
||||
|
||||
|
||||
def set_log_level_to_debug():
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# Update handler level as well
|
||||
for handler in logger.handlers:
|
||||
handler.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def set_log_level_to_info():
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO)
|
||||
# Update handler level as well
|
||||
for handler in logger.handlers:
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Set initial log level
|
||||
set_log_level_to_info()
|
||||
|
||||
|
||||
# Define generic type variables for return type and parameters
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def time_execution_sync(
|
||||
additional_text: str = "",
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def time_execution_async(
|
||||
additional_text: str = "",
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
|
||||
]:
|
||||
def decorator(
|
||||
func: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
# Create server parameters for stdio connection
|
||||
server_params = StdioServerParameters(
|
||||
command="python", # Executable
|
||||
args=["agentic_security/mcp/main.py"], # Your server script
|
||||
env=None, # Optional environment variables
|
||||
)
|
||||
|
||||
|
||||
async def run() -> None:
|
||||
try:
|
||||
logger.info(
|
||||
"Starting stdio client session with server parameters: %s", server_params
|
||||
)
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection --> connection does not work
|
||||
logger.info("Initializing client session...")
|
||||
await session.initialize()
|
||||
|
||||
# List available prompts, resources, and tools --> no avalialbe tools
|
||||
logger.info("Listing available prompts...")
|
||||
prompts = await session.list_prompts()
|
||||
logger.info(f"Available prompts: {prompts}")
|
||||
|
||||
logger.info("Listing available resources...")
|
||||
resources = await session.list_resources()
|
||||
logger.info(f"Available resources: {resources}")
|
||||
|
||||
logger.info("Listing available tools...")
|
||||
tools = await session.list_tools()
|
||||
logger.info(f"Available tools: {tools}")
|
||||
|
||||
# Call the echo tool --> echo tool issue
|
||||
logger.info("Calling echo_tool with message...")
|
||||
echo_result = await session.call_tool(
|
||||
"echo_tool", arguments={"message": "Hello from client!"}
|
||||
)
|
||||
logger.info(f"Tool result: {echo_result}")
|
||||
|
||||
# # Read the echo resource
|
||||
# echo_content, mime_type = await session.read_resource(
|
||||
# "echo://Hello_resource"
|
||||
# )
|
||||
# logger.info(f"Resource content: {echo_content}")
|
||||
# logger.info(f"Resource MIME type: {mime_type}")
|
||||
|
||||
# # Get and use the echo prompt
|
||||
# prompt_result = await session.get_prompt(
|
||||
# "echo_prompt", arguments={"message": "Hello prompt!"}
|
||||
# )
|
||||
# logger.info(f"Prompt result: {prompt_result}")
|
||||
|
||||
logger.info("Client operations completed successfully.")
|
||||
return prompts, resources, tools
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during client operations: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,108 @@
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize MCP server
|
||||
mcp = FastMCP(
|
||||
name="Agentic Security MCP Server",
|
||||
dependencies=["httpx"],
|
||||
)
|
||||
|
||||
# FastAPI Server Configuration
|
||||
AGENTIC_SECURITY = "http://0.0.0.0:8718"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def verify_llm(spec: str) -> dict:
|
||||
"""
|
||||
Verify an LLM model specification using the FastAPI server
|
||||
|
||||
Returns:
|
||||
dict: containing the verification result form the FastAPI server
|
||||
|
||||
Args: spect(str): The specification of the LLM model to verify.
|
||||
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/verify"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json={"spec": spec})
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def start_scan(
|
||||
llmSpec: str,
|
||||
maxBudget: int,
|
||||
optimize: bool = False,
|
||||
enableMultiStepAttack: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Start an LLM security scan via the FastAPI server.
|
||||
Returns:
|
||||
dict: The scan initiation result from the FastAPI server.
|
||||
|
||||
Args:
|
||||
llmSpec (str): The specification of the LLM model.
|
||||
maxBudget (int): The maximum budget for the scan.
|
||||
optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False.
|
||||
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack
|
||||
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/scan"
|
||||
payload = {
|
||||
"llmSpec": llmSpec,
|
||||
"maxBudget": maxBudget,
|
||||
"datasets": [],
|
||||
"optimize": optimize,
|
||||
"enableMultiStepAttack": enableMultiStepAttack,
|
||||
"probe_datasets": [],
|
||||
"secrets": {},
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=payload)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def stop_scan() -> dict:
|
||||
"""Stop an ongoing scan via the FastAPI server.
|
||||
|
||||
Returns:
|
||||
dict: The confirmation from the FastAPI server that the scan has been stopped.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/stop"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_data_config() -> list:
|
||||
"""
|
||||
Retrieve data configuration from the FastAPI server.
|
||||
|
||||
Returns:
|
||||
list: The response from the FastAPI server, confirming the scan has been stopped.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/v1/data-config"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_spec_templates() -> list:
|
||||
"""
|
||||
Retrieve data configuration from the FastAPI server.
|
||||
|
||||
Returns:
|
||||
list: The LLM specification templates from the FastAPI server.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
# Run the MCP server
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
|
||||
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
|
||||
@@ -8,14 +8,14 @@ except ImportError:
|
||||
|
||||
|
||||
def generate_banner(
|
||||
title="Agentic Security",
|
||||
font="slant",
|
||||
version="v2.1.0",
|
||||
tagline="Proactive Threat Detection & Automated Security Protocols",
|
||||
author="Developed by: [Security Team]",
|
||||
website="Website: https://github.com/msoedov/agentic_security",
|
||||
warning="",
|
||||
):
|
||||
title: str = "Agentic Security",
|
||||
font: str = "slant",
|
||||
version: str = "v2.1.0",
|
||||
tagline: str = "Proactive Threat Detection & Automated Security Protocols",
|
||||
author: str = "Developed by: [Security Team]",
|
||||
website: str = "Website: https://github.com/msoedov/agentic_security",
|
||||
warning: str | None = "", # Using Optional for warning since it might be None
|
||||
) -> str:
|
||||
"""Generate a visually enhanced banner with dynamic width and borders."""
|
||||
# Define the text elements
|
||||
|
||||
@@ -80,8 +80,13 @@ def generate_banner(
|
||||
|
||||
|
||||
def init_banner():
|
||||
return
|
||||
ver = version("agentic_security")
|
||||
print(generate_banner(version=ver))
|
||||
try:
|
||||
print(generate_banner(version=ver))
|
||||
except Exception:
|
||||
# UnicodeEncodeError with codec on some systems
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# noqa
|
||||
from agentic_security.primitives.models import CompletionRequest # noqa
|
||||
from agentic_security.primitives.models import ( # noqa
|
||||
CompletionRequest,
|
||||
FileProbeResponse,
|
||||
LLMInfo,
|
||||
Message,
|
||||
|
||||
@@ -18,13 +18,13 @@ class LLMInfo(BaseModel):
|
||||
class Scan(BaseModel):
|
||||
llmSpec: str
|
||||
maxBudget: int
|
||||
datasets: list[dict] = []
|
||||
datasets: list[dict] = Field(default_factory=list)
|
||||
optimize: bool = False
|
||||
enableMultiStepAttack: bool = False
|
||||
# MSJ only mode
|
||||
probe_datasets: list[dict] = []
|
||||
probe_datasets: list[dict] = Field(default_factory=list)
|
||||
# Set and managed by the backend
|
||||
secrets: dict[str, str] = {}
|
||||
secrets: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
def with_secrets(self, secrets) -> "Scan":
|
||||
match secrets:
|
||||
|
||||
@@ -3,32 +3,44 @@ import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from skopt import Optimizer
|
||||
from skopt.space import Real
|
||||
|
||||
from agentic_security.config import settings_var
|
||||
from agentic_security.http_spec import Modality
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.primitives import Scan, ScanResult
|
||||
from agentic_security.probe_actor.cost_module import calculate_cost
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||
from agentic_security.probe_actor.state import FuzzerState
|
||||
from agentic_security.probe_data import audio_generator, image_generator, msj_data
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
# TODO: full log file
|
||||
|
||||
MAX_PROMPT_LENGTH = 2048
|
||||
BUDGET_MULTIPLIER = 100_000_000
|
||||
INITIAL_OPTIMIZER_POINTS = 25
|
||||
MIN_FAILURE_SAMPLES = 5
|
||||
FAILURE_RATE_THRESHOLD = 0.5
|
||||
MAX_PROMPT_LENGTH = settings_var("fuzzer.max_prompt_lenght", 2048)
|
||||
BUDGET_MULTIPLIER = settings_var("fuzzer.budget_multiplier", 100000000)
|
||||
INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25)
|
||||
MIN_FAILURE_SAMPLES = settings_var("fuzzer.min_failure_samples", 5)
|
||||
FAILURE_RATE_THRESHOLD = settings_var("fuzzer.failure_rate_threshold", 0.5)
|
||||
|
||||
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Asynchronously generates and yields individual prompts.
|
||||
|
||||
If the input is a list of strings, the function sequentially yields each string.
|
||||
If the input is an asynchronous generator, it forwards each generated prompt.
|
||||
|
||||
Args:
|
||||
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
|
||||
|
||||
Yields:
|
||||
str: An individual prompt from the list or the asynchronous generator.
|
||||
"""
|
||||
if isinstance(prompts, list):
|
||||
for prompt in prompts:
|
||||
yield prompt
|
||||
@@ -37,7 +49,21 @@ async def generate_prompts(
|
||||
yield prompt
|
||||
|
||||
|
||||
def multi_modality_spec(llm_spec):
|
||||
def get_modality_adapter(llm_spec):
|
||||
"""
|
||||
Returns the appropriate request adapter based on the modality of the LLM specification.
|
||||
|
||||
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
|
||||
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
|
||||
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
|
||||
|
||||
Args:
|
||||
llm_spec: An object containing modality information for the LLM.
|
||||
|
||||
Returns:
|
||||
RequestAdapter | llm_spec: An instance of the appropriate request adapter
|
||||
or the original `llm_spec` if no adaptation is needed.
|
||||
"""
|
||||
match llm_spec.modality:
|
||||
case Modality.IMAGE:
|
||||
return image_generator.RequestAdapter(llm_spec)
|
||||
@@ -50,40 +76,71 @@ def multi_modality_spec(llm_spec):
|
||||
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors, outputs
|
||||
request_factory,
|
||||
prompt: str,
|
||||
tokens: int,
|
||||
module_name: str,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Process a single prompt and update the token count and failure status.
|
||||
Processes a single prompt using the provided request factory and updates tracking lists.
|
||||
|
||||
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
|
||||
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
|
||||
the response indicates a refusal, the function records the issue and returns the updated token count
|
||||
along with a boolean indicating whether the prompt was refused.
|
||||
|
||||
Args:
|
||||
request_factory: An object with a `fn` method used to send the prompt.
|
||||
prompt (str): The input prompt to be processed.
|
||||
tokens (int): The current token count, which will be updated.
|
||||
module_name (str): The name of the module handling the request.
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
|
||||
Returns:
|
||||
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
|
||||
"""
|
||||
try:
|
||||
response = await request_factory.fn(prompt=prompt)
|
||||
|
||||
# Handle HTTP errors
|
||||
if response.status_code == 422:
|
||||
logger.error(f"Invalid prompt: {prompt}, error=422")
|
||||
errors.append((module_name, prompt, 422, "Invalid prompt"))
|
||||
fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt")
|
||||
return tokens, True
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.error(f"HTTP {response.status_code} {response.content=}")
|
||||
errors.append((module_name, prompt, response.status_code, response.text))
|
||||
fuzzer_state.add_error(
|
||||
module_name, prompt, response.status_code, response.text
|
||||
)
|
||||
return tokens, True
|
||||
|
||||
# Process successful response
|
||||
response_text = response.text
|
||||
tokens += len(response_text.split())
|
||||
|
||||
# Check if the response indicates a refusal
|
||||
refused = refusal_heuristic(response.json())
|
||||
if refused:
|
||||
refusals.append((module_name, prompt, response.status_code, response_text))
|
||||
fuzzer_state.add_refusal(
|
||||
module_name, prompt, response.status_code, response_text
|
||||
)
|
||||
|
||||
outputs.append((module_name, prompt, response_text, refused))
|
||||
fuzzer_state.add_output(module_name, prompt, response_text, refused)
|
||||
return tokens, refused
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error(f"Request error: {exc}")
|
||||
errors.append((module_name, prompt, "?", str(exc)))
|
||||
fuzzer_state.add_error(module_name, prompt, "?", str(exc))
|
||||
return tokens, True
|
||||
except JSONDecodeError as json_decode_error:
|
||||
logger.error(f"Jason error: {json_decode_error}")
|
||||
errors.append((module_name, prompt, "?", str(json_decode_error)))
|
||||
logger.error(f"JSON error: {json_decode_error}")
|
||||
fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error))
|
||||
return tokens, True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
return tokens, False
|
||||
|
||||
|
||||
async def process_prompt_batch(
|
||||
@@ -91,14 +148,29 @@ async def process_prompt_batch(
|
||||
prompts: list[str],
|
||||
tokens: int,
|
||||
module_name: str,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Processes a batch of prompts asynchronously and aggregates the results.
|
||||
|
||||
This function sends multiple prompts concurrently using `process_prompt`,
|
||||
collects the token count and failure status for each prompt, and returns
|
||||
the total number of tokens processed and the number of failed prompts.
|
||||
|
||||
Args:
|
||||
request_factory: An object with a `fn` method used to send the prompts.
|
||||
prompts (list[str]): A list of input prompts to be processed.
|
||||
tokens (int): The initial token count, which will be updated.
|
||||
module_name (str): The name of the module handling the request.
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
|
||||
Returns:
|
||||
tuple[int, int]:
|
||||
- Total number of tokens processed.
|
||||
- Number of failed prompts.
|
||||
"""
|
||||
tasks = [
|
||||
process_prompt(
|
||||
request_factory, p, tokens, module_name, refusals, errors, outputs
|
||||
)
|
||||
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
|
||||
for p in prompts
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -107,133 +179,161 @@ async def process_prompt_batch(
|
||||
return total_tokens, failures
|
||||
|
||||
|
||||
async def perform_single_shot_scan(
|
||||
async def scan_module(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] = [],
|
||||
tools_inbox=None,
|
||||
optimize=False,
|
||||
stop_event: asyncio.Event = None,
|
||||
secrets: dict[str, str] = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a standard security scan."""
|
||||
max_budget = max_budget * BUDGET_MULTIPLIER
|
||||
selected_datasets = [m for m in datasets if m["selected"]]
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
try:
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in selected_datasets],
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
options=[m.get("opts", {}) for m in selected_datasets],
|
||||
module,
|
||||
fuzzer_state: FuzzerState,
|
||||
processed_prompts: int = 0,
|
||||
total_prompts: int = 0,
|
||||
max_budget: int = 0,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
token_counter: dict[str, int] | None = None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""
|
||||
Scan a single module.
|
||||
|
||||
Args:
|
||||
request_factory: The factory for creating requests
|
||||
module: The prompt module to scan
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
processed_prompts: Number of prompts processed so far
|
||||
total_prompts: Total number of prompts to process
|
||||
max_budget: Maximum token budget
|
||||
token_counter: Shared token counter to enforce global budget
|
||||
optimize: Whether to use optimization
|
||||
stop_event: Event to stop scanning
|
||||
|
||||
Yields:
|
||||
ScanResult objects as the scan progresses
|
||||
"""
|
||||
tokens = 0
|
||||
token_counter = token_counter or {"total": 0}
|
||||
module_failures = 0
|
||||
module_prompts = 0
|
||||
failure_rates = []
|
||||
should_stop = False
|
||||
|
||||
# Initialize optimizer if optimization is enabled
|
||||
optimizer = (
|
||||
Optimizer(
|
||||
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
|
||||
)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
if optimize
|
||||
else None
|
||||
)
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=0,
|
||||
cost=0,
|
||||
progress=0,
|
||||
failureRate=0,
|
||||
prompt="",
|
||||
latency=0,
|
||||
model="",
|
||||
).model_dump_json()
|
||||
|
||||
optimizer = (
|
||||
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||
if optimize
|
||||
else None
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
|
||||
processed_prompts += 1
|
||||
module_prompts += 1
|
||||
|
||||
# Calculate progress based on total processed prompts
|
||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
progress = progress % 100
|
||||
|
||||
start = time.time()
|
||||
|
||||
previous_tokens = tokens
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
failure_rates = []
|
||||
|
||||
total_tokens = 0
|
||||
tokens = 0
|
||||
should_stop = False
|
||||
for module in prompt_modules:
|
||||
if should_stop:
|
||||
end = time.time()
|
||||
token_delta = max(tokens - previous_tokens, 0)
|
||||
token_counter["total"] += token_delta
|
||||
|
||||
if failed:
|
||||
module_failures += 1
|
||||
|
||||
failure_rate = module_failures / max(module_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
response_text = fuzzer_state.get_last_output(prompt) or ""
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
latency=end - start,
|
||||
model=response_text,
|
||||
).model_dump_json()
|
||||
|
||||
# Optimization logic
|
||||
if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > FAILURE_RATE_THRESHOLD:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
tokens = 0
|
||||
module_failures = 0
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
module_prompts = 0 # Reset for each module
|
||||
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
# Budget check
|
||||
if token_counter["total"] > max_budget:
|
||||
logger.info(
|
||||
"Scan ran out of budget and stopped. %s %s",
|
||||
token_counter["total"],
|
||||
max_budget,
|
||||
)
|
||||
yield ScanResult.status_msg(
|
||||
f"Scan ran out of budget and stopped. total_tokens={token_counter['total']} max_budget={max_budget}"
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
processed_prompts += 1
|
||||
module_prompts += 1 # Fixed increment syntax
|
||||
# Calculate progress based on total processed prompts
|
||||
progress = (
|
||||
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
total_tokens -= tokens
|
||||
start = time.time()
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
)
|
||||
end = time.time()
|
||||
total_tokens += tokens
|
||||
return
|
||||
|
||||
if failed:
|
||||
module_failures += 1
|
||||
failure_rate = module_failures / max(module_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
last_output = outputs[-1] if outputs else None
|
||||
if last_output and last_output[1] == prompt:
|
||||
response_text = last_output[2]
|
||||
else:
|
||||
response_text = ""
|
||||
async def with_error_handling(agen):
|
||||
"""
|
||||
Wraps an asynchronous generator with error handling.
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
latency=end - start,
|
||||
model=response_text,
|
||||
).model_dump_json()
|
||||
This function iterates over an asynchronous generator, yielding its values.
|
||||
If an exception occurs, it logs the error and yields a failure message.
|
||||
Finally, it ensures that a completion message is always yielded.
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > 0.5:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
if total_tokens > max_budget:
|
||||
logger.info(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
yield ScanResult.status_msg(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
failure_data = errors + refusals
|
||||
df = pd.DataFrame(
|
||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv("failures.csv", index=False)
|
||||
Args:
|
||||
agen: An asynchronous generator that produces scan results.
|
||||
|
||||
Yields:
|
||||
ScanResult: Either a successful result, an error message if an
|
||||
exception occurs, or a completion message at the end.
|
||||
"""
|
||||
try:
|
||||
async for t in agen:
|
||||
yield t
|
||||
except Exception as e:
|
||||
logger.exception("Scan failed")
|
||||
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
||||
@@ -241,154 +341,279 @@ async def perform_single_shot_scan(
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
|
||||
async def perform_single_shot_scan(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] | None = None,
|
||||
tools_inbox=None,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
secrets: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Perform a standard security scan using a given request factory.
|
||||
|
||||
This function processes security scan prompts from selected datasets while
|
||||
respecting a predefined token budget. It supports optimization, failure tracking,
|
||||
and early stopping based on budget constraints or user intervention.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function that generates requests for processing prompts.
|
||||
max_budget (int): The maximum token budget for the scan.
|
||||
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
|
||||
tools_inbox: Optional additional tools for processing (default: None).
|
||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
||||
|
||||
Yields:
|
||||
str: JSON-encoded scan results or status messages.
|
||||
|
||||
The function iterates over prompts, processes them asynchronously, and updates
|
||||
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
|
||||
it stops execution. Results are saved to a CSV file upon completion.
|
||||
"""
|
||||
datasets = datasets or []
|
||||
secrets = secrets or {}
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
return
|
||||
max_budget = max_budget * BUDGET_MULTIPLIER
|
||||
selected_datasets = [m for m in datasets if m.get("selected")]
|
||||
request_factory = get_modality_adapter(request_factory)
|
||||
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in selected_datasets],
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
options=[m.get("opts", {}) for m in selected_datasets],
|
||||
)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
fuzzer_state = FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
token_counter = {"total": 0}
|
||||
for module in prompt_modules:
|
||||
module_gen = scan_module(
|
||||
request_factory=request_factory,
|
||||
module=module,
|
||||
fuzzer_state=fuzzer_state,
|
||||
processed_prompts=processed_prompts,
|
||||
total_prompts=total_prompts,
|
||||
max_budget=max_budget,
|
||||
optimize=optimize,
|
||||
stop_event=stop_event,
|
||||
token_counter=token_counter,
|
||||
)
|
||||
try:
|
||||
async for result in module_gen:
|
||||
yield result
|
||||
except Exception:
|
||||
logger.error("Module exception")
|
||||
continue
|
||||
# Update processed_prompts count
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
processed_prompts += module_size
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
fuzzer_state.export_failures("failures.csv")
|
||||
|
||||
|
||||
async def perform_many_shot_scan(
|
||||
request_factory,
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] = [],
|
||||
probe_datasets: list[dict[str, str]] = [],
|
||||
datasets: list[dict[str, str]] | None = None,
|
||||
probe_datasets: list[dict[str, str]] | None = None,
|
||||
tools_inbox=None,
|
||||
optimize=False,
|
||||
stop_event: asyncio.Event = None,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
probe_frequency: float = 0.2,
|
||||
max_ctx_length: int = 10_000,
|
||||
secrets: dict[str, str] = {},
|
||||
secrets: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a multi-step security scan with probe injection."""
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
try:
|
||||
# Load main and probe datasets
|
||||
"""
|
||||
Perform a multi-step security scan with probe injection.
|
||||
|
||||
This function executes a security scan while periodically injecting probe datasets
|
||||
to test system robustness. It tracks failures, optimizes scan efficiency,
|
||||
and ensures adherence to a predefined token budget.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function that generates requests for processing prompts.
|
||||
max_budget (int): The maximum token budget for the scan.
|
||||
datasets (list[dict[str, str]], optional): The main datasets for scanning.
|
||||
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
|
||||
tools_inbox: Optional tools for additional processing (default: None).
|
||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
|
||||
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
|
||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
||||
|
||||
Yields:
|
||||
str: JSON-encoded scan results or status messages.
|
||||
|
||||
This function iterates over prompts, injects probe prompts at random intervals,
|
||||
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
|
||||
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
|
||||
"""
|
||||
datasets = datasets or []
|
||||
probe_datasets = probe_datasets or []
|
||||
secrets = secrets or {}
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
)
|
||||
yield ScanResult.status_msg("Loading datasets for MSJ...")
|
||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
optimizer = (
|
||||
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||
if optimize
|
||||
else None
|
||||
)
|
||||
failure_rates = []
|
||||
|
||||
for module in prompt_modules:
|
||||
module_failures = 0
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
tokens = 0
|
||||
processed_prompts += 1
|
||||
progress = (
|
||||
100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
)
|
||||
|
||||
full_prompt = ""
|
||||
msj_module = random.choice(msj_modules)
|
||||
|
||||
prompt_tokens = len(full_prompt.split())
|
||||
tokens += prompt_tokens
|
||||
|
||||
injected = False
|
||||
for _ in range(20):
|
||||
if injected:
|
||||
break
|
||||
|
||||
m_prompt = random.choice(msj_module.prompts)
|
||||
full_prompt += "\n" + m_prompt
|
||||
if tokens > max_ctx_length:
|
||||
full_prompt = "\n" + prompt
|
||||
injected = True
|
||||
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
full_prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
break
|
||||
if injected:
|
||||
break
|
||||
|
||||
failure_rate = module_failures / max(processed_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > 0.5:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
break
|
||||
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
return
|
||||
request_factory = get_modality_adapter(request_factory)
|
||||
# Load main and probe datasets
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
|
||||
budget=max_budget,
|
||||
tools_inbox=tools_inbox,
|
||||
)
|
||||
yield ScanResult.status_msg("Loading datasets for MSJ...")
|
||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
df = pd.DataFrame(
|
||||
errors + refusals, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv("failures.csv", index=False)
|
||||
fuzzer_state = FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Scan failed")
|
||||
yield ScanResult.status_msg(f"Scan failed: {str(e)}")
|
||||
raise e
|
||||
failure_rates = []
|
||||
|
||||
for module in prompt_modules:
|
||||
module_failures = 0
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
|
||||
tokens = 0
|
||||
processed_prompts += 1
|
||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
progress = progress % 100
|
||||
|
||||
full_prompt = ""
|
||||
msj_module = random.choice(msj_modules)
|
||||
|
||||
prompt_tokens = len(full_prompt.split())
|
||||
tokens += prompt_tokens
|
||||
|
||||
injected = False
|
||||
for _ in range(20):
|
||||
if injected:
|
||||
break
|
||||
|
||||
m_prompt = random.choice(msj_module.prompts)
|
||||
full_prompt += "\n" + m_prompt
|
||||
if tokens > max_ctx_length:
|
||||
full_prompt = "\n" + prompt
|
||||
injected = True
|
||||
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
full_prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
break
|
||||
if injected:
|
||||
break
|
||||
|
||||
failure_rate = module_failures / max(processed_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
break
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
fuzzer_state.export_failures("failures.csv")
|
||||
|
||||
|
||||
def scan_router(
|
||||
request_factory,
|
||||
scan_parameters: Scan,
|
||||
tools_inbox=None,
|
||||
stop_event: asyncio.Event = None,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
):
|
||||
"""
|
||||
Route scan requests to the appropriate scanning function.
|
||||
|
||||
This function determines whether to perform a multi-step or single-shot
|
||||
security scan based on the provided scan parameters.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function to generate requests for processing prompts.
|
||||
scan_parameters (Scan): An object containing the parameters for the scan, including:
|
||||
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
|
||||
- maxBudget (int): The maximum token budget for the scan.
|
||||
- datasets (list[dict[str, str]]): The datasets to scan.
|
||||
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
|
||||
- optimize (bool): Whether to enable optimization.
|
||||
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
|
||||
tools_inbox: Optional tools for additional processing (default: None).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
|
||||
Returns:
|
||||
A function wrapped with `with_error_handling`, which executes either:
|
||||
- `perform_many_shot_scan` for multi-step scanning.
|
||||
- `perform_single_shot_scan` for single-shot scanning.
|
||||
|
||||
The function ensures that the appropriate scanning method is chosen based on
|
||||
the `enableMultiStepAttack` flag in `scan_parameters`.
|
||||
"""
|
||||
if scan_parameters.enableMultiStepAttack:
|
||||
return perform_many_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
probe_datasets=scan_parameters.probe_datasets,
|
||||
tools_inbox=tools_inbox,
|
||||
optimize=scan_parameters.optimize,
|
||||
stop_event=stop_event,
|
||||
secrets=scan_parameters.secrets,
|
||||
return with_error_handling(
|
||||
perform_many_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
probe_datasets=scan_parameters.probe_datasets,
|
||||
tools_inbox=tools_inbox,
|
||||
optimize=scan_parameters.optimize,
|
||||
stop_event=stop_event,
|
||||
secrets=scan_parameters.secrets,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return perform_single_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
tools_inbox=tools_inbox,
|
||||
optimize=scan_parameters.optimize,
|
||||
stop_event=stop_event,
|
||||
secrets=scan_parameters.secrets,
|
||||
return with_error_handling(
|
||||
perform_single_shot_scan(
|
||||
request_factory=request_factory,
|
||||
max_budget=scan_parameters.maxBudget,
|
||||
datasets=scan_parameters.datasets,
|
||||
tools_inbox=tools_inbox,
|
||||
optimize=scan_parameters.optimize,
|
||||
stop_event=stop_event,
|
||||
secrets=scan_parameters.secrets,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -7,13 +6,10 @@ from pydantic import BaseModel, Field
|
||||
from pydantic_ai import Agent, RunContext
|
||||
|
||||
from agentic_security.http_spec import LLMSpec
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
LLM_SPECS = []
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentSpecification(BaseModel):
|
||||
name: str | None = Field(None, description="Name of the LLM/agent")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from agentic_security.refusal_classifier import RefusalClassifier
|
||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
||||
|
||||
classifier = RefusalClassifier()
|
||||
classifier.load_model()
|
||||
@@ -50,7 +50,6 @@ class RefusalClassifierPlugin(ABC):
|
||||
Returns:
|
||||
bool: True if the response contains a refusal, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultRefusalClassifier(RefusalClassifierPlugin):
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class FuzzerState:
|
||||
"""Container for tracking scan results"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.refusals = []
|
||||
self.outputs = []
|
||||
|
||||
def add_error(
|
||||
self,
|
||||
module_name: str,
|
||||
prompt: str,
|
||||
status_code: int | str,
|
||||
error_msg: str,
|
||||
):
|
||||
"""Add an error to the state"""
|
||||
self.errors.append((module_name, prompt, status_code, error_msg))
|
||||
|
||||
def add_refusal(
|
||||
self, module_name: str, prompt: str, status_code: int, response_text: str
|
||||
):
|
||||
"""Add a refusal to the state"""
|
||||
self.refusals.append((module_name, prompt, status_code, response_text))
|
||||
|
||||
def add_output(
|
||||
self, module_name: str, prompt: str, response_text: str, refused: bool
|
||||
):
|
||||
"""Add an output to the state"""
|
||||
self.outputs.append((module_name, prompt, response_text, refused))
|
||||
|
||||
def get_last_output(self, prompt: str) -> str | None:
|
||||
"""Get the last output for a given prompt"""
|
||||
for output in reversed(self.outputs):
|
||||
if output[1] == prompt:
|
||||
return output[2]
|
||||
return None
|
||||
|
||||
def export_failures(self, filename: str = "failures.csv"):
|
||||
"""Export failures to a CSV file"""
|
||||
failure_data = self.errors + self.refusals
|
||||
df = pd.DataFrame(
|
||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv(filename, index=False)
|
||||
@@ -1,4 +1,4 @@
|
||||
from .data import load_local_csv
|
||||
from .data import load_local_csv, load_local_csv_files
|
||||
|
||||
REGISTRY_V0 = [
|
||||
{
|
||||
@@ -484,3 +484,18 @@ REGISTRY = REGISTRY_V0 + [
|
||||
"modality": "text",
|
||||
},
|
||||
]
|
||||
|
||||
for ds in load_local_csv_files():
|
||||
REGISTRY.append(
|
||||
{
|
||||
"dataset_name": ds.dataset_name,
|
||||
"num_prompts": len(ds.prompts),
|
||||
"tokens": ds.prompts,
|
||||
"approx_cost": 0.0,
|
||||
"is_active": True,
|
||||
"source": f"Local file dataset: {ds.metadata['src']}",
|
||||
"selected": False,
|
||||
"url": "",
|
||||
"modality": "text",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
@@ -7,6 +8,14 @@ import uuid
|
||||
import httpx
|
||||
from cache_to_disk import cache_to_disk
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioGenerationError(Exception):
|
||||
"""Custom exception for errors during audio generation."""
|
||||
|
||||
|
||||
def encode(content: bytes) -> str:
|
||||
encoded_content = base64.b64encode(content).decode("utf-8")
|
||||
@@ -41,12 +50,25 @@ def generate_audio_mac_wav(prompt: str) -> bytes:
|
||||
# Read the WAV file into memory
|
||||
with open(temp_wav_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Subprocess error: {e}")
|
||||
raise AudioGenerationError("Failed to generate or convert audio.") from e
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"File not found: {e}")
|
||||
raise AudioGenerationError("Required file not found.") from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error occurred.")
|
||||
raise AudioGenerationError(
|
||||
"An unexpected error occurred during audio generation."
|
||||
) from e
|
||||
finally:
|
||||
# Clean up the temporary files
|
||||
if os.path.exists(temp_aiff_path):
|
||||
os.remove(temp_aiff_path)
|
||||
if os.path.exists(temp_wav_path):
|
||||
os.remove(temp_wav_path)
|
||||
for path in (temp_aiff_path, temp_wav_path):
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temporary file {path}: {e}")
|
||||
|
||||
# Return the audio bytes
|
||||
return audio_bytes
|
||||
|
||||
+456
-391
@@ -1,14 +1,15 @@
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
from functools import lru_cache
|
||||
from collections.abc import Callable, Iterator
|
||||
from functools import partial
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from cache_to_disk import cache_to_disk
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.probe_data import stenography_fn
|
||||
from agentic_security.probe_data.models import ProbeDataset
|
||||
from agentic_security.probe_data.modules import (
|
||||
@@ -18,282 +19,281 @@ from agentic_security.probe_data.modules import (
|
||||
inspect_ai_tool,
|
||||
rl_model,
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# Type aliases for clarity
|
||||
T = TypeVar("T")
|
||||
FilterFn = Callable[[pd.Series], bool]
|
||||
ColumnMappings = dict[str, str]
|
||||
DatasetLoader = Callable[[], ProbeDataset]
|
||||
TransformFn = Callable[[str], str]
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_general(
|
||||
dataset_name,
|
||||
dataset_split="train",
|
||||
column_mappings=None,
|
||||
filter_fn=None,
|
||||
custom_url=None,
|
||||
additional_metadata=None,
|
||||
):
|
||||
"""
|
||||
Generalized function to load datasets with flexible configurations.
|
||||
# Core data loading utilities
|
||||
def fetch_csv_content(url: str) -> str:
|
||||
"""Fetch CSV content from a URL."""
|
||||
response = httpx.get(url)
|
||||
response.raise_for_status() # Raise exception for bad responses
|
||||
return response.content.decode("utf-8")
|
||||
|
||||
:param dataset_name: Name of the dataset or URL for custom CSVs
|
||||
:param dataset_split: Split to load from the dataset (e.g., "train")
|
||||
:param column_mappings: Dictionary mapping dataset columns to expected keys, e.g., {'prompt': 'query'}
|
||||
:param filter_fn: A filtering function that takes a row and returns True/False
|
||||
:param custom_url: URL for custom CSV datasets
|
||||
:param additional_metadata: Additional metadata to include in the ProbeDataset
|
||||
:return: A ProbeDataset object with the processed data
|
||||
"""
|
||||
if custom_url:
|
||||
logger.info(f"Loading custom CSV dataset from {custom_url}")
|
||||
r = httpx.get(custom_url)
|
||||
content = r.content
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
else:
|
||||
logger.info(f"Loading dataset {dataset_name} from Hugging Face datasets")
|
||||
dataset = load_dataset(dataset_name)
|
||||
df = pd.DataFrame(dataset[dataset_split])
|
||||
|
||||
# Apply column mappings if provided
|
||||
if column_mappings:
|
||||
df.rename(columns=column_mappings, inplace=True)
|
||||
def load_df_from_source(source: str, is_url: bool = False) -> pd.DataFrame:
|
||||
"""Load DataFrame from either URL or Hugging Face dataset."""
|
||||
if is_url:
|
||||
logger.info(f"Loading custom CSV from {source}")
|
||||
return pd.read_csv(io.StringIO(fetch_csv_content(source)))
|
||||
logger.info(f"Loading dataset {source} from Hugging Face")
|
||||
return pd.DataFrame(load_dataset(source)["train"])
|
||||
|
||||
# Filter rows if filter_fn is provided
|
||||
|
||||
def transform_df(
|
||||
df: pd.DataFrame,
|
||||
mappings: ColumnMappings | None = None,
|
||||
filter_fn: FilterFn | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Apply transformations to DataFrame."""
|
||||
result = df.copy()
|
||||
if mappings:
|
||||
result.rename(columns=mappings, inplace=True)
|
||||
if filter_fn:
|
||||
df = df[df.apply(filter_fn, axis=1)]
|
||||
|
||||
# Extract prompts
|
||||
prompts = df[column_mappings.get("prompt", "prompt")].tolist()
|
||||
|
||||
return ProbeDataset(
|
||||
dataset_name=dataset_name,
|
||||
metadata=additional_metadata or {},
|
||||
prompts=prompts,
|
||||
tokens=count_words_in_list(prompts),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
result = result[result.apply(filter_fn, axis=1)]
|
||||
return result
|
||||
|
||||
|
||||
def count_words_in_list(str_list):
|
||||
"""Calculate the total number of words in a given list of strings.
|
||||
|
||||
:param str_list: List of strings
|
||||
:return: Total number of words across all strings in the list
|
||||
"""
|
||||
total_words = sum(len(str(s).split()) for s in str_list)
|
||||
return total_words
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v1():
|
||||
return load_dataset_general(
|
||||
dataset_name="ShawnMenz/DAN_jailbreak",
|
||||
column_mappings={"prompt": "prompt", "jailbreak": "jailbreak"},
|
||||
filter_fn=lambda row: row["jailbreak"],
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v2():
|
||||
return load_dataset_general(
|
||||
dataset_name="deepset/prompt-injections",
|
||||
column_mappings={"prompt": "text", "jailbreak": "label"},
|
||||
filter_fn=lambda row: row["label"],
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v4():
|
||||
return load_dataset_general(
|
||||
dataset_name="notrichardren/refuse-to-answer-prompts",
|
||||
column_mappings={"prompt": "claim", "jailbreak": "label"},
|
||||
filter_fn=lambda row: row["label"],
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v3():
|
||||
return load_dataset_general(
|
||||
dataset_name="rubend18/ChatGPT-Jailbreak-Prompts",
|
||||
column_mappings={"prompt": "Prompt"},
|
||||
filter_fn=lambda row: row["label"],
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v6():
|
||||
dataset = load_dataset("markush1/LLM-Jailbreak-Classifier")
|
||||
bad_prompts = [
|
||||
p
|
||||
for p, c in zip(dataset["train"]["prompt"], dataset["train"]["classification"])
|
||||
if c == "jailbreak"
|
||||
]
|
||||
return ProbeDataset(
|
||||
dataset_name="markush1/LLM-Jailbreak-Classifier",
|
||||
metadata={},
|
||||
prompts=bad_prompts,
|
||||
tokens=count_words_in_list(bad_prompts),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v7():
|
||||
splits = {
|
||||
"mini_JailBreakV_28K": "JailBreakV_28K/mini_JailBreakV_28K.csv",
|
||||
"JailBreakV_28K": "JailBreakV_28K/JailBreakV_28K.csv",
|
||||
}
|
||||
df = pd.read_csv(
|
||||
"hf://datasets/JailbreakV-28K/JailBreakV-28k/" + splits["JailBreakV_28K"]
|
||||
)
|
||||
bad_prompts = df["jailbreak_query"].tolist()
|
||||
print(df.shape)
|
||||
return ProbeDataset(
|
||||
dataset_name="JailbreakV-28K/JailBreakV-28k",
|
||||
metadata={},
|
||||
prompts=bad_prompts,
|
||||
tokens=count_words_in_list(bad_prompts),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v8():
|
||||
df = pd.read_csv(
|
||||
"hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
|
||||
names=["jailbreak", "prompt"],
|
||||
)
|
||||
filtered = df[df["jailbreak"] == "jailbreak"]["prompt"].tolist()
|
||||
return ProbeDataset(
|
||||
dataset_name="JailbreakV-28K/JailBreakV-28k",
|
||||
metadata={},
|
||||
prompts=filtered,
|
||||
tokens=count_words_in_list(filtered),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_dataset_v5():
|
||||
ds = []
|
||||
for c in [
|
||||
"AdvBench",
|
||||
"ForbiddenQuestion",
|
||||
"MJP",
|
||||
"MaliciousInstruct",
|
||||
"QuestionList",
|
||||
]:
|
||||
dataset = load_dataset("Lemhf14/EasyJailbreak_Datasets", c)
|
||||
bad_prompts = dataset["train"]["query"]
|
||||
ds.extend(bad_prompts)
|
||||
|
||||
return ProbeDataset(
|
||||
dataset_name="Lemhf14/EasyJailbreak_Datasets",
|
||||
metadata={},
|
||||
prompts=ds,
|
||||
tokens=count_words_in_list(ds),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_generic_csv(url, name, column="prompt", predicator=None):
|
||||
r = httpx.get(url)
|
||||
content = r.content
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
logger.info(f"Loaded {len(df)} prompts from {url}")
|
||||
filtered_prompts = df[df.apply(predicator, axis=1)][column].tolist()
|
||||
def create_probe_dataset(
|
||||
name: str, prompts: list[str], metadata: dict[str, Any] | None = None
|
||||
) -> ProbeDataset:
|
||||
"""Create a ProbeDataset from prompts."""
|
||||
metadata = metadata or {}
|
||||
return ProbeDataset(
|
||||
dataset_name=name,
|
||||
metadata={},
|
||||
prompts=filtered_prompts,
|
||||
tokens=count_words_in_list(filtered_prompts),
|
||||
metadata=metadata,
|
||||
prompts=prompts,
|
||||
tokens=sum(len(str(s).split()) for s in prompts),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
|
||||
|
||||
def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
|
||||
# ## Datasets used and cleaned:
|
||||
# markush1/LLM-Jailbreak-Classifier
|
||||
# 1. Open-Orca/OpenOrca
|
||||
# 2. ShawnMenz/DAN_jailbreak
|
||||
# 3. EddyLuo/JailBreakV_28K
|
||||
# 4. https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/jailbreak_prompts.csv
|
||||
# Generalized dataset loader
|
||||
@cache_to_disk()
|
||||
def load_dataset_generic(
|
||||
name: str,
|
||||
mappings: ColumnMappings | None = None,
|
||||
filter_fn: FilterFn | None = None,
|
||||
url: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ProbeDataset:
|
||||
"""Load and process a dataset with flexible configuration."""
|
||||
try:
|
||||
df = load_df_from_source(url or name, is_url=bool(url))
|
||||
transformed_df = transform_df(df, mappings, filter_fn)
|
||||
|
||||
dataset_map = {
|
||||
"ShawnMenz/DAN_jailbreak": load_dataset_v1,
|
||||
"deepset/prompt-injections": load_dataset_v2,
|
||||
"notrichardren/refuse-to-answer-prompts": load_dataset_v4,
|
||||
"rubend18/ChatGPT-Jailbreak-Prompts": load_dataset_v3,
|
||||
"Lemhf14/EasyJailbreak_Datasets": load_dataset_v5,
|
||||
"markush1/LLM-Jailbreak-Classifier": load_dataset_v6,
|
||||
"JailbreakV-28K/JailBreakV-28k": load_dataset_v7,
|
||||
"ShawnMenz/jailbreak_sft_rm_ds": load_dataset_v8,
|
||||
"verazuo/jailbreak_llms/2023_05_07": lambda: load_generic_csv(
|
||||
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
|
||||
name="verazuo/jailbreak_llms/2023_05_07",
|
||||
column="prompt",
|
||||
predicator=lambda x: bool(x["jailbreak"]),
|
||||
),
|
||||
"verazuo/jailbreak_llms/2023_12_25.csv": lambda: load_generic_csv(
|
||||
url="https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv.csv",
|
||||
name="verazuo/jailbreak_llms/2023_12_25.csv",
|
||||
column="prompt",
|
||||
predicator=lambda x: bool(x["jailbreak"]),
|
||||
),
|
||||
"Custom CSV": load_local_csv,
|
||||
}
|
||||
dataset_map.update(dataset_map_generics)
|
||||
# Determine which column to use as the prompt source
|
||||
prompt_col = None
|
||||
if mappings and "prompt" in mappings:
|
||||
prompt_col = mappings["prompt"]
|
||||
elif "prompt" in transformed_df.columns:
|
||||
prompt_col = "prompt"
|
||||
else:
|
||||
# Try to find a suitable text column
|
||||
text_columns = [
|
||||
col
|
||||
for col in transformed_df.columns
|
||||
if any(
|
||||
keyword in col.lower()
|
||||
for keyword in ["prompt", "text", "query", "question"]
|
||||
)
|
||||
]
|
||||
if text_columns:
|
||||
prompt_col = text_columns[0]
|
||||
logger.info(f"Using column '{prompt_col}' as prompt source")
|
||||
else:
|
||||
logger.error(f"No suitable prompt column found in dataset {name}")
|
||||
return create_probe_dataset(name, [], metadata)
|
||||
|
||||
group = []
|
||||
for dataset_name in dataset_names:
|
||||
if dataset_name in dataset_map:
|
||||
logger.info(f"Loading {dataset_name}")
|
||||
try:
|
||||
group.append(dataset_map[dataset_name]())
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {dataset_name}: {e}")
|
||||
|
||||
dynamic_datasets = {
|
||||
"AgenticBackend": lambda opts: dataset_from_iterator(
|
||||
"AgenticBackend",
|
||||
fine_tuned.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Steganography": lambda opts: Stenography(group),
|
||||
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
|
||||
"llm-adaptive-attacks",
|
||||
adaptive_attacks.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
),
|
||||
"Garak": lambda opts: dataset_from_iterator(
|
||||
"Garak",
|
||||
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
||||
"Reinforcement Learning Optimization",
|
||||
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"InspectAI": lambda opts: dataset_from_iterator(
|
||||
"InspectAI",
|
||||
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"GPT fuzzer": lambda opts: [],
|
||||
}
|
||||
|
||||
dynamic_groups = []
|
||||
options = options or [{} for _ in dataset_names]
|
||||
for dataset_name, opts in zip(dataset_names, options):
|
||||
if dataset_name in dynamic_datasets:
|
||||
logger.info(f"Loading {dataset_name}")
|
||||
|
||||
ds = dynamic_datasets[dataset_name](opts)
|
||||
|
||||
for g in ds:
|
||||
dynamic_groups.append(g)
|
||||
return group + dynamic_groups
|
||||
# Extract prompts and filter out empty ones
|
||||
prompts = [
|
||||
p
|
||||
for p in transformed_df[prompt_col].tolist()
|
||||
if p and isinstance(p, (str, int, float))
|
||||
]
|
||||
return create_probe_dataset(name, prompts, metadata)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset {name}: {e}")
|
||||
return create_probe_dataset(name, [], {"error": str(e)})
|
||||
|
||||
|
||||
class Stenography:
|
||||
fn_library = {
|
||||
# Dataset-specific configurations
|
||||
DATASET_CONFIGS = {
|
||||
"ShawnMenz/DAN_jailbreak": {
|
||||
"mappings": {"prompt": "prompt"},
|
||||
"filter_fn": lambda row: row["jailbreak"],
|
||||
},
|
||||
"deepset/prompt-injections": {
|
||||
"mappings": {"prompt": "text"},
|
||||
"filter_fn": lambda row: row["label"],
|
||||
},
|
||||
"notrichardren/refuse-to-answer-prompts": {
|
||||
"mappings": {"prompt": "claim"},
|
||||
"filter_fn": lambda row: row["label"],
|
||||
},
|
||||
"rubend18/ChatGPT-Jailbreak-Prompts": {
|
||||
"mappings": {"prompt": "Prompt"},
|
||||
"filter_fn": lambda row: row["label"],
|
||||
},
|
||||
"markush1/LLM-Jailbreak-Classifier": {
|
||||
"mappings": {"prompt": "prompt"},
|
||||
"filter_fn": lambda row: row["classification"] == "jailbreak",
|
||||
},
|
||||
"ShawnMenz/jailbreak_sft_rm_ds": {
|
||||
"url": "hf://datasets/ShawnMenz/jailbreak_sft_rm_ds/jailbreak_sft_rm_ds.csv",
|
||||
"mappings": {"prompt": "prompt"},
|
||||
"filter_fn": lambda row: row["jailbreak"] == "jailbreak",
|
||||
},
|
||||
"verazuo/jailbreak_llms/2023_05_07": {
|
||||
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_05_07.csv",
|
||||
"mappings": {"prompt": "prompt"},
|
||||
"filter_fn": lambda row: bool(row["jailbreak"]),
|
||||
},
|
||||
"verazuo/jailbreak_llms/2023_12_25": {
|
||||
"url": "https://raw.githubusercontent.com/verazuo/jailbreak_llms/main/data/prompts/jailbreak_prompts_2023_12_25.csv",
|
||||
"mappings": {"prompt": "prompt"},
|
||||
"filter_fn": lambda row: bool(row["jailbreak"]),
|
||||
},
|
||||
}
|
||||
|
||||
# Additional generic dataset configurations
|
||||
DATASET_CONFIGS_GENERICS = {
|
||||
"simonycl/aya-23-8B_advbench_jailbreak": {"mappings": {"prompt": "prompt"}},
|
||||
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": {},
|
||||
"karanxa/dolphin-jailbreak-finetuning-dataset": {"mappings": {"prompt": "text"}},
|
||||
"karanxa/llama-2-jailbreak-dataset": {"mappings": {"prompt": "text"}},
|
||||
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": {
|
||||
"mappings": {"prompt": "text"}
|
||||
},
|
||||
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": {
|
||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
||||
},
|
||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": {
|
||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
||||
},
|
||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": {
|
||||
"mappings": {"prompt": "jailbreak_prompt_text"}
|
||||
},
|
||||
"jackhhao/jailbreak-classification": {"mappings": {"prompt": "prompt"}},
|
||||
"walledai/JailbreakBench": {"mappings": {"prompt": "prompt"}},
|
||||
"walledai/JailbreakHub": {"mappings": {"prompt": "prompt"}},
|
||||
"Granther/evil-jailbreak": {"mappings": {"prompt": "text"}},
|
||||
"sevdeawesome/jailbreak_success": {"mappings": {"prompt": "jailbreak_prompt_text"}},
|
||||
"IDA-SERICS/Disaster-tweet-jailbreaking": {"mappings": {"prompt": "prompt_attack"}},
|
||||
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": {"mappings": {"prompt": "question"}},
|
||||
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": {"mappings": {"prompt": "prompt"}},
|
||||
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": {
|
||||
"mappings": {"prompt": "prompt"}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Dataset factory
|
||||
def create_dataset_loader(name: str, config: dict[str, Any]) -> DatasetLoader:
|
||||
"""Create a dataset loader from configuration."""
|
||||
return partial(
|
||||
load_dataset_generic,
|
||||
name=name,
|
||||
mappings=config.get("mappings"),
|
||||
filter_fn=config.get("filter_fn"),
|
||||
url=config.get("url"),
|
||||
metadata={"source": name, "config": str(config)},
|
||||
)
|
||||
|
||||
|
||||
# Specialized loaders
|
||||
@cache_to_disk()
|
||||
def load_multi_dataset(name: str, sub_datasets: list[str]) -> ProbeDataset:
|
||||
"""Load and combine multiple sub-datasets."""
|
||||
prompts = []
|
||||
for sub in sub_datasets:
|
||||
try:
|
||||
dataset = load_dataset(name, sub)
|
||||
if "query" in dataset["train"].features:
|
||||
prompts.extend(dataset["train"]["query"])
|
||||
else:
|
||||
logger.warning(f"No 'query' column in {name}/{sub}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {name}/{sub}: {e}")
|
||||
|
||||
return create_probe_dataset(
|
||||
f"{name}_combined", prompts, {"source": name, "sub_datasets": sub_datasets}
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_jailbreak_v28k() -> ProbeDataset:
|
||||
"""Load JailBreakV-28K dataset."""
|
||||
try:
|
||||
df = pd.read_csv(
|
||||
"hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv"
|
||||
)
|
||||
prompts = df["jailbreak_query"].tolist()
|
||||
return create_probe_dataset(
|
||||
"JailbreakV-28K/JailBreakV-28k",
|
||||
prompts,
|
||||
{"source": "JailbreakV-28K/JailBreakV-28k"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JailbreakV-28K: {e}")
|
||||
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", [])
|
||||
|
||||
|
||||
@cache_to_disk(1)
|
||||
def file_dataset(file) -> list[str]:
|
||||
prompts = []
|
||||
try:
|
||||
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
|
||||
if "prompt" in df.columns:
|
||||
prompts = df["prompt"].tolist()
|
||||
else:
|
||||
logger.warning(f"File {file} lacks a suitable prompt column")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {file}: {e}")
|
||||
return prompts
|
||||
|
||||
|
||||
def load_local_csv() -> ProbeDataset:
|
||||
"""Load prompts from local CSV files."""
|
||||
os.makedirs("./datasets", exist_ok=True)
|
||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
||||
|
||||
prompts = []
|
||||
for file in csv_files:
|
||||
prompts.extend(file_dataset(file))
|
||||
return create_probe_dataset("Local CSV", prompts, {"src": str(csv_files)})
|
||||
|
||||
|
||||
def load_csv(file: str) -> ProbeDataset:
|
||||
"""Load prompts from local CSV files."""
|
||||
prompts = file_dataset(file)
|
||||
return create_probe_dataset(f"fs://{file}", prompts, {"src": str(file)})
|
||||
|
||||
|
||||
def load_local_csv_files() -> list[ProbeDataset]:
|
||||
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
|
||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
||||
|
||||
datasets = []
|
||||
|
||||
for file in csv_files:
|
||||
datasets.append(create_probe_dataset(file, file_dataset(file), {"src": file}))
|
||||
return datasets
|
||||
|
||||
|
||||
# Stenography transformer
|
||||
class StenographyTransformer:
|
||||
"""Apply stenography transformations to datasets."""
|
||||
|
||||
TRANSFORMATIONS = {
|
||||
"rot5": stenography_fn.rot5,
|
||||
"rot13": stenography_fn.rot13,
|
||||
"base64": stenography_fn.base64_encode,
|
||||
@@ -306,151 +306,216 @@ class Stenography:
|
||||
"zigzag_obfuscation": stenography_fn.zigzag_obfuscation,
|
||||
}
|
||||
|
||||
def __init__(self, prompt_groups: [ProbeDataset]):
|
||||
self.prompt_groups = prompt_groups
|
||||
def __init__(self, datasets: list[ProbeDataset]):
|
||||
self.datasets = datasets
|
||||
|
||||
def __iter__(self):
|
||||
return self.apply()
|
||||
def transform(self) -> Iterator[ProbeDataset]:
|
||||
"""Generate transformed datasets."""
|
||||
if not self.datasets:
|
||||
logger.warning("No datasets provided to StenographyTransformer")
|
||||
return
|
||||
|
||||
def apply(self):
|
||||
for prompt_group in self.prompt_groups:
|
||||
size = len(prompt_group.prompts)
|
||||
for name, fn in self.fn_library.items():
|
||||
logger.info(f"Applying {name} to {prompt_group.dataset_name}")
|
||||
selected_prompts = random.sample(prompt_group.prompts, int(size * 0.25))
|
||||
prompts = [
|
||||
f"Follow the instructions in {name}: {fn(p)}"
|
||||
for p in selected_prompts
|
||||
]
|
||||
yield ProbeDataset(
|
||||
dataset_name=f"stenography.{name}({prompt_group.dataset_name})",
|
||||
metadata={},
|
||||
prompts=prompts,
|
||||
tokens=count_words_in_list(prompts),
|
||||
approx_cost=0.0,
|
||||
for dataset in self.datasets:
|
||||
if not dataset.prompts:
|
||||
logger.warning(
|
||||
f"Dataset {dataset.dataset_name} has no prompts to transform"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert async generator (if any) to a list
|
||||
prompts = (
|
||||
list(dataset.prompts) if hasattr(dataset.prompts, "__iter__") else []
|
||||
)
|
||||
size = len(prompts)
|
||||
|
||||
if size == 0:
|
||||
continue
|
||||
|
||||
sample_size = max(1, int(size * 0.25)) # At least 1 sample
|
||||
|
||||
for name, fn in self.TRANSFORMATIONS.items():
|
||||
logger.info(f"Applying {name} to {dataset.dataset_name}")
|
||||
sampled = random.sample(prompts, min(sample_size, size))
|
||||
try:
|
||||
transformed = [
|
||||
f"Follow the instructions in {name}: {fn(str(p))}"
|
||||
for p in sampled
|
||||
]
|
||||
transformed_dataset = create_probe_dataset(
|
||||
f"stenography.{name}({dataset.dataset_name})",
|
||||
transformed,
|
||||
{"source": dataset.dataset_name, "transformation": name},
|
||||
)
|
||||
yield transformed_dataset
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error applying {name} to {dataset.dataset_name}: {e}"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def load_local_csv() -> ProbeDataset:
|
||||
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
|
||||
logger.info(f"Found {len(csv_files)} CSV files")
|
||||
logger.info(f"CSV files: {csv_files}")
|
||||
prompt_list = []
|
||||
def dataset_from_iterator(
|
||||
name: str, iterator: Iterator[str], lazy: bool = False
|
||||
) -> list[ProbeDataset]:
|
||||
"""Convert an iterator into a list of ProbeDataset objects."""
|
||||
try:
|
||||
prompts = list(iterator) if not lazy else iterator
|
||||
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0
|
||||
dataset = ProbeDataset(
|
||||
dataset_name=name,
|
||||
metadata={"source": name, "lazy": lazy},
|
||||
prompts=prompts,
|
||||
tokens=tokens,
|
||||
approx_cost=0.0,
|
||||
lazy=lazy,
|
||||
)
|
||||
return [dataset]
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating dataset from iterator {name}: {e}")
|
||||
return [create_probe_dataset(name, [], {"error": str(e)})]
|
||||
|
||||
for file in csv_files:
|
||||
try:
|
||||
df = pd.read_csv(file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {file}: {e}")
|
||||
|
||||
# Main dataset preparation
|
||||
def prepare_prompts(
|
||||
dataset_names: list[str],
|
||||
budget: float,
|
||||
tools_inbox=None,
|
||||
options: list[dict] = None,
|
||||
) -> list[ProbeDataset]:
|
||||
"""Prepare datasets based on names and options."""
|
||||
# Base dataset loaders
|
||||
logger.info(f"Preparing datasets: {dataset_names}")
|
||||
dataset_loaders = {
|
||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS.items()},
|
||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS_GENERICS.items()},
|
||||
"Lemhf14/EasyJailbreak_Datasets": partial(
|
||||
load_multi_dataset,
|
||||
name="Lemhf14/EasyJailbreak_Datasets",
|
||||
sub_datasets=[
|
||||
"AdvBench",
|
||||
"ForbiddenQuestion",
|
||||
"MJP",
|
||||
"MaliciousInstruct",
|
||||
"QuestionList",
|
||||
],
|
||||
),
|
||||
"JailbreakV-28K/JailBreakV-28k": load_jailbreak_v28k,
|
||||
"Local CSV": load_local_csv,
|
||||
"Custom CSV": load_local_csv,
|
||||
}
|
||||
|
||||
# Dynamic dataset loaders
|
||||
dynamic_loaders = {
|
||||
"AgenticBackend": lambda opts: dataset_from_iterator(
|
||||
"AgenticBackend",
|
||||
fine_tuned.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Steganography": lambda opts: list(
|
||||
StenographyTransformer(opts["datasets"]).transform()
|
||||
),
|
||||
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
|
||||
"llm-adaptive-attacks",
|
||||
adaptive_attacks.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
),
|
||||
"Garak": lambda opts: dataset_from_iterator(
|
||||
"Garak",
|
||||
garak_tool.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
||||
"Reinforcement Learning Optimization",
|
||||
rl_model.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"InspectAI": lambda opts: dataset_from_iterator(
|
||||
"InspectAI",
|
||||
inspect_ai_tool.Module([], tools_inbox=tools_inbox).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"GPT fuzzer": lambda opts: [],
|
||||
}
|
||||
|
||||
datasets = []
|
||||
options = options or [dict(datasets=datasets) for _ in dataset_names]
|
||||
|
||||
# Load base datasets
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if name not in dataset_loaders:
|
||||
continue
|
||||
# Check if 'prompt' column exists
|
||||
if "prompt" in df.columns:
|
||||
prompt_list.extend(df["prompt"].tolist())
|
||||
else:
|
||||
logger.warning(f"File {file} does not contain a 'prompt' column")
|
||||
return ProbeDataset(
|
||||
dataset_name="Local CSV",
|
||||
metadata={"src": str(csv_files)},
|
||||
prompts=prompt_list,
|
||||
tokens=count_words_in_list(prompt_list),
|
||||
approx_cost=0.0,
|
||||
)
|
||||
try:
|
||||
datasets.append(dataset_loaders[name]())
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {name}: {e}")
|
||||
|
||||
# Load dynamic datasets and apply transformations
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if name not in dynamic_loaders:
|
||||
continue
|
||||
logger.info(f"Loading dynamic dataset {name} {opts}")
|
||||
opts["datasets"] = datasets
|
||||
try:
|
||||
dynamic_result = dynamic_loaders[name](opts)
|
||||
datasets.extend(dynamic_result)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error loading dynamic {name}: {e}")
|
||||
|
||||
# Load csv datasets and apply transformations
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if not name.endswith(".csv"):
|
||||
continue
|
||||
logger.info(f"Loading csv dataset {name} {opts}")
|
||||
datasets.append(load_csv(name))
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def dataset_from_iterator(name: str, iterator, lazy=False) -> list:
|
||||
"""Convert an iterator into a list of prompts and create a ProbeDataset
|
||||
object.
|
||||
async def prepare_prompts_unified(configs: list) -> list[ProbeDataset]:
|
||||
"""Prepare datasets using unified loader configuration.
|
||||
|
||||
This is an alternative to prepare_prompts() that uses the UnifiedDatasetLoader
|
||||
for streamlined configuration and merging of multiple sources.
|
||||
|
||||
Args:
|
||||
name (str): The name of the dataset.
|
||||
iterator (iterator): An iterator that generates prompts.
|
||||
configs: List of InputSourceConfig objects or dicts
|
||||
|
||||
Returns:
|
||||
list: A list containing a single ProbeDataset object.
|
||||
list[ProbeDataset]: List containing the merged dataset
|
||||
|
||||
Example:
|
||||
>>> from agentic_security.probe_data.unified_loader import InputSourceConfig
|
||||
>>> configs = [
|
||||
... InputSourceConfig(
|
||||
... source_type="huggingface",
|
||||
... dataset_name="deepset/prompt-injections",
|
||||
... enabled=True,
|
||||
... weight=1.0
|
||||
... )
|
||||
... ]
|
||||
>>> datasets = await prepare_prompts_unified(configs)
|
||||
"""
|
||||
prompts = list(iterator) if not lazy else iterator
|
||||
tokens = count_words_in_list(prompts) if not lazy else 0
|
||||
dataset = ProbeDataset(
|
||||
dataset_name=name,
|
||||
metadata={},
|
||||
prompts=prompts,
|
||||
tokens=tokens,
|
||||
approx_cost=0.0,
|
||||
lazy=lazy,
|
||||
from agentic_security.probe_data.unified_loader import (
|
||||
UnifiedDatasetLoader,
|
||||
InputSourceConfig,
|
||||
)
|
||||
return [dataset]
|
||||
|
||||
# Convert dicts to InputSourceConfig if needed
|
||||
config_objects = []
|
||||
for config in configs:
|
||||
if isinstance(config, dict):
|
||||
config_objects.append(InputSourceConfig(**config))
|
||||
else:
|
||||
config_objects.append(config)
|
||||
|
||||
# TODO: refactor this abstraction
|
||||
loader = UnifiedDatasetLoader(config_objects)
|
||||
merged_dataset = await loader.load_all()
|
||||
|
||||
dataset_map_generics = {
|
||||
"simonycl/aya-23-8B_advbench_jailbreak": lambda: load_dataset_general(
|
||||
dataset_name="simonycl/aya-23-8B_advbench_jailbreak",
|
||||
column_mappings={"prompt": "prompt"},
|
||||
),
|
||||
"acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b": lambda: load_dataset_general(
|
||||
dataset_name="acmc/jailbreaks_dataset_with_perplexity_bigcode_starcoder2-3b_bigcode_starcoder2-7b"
|
||||
),
|
||||
"karanxa/dolphin-jailbreak-finetuning-dataset": lambda: load_dataset_general(
|
||||
dataset_name="karanxa/dolphin-jailbreak-finetuning-dataset",
|
||||
column_mappings={"prompt": "text"},
|
||||
),
|
||||
"karanxa/llama-2-jailbreak-dataset": lambda: load_dataset_general(
|
||||
dataset_name="karanxa/llama-2-jailbreak-dataset",
|
||||
column_mappings={"prompt": "text"},
|
||||
),
|
||||
"karanxa/llama2-uncensored-jailbreak-dataset-finetuning": lambda: load_dataset_general(
|
||||
dataset_name="karanxa/llama2-uncensored-jailbreak-dataset-finetuning",
|
||||
column_mappings={"prompt": "text"},
|
||||
),
|
||||
"liuyanchen1015/Llama-3.2-1B_jailbreak_responses": lambda: load_dataset_general(
|
||||
dataset_name="liuyanchen1015/Llama-3.2-1B_jailbreak_responses",
|
||||
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||
),
|
||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses": lambda: load_dataset_general(
|
||||
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses",
|
||||
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||
),
|
||||
"liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment": lambda: load_dataset_general(
|
||||
dataset_name="liuyanchen1015/Llama-3.2-1B-Instruct_jailbreak_responses_with_judgment",
|
||||
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||
),
|
||||
"jackhhao/jailbreak-classification": lambda: load_dataset_general(
|
||||
dataset_name="jackhhao/jailbreak-classification",
|
||||
column_mappings={"prompt": "prompt"},
|
||||
),
|
||||
"markush1/LLM-Jailbreak-Classifier": lambda: load_dataset_general(
|
||||
dataset_name="markush1/LLM-Jailbreak-Classifier",
|
||||
column_mappings={"prompt": "prompt"},
|
||||
),
|
||||
"walledai/JailbreakBench": lambda: load_dataset_general(
|
||||
dataset_name="walledai/JailbreakBench", column_mappings={"prompt": "prompt"}
|
||||
),
|
||||
"walledai/JailbreakHub": lambda: load_dataset_general(
|
||||
dataset_name="walledai/JailbreakHub", column_mappings={"prompt": "prompt"}
|
||||
),
|
||||
"Granther/evil-jailbreak": lambda: load_dataset_general(
|
||||
dataset_name="Granther/evil-jailbreak", column_mappings={"prompt": "text"}
|
||||
),
|
||||
"sevdeawesome/jailbreak_success": lambda: load_dataset_general(
|
||||
dataset_name="sevdeawesome/jailbreak_success",
|
||||
column_mappings={"prompt": "jailbreak_prompt_text"},
|
||||
),
|
||||
"IDA-SERICS/Disaster-tweet-jailbreaking": lambda: load_dataset_general(
|
||||
dataset_name="IDA-SERICS/Disaster-tweet-jailbreaking",
|
||||
column_mappings={"prompt": "prompt_attack"},
|
||||
),
|
||||
"GeorgeDaDude/Jailbreak_Complete_DS_labeled": lambda: load_dataset_general(
|
||||
dataset_name="GeorgeDaDude/Jailbreak_Complete_DS_labeled",
|
||||
column_mappings={"prompt": "question"},
|
||||
),
|
||||
"dayone3nder/jailbreak_prompt_JBB_sft_trainset": lambda: load_dataset_general(
|
||||
dataset_name="dayone3nder/jailbreak_prompt_JBB_sft_trainset",
|
||||
column_mappings={"prompt": "prompt"},
|
||||
),
|
||||
"dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset": lambda: load_dataset_general(
|
||||
dataset_name="dayone3nder/general_safe_mix_jailbreak_prompt_JBB_trainset",
|
||||
column_mappings={"prompt": "prompt"},
|
||||
),
|
||||
}
|
||||
# Return as list for compatibility with existing code
|
||||
return [merged_dataset] if merged_dataset.prompts else []
|
||||
|
||||
@@ -3,7 +3,8 @@ import io
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ import os
|
||||
import uuid as U
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
||||
|
||||
@@ -65,7 +66,7 @@ class Module:
|
||||
return {}
|
||||
|
||||
async def fetch_prompts(self) -> list[str]:
|
||||
api_url = "https://edge.metaheuristic.co/infer"
|
||||
api_url = "https://mcp.metaheuristic.co/infer"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {AUTH_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from loguru import logger
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
# TODO: add probes modules
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
inspect_ai_task = (
|
||||
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import Deque
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
||||
|
||||
@@ -19,12 +20,10 @@ class PromptSelectionInterface(ABC):
|
||||
@abstractmethod
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
"""Selects the next prompt based on current state and guard result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
"""Selects the next prompts based on current state and guard result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_rewards(
|
||||
@@ -35,7 +34,6 @@ class PromptSelectionInterface(ABC):
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
"""Updates internal rewards based on the outcome of the last selected prompt."""
|
||||
pass
|
||||
|
||||
|
||||
class RandomPromptSelector(PromptSelectionInterface):
|
||||
@@ -120,8 +118,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class QLearningPromptSelector(PromptSelectionInterface):
|
||||
@@ -206,7 +203,11 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
|
||||
class Module:
|
||||
def __init__(
|
||||
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
|
||||
self,
|
||||
prompt_groups: list[str],
|
||||
tools_inbox: asyncio.Queue,
|
||||
opts: dict = {},
|
||||
rl_model: PromptSelectionInterface | None = None,
|
||||
):
|
||||
self.tools_inbox = tools_inbox
|
||||
self.opts = opts
|
||||
@@ -214,8 +215,8 @@ class Module:
|
||||
self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts
|
||||
self.run_id = U.uuid4().hex
|
||||
self.batch_size = self.opts.get("batch_size", 500)
|
||||
self.rl_model = CloudRLPromptSelector(
|
||||
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
|
||||
self.rl_model = rl_model or CloudRLPromptSelector(
|
||||
prompt_groups, "https://mcp.metaheuristic.co", run_id=self.run_id
|
||||
)
|
||||
|
||||
async def apply(self):
|
||||
|
||||
@@ -33,11 +33,19 @@ def mock_requests() -> Mock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rl_selector() -> Mock:
|
||||
return CloudRLPromptSelector(
|
||||
dataset_prompts,
|
||||
api_url="https://edge.metaheuristic.co",
|
||||
)
|
||||
def mock_rl_selector(dataset_prompts) -> Mock:
|
||||
class StubSelector:
|
||||
def __init__(self, prompts: list[str]):
|
||||
self.prompts = prompts
|
||||
self.idx = 0
|
||||
|
||||
def select_next_prompts(
|
||||
self, current_prompt: str, passed_guard: bool
|
||||
) -> list[str]:
|
||||
self.idx = (self.idx + 1) % len(self.prompts)
|
||||
return [self.prompts[self.idx]]
|
||||
|
||||
return StubSelector(dataset_prompts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -91,15 +99,18 @@ class TestCloudRLPromptSelector:
|
||||
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
|
||||
assert next_prompt in dataset_prompts
|
||||
|
||||
def test_select_next_prompt_success_service(self, dataset_prompts):
|
||||
def test_select_next_prompt_success_service(self, dataset_prompts, mock_requests):
|
||||
mock_requests.return_value.status_code = 200
|
||||
mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]}
|
||||
|
||||
selector = CloudRLPromptSelector(
|
||||
dataset_prompts,
|
||||
api_url="https://edge.metaheuristic.co",
|
||||
api_url="https://mcp.metaheuristic.co",
|
||||
)
|
||||
next_prompt = selector.select_next_prompt(
|
||||
"How does RL work?", passed_guard=True
|
||||
)
|
||||
assert next_prompt
|
||||
assert next_prompt == "What is AI?"
|
||||
|
||||
|
||||
# Tests for QLearningPromptSelector
|
||||
@@ -188,7 +199,7 @@ class TestModule:
|
||||
async def test_apply_basic_flow(
|
||||
self, dataset_prompts, tools_inbox, mock_rl_selector
|
||||
):
|
||||
module = Module(dataset_prompts, tools_inbox)
|
||||
module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector)
|
||||
|
||||
count = 0
|
||||
async for prompt in module.apply():
|
||||
@@ -198,7 +209,9 @@ class TestModule:
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox):
|
||||
async def test_apply_rl_with_tools_inbox(
|
||||
self, dataset_prompts, tools_inbox, mock_rl_selector
|
||||
):
|
||||
# Add a test message to the tools inbox
|
||||
test_message = {
|
||||
"message": "Test message",
|
||||
@@ -207,7 +220,7 @@ class TestModule:
|
||||
}
|
||||
await tools_inbox.put(test_message)
|
||||
|
||||
module = Module(dataset_prompts, tools_inbox)
|
||||
module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector)
|
||||
|
||||
async for output in module.apply():
|
||||
if output == "Test message":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cache_to_disk import cache_to_disk
|
||||
from cache_to_disk import cache_to_disk # noqa
|
||||
|
||||
|
||||
# TODO: refactor this class to use from .data
|
||||
@@ -22,7 +22,7 @@ class ProbeDataset:
|
||||
}
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
# @cache_to_disk(n_days_to_cache=1)
|
||||
def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Unified dataset loader for CSV, HuggingFace, and proxy sources."""
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.probe_data.data import (
|
||||
load_dataset_generic,
|
||||
load_csv,
|
||||
create_probe_dataset,
|
||||
)
|
||||
from agentic_security.probe_data.models import ProbeDataset
|
||||
|
||||
|
||||
class InputSourceConfig(BaseModel):
|
||||
"""Configuration for a single input source."""
|
||||
|
||||
source_type: Literal["csv", "huggingface", "proxy"] = Field(
|
||||
description="Type of input source"
|
||||
)
|
||||
enabled: bool = Field(default=True, description="Whether this source is enabled")
|
||||
dataset_name: str = Field(description="Name/identifier of the dataset")
|
||||
weight: float = Field(
|
||||
default=1.0, ge=0.0, description="Sampling weight for merging"
|
||||
)
|
||||
|
||||
# CSV-specific fields
|
||||
path: str | None = Field(default=None, description="File path for CSV sources")
|
||||
prompt_column: str | None = Field(
|
||||
default="prompt", description="Column name containing prompts"
|
||||
)
|
||||
|
||||
# HuggingFace-specific fields
|
||||
split: str | None = Field(
|
||||
default="train", description="Dataset split to load (train/test/validation)"
|
||||
)
|
||||
max_samples: int | None = Field(
|
||||
default=None, ge=1, description="Maximum number of samples to load"
|
||||
)
|
||||
|
||||
# URL for custom sources
|
||||
url: str | None = Field(default=None, description="URL for remote CSV files")
|
||||
|
||||
|
||||
class UnifiedDatasetLoader:
|
||||
"""Loads and merges datasets from multiple sources."""
|
||||
|
||||
def __init__(self, configs: list[InputSourceConfig]):
|
||||
"""Initialize with list of input source configurations.
|
||||
|
||||
Args:
|
||||
configs: List of InputSourceConfig objects defining data sources
|
||||
"""
|
||||
self.configs = configs
|
||||
logger.info(f"Initialized UnifiedDatasetLoader with {len(configs)} sources")
|
||||
|
||||
async def load_all(self) -> ProbeDataset:
|
||||
"""Load all enabled sources and merge into a single dataset.
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Merged dataset from all enabled sources
|
||||
"""
|
||||
datasets = []
|
||||
|
||||
for config in self.configs:
|
||||
if not config.enabled:
|
||||
logger.debug(f"Skipping disabled source: {config.dataset_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
dataset = await self._load_single(config)
|
||||
if dataset and dataset.prompts:
|
||||
datasets.append((dataset, config.weight))
|
||||
logger.info(
|
||||
f"Loaded {len(dataset.prompts)} prompts from {config.dataset_name} "
|
||||
f"(weight={config.weight})"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"No prompts loaded from {config.dataset_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {config.dataset_name}: {e}")
|
||||
|
||||
if not datasets:
|
||||
logger.warning("No datasets loaded successfully")
|
||||
return create_probe_dataset("unified_empty", [], {"sources": []})
|
||||
|
||||
return self._merge_weighted(datasets)
|
||||
|
||||
async def _load_single(self, config: InputSourceConfig) -> ProbeDataset:
|
||||
"""Load a single dataset based on its configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for the source to load
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Loaded dataset
|
||||
"""
|
||||
if config.source_type == "csv":
|
||||
return self._load_csv_source(config)
|
||||
elif config.source_type == "huggingface":
|
||||
return self._load_huggingface_source(config)
|
||||
elif config.source_type == "proxy":
|
||||
return self._load_proxy_source(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown source type: {config.source_type}")
|
||||
|
||||
def _load_csv_source(self, config: InputSourceConfig) -> ProbeDataset:
|
||||
"""Load dataset from CSV file.
|
||||
|
||||
Args:
|
||||
config: CSV source configuration
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Dataset loaded from CSV
|
||||
"""
|
||||
if config.path:
|
||||
# Local CSV file
|
||||
logger.info(f"Loading CSV from path: {config.path}")
|
||||
dataset = load_csv(config.path)
|
||||
elif config.url:
|
||||
# Remote CSV file
|
||||
logger.info(f"Loading CSV from URL: {config.url}")
|
||||
mappings = (
|
||||
{config.prompt_column: "prompt"} if config.prompt_column else None
|
||||
)
|
||||
dataset = load_dataset_generic(
|
||||
name=config.dataset_name,
|
||||
url=config.url,
|
||||
mappings=mappings,
|
||||
metadata={"source_type": "csv", "url": config.url},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"CSV source {config.dataset_name} requires either path or url"
|
||||
)
|
||||
|
||||
# Apply max_samples limit if specified
|
||||
if config.max_samples and len(dataset.prompts) > config.max_samples:
|
||||
logger.info(
|
||||
f"Limiting {config.dataset_name} from {len(dataset.prompts)} "
|
||||
f"to {config.max_samples} samples"
|
||||
)
|
||||
dataset.prompts = dataset.prompts[: config.max_samples]
|
||||
|
||||
return dataset
|
||||
|
||||
def _load_huggingface_source(self, config: InputSourceConfig) -> ProbeDataset:
|
||||
"""Load dataset from HuggingFace.
|
||||
|
||||
Args:
|
||||
config: HuggingFace source configuration
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Dataset loaded from HuggingFace
|
||||
"""
|
||||
logger.info(
|
||||
f"Loading HuggingFace dataset: {config.dataset_name} "
|
||||
f"(split={config.split})"
|
||||
)
|
||||
|
||||
# Build column mappings
|
||||
mappings = None
|
||||
if config.prompt_column and config.prompt_column != "prompt":
|
||||
mappings = {config.prompt_column: "prompt"}
|
||||
|
||||
dataset = load_dataset_generic(
|
||||
name=config.dataset_name,
|
||||
mappings=mappings,
|
||||
metadata={
|
||||
"source_type": "huggingface",
|
||||
"split": config.split,
|
||||
},
|
||||
)
|
||||
|
||||
# Apply max_samples limit if specified
|
||||
if config.max_samples and len(dataset.prompts) > config.max_samples:
|
||||
logger.info(
|
||||
f"Limiting {config.dataset_name} from {len(dataset.prompts)} "
|
||||
f"to {config.max_samples} samples"
|
||||
)
|
||||
dataset.prompts = dataset.prompts[: config.max_samples]
|
||||
|
||||
return dataset
|
||||
|
||||
def _load_proxy_source(self, config: InputSourceConfig) -> ProbeDataset:
|
||||
"""Load dataset from proxy queue (placeholder for PoC).
|
||||
|
||||
Args:
|
||||
config: Proxy source configuration
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Empty dataset (proxy integration not implemented in PoC)
|
||||
"""
|
||||
logger.warning(
|
||||
f"Proxy source {config.dataset_name} not implemented in PoC - returning empty dataset"
|
||||
)
|
||||
return create_probe_dataset(
|
||||
config.dataset_name,
|
||||
[],
|
||||
{"source_type": "proxy", "status": "not_implemented"},
|
||||
)
|
||||
|
||||
def _merge_weighted(
|
||||
self, datasets: list[tuple[ProbeDataset, float]]
|
||||
) -> ProbeDataset:
|
||||
"""Merge multiple datasets with weighted sampling.
|
||||
|
||||
For PoC, this implements simple concatenation with optional weighting.
|
||||
Production version would implement proper stratified sampling.
|
||||
|
||||
Args:
|
||||
datasets: List of (ProbeDataset, weight) tuples
|
||||
|
||||
Returns:
|
||||
ProbeDataset: Merged dataset
|
||||
"""
|
||||
if not datasets:
|
||||
return create_probe_dataset("unified_empty", [], {"sources": []})
|
||||
|
||||
# For PoC: simple concatenation, repeat prompts based on weight
|
||||
all_prompts = []
|
||||
source_names = []
|
||||
total_tokens = 0
|
||||
|
||||
for dataset, weight in datasets:
|
||||
source_names.append(dataset.dataset_name)
|
||||
|
||||
# Calculate how many times to include this dataset based on weight
|
||||
# Weight of 1.0 = include once, 2.0 = include twice, etc.
|
||||
repeat_count = max(1, int(weight))
|
||||
|
||||
for _ in range(repeat_count):
|
||||
all_prompts.extend(dataset.prompts)
|
||||
|
||||
total_tokens += dataset.tokens * repeat_count
|
||||
|
||||
logger.info(
|
||||
f"Merged {len(datasets)} datasets into {len(all_prompts)} total prompts "
|
||||
f"from sources: {source_names}"
|
||||
)
|
||||
|
||||
return ProbeDataset(
|
||||
dataset_name="unified",
|
||||
metadata={
|
||||
"sources": source_names,
|
||||
"source_count": len(datasets),
|
||||
"weights": {ds.dataset_name: w for ds, w in datasets},
|
||||
},
|
||||
prompts=all_prompts,
|
||||
tokens=total_tokens,
|
||||
approx_cost=0.0,
|
||||
)
|
||||
@@ -1,8 +1,10 @@
|
||||
import importlib.resources as pkg_resources
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import joblib
|
||||
import pandas as pd
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import OneClassSVM
|
||||
@@ -70,27 +72,31 @@ class RefusalClassifier:
|
||||
"""
|
||||
Load the trained model, vectorizer, and scaler from disk.
|
||||
"""
|
||||
try:
|
||||
self.model = joblib.load(self.model_path)
|
||||
self.vectorizer = joblib.load(self.vectorizer_path)
|
||||
self.scaler = joblib.load(self.scaler_path)
|
||||
except FileNotFoundError:
|
||||
# Load from package resources
|
||||
package = (
|
||||
__package__ # This should be 'agentic_security.refusal_classifier'
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
|
||||
try:
|
||||
self.model = joblib.load(self.model_path)
|
||||
self.vectorizer = joblib.load(self.vectorizer_path)
|
||||
self.scaler = joblib.load(self.scaler_path)
|
||||
except FileNotFoundError:
|
||||
# Load from package resources
|
||||
package = (
|
||||
__package__ # This should be 'agentic_security.refusal_classifier'
|
||||
)
|
||||
|
||||
# Load model
|
||||
with pkg_resources.open_binary(package, "oneclass_svm_model.joblib") as f:
|
||||
self.model = joblib.load(f)
|
||||
# Load model
|
||||
with pkg_resources.open_binary(
|
||||
package, "oneclass_svm_model.joblib"
|
||||
) as f:
|
||||
self.model = joblib.load(f)
|
||||
|
||||
# Load vectorizer
|
||||
with pkg_resources.open_binary(package, "tfidf_vectorizer.joblib") as f:
|
||||
self.vectorizer = joblib.load(f)
|
||||
# Load vectorizer
|
||||
with pkg_resources.open_binary(package, "tfidf_vectorizer.joblib") as f:
|
||||
self.vectorizer = joblib.load(f)
|
||||
|
||||
# Load scaler
|
||||
with pkg_resources.open_binary(package, "scaler.joblib") as f:
|
||||
self.scaler = joblib.load(f)
|
||||
# Load scaler
|
||||
with pkg_resources.open_binary(package, "scaler.joblib") as f:
|
||||
self.scaler = joblib.load(f)
|
||||
|
||||
def is_refusal(self, text):
|
||||
"""
|
||||
|
||||
@@ -7,9 +7,62 @@ import pandas as pd
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import LinearSegmentedColormap, Normalize
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
def plot_security_report(table):
|
||||
from .primitives import Table
|
||||
|
||||
|
||||
def plot_security_report(table: Table) -> io.BytesIO:
|
||||
"""
|
||||
Generates a polar plot representing the security report based on the given data.
|
||||
|
||||
Args:
|
||||
table (Table): The input data table containing security metrics.
|
||||
|
||||
Returns:
|
||||
io.BytesIO: A buffer containing the generated plot image in PNG format.
|
||||
Returns an empty buffer in case of an error.
|
||||
"""
|
||||
try:
|
||||
return _plot_security_report(table=table)
|
||||
except (TypeError, ValueError, OverflowError, IndexError, Exception) as e:
|
||||
logger.error(f"Error in generating the security report: {e} {table}")
|
||||
return io.BytesIO()
|
||||
|
||||
|
||||
def generate_identifiers(data: pd.DataFrame) -> list[str]:
|
||||
"""
|
||||
Generates unique identifiers for the given dataset.
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): A pandas DataFrame containing security-related data.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of generated identifiers. Returns a list with an empty string in case of an error.
|
||||
"""
|
||||
return _generate_identifiers(data=data)
|
||||
|
||||
|
||||
def _plot_security_report(table: Table) -> io.BytesIO:
|
||||
"""
|
||||
Generates a polar plot-based security report visualizing the failure rates for different modules.
|
||||
|
||||
This function processes the input data, sorts it by failure rate, and generates a polar plot
|
||||
where each bar represents the failure rate for a specific module. The plot includes identifiers,
|
||||
color-coding based on token count, failure rate values on the bars, and a table listing the modules
|
||||
and their corresponding failure rates.
|
||||
|
||||
Args:
|
||||
table (Table): A table-like structure (e.g., pandas DataFrame) containing security report data
|
||||
with columns for failure rate, tokens, and modules.
|
||||
|
||||
Returns:
|
||||
io.BytesIO: A buffer containing the generated plot image in PNG format.
|
||||
"""
|
||||
return io.BytesIO()
|
||||
# Data preprocessing
|
||||
logger.info("Data preprocessing started.")
|
||||
|
||||
data = pd.DataFrame(table)
|
||||
|
||||
# Sort by failure rate and reset index
|
||||
@@ -20,10 +73,10 @@ def plot_security_report(table):
|
||||
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={"projection": "polar"})
|
||||
fig.set_facecolor("#f0f0f0")
|
||||
ax.set_facecolor("#f0f0f0")
|
||||
logger.info("Plot setup complete.")
|
||||
|
||||
# Styling parameters
|
||||
colors = ["#6C5B7B", "#C06C84", "#F67280", "#F8B195"][::-1] # Pastel palette
|
||||
# colors = ["#440154", "#3b528b", "#21908c", "#5dc863"] # Viridis-inspired palette
|
||||
cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
|
||||
norm = Normalize(vmin=data["tokens"].min(), vmax=data["tokens"].max())
|
||||
|
||||
@@ -74,7 +127,10 @@ def plot_security_report(table):
|
||||
|
||||
# Title and caption
|
||||
fig.suptitle(
|
||||
"Security Report for Different Modules", fontsize=16, fontweight="bold", y=1.02
|
||||
"Security Report for Different Modules",
|
||||
fontsize=16,
|
||||
fontweight="bold",
|
||||
y=1.02,
|
||||
)
|
||||
caption = "Report generated by https://github.com/msoedov/agentic_security"
|
||||
fig.text(
|
||||
@@ -112,17 +168,12 @@ def plot_security_report(table):
|
||||
data["identifier"], data["failureRate"], data["module"]
|
||||
)
|
||||
]
|
||||
table = ax.table(
|
||||
cellText=table_data,
|
||||
loc="right",
|
||||
cellLoc="left",
|
||||
)
|
||||
table = ax.table(cellText=table_data, loc="right", cellLoc="left")
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(8)
|
||||
|
||||
# Adjust table style
|
||||
table.scale(1, 0.7)
|
||||
|
||||
for (row, col), cell in table.get_celld().items():
|
||||
cell.set_edgecolor("none")
|
||||
cell.set_facecolor("#f0f0f0" if row % 2 == 0 else "#e0e0e0")
|
||||
@@ -132,17 +183,33 @@ def plot_security_report(table):
|
||||
cell.set_text_props(fontweight="bold")
|
||||
|
||||
# Adjust layout and save
|
||||
|
||||
plt.tight_layout()
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png", dpi=300, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
logger.info("Report successfully generated and saved to buffer.")
|
||||
return buf
|
||||
|
||||
|
||||
def generate_identifiers(data):
|
||||
def _generate_identifiers(data: pd.DataFrame) -> list[str]:
|
||||
"""
|
||||
Generates a list of unique identifiers for each row in the given DataFrame.
|
||||
|
||||
The identifiers are based on the English alphabet, with each identifier consisting
|
||||
of a letter followed by a number. The letter represents the "group" of identifiers
|
||||
(using a letter from A to Z) and the number is a counter within that group. For example:
|
||||
- A1, A2, ..., A26, B1, B2, ..., Z1, Z2, ...
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): The input DataFrame containing data for which identifiers
|
||||
are to be generated.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of unique identifiers as strings, one for each row in the DataFrame.
|
||||
"""
|
||||
data_length = len(data)
|
||||
|
||||
alphabet = string.ascii_uppercase
|
||||
num_letters = len(alphabet)
|
||||
|
||||
|
||||
@@ -0,0 +1,397 @@
|
||||
_SPECS = [
|
||||
"""POST ${SELF_URL}/v1/self-probe
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>"
|
||||
}
|
||||
|
||||
""",
|
||||
"""POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer $OPENAI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
""",
|
||||
"""
|
||||
POST https://api.deepseek.com/chat/completions
|
||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
""",
|
||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"input": {
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
||||
"temperature": 0.6,
|
||||
"max_new_tokens": 1024,
|
||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}
|
||||
""",
|
||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model_id": "codellama-34b",
|
||||
"system_prompt": "You are helpful and concise coding assistant",
|
||||
"user_prompt": "<<PROMPT>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.together.xyz/v1/chat/completions
|
||||
Authorization: Bearer $TOGETHER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert travel guide"},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
]
|
||||
}
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-image
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-file
|
||||
Authorization: Bearer $API_KEY
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
"file": "@./sample_audio.m4a",
|
||||
"model": "whisper-large-v3"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.gemini.com/v1/generate
|
||||
Authorization: Bearer $GEMINI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gemini-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.anthropic.com/v1/complete
|
||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens_to_sample": 256,
|
||||
"stop_sequences": ["\n\nHuman:"]
|
||||
}
|
||||
""",
|
||||
"""POST https://api.cohere.ai/generate
|
||||
Authorization: Bearer $COHERE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "command-xlarge-nightly",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.75,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
}
|
||||
""",
|
||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
||||
Authorization: Bearer $AZURE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.assemblyai.com/v2/transcript
|
||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
||||
Authorization: Bearer $OPENROUTER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "openrouter-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
LLM_SPECS = [
|
||||
"""POST ${SELF_URL}/v1/self-probe
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>"
|
||||
}
|
||||
|
||||
""",
|
||||
"""POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer $OPENAI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
""",
|
||||
"""
|
||||
POST https://api.deepseek.com/chat/completions
|
||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
""",
|
||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"input": {
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
||||
"temperature": 0.6,
|
||||
"max_new_tokens": 1024,
|
||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}
|
||||
""",
|
||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model_id": "codellama-34b",
|
||||
"system_prompt": "You are helpful and concise coding assistant",
|
||||
"user_prompt": "<<PROMPT>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.together.xyz/v1/chat/completions
|
||||
Authorization: Bearer $TOGETHER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert travel guide"},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
]
|
||||
}
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-image
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-file
|
||||
Authorization: Bearer $API_KEY
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
"file": "@./sample_audio.m4a",
|
||||
"model": "whisper-large-v3"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.gemini.com/v1/generate
|
||||
Authorization: Bearer $GEMINI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gemini-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.anthropic.com/v1/complete
|
||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens_to_sample": 256,
|
||||
"stop_sequences": ["\n\nHuman:"]
|
||||
}
|
||||
""",
|
||||
"""POST https://api.cohere.ai/generate
|
||||
Authorization: Bearer $COHERE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "command-xlarge-nightly",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.75,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
}
|
||||
""",
|
||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
||||
Authorization: Bearer $AZURE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.assemblyai.com/v2/transcript
|
||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
||||
Authorization: Bearer $OPENROUTER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "openrouter-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
LLM_CONFIGS = [
|
||||
{
|
||||
"name": "Custom API",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{"name": "Open AI", "prompts": 24000, "logo": "/icons/openai.png"},
|
||||
{"name": "Deepseek v1", "prompts": 24000, "logo": "/icons/deepseek.png"},
|
||||
{"name": "Replicate", "prompts": 40000, "logo": "/icons/replicate.png"},
|
||||
{"name": "Groq", "prompts": 40000, "logo": "/icons/groq.png"},
|
||||
{"name": "Together.ai", "prompts": 40000, "logo": "/icons/together.png"},
|
||||
{
|
||||
"name": "Custom API Image",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"modality": "Image",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{
|
||||
"name": "Custom API Files",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"modality": "Files",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{"name": "Gemini", "prompts": 40000, "logo": "/icons/gemini.png"},
|
||||
{"name": "Claude", "prompts": 40000, "logo": "/icons/claude.png"},
|
||||
{"name": "Cohere", "prompts": 40000, "logo": "/icons/cohere.png"},
|
||||
{"name": "Azure OpenAI", "prompts": 40000, "logo": "/icons/azureai.png"},
|
||||
{"name": "assemblyai", "prompts": 40000, "logo": "/icons/myshell.png"},
|
||||
{"name": "OpenRouter.ai", "prompts": 40000, "logo": "/icons/openrouter.png"},
|
||||
]
|
||||
|
||||
LLM_SPECS = [dict(spec=spec, **d) for spec, d in zip(_SPECS, LLM_CONFIGS)]
|
||||
@@ -6,6 +6,7 @@ from fastapi.responses import JSONResponse
|
||||
from ..primitives import FileProbeResponse, Probe
|
||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||
from ..probe_data import REGISTRY
|
||||
from ._specs import LLM_SPECS
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -73,6 +74,12 @@ async def data_config():
|
||||
return [m for m in REGISTRY]
|
||||
|
||||
|
||||
@router.get("/v1/llm-specs", response_model=list)
|
||||
def get_llm_specs():
|
||||
"""Returns the LLM API specifications."""
|
||||
return LLM_SPECS
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
|
||||
@@ -2,7 +2,8 @@ import random
|
||||
from asyncio import Event
|
||||
|
||||
from fastapi import APIRouter
|
||||
from loguru import logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
from ..core.app import get_current_run, get_tools_inbox
|
||||
from ..primitives import CompletionRequest, Settings
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -11,9 +13,11 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
|
||||
from ..dependencies import InMemorySecrets, get_in_memory_secrets
|
||||
from ..http_spec import LLMSpec
|
||||
from ..http_spec import InvalidHTTPSpecError, LLMSpec
|
||||
from ..primitives import LLMInfo, Scan
|
||||
from ..probe_actor import fuzzer
|
||||
|
||||
@@ -23,9 +27,16 @@ router = APIRouter()
|
||||
@router.post("/verify")
|
||||
async def verify(
|
||||
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
|
||||
):
|
||||
) -> dict[str, int | str | float]:
|
||||
spec = LLMSpec.from_string(info.spec)
|
||||
r = await spec.verify()
|
||||
try:
|
||||
r = await spec.verify()
|
||||
except InvalidHTTPSpecError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=r.status_code, detail=r.text)
|
||||
return dict(
|
||||
@@ -36,7 +47,7 @@ async def verify(
|
||||
)
|
||||
|
||||
|
||||
def streaming_response_generator(scan_parameters: Scan):
|
||||
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]:
|
||||
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
|
||||
set_current_run(request_factory)
|
||||
|
||||
@@ -57,7 +68,7 @@ async def scan(
|
||||
scan_parameters: Scan,
|
||||
background_tasks: BackgroundTasks,
|
||||
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
||||
):
|
||||
) -> StreamingResponse:
|
||||
scan_parameters.with_secrets(secrets)
|
||||
return StreamingResponse(
|
||||
streaming_response_generator(scan_parameters), media_type="application/json"
|
||||
@@ -65,7 +76,7 @@ async def scan(
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_scan():
|
||||
async def stop_scan() -> dict[str, str]:
|
||||
get_stop_event().set()
|
||||
return {"status": "Scan stopped"}
|
||||
|
||||
@@ -79,7 +90,7 @@ async def scan_csv(
|
||||
maxBudget: int = Query(10_000),
|
||||
enableMultiStepAttack: bool = Query(False),
|
||||
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
|
||||
):
|
||||
) -> StreamingResponse:
|
||||
# TODO: content dataset to fuzzer
|
||||
content = await file.read() # noqa
|
||||
llm_spec = await llmSpec.read()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import sentry_sdk
|
||||
from loguru import logger
|
||||
from sentry_sdk.integrations.logging import ignore_logger
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
from ..primitives import Settings
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ Content-Type: application/json
|
||||
]
|
||||
`,
|
||||
`POST ${SELF_URL}/v1/self-probe-file
|
||||
Authorization: Bearer $GROQ_API_KEY
|
||||
Authorization: Bearer $API_KEY
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
@@ -204,8 +204,9 @@ let LLM_CONFIGS = [
|
||||
{ name: 'Claude', prompts: 40000, logo: '/icons/claude.png' },
|
||||
{ name: 'Cohere', prompts: 40000, logo: '/icons/cohere.png' },
|
||||
{ name: 'Azure OpenAI', prompts: 40000, logo: '/icons/azureai.png' },
|
||||
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
|
||||
{ name: 'assemblyai', prompts: 40000, logo: fallbackIcon },
|
||||
{ name: 'OpenRouter.ai', prompts: 40000, logo: '/icons/openrouter.png' },
|
||||
|
||||
];
|
||||
function has_image(spec) {
|
||||
return spec.includes('<<BASE64_IMAGE>>');
|
||||
|
||||
@@ -63,6 +63,32 @@
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
<div class="fixed top-6 right-6 z-50 space-y-3">
|
||||
<transition-group name="toast">
|
||||
<div
|
||||
v-for="toast in toasts"
|
||||
:key="toast.id"
|
||||
class="flex items-center p-3 rounded-xl shadow-xl text-white max-w-md animate-toast-in border border-opacity-30"
|
||||
:class="{
|
||||
'bg-success-toast border-accent-green': toast.type === 'success',
|
||||
'bg-error-toast border-accent-red': toast.type === 'error',
|
||||
'bg-info-toast border-accent-orange': toast.type === 'info'
|
||||
}"
|
||||
>
|
||||
<span class="flex-1 font-medium tracking-wide text-sm">{{ toast.message }}</span>
|
||||
<button
|
||||
@click="removeToast(toast.id)"
|
||||
class="ml-3 focus:outline-none hover:opacity-80 transition-opacity"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</transition-group>
|
||||
</div>
|
||||
|
||||
|
||||
<main class="max-w-6xl mx-auto space-y-8">
|
||||
<section class="bg-dark-card rounded-lg p-6 shadow-lg" v-show="false">
|
||||
<h2 class="text-2xl font-bold mb-4">Select a Config</h2>
|
||||
|
||||
@@ -25,6 +25,8 @@ var app = new Vue({
|
||||
showModules: false,
|
||||
showLogs: false,
|
||||
showConsentModal: true,
|
||||
toasts: [], // Array to store toast notifications
|
||||
toastTimeout: 3000, // Duration in milliseconds (3 seconds)
|
||||
statusDotClass: 'bg-gray-500', // Default status dot class
|
||||
statusText: 'Verified', // Default status text
|
||||
statusClass: 'bg-green-500 text-dark-bg', // Default status class
|
||||
@@ -93,21 +95,36 @@ var app = new Vue({
|
||||
|
||||
},
|
||||
methods: {
|
||||
showToast(message, type = 'success') {
|
||||
const id = Date.now(); // Unique ID for each toast
|
||||
this.toasts.push({ id, message, type });
|
||||
|
||||
// Automatically remove toast after timeout
|
||||
setTimeout(() => {
|
||||
this.removeToast(id);
|
||||
}, this.toastTimeout);
|
||||
},
|
||||
|
||||
removeToast(id) {
|
||||
this.toasts = this.toasts.filter(toast => toast.id !== id);
|
||||
},
|
||||
focusTextarea() {
|
||||
this.isFocused = true;
|
||||
self = this.$refs;
|
||||
// Remove 'self' assignment if not used elsewhere
|
||||
this.$nextTick(() => {
|
||||
// Focus the textarea after rendering
|
||||
self.textarea.focus();
|
||||
this.adjustHeight({ target: self.textarea });
|
||||
this.$refs.textarea.focus();
|
||||
this.adjustHeight({ target: this.$refs.textarea });
|
||||
});
|
||||
document.addEventListener("mousedown", this.handleClickOutside);
|
||||
|
||||
// Correct the event listener to use handleOutsideClick
|
||||
document.addEventListener("mousedown", this.handleOutsideClick);
|
||||
},
|
||||
handleOutsideClick(event) {
|
||||
if (!this.$refs.container.contains(event.target)) {
|
||||
if (!this.$refs.textarea) {
|
||||
return
|
||||
}
|
||||
if (!this.$refs.textarea.contains(event.target)) {
|
||||
this.isFocused = false;
|
||||
document.removeEventListener("mousedown", this.handleClickOutside);
|
||||
document.removeEventListener("mousedown", this.handleOutsideClick);
|
||||
}
|
||||
},
|
||||
unfocusTextarea() {
|
||||
@@ -115,7 +132,12 @@ var app = new Vue({
|
||||
},
|
||||
acceptConsent() {
|
||||
this.showConsentModal = false; // Close the modal
|
||||
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
||||
|
||||
try {
|
||||
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
||||
} catch (e) {
|
||||
this.showToast('Failed to save consent', 'error'); // Show error if saving fails
|
||||
}
|
||||
},
|
||||
|
||||
saveStateToLocalStorage() {
|
||||
@@ -156,6 +178,7 @@ var app = new Vue({
|
||||
this.integrationVerified = false;
|
||||
this.showResetConfirmation = false;
|
||||
this.enableMultiStepAttack = false;
|
||||
this.showToast('All settings have been reset to default', 'info');
|
||||
},
|
||||
confirmResetState() {
|
||||
this.showResetConfirmation = true;
|
||||
@@ -194,31 +217,39 @@ var app = new Vue({
|
||||
spec: this.modelSpec,
|
||||
};
|
||||
let startTime = performance.now(); // Capture start time
|
||||
const response = await fetch(`${SELF_URL}/verify`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
console.log(response);
|
||||
let r = await response.json();
|
||||
let endTime = performance.now(); // Capture end time
|
||||
let latency = endTime - startTime; // Calculate latency in milliseconds
|
||||
latency = latency.toFixed(3) / 1000; // Round to 2 decimal places
|
||||
this.latency = latency;
|
||||
if (!response.ok) {
|
||||
this.updateStatusDot(false);
|
||||
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
||||
} else {
|
||||
this.errorMsg = '';
|
||||
this.updateStatusDot(true);
|
||||
this.okMsg = 'Integration verified';
|
||||
this.integrationVerified = true;
|
||||
// console.log('Integration verified', this.integrationVerified);
|
||||
// this.$forceUpdate();
|
||||
|
||||
try {
|
||||
const response = await fetch(`${SELF_URL}/verify`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
let r = await response.json();
|
||||
|
||||
let endTime = performance.now(); // Capture end time
|
||||
let latency = ((endTime - startTime) / 1000).toFixed(3); // Calculate latency in milliseconds
|
||||
this.latency = latency;
|
||||
|
||||
if (!response.ok) {
|
||||
this.updateStatusDot(false);
|
||||
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
||||
this.showToast('Integration verification failed', 'error');
|
||||
} else {
|
||||
this.errorMsg = '';
|
||||
this.updateStatusDot(true);
|
||||
this.okMsg = 'Integration verified';
|
||||
this.showToast('Integration verified successfully', 'success');
|
||||
this.integrationVerified = true;
|
||||
}
|
||||
} catch (error) {
|
||||
this.updateStatusDot(true);
|
||||
this.errorMsg = 'Server unreachable';
|
||||
this.showToast('Network error', 'error');
|
||||
}
|
||||
|
||||
this.saveStateToLocalStorage();
|
||||
},
|
||||
loadConfigs: async function () {
|
||||
@@ -240,6 +271,7 @@ var app = new Vue({
|
||||
this.errorMsg = '';
|
||||
this.okMsg = '';
|
||||
this.integrationVerified = false;
|
||||
this.showToast(`Config ${index + 1} selected`, 'info');
|
||||
},
|
||||
toggleModules() {
|
||||
this.showModules = !this.showModules;
|
||||
@@ -327,6 +359,7 @@ var app = new Vue({
|
||||
return
|
||||
}
|
||||
console.log('New row');
|
||||
this.showToast('New module', 'success');
|
||||
let payload = {
|
||||
table: this.mainTable,
|
||||
};
|
||||
@@ -437,6 +470,8 @@ var app = new Vue({
|
||||
}
|
||||
});
|
||||
}
|
||||
this.scanRunning = false;
|
||||
this.showToast('Scan finished successfully', 'success');
|
||||
this.saveStateToLocalStorage();
|
||||
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
technopollas: ['Technopollas', 'sans-serif'],
|
||||
},
|
||||
colors: {
|
||||
t1: {
|
||||
dark: {
|
||||
bg: '#0D0D0D', // Jet Black
|
||||
card: '#1A1A1A', // Dark Carbon Fiber
|
||||
text: '#FFFFFF',
|
||||
@@ -30,7 +30,7 @@
|
||||
yellow: '#2E4053', // Jet Black
|
||||
},
|
||||
},
|
||||
dark: {
|
||||
dark1: {
|
||||
bg: '#121212',
|
||||
card: '#1E1E1E',
|
||||
text: '#FFFFFF',
|
||||
@@ -86,6 +86,7 @@
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.scrollbar-hide::-webkit-scrollbar {
|
||||
display: none;
|
||||
@@ -96,4 +97,55 @@
|
||||
}
|
||||
|
||||
</style>
|
||||
<style>
|
||||
/* Toast-specific colors */
|
||||
.bg-success-toast {
|
||||
background: #1C3F74
|
||||
}
|
||||
|
||||
.bg-error-toast {
|
||||
background: #85144B
|
||||
}
|
||||
|
||||
.bg-info-toast {
|
||||
background: #FFC300
|
||||
}
|
||||
|
||||
.border-accent-green, .border-accent-red, .border-accent-orange {
|
||||
border-color: rgba(255, 255, 255, 0.1); /* Subtle white border for depth */
|
||||
}
|
||||
|
||||
/* Animation */
|
||||
.animate-toast-in {
|
||||
animation: tSlideIn 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
|
||||
}
|
||||
|
||||
@keyframes tSlideIn {
|
||||
from {
|
||||
transform: translateX(120%) scale(0.95);
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
transform: translateX(0) scale(1);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.toast-enter-active,
|
||||
.toast-leave-active {
|
||||
transition: all 0.4s cubic-bezier(0.25, 0.1, 0.25, 1);
|
||||
}
|
||||
|
||||
.toast-enter-from,
|
||||
.toast-leave-to {
|
||||
transform: translateX(120%) scale(0.95);
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Hover effect */
|
||||
[toast-type]:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.6);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
+13
-17
@@ -33,7 +33,7 @@ The `LLMSpec` class is the core of the HTTP specification. It provides the follo
|
||||
### Methods
|
||||
|
||||
- **`from_string(http_spec: str) -> LLMSpec`**: Parses an HTTP specification string into an `LLMSpec` object.
|
||||
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> None`**: Validates the request parameters based on the specified modality.
|
||||
- **`validate(prompt: str, encoded_image: str, encoded_audio: str, files: dict) -> null`**: Validates the request parameters based on the specified modality.
|
||||
- **`probe(prompt: str, encoded_image: str = "", encoded_audio: str = "", files: dict = {}) -> httpx.Response`**: Sends an HTTP request using the specified parameters.
|
||||
- **`verify() -> httpx.Response`**: Verifies the HTTP specification by sending a test request.
|
||||
|
||||
@@ -52,12 +52,11 @@ Authorization: Bearer sk-xxxxxxxxx
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
"""
|
||||
|
||||
spec = LLMSpec.from_string(http_spec)
|
||||
response = await spec.probe("What is the capital of France?")
|
||||
```
|
||||
@@ -71,12 +70,11 @@ Authorization: Bearer sk-xxxxxxxxx
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [{"role": "user", "content": "What is in this image? <<BASE64_IMAGE>>"}],
|
||||
"temperature": 0.7
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [{"role": "user", "content": "What is in this image? <<BASE64_IMAGE>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
"""
|
||||
|
||||
spec = LLMSpec.from_string(http_spec)
|
||||
encoded_image = encode_image_base64_by_url("https://example.com/image.jpg")
|
||||
response = await spec.probe("What is in this image?", encoded_image=encoded_image)
|
||||
@@ -91,12 +89,11 @@ Authorization: Bearer sk-xxxxxxxxx
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "whisper-large-v3",
|
||||
"messages": [{"role": "user", "content": "Transcribe this audio: <<BASE64_AUDIO>>"}],
|
||||
"temperature": 0.7
|
||||
"model": "whisper-large-v3",
|
||||
"messages": [{"role": "user", "content": "Transcribe this audio: <<BASE64_AUDIO>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
"""
|
||||
|
||||
spec = LLMSpec.from_string(http_spec)
|
||||
encoded_audio = encode_audio_base64_by_url("https://example.com/audio.mp3")
|
||||
response = await spec.probe("Transcribe this audio:", encoded_audio=encoded_audio)
|
||||
@@ -111,12 +108,11 @@ Authorization: Bearer sk-xxxxxxxxx
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Process this file: <<FILE>>"}],
|
||||
"temperature": 0.7
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Process this file: <<FILE>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
"""
|
||||
|
||||
spec = LLMSpec.from_string(http_spec)
|
||||
files = {"file": ("document.txt", open("document.txt", "rb"))}
|
||||
response = await spec.probe("Process this file:", files=files)
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 9.0 MiB |
+1
-1
@@ -21,4 +21,4 @@ Note: Please be aware that Agentic Security is designed as a safety scanner tool
|
||||
|
||||
## UI 🧙
|
||||
|
||||
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1736433557/z0bsyzhsqlgcr3w4ovwp.gif">
|
||||
<img width="100%" alt="booking-screen" src="https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741192668/final_aa9jhb.gif">
|
||||
|
||||
+8
-18
@@ -54,20 +54,15 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
||||
|
||||
- **Classes:**
|
||||
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
|
||||
|
||||
- Methods:
|
||||
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
|
||||
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
|
||||
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
|
||||
|
||||
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> null`: Updates rewards
|
||||
- `RandomPromptSelector`: Basic random selection with history tracking.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `history_size: int = 3`: Size of history to prevent cycles
|
||||
|
||||
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `api_url: str`: URL of RL service
|
||||
@@ -75,9 +70,7 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
||||
- `history_size: int = 300`: Size of history
|
||||
- `timeout: int = 5`: Request timeout
|
||||
- `run_id: str = ""`: Unique run identifier
|
||||
|
||||
- `QLearningPromptSelector`: Local Q-learning implementation.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `learning_rate: float = 0.1`: Learning rate
|
||||
@@ -86,13 +79,11 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
||||
- `exploration_decay: float = 0.995`: Exploration decay rate
|
||||
- `min_exploration: float = 0.01`: Minimum exploration rate
|
||||
- `history_size: int = 300`: Size of history
|
||||
|
||||
- `Module`: Main class that uses CloudRLPromptSelector.
|
||||
|
||||
- Parameters:
|
||||
- `prompt_groups: list[str]`: Groups of prompts
|
||||
- `tools_inbox: asyncio.Queue`: Queue for tool communication
|
||||
- `opts: dict = {}`: Configuration options
|
||||
- **Module**: Main class that uses CloudRLPromptSelector.
|
||||
- Parameters:
|
||||
- `prompt_groups: list[str]`: Groups of prompts
|
||||
- `tools_inbox: asyncio.Queue`: Queue for tool communication
|
||||
- `opts: dict = {}`: Configuration options
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -119,10 +110,9 @@ from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
|
||||
|
||||
prompts = ["What is AI?", "Explain machine learning"]
|
||||
selector = QLearningPromptSelector(prompts)
|
||||
|
||||
current_prompt = "What is AI?"
|
||||
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
|
||||
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
|
||||
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=true)
|
||||
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=true)
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
:root {
|
||||
--md-primary-fg-color: #e92063;
|
||||
--md-primary-fg-color--light: #e92063;
|
||||
--md-primary-fg-color--dark: #e92063;
|
||||
--md-primary-fg-color: #2E4053;
|
||||
/* Primary color changed to pinkish */
|
||||
--md-primary-fg-color--light: #E0A3B6;
|
||||
--md-primary-fg-color--dark: #1C3F74;
|
||||
/* Dark variant changed to blue */
|
||||
}
|
||||
|
||||
|
||||
/* Revert hue value to that of pre mkdocs-material v9.4.0 */
|
||||
/* Updated slate color scheme with new background */
|
||||
[data-md-color-scheme="slate"] {
|
||||
--md-hue: 230;
|
||||
--md-default-bg-color: hsla(230, 15%, 21%, 1);
|
||||
--md-default-bg-color: #1A1A1A;
|
||||
/* Background changed to dark gray */
|
||||
}
|
||||
|
||||
.hide {
|
||||
@@ -24,12 +26,15 @@ img.index-header {
|
||||
max-width: 500px;
|
||||
}
|
||||
|
||||
/* Updated custom colors */
|
||||
.pydantic-pink {
|
||||
color: #FF007F;
|
||||
color: #E0A3B6;
|
||||
/* Updated to match new theme */
|
||||
}
|
||||
|
||||
.team-blue {
|
||||
color: #0072CE;
|
||||
color: #1C3F74;
|
||||
/* Updated to match new theme */
|
||||
}
|
||||
|
||||
.secure-green {
|
||||
@@ -67,7 +72,6 @@ img.index-header {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
|
||||
/* Hide the entire footer */
|
||||
.md-footer {
|
||||
display: none;
|
||||
|
||||
+1
-1
@@ -89,7 +89,7 @@ theme:
|
||||
name: Switch to light mode
|
||||
icon:
|
||||
repo: fontawesome/brands/github
|
||||
favicon: "https://res.cloudinary.com/dq0w2rtm9/image/upload/v1737555066/r17hrkre246doczwmvbv.png"
|
||||
favicon: https://res.cloudinary.com/dq0w2rtm9/image/upload/v1741195421/favicon_kuz6xr.png
|
||||
|
||||
extra:
|
||||
generator: false
|
||||
|
||||
Generated
+2897
-2252
File diff suppressed because it is too large
Load Diff
+33
-29
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "agentic_security"
|
||||
version = "0.5.1"
|
||||
version = "0.7.4"
|
||||
description = "Agentic LLM vulnerability scanner"
|
||||
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||
@@ -28,52 +28,49 @@ agentic_security = "agentic_security.__main__:main"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
fastapi = "^0.115.8"
|
||||
uvicorn = "^0.34.0"
|
||||
fire = "0.7.0"
|
||||
fastapi = "^0.122.0"
|
||||
uvicorn = "^0.38.0"
|
||||
fire = "0.7.1"
|
||||
loguru = "^0.7.3"
|
||||
httpx = "^0.28.1"
|
||||
cache-to-disk = "^2.0.0"
|
||||
pandas = ">=1.4,<3.0"
|
||||
datasets = "^3.3.0"
|
||||
datasets = "^4.4.1"
|
||||
tabulate = ">=0.8.9,<0.10.0"
|
||||
colorama = "^0.4.4"
|
||||
matplotlib = "^3.9.2"
|
||||
pydantic = "2.10.6"
|
||||
matplotlib = "^3.10.7"
|
||||
pydantic = "^2.12.5"
|
||||
scikit-optimize = "^0.10.2"
|
||||
scikit-learn = "1.6.1"
|
||||
scikit-learn = "^1.7.2"
|
||||
numpy = ">=1.24.3,<3.0.0"
|
||||
jinja2 = "^3.1.4"
|
||||
python-multipart = "^0.0.20"
|
||||
tomli = "^2.2.1"
|
||||
rich = "13.9.4"
|
||||
tomli = "^2.3.0"
|
||||
rich = "^14.2.0"
|
||||
gTTS = "^2.5.4"
|
||||
sentry_sdk = "^2.22.0"
|
||||
orjson = "^3.10"
|
||||
pyfiglet = "^1.0.2"
|
||||
termcolor = "^2.4.0"
|
||||
|
||||
sentry_sdk = "^2.46.0"
|
||||
orjson = "^3.11.4"
|
||||
pyfiglet = "^1.0.4"
|
||||
termcolor = "^3.2.0"
|
||||
mcp = "^1.22.0"
|
||||
# garak = { version = "*", optional = true }
|
||||
|
||||
pytest-xdist = "^3.8.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
# Pytest
|
||||
pytest = "^8.3.4"
|
||||
pytest-asyncio = "^0.25.2"
|
||||
inline-snapshot = ">=0.13.3,<0.21.0"
|
||||
pytest-httpx = "^0.35.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
|
||||
pytest = "^9.0.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
inline-snapshot = "^0.31.1"
|
||||
pytest-mock = "^3.15.1"
|
||||
# Rest
|
||||
black = ">=24.10,<26.0"
|
||||
mypy = "^1.12.0"
|
||||
pre-commit = "^4.0.1"
|
||||
huggingface-hub = ">=0.25.1,<0.29.0"
|
||||
|
||||
mypy = "^1.19.0"
|
||||
pre-commit = "^4.5.0"
|
||||
huggingface-hub = "^1.1.6"
|
||||
# Docs
|
||||
mkdocs = ">=1.4.2"
|
||||
mkdocs-material = "^9.6.4"
|
||||
mkdocstrings = ">=0.26.1"
|
||||
mkdocs-material = "^9.7.0"
|
||||
mkdocstrings = "^1.0.0"
|
||||
mkdocs-jupyter = ">=0.25.1"
|
||||
|
||||
|
||||
@@ -86,7 +83,14 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--durations=5 -m 'not slow'"
|
||||
addopts = "-m 'not slow'"
|
||||
# addopts = "--durations=5 -m 'not slow' -n 3"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
markers = "slow: marks tests as slow"
|
||||
|
||||
[project]
|
||||
# MCP requires the following fields to be present in the pyproject.toml file
|
||||
name = "agentic_security"
|
||||
version = "1.0.0"
|
||||
requires-python = ">=3.11"
|
||||
|
||||
@@ -1,8 +1,43 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
from agentic_security.cache_config import ensure_cache_dir
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
CACHE_DIR = ensure_cache_dir(Path(__file__).parent / ".cache_to_disk")
|
||||
|
||||
from cache_to_disk import delete_old_disk_caches # noqa: E402 # isort: skip
|
||||
|
||||
# Silence noisy third-party warnings that do not impact test behavior
|
||||
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
|
||||
try:
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
|
||||
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
||||
except Exception: # pragma: no cover - fallback for older langchain versions
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=DeprecationWarning,
|
||||
module=r"langchain\\.agents",
|
||||
message=r".*langchain_core.pydantic_v1.*",
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "slow" in item.keywords and not os.getenv("RUN_SLOW_TESTS"):
|
||||
pytest.skip("Skipping slow test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def setup_delete_old_disk_caches():
|
||||
logger.info("delete_old_disk_caches at %s", CACHE_DIR)
|
||||
try:
|
||||
delete_old_disk_caches()
|
||||
except PermissionError:
|
||||
logger.warning("Skipping cache cleanup due to permissions for %s", CACHE_DIR)
|
||||
except OSError as exc:
|
||||
logger.warning("Skipping cache cleanup due to OS error: %s", exc)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for executor package."""
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Tests for CircuitBreaker."""
|
||||
|
||||
import time
|
||||
from agentic_security.executor.circuit_breaker import CircuitBreaker
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test CircuitBreaker functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test circuit breaker initialization."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
assert breaker.failure_threshold == 0.5
|
||||
assert breaker.recovery_timeout == 30
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.failures == 0
|
||||
assert breaker.successes == 0
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful requests."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
breaker.record_success()
|
||||
assert breaker.successes == 1
|
||||
assert breaker.failures == 0
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed requests."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
breaker.record_failure()
|
||||
assert breaker.failures == 1
|
||||
assert breaker.successes == 0
|
||||
assert breaker.last_failure_time is not None
|
||||
|
||||
def test_circuit_opens_on_failure_threshold(self):
|
||||
"""Test that circuit opens when failure threshold is exceeded."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
# Record 10 requests: 6 failures, 4 successes (60% failure rate)
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should be open (60% > 50% threshold)
|
||||
assert breaker.state == "open"
|
||||
assert breaker.is_open() is True
|
||||
|
||||
def test_circuit_stays_closed_below_threshold(self):
|
||||
"""Test that circuit stays closed when below threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
|
||||
|
||||
# Record 10 requests: 4 failures, 6 successes (40% failure rate)
|
||||
for _ in range(6):
|
||||
breaker.record_success()
|
||||
for _ in range(4):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should stay closed (40% < 50% threshold)
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.is_open() is False
|
||||
|
||||
def test_minimum_sample_size_required(self):
|
||||
"""Test that minimum sample size is required before opening."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5)
|
||||
|
||||
# Only 5 failures (below minimum of 10 total requests)
|
||||
for _ in range(5):
|
||||
breaker.record_failure()
|
||||
|
||||
# Circuit should stay closed (not enough samples)
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.is_open() is False
|
||||
|
||||
def test_circuit_recovery_after_timeout(self):
|
||||
"""Test that circuit enters half-open state after recovery timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "open"
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
# Check if circuit moves to half-open
|
||||
is_open = breaker.is_open()
|
||||
assert is_open is False
|
||||
assert breaker.state == "half_open"
|
||||
|
||||
def test_half_open_to_closed_on_successes(self):
|
||||
"""Test that circuit closes from half-open after enough successes."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=1)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(4):
|
||||
breaker.record_success()
|
||||
for _ in range(6):
|
||||
breaker.record_failure()
|
||||
|
||||
# Wait for recovery
|
||||
time.sleep(1.1)
|
||||
breaker.is_open() # Triggers transition to half-open
|
||||
|
||||
assert breaker.state == "half_open"
|
||||
|
||||
# Record 3 successes
|
||||
breaker.record_success()
|
||||
breaker.record_success()
|
||||
breaker.record_success()
|
||||
|
||||
# Should transition to closed
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_get_state(self):
|
||||
"""Test get_state method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
assert breaker.get_state() == "closed"
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.get_state() == "open"
|
||||
|
||||
def test_get_failure_rate(self):
|
||||
"""Test get_failure_rate method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# No requests
|
||||
assert breaker.get_failure_rate() == 0.0
|
||||
|
||||
# 3 failures, 7 successes (30% failure rate)
|
||||
for _ in range(7):
|
||||
breaker.record_success()
|
||||
for _ in range(3):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.get_failure_rate() == 0.3
|
||||
|
||||
def test_reset(self):
|
||||
"""Test reset method."""
|
||||
breaker = CircuitBreaker()
|
||||
|
||||
# Record some activity
|
||||
breaker.record_success()
|
||||
breaker.record_failure()
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
# Reset
|
||||
breaker.reset()
|
||||
|
||||
# Should be back to initial state
|
||||
assert breaker.state == "closed"
|
||||
assert breaker.failures == 0
|
||||
assert breaker.successes == 0
|
||||
assert breaker.last_failure_time is None
|
||||
|
||||
def test_exact_failure_threshold(self):
|
||||
"""Test behavior at exact failure threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5)
|
||||
|
||||
# Exactly 50% failure rate (5 failures, 5 successes)
|
||||
for _ in range(5):
|
||||
breaker.record_success()
|
||||
for _ in range(5):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should be open (>= threshold)
|
||||
assert breaker.state == "open"
|
||||
|
||||
def test_high_failure_threshold(self):
|
||||
"""Test with high failure threshold."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.9)
|
||||
|
||||
# 80% failure rate (8 failures, 2 successes)
|
||||
for _ in range(2):
|
||||
breaker.record_success()
|
||||
for _ in range(8):
|
||||
breaker.record_failure()
|
||||
|
||||
# Should stay closed (80% < 90%)
|
||||
assert breaker.state == "closed"
|
||||
|
||||
def test_zero_recovery_timeout(self):
|
||||
"""Test with zero recovery timeout."""
|
||||
breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=0)
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(10):
|
||||
breaker.record_failure()
|
||||
|
||||
assert breaker.state == "open"
|
||||
|
||||
# Should immediately allow recovery attempt
|
||||
time.sleep(0.01)
|
||||
is_open = breaker.is_open()
|
||||
|
||||
assert is_open is False
|
||||
assert breaker.state == "half_open"
|
||||
@@ -0,0 +1,279 @@
|
||||
"""Tests for ConcurrentExecutor."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
|
||||
from agentic_security.probe_actor.state import FuzzerState
|
||||
|
||||
|
||||
class TestExecutorMetrics:
|
||||
"""Test ExecutorMetrics functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test metrics initialization."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
assert metrics.successful_requests == 0
|
||||
assert metrics.failed_requests == 0
|
||||
assert metrics.total_latency == 0.0
|
||||
assert len(metrics.latencies) == 0
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
metrics.record_success(0.5)
|
||||
metrics.record_success(0.3)
|
||||
|
||||
assert metrics.successful_requests == 2
|
||||
assert metrics.total_latency == 0.8
|
||||
assert len(metrics.latencies) == 2
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
metrics.record_failure()
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.failed_requests == 2
|
||||
assert metrics.successful_requests == 0
|
||||
|
||||
def test_get_stats_no_requests(self):
|
||||
"""Test get_stats with no requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
assert stats["total_requests"] == 0
|
||||
assert stats["success_rate"] == 0.0
|
||||
assert stats["avg_latency_ms"] == 0.0
|
||||
assert stats["p95_latency_ms"] == 0.0
|
||||
|
||||
def test_get_stats_with_requests(self):
|
||||
"""Test get_stats with recorded requests."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
# Record some requests
|
||||
metrics.record_success(0.1) # 100ms
|
||||
metrics.record_success(0.2) # 200ms
|
||||
metrics.record_success(0.3) # 300ms
|
||||
metrics.record_failure()
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
assert stats["total_requests"] == 4
|
||||
assert stats["successful_requests"] == 3
|
||||
assert stats["failed_requests"] == 1
|
||||
assert stats["success_rate"] == 0.75
|
||||
assert stats["avg_latency_ms"] == pytest.approx(200.0, rel=0.01)
|
||||
|
||||
def test_get_stats_p95_latency(self):
|
||||
"""Test p95 latency calculation."""
|
||||
metrics = ExecutorMetrics()
|
||||
|
||||
# Add 100 requests with varying latencies
|
||||
for i in range(100):
|
||||
metrics.record_success(i * 0.001) # 0ms to 99ms
|
||||
|
||||
stats = metrics.get_stats()
|
||||
|
||||
# p95 should be around 95ms
|
||||
assert stats["p95_latency_ms"] >= 90.0
|
||||
assert stats["p95_latency_ms"] <= 100.0
|
||||
|
||||
|
||||
class TestConcurrentExecutor:
|
||||
"""Test ConcurrentExecutor functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test executor initialization."""
|
||||
executor = ConcurrentExecutor(
|
||||
max_concurrent=20,
|
||||
rate_limit=10,
|
||||
burst=5,
|
||||
failure_threshold=0.5,
|
||||
recovery_timeout=30,
|
||||
)
|
||||
|
||||
assert executor.semaphore._value == 20
|
||||
assert executor.rate_limiter.rate == 10
|
||||
assert executor.rate_limiter.burst == 5
|
||||
assert executor.circuit_breaker.failure_threshold == 0.5
|
||||
assert executor.circuit_breaker.recovery_timeout == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_success(self):
|
||||
"""Test successful batch execution."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
# Mock request factory
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to return success
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False) # 10 tokens, not refused
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
assert tokens == 30 # 3 prompts * 10 tokens
|
||||
assert failures == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_with_failures(self):
|
||||
"""Test batch execution with some failures."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to alternate success/failure
|
||||
call_count = [0]
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
call_count[0] += 1
|
||||
if call_count[0] % 2 == 0:
|
||||
return (10, True) # Refused
|
||||
return (10, False) # Success
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
assert tokens == 40 # 4 prompts * 10 tokens
|
||||
assert failures == 2 # 2 refused
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_batch_respects_concurrency_limit(self):
|
||||
"""Test that concurrency limit is respected."""
|
||||
executor = ConcurrentExecutor(max_concurrent=2, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
|
||||
request_factory = Mock()
|
||||
|
||||
# Track concurrent executions
|
||||
concurrent_count = [0]
|
||||
max_concurrent = [0]
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
concurrent_count[0] += 1
|
||||
max_concurrent[0] = max(max_concurrent[0], concurrent_count[0])
|
||||
await asyncio.sleep(0.01) # Simulate work
|
||||
concurrent_count[0] -= 1
|
||||
return (10, False)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4", "p5"]
|
||||
await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
# Max concurrent should not exceed limit
|
||||
assert max_concurrent[0] <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_integration(self):
|
||||
"""Test that circuit breaker opens on failures."""
|
||||
executor = ConcurrentExecutor(
|
||||
max_concurrent=10,
|
||||
rate_limit=100,
|
||||
burst=20,
|
||||
failure_threshold=0.5,
|
||||
recovery_timeout=1,
|
||||
)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
# Mock process_prompt to always fail
|
||||
async def mock_process_prompt_fail(rf, prompt, tokens, module, state):
|
||||
raise Exception("Request failed")
|
||||
|
||||
# First batch - all failures
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt_fail,
|
||||
):
|
||||
prompts = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10"]
|
||||
tokens, failures = await executor.execute_batch(
|
||||
request_factory, prompts, "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
# All should have failed
|
||||
assert failures == 10
|
||||
|
||||
# Circuit should be open now
|
||||
assert executor.circuit_breaker.state == "open"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metrics(self):
|
||||
"""Test getting executor metrics."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=100, burst=10)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
await executor.execute_batch(
|
||||
request_factory, ["p1", "p2"], "test_module", fuzzer_state
|
||||
)
|
||||
|
||||
metrics = executor.get_metrics()
|
||||
|
||||
assert "total_requests" in metrics
|
||||
assert "success_rate" in metrics
|
||||
assert "circuit_breaker_state" in metrics
|
||||
assert "available_tokens" in metrics
|
||||
assert metrics["total_requests"] == 2
|
||||
assert metrics["circuit_breaker_state"] == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting_applied(self):
|
||||
"""Test that rate limiting is applied."""
|
||||
executor = ConcurrentExecutor(max_concurrent=10, rate_limit=5, burst=2)
|
||||
fuzzer_state = FuzzerState()
|
||||
request_factory = Mock()
|
||||
|
||||
async def mock_process_prompt(rf, prompt, tokens, module, state):
|
||||
return (10, False)
|
||||
|
||||
import time
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_actor.fuzzer.process_prompt",
|
||||
side_effect=mock_process_prompt,
|
||||
):
|
||||
start = time.monotonic()
|
||||
# 5 requests with rate=5/s and burst=2
|
||||
# First 2 immediate, next 3 should take ~0.6s total
|
||||
await executor.execute_batch(
|
||||
request_factory,
|
||||
["p1", "p2", "p3", "p4", "p5"],
|
||||
"test_module",
|
||||
fuzzer_state,
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should take at least 0.5s (3 requests / 5 per second)
|
||||
assert elapsed >= 0.4
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Tests for TokenBucketRateLimiter."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
|
||||
|
||||
|
||||
class TestTokenBucketRateLimiter:
|
||||
"""Test TokenBucketRateLimiter functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
"""Test rate limiter initialization."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=20)
|
||||
|
||||
assert limiter.rate == 10
|
||||
assert limiter.burst == 20
|
||||
assert limiter.tokens == 20 # Starts full
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_with_available_tokens(self):
|
||||
"""Test acquiring tokens when they're available."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
start = time.monotonic()
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should return immediately
|
||||
assert elapsed < 0.1
|
||||
assert limiter.tokens < 5 # One token consumed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_waits_when_no_tokens(self):
|
||||
"""Test that acquire waits when no tokens available."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=1)
|
||||
|
||||
# Consume the initial token
|
||||
await limiter.acquire()
|
||||
|
||||
# Next acquire should wait
|
||||
start = time.monotonic()
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should wait approximately 1/rate seconds (0.1s for rate=10)
|
||||
assert elapsed >= 0.08 # Allow some tolerance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting(self):
|
||||
"""Test that rate limiting actually limits request rate."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=2)
|
||||
|
||||
# Make 5 requests
|
||||
start = time.monotonic()
|
||||
for _ in range(5):
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# With rate=10/s and burst=2:
|
||||
# - First 2 requests are immediate (burst)
|
||||
# - Next 3 requests require waiting: 3 * (1/10) = 0.3s
|
||||
# Total should be around 0.3s
|
||||
assert elapsed >= 0.25 # Allow some tolerance
|
||||
assert elapsed < 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_burst_capacity(self):
|
||||
"""Test that burst capacity allows immediate requests."""
|
||||
limiter = TokenBucketRateLimiter(rate=5, burst=10)
|
||||
|
||||
# Make burst number of requests immediately
|
||||
start = time.monotonic()
|
||||
for _ in range(10):
|
||||
await limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# All 10 requests should be nearly immediate (using burst capacity)
|
||||
assert elapsed < 0.2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_replenishment(self):
|
||||
"""Test that tokens are replenished over time."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
# Consume all tokens
|
||||
for _ in range(5):
|
||||
await limiter.acquire()
|
||||
|
||||
assert limiter.tokens < 1
|
||||
|
||||
# Wait for tokens to replenish
|
||||
await asyncio.sleep(0.3) # Should add 3 tokens at rate=10
|
||||
|
||||
# Should have tokens again (approximately 3)
|
||||
available = limiter.get_available_tokens()
|
||||
assert available >= 2.5
|
||||
assert available <= 3.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_tokens(self):
|
||||
"""Test get_available_tokens method."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=5)
|
||||
|
||||
# Initially full
|
||||
assert limiter.get_available_tokens() == 5
|
||||
|
||||
# After consuming one
|
||||
await limiter.acquire()
|
||||
assert limiter.get_available_tokens() < 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self):
|
||||
"""Test rate limiter with concurrent requests."""
|
||||
limiter = TokenBucketRateLimiter(rate=10, burst=3)
|
||||
|
||||
async def make_request(limiter):
|
||||
await limiter.acquire()
|
||||
return time.monotonic()
|
||||
|
||||
# Make 5 concurrent requests
|
||||
start = time.monotonic()
|
||||
tasks = [make_request(limiter) for _ in range(5)]
|
||||
timestamps = await asyncio.gather(*tasks)
|
||||
total_elapsed = time.monotonic() - start
|
||||
|
||||
# First 3 should be immediate (burst=3)
|
||||
# Next 2 should wait
|
||||
# Total time should be around 0.2s (2 * 1/10)
|
||||
assert total_elapsed >= 0.15
|
||||
assert total_elapsed < 0.4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_burst_capacity(self):
|
||||
"""Test that tokens don't exceed burst capacity."""
|
||||
limiter = TokenBucketRateLimiter(rate=100, burst=5)
|
||||
|
||||
# Wait longer than needed to fill
|
||||
await asyncio.sleep(0.2) # Would add 20 tokens, but capped at 5
|
||||
|
||||
# Check tokens don't exceed burst
|
||||
available = limiter.get_available_tokens()
|
||||
assert available <= 5
|
||||
assert available >= 4.5 # Close to full
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from agentic_security.primitives import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
FuzzerState,
|
||||
generate_prompts,
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
@@ -75,14 +76,23 @@ async def test_perform_single_shot_scan_success(prepare_prompts_mock):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agentic_security.probe_data.msj_data.prepare_prompts")
|
||||
@patch("agentic_security.probe_data.data.prepare_prompts")
|
||||
async def test_perform_many_shot_scan_probe_injection(prepare_prompts_mock):
|
||||
async def test_perform_many_shot_scan_probe_injection(
|
||||
prepare_prompts_mock, msj_prepare_prompts_mock
|
||||
):
|
||||
# Mock main and probe prompt modules
|
||||
prepare_prompts_mock.side_effect = [
|
||||
[MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)],
|
||||
[MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)],
|
||||
]
|
||||
|
||||
msj_prepare_prompts_mock.return_value = [
|
||||
MagicMock(
|
||||
dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False
|
||||
)
|
||||
]
|
||||
|
||||
# Mock request_factory
|
||||
mock_response = AsyncMock()
|
||||
mock_response.fn.side_effect = [
|
||||
@@ -207,9 +217,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=FuzzerState(),
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||
@@ -226,20 +234,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||
self.assertFalse(refusal)
|
||||
# self.assertFalse(fuzzer_state.refusals)
|
||||
|
||||
async def test_http_error_response(self):
|
||||
mock_request_factory = Mock()
|
||||
@@ -252,15 +257,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
fuzzer_state = FuzzerState()
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
@@ -269,18 +272,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
side_effect=httpx.RequestError("Connection error")
|
||||
)
|
||||
|
||||
errors = []
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=errors,
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertTrue(refusal)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIn("Connection error", errors[0][3])
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Tests for unified dataset loader."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from agentic_security.probe_data.unified_loader import (
|
||||
InputSourceConfig,
|
||||
UnifiedDatasetLoader,
|
||||
)
|
||||
from agentic_security.probe_data.models import ProbeDataset
|
||||
|
||||
|
||||
class TestInputSourceConfig:
|
||||
"""Test InputSourceConfig validation."""
|
||||
|
||||
def test_csv_source_config(self):
|
||||
"""Test CSV source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="./test.csv",
|
||||
prompt_column="prompt",
|
||||
weight=1.5,
|
||||
)
|
||||
assert config.source_type == "csv"
|
||||
assert config.dataset_name == "test_csv"
|
||||
assert config.path == "./test.csv"
|
||||
assert config.weight == 1.5
|
||||
|
||||
def test_huggingface_source_config(self):
|
||||
"""Test HuggingFace source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="huggingface",
|
||||
dataset_name="test/dataset",
|
||||
split="train",
|
||||
max_samples=100,
|
||||
)
|
||||
assert config.source_type == "huggingface"
|
||||
assert config.split == "train"
|
||||
assert config.max_samples == 100
|
||||
|
||||
def test_proxy_source_config(self):
|
||||
"""Test proxy source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="proxy",
|
||||
dataset_name="proxy_test",
|
||||
)
|
||||
assert config.source_type == "proxy"
|
||||
assert config.enabled is True # Default value
|
||||
|
||||
def test_disabled_source(self):
|
||||
"""Test disabled source configuration."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="disabled_test",
|
||||
enabled=False,
|
||||
)
|
||||
assert config.enabled is False
|
||||
|
||||
def test_weight_validation(self):
|
||||
"""Test that weight must be non-negative."""
|
||||
with pytest.raises(ValueError):
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test",
|
||||
weight=-1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedDatasetLoader:
|
||||
"""Test UnifiedDatasetLoader functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_single_csv_source(self):
|
||||
"""Test loading a single CSV source."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="test.csv",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock the load_csv function
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test_csv",
|
||||
prompts=["prompt1", "prompt2", "prompt3"],
|
||||
tokens=10,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
assert len(result.prompts) == 3
|
||||
assert result.prompts == ["prompt1", "prompt2", "prompt3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_single_huggingface_source(self):
|
||||
"""Test loading a single HuggingFace source."""
|
||||
config = InputSourceConfig(
|
||||
source_type="huggingface",
|
||||
dataset_name="test/dataset",
|
||||
split="train",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock the load_dataset_generic function
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test/dataset",
|
||||
prompts=["hf_prompt1", "hf_prompt2"],
|
||||
tokens=8,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_dataset_generic",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
assert len(result.prompts) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_multiple_sources(self):
|
||||
"""Test merging multiple sources."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="csv1",
|
||||
path="test1.csv",
|
||||
weight=1.0,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="csv2",
|
||||
path="test2.csv",
|
||||
weight=2.0,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
# Mock datasets
|
||||
mock_dataset1 = ProbeDataset(
|
||||
dataset_name="csv1",
|
||||
prompts=["prompt1"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
mock_dataset2 = ProbeDataset(
|
||||
dataset_name="csv2",
|
||||
prompts=["prompt2", "prompt3"],
|
||||
tokens=10,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=[mock_dataset1, mock_dataset2],
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified"
|
||||
# Weight 1.0 = include once, weight 2.0 = include twice
|
||||
# csv1: 1 prompt * 1 = 1
|
||||
# csv2: 2 prompts * 2 = 4
|
||||
assert len(result.prompts) == 5
|
||||
assert "csv1" in result.metadata["sources"]
|
||||
assert "csv2" in result.metadata["sources"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disabled_sources(self):
|
||||
"""Test that disabled sources are skipped."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="enabled_csv",
|
||||
path="enabled.csv",
|
||||
enabled=True,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="disabled_csv",
|
||||
path="disabled.csv",
|
||||
enabled=False,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="enabled_csv",
|
||||
prompts=["prompt1"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
) as mock_load:
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should only be called once (for enabled source)
|
||||
assert mock_load.call_count == 1
|
||||
assert len(result.prompts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_samples_limit(self):
|
||||
"""Test that max_samples limits the number of prompts."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="test_csv",
|
||||
path="test.csv",
|
||||
max_samples=2,
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
# Mock dataset with more prompts than max_samples
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="test_csv",
|
||||
prompts=["prompt1", "prompt2", "prompt3", "prompt4", "prompt5"],
|
||||
tokens=20,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should be limited to 2 prompts
|
||||
assert len(result.prompts) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(self):
|
||||
"""Test that errors are handled gracefully."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="error_csv",
|
||||
path="nonexistent.csv",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=Exception("File not found"),
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Should return empty dataset on error
|
||||
assert result.dataset_name == "unified_empty"
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_source_placeholder(self):
|
||||
"""Test that proxy source returns empty dataset (not implemented in PoC)."""
|
||||
config = InputSourceConfig(
|
||||
source_type="proxy",
|
||||
dataset_name="proxy_test",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
result = await loader.load_all()
|
||||
|
||||
# Proxy not implemented in PoC, should return empty
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weighted_sampling(self):
|
||||
"""Test weighted sampling behavior."""
|
||||
configs = [
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="low_weight",
|
||||
path="low.csv",
|
||||
weight=1.0,
|
||||
),
|
||||
InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="high_weight",
|
||||
path="high.csv",
|
||||
weight=3.0,
|
||||
),
|
||||
]
|
||||
loader = UnifiedDatasetLoader(configs)
|
||||
|
||||
mock_dataset1 = ProbeDataset(
|
||||
dataset_name="low_weight",
|
||||
prompts=["a"],
|
||||
tokens=1,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
mock_dataset2 = ProbeDataset(
|
||||
dataset_name="high_weight",
|
||||
prompts=["b"],
|
||||
tokens=1,
|
||||
approx_cost=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_csv",
|
||||
side_effect=[mock_dataset1, mock_dataset2],
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
# Weight 1.0: 1 prompt * 1 = 1
|
||||
# Weight 3.0: 1 prompt * 3 = 3
|
||||
# Total: 4 prompts
|
||||
assert len(result.prompts) == 4
|
||||
assert result.prompts.count("a") == 1
|
||||
assert result.prompts.count("b") == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_configs_list(self):
|
||||
"""Test loading with empty configs list."""
|
||||
loader = UnifiedDatasetLoader([])
|
||||
result = await loader.load_all()
|
||||
|
||||
assert result.dataset_name == "unified_empty"
|
||||
assert len(result.prompts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csv_with_url(self):
|
||||
"""Test CSV loading from URL."""
|
||||
config = InputSourceConfig(
|
||||
source_type="csv",
|
||||
dataset_name="remote_csv",
|
||||
url="https://example.com/data.csv",
|
||||
prompt_column="text",
|
||||
)
|
||||
loader = UnifiedDatasetLoader([config])
|
||||
|
||||
mock_dataset = ProbeDataset(
|
||||
dataset_name="remote_csv",
|
||||
prompts=["remote_prompt"],
|
||||
tokens=5,
|
||||
approx_cost=0.0,
|
||||
metadata={"source_type": "csv", "url": "https://example.com/data.csv"},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentic_security.probe_data.unified_loader.load_dataset_generic",
|
||||
return_value=mock_dataset,
|
||||
):
|
||||
result = await loader.load_all()
|
||||
|
||||
assert len(result.prompts) == 1
|
||||
assert result.prompts[0] == "remote_prompt"
|
||||
+1
-1
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from .model import RefusalClassifier
|
||||
from agentic_security.refusal_classifier.model import RefusalClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1,9 +1,12 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import agentic_security.test_spec_assets as test_spec_assets
|
||||
from agentic_security.routes.scan import router
|
||||
|
||||
client = TestClient(router)
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_upload_csv_and_run():
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..app import app
|
||||
from agentic_security.app import app
|
||||
|
||||
|
||||
def test_health_check():
|
||||
@@ -1,14 +1,15 @@
|
||||
import base64
|
||||
import io
|
||||
import random
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..app import app
|
||||
from ..primitives import Probe
|
||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||
from ..probe_data import REGISTRY
|
||||
from agentic_security.app import app
|
||||
from agentic_security.primitives import Probe
|
||||
from agentic_security.probe_actor.refusal import REFUSAL_MARKS
|
||||
from agentic_security.probe_data import REGISTRY
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -85,8 +86,9 @@ def test_data_config_endpoint():
|
||||
|
||||
def test_refusal_rate():
|
||||
"""Test that refusal rate is approximately 20%"""
|
||||
random.seed(0)
|
||||
refusal_count = 0
|
||||
total_trials = 1000
|
||||
total_trials = 200
|
||||
|
||||
for _ in range(total_trials):
|
||||
response = client.post("/v1/self-probe", json={"prompt": "test"})
|
||||
@@ -2,11 +2,14 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from .report import router
|
||||
from agentic_security.routes.report import router
|
||||
|
||||
client = TestClient(router)
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1,13 +1,15 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ..primitives import Settings
|
||||
from .static import get_static_file, router
|
||||
from agentic_security.primitives import Settings
|
||||
from agentic_security.routes.static import get_static_file, router
|
||||
|
||||
client = TestClient(router)
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_root_route():
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from agentic_security.cache_config import ensure_cache_dir
|
||||
|
||||
|
||||
def test_ensure_cache_dir_creates_dir_and_sets_env(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("DISK_CACHE_DIR", raising=False)
|
||||
target_dir = tmp_path / "cache_to_disk"
|
||||
|
||||
resolved = ensure_cache_dir(target_dir)
|
||||
|
||||
assert resolved == target_dir
|
||||
assert resolved.is_dir()
|
||||
assert Path(os.environ["DISK_CACHE_DIR"]) == resolved
|
||||
|
||||
|
||||
def test_ensure_cache_dir_respects_existing_env(tmp_path, monkeypatch):
|
||||
env_dir = tmp_path / "preconfigured"
|
||||
monkeypatch.setenv("DISK_CACHE_DIR", str(env_dir))
|
||||
|
||||
resolved = ensure_cache_dir()
|
||||
|
||||
assert resolved == env_dir
|
||||
assert resolved.exists()
|
||||
@@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
@@ -8,7 +9,7 @@ import time
|
||||
import pytest
|
||||
|
||||
import agentic_security.test_spec_assets as test_spec_assets
|
||||
from agentic_security.lib import AgenticSecurity
|
||||
from agentic_security.lib import SecurityScanner
|
||||
|
||||
|
||||
def has_module(module_name):
|
||||
@@ -24,12 +25,29 @@ def test_server(request):
|
||||
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN),
|
||||
)
|
||||
|
||||
# Give the server time to start
|
||||
time.sleep(2)
|
||||
def wait_for_port(host: str, port: int, timeout: float = 5.0) -> bool:
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(0.2)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
return True
|
||||
except OSError:
|
||||
time.sleep(0.1)
|
||||
return False
|
||||
|
||||
if not wait_for_port("127.0.0.1", 9094):
|
||||
server.kill()
|
||||
pytest.skip("Test server failed to start within timeout")
|
||||
|
||||
def cleanup():
|
||||
server.terminate()
|
||||
server.wait()
|
||||
try:
|
||||
server.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
server.kill()
|
||||
server.wait(timeout=2)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return server
|
||||
@@ -57,7 +75,7 @@ class TestLibraryLevel:
|
||||
maxBudget = 1000000
|
||||
max_th = 0.3
|
||||
datasets = make_test_registry()
|
||||
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
assert isinstance(result, dict)
|
||||
print(result)
|
||||
assert len(result) in [0, 1]
|
||||
@@ -68,7 +86,7 @@ class TestLibraryLevel:
|
||||
maxBudget = 1000
|
||||
max_th = 0.3
|
||||
datasets = make_test_registry()
|
||||
result = AgenticSecurity.scan(
|
||||
result = SecurityScanner.scan(
|
||||
llmSpec, maxBudget, datasets, max_th, enableMultiStepAttack=True
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
@@ -93,7 +111,7 @@ class TestLibraryLevel:
|
||||
"opts": {"port": 9094},
|
||||
},
|
||||
]
|
||||
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
assert isinstance(result, dict)
|
||||
print(result)
|
||||
assert len(result) in [0, 1]
|
||||
@@ -120,11 +138,12 @@ class TestLibraryLevel:
|
||||
"modality": "text",
|
||||
},
|
||||
]
|
||||
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
assert isinstance(result, dict)
|
||||
print(result)
|
||||
assert len(result) in [0, 1]
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_image_modality(self):
|
||||
llmSpec = test_spec_assets.IMAGE_SPEC
|
||||
maxBudget = 2
|
||||
@@ -148,7 +167,7 @@ class TestLibraryLevel:
|
||||
"modality": "text",
|
||||
},
|
||||
]
|
||||
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
result = SecurityScanner.scan(llmSpec, maxBudget, datasets, max_th)
|
||||
assert isinstance(result, dict)
|
||||
print(result)
|
||||
assert len(result) in [0, 1]
|
||||
@@ -164,10 +183,10 @@ class TestEntrypointCI:
|
||||
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
||||
|
||||
# Override default_path to the temporary path
|
||||
AgenticSecurity.default_path = temp_path
|
||||
SecurityScanner.default_path = temp_path
|
||||
|
||||
# Generate the default configuration
|
||||
security = AgenticSecurity()
|
||||
security = SecurityScanner()
|
||||
security.generate_default_settings()
|
||||
|
||||
# Check that the config file was created at the temporary path
|
||||
@@ -189,17 +208,17 @@ class TestEntrypointCI:
|
||||
temp_path = os.path.join(tmpdir, "custom_agesec.toml")
|
||||
|
||||
# Override default_path to the temporary path
|
||||
AgenticSecurity.default_path = temp_path
|
||||
SecurityScanner.default_path = temp_path
|
||||
|
||||
# Generate the default configuration
|
||||
security = AgenticSecurity()
|
||||
security = SecurityScanner()
|
||||
security.generate_default_settings()
|
||||
|
||||
# Load the generated configuration
|
||||
AgenticSecurity.load_config(temp_path)
|
||||
SecurityScanner.load_config(temp_path)
|
||||
|
||||
# Validate loaded configuration
|
||||
config = AgenticSecurity.config
|
||||
config = SecurityScanner.config
|
||||
assert (
|
||||
config["general"]["maxBudget"] == 1000000
|
||||
), "maxBudget should be 1000000"
|
||||
@@ -0,0 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from agentic_security.mcp.client import run
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_echo_tool():
|
||||
"""Test the echo tool functionality"""
|
||||
prompts, resources, tools = await run()
|
||||
assert prompts
|
||||
assert resources
|
||||
assert tools
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from agentic_security.probe_data import REGISTRY
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from agentic_security.http_spec import LLMSpec, parse_http_spec
|
||||
from agentic_security.http_spec import (
|
||||
InvalidHTTPSpecError,
|
||||
LLMSpec,
|
||||
parse_http_spec,
|
||||
)
|
||||
|
||||
|
||||
class TestParseHttpSpec:
|
||||
@@ -55,6 +59,19 @@ class TestParseHttpSpec:
|
||||
assert result.headers == {"Content-Type": "application/json"}
|
||||
assert result.body == ""
|
||||
|
||||
def test_parse_http_spec_rejects_malformed_header(self):
|
||||
http_spec = "GET http://example.com\nHeaderWithoutColon\n\n"
|
||||
|
||||
with pytest.raises(InvalidHTTPSpecError, match="Invalid header line"):
|
||||
parse_http_spec(http_spec)
|
||||
|
||||
def test_parse_http_spec_trims_header_whitespace(self):
|
||||
http_spec = "GET http://example.com\nAuthorization:Bearer token\n\n"
|
||||
|
||||
result = parse_http_spec(http_spec)
|
||||
|
||||
assert result.headers == {"Authorization": "Bearer token"}
|
||||
|
||||
|
||||
class TestLLMSpec:
|
||||
def test_validate_raises_error_for_missing_files(self):
|
||||
@@ -70,49 +87,3 @@ class TestLLMSpec:
|
||||
)
|
||||
with pytest.raises(ValueError, match="An image is required for this request."):
|
||||
spec.validate(prompt="", encoded_image="", encoded_audio="", files={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_sends_request(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
method="POST", url="http://example.com", status_code=200
|
||||
)
|
||||
spec = LLMSpec(
|
||||
method="POST",
|
||||
url="http://example.com",
|
||||
headers={},
|
||||
body='{"prompt": "<<PROMPT>>"}',
|
||||
)
|
||||
response = await spec.probe(prompt="test")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_with_files(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
method="POST", url="http://example.com", status_code=200
|
||||
)
|
||||
spec = LLMSpec(
|
||||
method="POST",
|
||||
url="http://example.com",
|
||||
headers={"Content-Type": "multipart/form-data"},
|
||||
body='{"prompt": "<<PROMPT>>"}',
|
||||
has_files=True,
|
||||
)
|
||||
files = {"file": ("filename.txt", "file content")}
|
||||
response = await spec.probe(prompt="test", files=files)
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_with_image(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
method="POST", url="http://example.com", status_code=200
|
||||
)
|
||||
spec = LLMSpec(
|
||||
method="POST",
|
||||
url="http://example.com",
|
||||
headers={},
|
||||
body='{"image": "<<BASE64_IMAGE>>"}',
|
||||
has_image=True,
|
||||
)
|
||||
encoded_image = "base64encodedstring"
|
||||
response = await spec.probe(prompt="test", encoded_image=encoded_image)
|
||||
assert response.status_code == 200
|
||||
Generated
+10
-10
@@ -4266,9 +4266,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/compression": {
|
||||
"version": "1.8.0",
|
||||
"resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz",
|
||||
"integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==",
|
||||
"version": "1.8.1",
|
||||
"resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz",
|
||||
"integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -4276,7 +4276,7 @@
|
||||
"compressible": "~2.0.18",
|
||||
"debug": "2.6.9",
|
||||
"negotiator": "~0.6.4",
|
||||
"on-headers": "~1.0.2",
|
||||
"on-headers": "~1.1.0",
|
||||
"safe-buffer": "5.2.1",
|
||||
"vary": "~1.1.2"
|
||||
},
|
||||
@@ -6891,9 +6891,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/http-proxy-middleware": {
|
||||
"version": "2.0.7",
|
||||
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.7.tgz",
|
||||
"integrity": "sha512-fgVY8AV7qU7z/MmXJ/rxwbrtQH4jBQ9m7kp3llF0liB7glmFeVZFBepQb32T3y8n8k2+AEYuMPCpinYW+/CuRA==",
|
||||
"version": "2.0.9",
|
||||
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz",
|
||||
"integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -8419,9 +8419,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/on-headers": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz",
|
||||
"integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==",
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz",
|
||||
"integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
Reference in New Issue
Block a user