mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 14:19:55 +02:00
Compare commits
88 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 926c583a17 | |||
| 17e34356e1 | |||
| 312fa756a5 | |||
| 145e7f81e1 | |||
| 04af7d24a1 | |||
| c5c5ae2e4b | |||
| 2bc0605a1d | |||
| 335787d40e | |||
| 1b211b5d76 | |||
| 444f908009 | |||
| f81dc508f9 | |||
| 4a55b99d70 | |||
| 5c2f9eba71 | |||
| aa2fe4d1ad | |||
| cf7c017621 | |||
| 73184e3454 | |||
| 3720ece2af | |||
| 0dc738a11e | |||
| 47ca656d59 | |||
| 4fa166298d | |||
| 77557ade85 | |||
| 5cdbf933de | |||
| 54d159a737 | |||
| 35fd373cb2 | |||
| f2b95a0040 | |||
| a8e80e85e1 | |||
| f97c3367b4 | |||
| c065818053 | |||
| 1139577eaa | |||
| 5d6a65350f | |||
| c277cca045 | |||
| fcbb832968 | |||
| a0e523758d | |||
| 5ebf428de6 | |||
| d5fe89f298 | |||
| 98b7d7f691 | |||
| c5ddcb2d75 | |||
| da63270142 | |||
| bf5f7a7dff | |||
| d3ccea76b6 | |||
| b7fef85750 | |||
| a1249cae12 | |||
| 8549aee952 | |||
| 414ee62467 | |||
| 7f68224716 | |||
| 3910bab28e | |||
| 8a4dcfd43e | |||
| 17234a846b | |||
| a51a3aa497 | |||
| 0b3424e9fd | |||
| f81b32d9b4 | |||
| a9f8090614 | |||
| 8770726f63 | |||
| ffc4f94a0a | |||
| 5edd4f0959 | |||
| e495f9626f | |||
| b45006c0d1 | |||
| d60d87f142 | |||
| 68f01622fc | |||
| 29787ae5fc | |||
| 1d0e88b001 | |||
| 8e5a53eaa3 | |||
| dcaba04dd6 | |||
| f4271ef2a1 | |||
| feb1becb3e | |||
| 7b44a2f510 | |||
| e3c3119790 | |||
| e171f0216e | |||
| 5d712ebce4 | |||
| 37a6e7a5bc | |||
| 85216ad106 | |||
| bb2e0e7517 | |||
| 8689efbe59 | |||
| 0b41fe0e3f | |||
| c3776df5c1 | |||
| 143ea4f8c1 | |||
| dd2eb1472f | |||
| 4332e4affd | |||
| e871443e76 | |||
| e9ae785625 | |||
| b1e2dc8cef | |||
| b9802fd268 | |||
| ac3f2f803c | |||
| bd6d2f3db1 | |||
| dda8d13b72 | |||
| 839c1af9d7 | |||
| e261fe55c5 | |||
| b4857a5f36 |
@@ -9,7 +9,7 @@ on:
|
||||
- 0.*
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
POETRY_VERSION: "1.8.5"
|
||||
|
||||
jobs:
|
||||
if_release:
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
POETRY_VERSION: "1.8.5"
|
||||
OPENAI_API_KEY: "sk-fake"
|
||||
|
||||
jobs:
|
||||
|
||||
+2
-1
@@ -17,4 +17,5 @@ inv/
|
||||
scripts/
|
||||
docx/
|
||||
agentic_security.toml
|
||||
/venv
|
||||
/venv
|
||||
*.csv
|
||||
|
||||
+8
-1
@@ -1,5 +1,5 @@
|
||||
# Build stage
|
||||
FROM python:3.11-slim as builder
|
||||
FROM python:3.11-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -14,8 +14,15 @@ RUN poetry self add "poetry-plugin-export"
|
||||
# Copy only dependency files to leverage Docker layer caching
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
# update lock file to avoid failure
|
||||
RUN poetry lock
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry export -f requirements.txt --without-hashes -o requirements.txt
|
||||
|
||||
# Install wheel (required to build packages like fire)
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Runtime stage
|
||||
|
||||
@@ -21,9 +21,7 @@
|
||||
<a href="https://pypi.org/project/agentic-security/">
|
||||
<img alt="PyPI Version" src="https://img.shields.io/pypi/v/agentic-security?style=for-the-badge&logo=pypi&labelColor=000000&color=00CCFF" />
|
||||
</a>
|
||||
<a href="https://discord.gg/stw3DfZQ">
|
||||
<img alt="Join Discord" src="https://img.shields.io/badge/Discord-Join%20Us-black?style=for-the-badge&logo=discord&labelColor=000000&color=DD55FF" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
@@ -402,6 +400,16 @@ This setup ensures a continuous integration approach towards maintaining securit
|
||||
|
||||
The `Module` class is designed to manage prompt processing and interaction with external AI models and tools. It supports fetching, processing, and posting prompts asynchronously for model vulnerabilities. Check out [module.md](https://github.com/msoedov/agentic_security/blob/main/docs/module.md) for details.
|
||||
|
||||
|
||||
## MCP server
|
||||
|
||||
```shell
|
||||
pip install -U mcp
|
||||
|
||||
# From cloned directory
|
||||
mcp install agentic_security/mcp/main.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For more detailed information on how to use Agentic Security, including advanced features and customization options, please refer to the official documentation.
|
||||
@@ -428,6 +436,7 @@ We’re just getting started! Here’s what’s on the horizon:
|
||||
|
||||
Note: All dates are tentative and subject to change based on project progress and priorities.
|
||||
|
||||
|
||||
## 👋 Contributing
|
||||
|
||||
Contributions to Agentic Security are welcome! If you'd like to contribute, please follow these steps:
|
||||
|
||||
@@ -4,7 +4,7 @@ import tomli
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
SETTINGS_VERSION = 1
|
||||
SETTINGS_VERSION = 2
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -143,6 +143,13 @@ use_disk_cache = false
|
||||
retry = 3
|
||||
timeout_connect = 30
|
||||
timeout_response = 90
|
||||
|
||||
[fuzzer]
|
||||
max_prompt_lenght = 2048
|
||||
budget_multiplier = 100000000
|
||||
initial_optimizer_points = 25
|
||||
min_failure_samples = 5
|
||||
failure_rate_threshold = 0.5
|
||||
""".replace(
|
||||
"$HOST", host
|
||||
)
|
||||
|
||||
@@ -22,7 +22,11 @@
|
||||
# logger.add(sys.stdout, format=LOG_FORMAT, level="DEBUG", colorize=True)
|
||||
import logging
|
||||
import logging.config
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from os import getenv
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
LOGGER_NAME = None
|
||||
|
||||
@@ -49,6 +53,16 @@ LOGGING_CONFIG = {
|
||||
"handlers": ["rich"],
|
||||
"propagate": True,
|
||||
},
|
||||
"httpx": { # Disable httpx logging
|
||||
"level": "WARNING", # Suppress DEBUG and INFO messages from httpx
|
||||
"handlers": [],
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": { # Disable uvicorn.access logging
|
||||
"level": "WARNING", # Suppress DEBUG and INFO messages from uvicorn.access
|
||||
"handlers": [],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -83,3 +97,50 @@ def set_log_level_to_info():
|
||||
|
||||
# Set initial log level
|
||||
set_log_level_to_info()
|
||||
|
||||
|
||||
# Define generic type variables for return type and parameters
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def time_execution_sync(
|
||||
additional_text: str = "",
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def time_execution_async(
|
||||
additional_text: str = "",
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
|
||||
]:
|
||||
def decorator(
|
||||
func: Callable[P, Coroutine[Any, Any, R]]
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"{additional_text} Execution time: {execution_time:.2f} seconds"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
# Create server parameters for stdio connection
|
||||
server_params = StdioServerParameters(
|
||||
command="python", # Executable
|
||||
args=["agentic_security/mcp/main.py"], # Your server script
|
||||
env=None, # Optional environment variables
|
||||
)
|
||||
|
||||
|
||||
async def run() -> None:
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection --> connection does not work
|
||||
await session.initialize()
|
||||
|
||||
# List available prompts, resources, and tools --> no avalialbe tools
|
||||
prompts = await session.list_prompts()
|
||||
print(f"Available prompts: {prompts}")
|
||||
|
||||
resources = await session.list_resources()
|
||||
print(f"Available resources: {resources}")
|
||||
|
||||
tools = await session.list_tools()
|
||||
print(f"Available tools: {tools}")
|
||||
|
||||
# Call the echo tool --> echo tool iisue
|
||||
echo_result = await session.call_tool(
|
||||
"echo_tool", arguments={"message": "Hello from client!"}
|
||||
)
|
||||
print(f"Tool result: {echo_result}")
|
||||
|
||||
# # Read the echo resource
|
||||
# echo_content, mime_type = await session.read_resource(
|
||||
# "echo://Hello_resource"
|
||||
# )
|
||||
# print(f"Resource content: {echo_content}")
|
||||
# print(f"Resource MIME type: {mime_type}")
|
||||
|
||||
# # Get and use the echo prompt
|
||||
# prompt_result = await session.get_prompt(
|
||||
# "echo_prompt", arguments={"message": "Hello prompt!"}
|
||||
# )
|
||||
# print(f"Prompt result: {prompt_result}")
|
||||
|
||||
# You can perform additional operations here as needed
|
||||
return prompts, resources, tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,109 @@
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize MCP server
|
||||
mcp = FastMCP(
|
||||
name="Agentic Security MCP Server",
|
||||
description="MCP server to interact with LLM scanning test",
|
||||
dependencies=["httpx"],
|
||||
)
|
||||
|
||||
# FastAPI Server Configuration
|
||||
AGENTIC_SECURITY = "http://0.0.0.0:8718"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def verify_llm(spec: str) -> dict:
|
||||
"""
|
||||
Verify an LLM model specification using the FastAPI server
|
||||
|
||||
Returns:
|
||||
dict: containing the verification result form the FastAPI server
|
||||
|
||||
Args: spect(str): The specification of the LLM model to verify.
|
||||
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/verify"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json={"spec": spec})
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def start_scan(
|
||||
llmSpec: str,
|
||||
maxBudget: int,
|
||||
optimize: bool = False,
|
||||
enableMultiStepAttack: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Start an LLM security scan via the FastAPI server.
|
||||
Returns:
|
||||
dict: The scan initiation result from the FastAPI server.
|
||||
|
||||
Args:
|
||||
llmSpec (str): The specification of the LLM model.
|
||||
maxBudget (int): The maximum budget for the scan.
|
||||
optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False.
|
||||
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack
|
||||
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/scan"
|
||||
payload = {
|
||||
"llmSpec": llmSpec,
|
||||
"maxBudget": maxBudget,
|
||||
"datasets": [],
|
||||
"optimize": optimize,
|
||||
"enableMultiStepAttack": enableMultiStepAttack,
|
||||
"probe_datasets": [],
|
||||
"secrets": {},
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=payload)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def stop_scan() -> dict:
|
||||
"""Stop an ongoing scan via the FastAPI server.
|
||||
|
||||
Returns:
|
||||
dict: The confirmation from the FastAPI server that the scan has been stopped.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/stop"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_data_config() -> list:
|
||||
"""
|
||||
Retrieve data configuration from the FastAPI server.
|
||||
|
||||
Returns:
|
||||
list: The response from the FastAPI server, confirming the scan has been stopped.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/v1/data-config"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_spec_templates() -> list:
|
||||
"""
|
||||
Retrieve data configuration from the FastAPI server.
|
||||
|
||||
Returns:
|
||||
list: The LLM specification templates from the FastAPI server.
|
||||
"""
|
||||
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
# Run the MCP server
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -1,5 +1,6 @@
|
||||
# noqa
|
||||
from agentic_security.primitives.models import CompletionRequest # noqa
|
||||
from agentic_security.primitives.models import ( # noqa
|
||||
CompletionRequest,
|
||||
FileProbeResponse,
|
||||
LLMInfo,
|
||||
Message,
|
||||
|
||||
@@ -3,32 +3,44 @@ import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
from skopt import Optimizer
|
||||
from skopt.space import Real
|
||||
|
||||
from agentic_security.config import settings_var
|
||||
from agentic_security.http_spec import Modality
|
||||
from agentic_security.logutils import logger
|
||||
from agentic_security.primitives import Scan, ScanResult
|
||||
from agentic_security.probe_actor.cost_module import calculate_cost
|
||||
from agentic_security.probe_actor.refusal import refusal_heuristic
|
||||
from agentic_security.probe_actor.state import FuzzerState
|
||||
from agentic_security.probe_data import audio_generator, image_generator, msj_data
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
# TODO: full log file
|
||||
|
||||
MAX_PROMPT_LENGTH = 2048
|
||||
BUDGET_MULTIPLIER = 100_000_000
|
||||
INITIAL_OPTIMIZER_POINTS = 25
|
||||
MIN_FAILURE_SAMPLES = 5
|
||||
FAILURE_RATE_THRESHOLD = 0.5
|
||||
MAX_PROMPT_LENGTH = settings_var("fuzzer.max_prompt_lenght", 2048)
|
||||
BUDGET_MULTIPLIER = settings_var("fuzzer.budget_multiplier", 100000000)
|
||||
INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25)
|
||||
MIN_FAILURE_SAMPLES = settings_var("min_failure_samples", 5)
|
||||
FAILURE_RATE_THRESHOLD = settings_var("failure_rate_threshold", 0.5)
|
||||
|
||||
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Asynchronously generates and yields individual prompts.
|
||||
|
||||
If the input is a list of strings, the function sequentially yields each string.
|
||||
If the input is an asynchronous generator, it forwards each generated prompt.
|
||||
|
||||
Args:
|
||||
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
|
||||
|
||||
Yields:
|
||||
str: An individual prompt from the list or the asynchronous generator.
|
||||
"""
|
||||
if isinstance(prompts, list):
|
||||
for prompt in prompts:
|
||||
yield prompt
|
||||
@@ -37,7 +49,21 @@ async def generate_prompts(
|
||||
yield prompt
|
||||
|
||||
|
||||
def multi_modality_spec(llm_spec):
|
||||
def get_modality_adapter(llm_spec):
|
||||
"""
|
||||
Returns the appropriate request adapter based on the modality of the LLM specification.
|
||||
|
||||
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
|
||||
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
|
||||
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
|
||||
|
||||
Args:
|
||||
llm_spec: An object containing modality information for the LLM.
|
||||
|
||||
Returns:
|
||||
RequestAdapter | llm_spec: An instance of the appropriate request adapter
|
||||
or the original `llm_spec` if no adaptation is needed.
|
||||
"""
|
||||
match llm_spec.modality:
|
||||
case Modality.IMAGE:
|
||||
return image_generator.RequestAdapter(llm_spec)
|
||||
@@ -50,40 +76,71 @@ def multi_modality_spec(llm_spec):
|
||||
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors, outputs
|
||||
request_factory,
|
||||
prompt: str,
|
||||
tokens: int,
|
||||
module_name: str,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Process a single prompt and update the token count and failure status.
|
||||
Processes a single prompt using the provided request factory and updates tracking lists.
|
||||
|
||||
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
|
||||
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
|
||||
the response indicates a refusal, the function records the issue and returns the updated token count
|
||||
along with a boolean indicating whether the prompt was refused.
|
||||
|
||||
Args:
|
||||
request_factory: An object with a `fn` method used to send the prompt.
|
||||
prompt (str): The input prompt to be processed.
|
||||
tokens (int): The current token count, which will be updated.
|
||||
module_name (str): The name of the module handling the request.
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
|
||||
Returns:
|
||||
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
|
||||
"""
|
||||
try:
|
||||
response = await request_factory.fn(prompt=prompt)
|
||||
|
||||
# Handle HTTP errors
|
||||
if response.status_code == 422:
|
||||
logger.error(f"Invalid prompt: {prompt}, error=422")
|
||||
errors.append((module_name, prompt, 422, "Invalid prompt"))
|
||||
fuzzer_state.add_error(module_name, prompt, 422, "Invalid prompt")
|
||||
return tokens, True
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.error(f"HTTP {response.status_code} {response.content=}")
|
||||
errors.append((module_name, prompt, response.status_code, response.text))
|
||||
fuzzer_state.add_error(
|
||||
module_name, prompt, response.status_code, response.text
|
||||
)
|
||||
return tokens, True
|
||||
|
||||
# Process successful response
|
||||
response_text = response.text
|
||||
tokens += len(response_text.split())
|
||||
|
||||
# Check if the response indicates a refusal
|
||||
refused = refusal_heuristic(response.json())
|
||||
if refused:
|
||||
refusals.append((module_name, prompt, response.status_code, response_text))
|
||||
fuzzer_state.add_refusal(
|
||||
module_name, prompt, response.status_code, response_text
|
||||
)
|
||||
|
||||
outputs.append((module_name, prompt, response_text, refused))
|
||||
fuzzer_state.add_output(module_name, prompt, response_text, refused)
|
||||
return tokens, refused
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error(f"Request error: {exc}")
|
||||
errors.append((module_name, prompt, "?", str(exc)))
|
||||
fuzzer_state.add_error(module_name, prompt, "?", str(exc))
|
||||
return tokens, True
|
||||
except JSONDecodeError as json_decode_error:
|
||||
logger.error(f"Jason error: {json_decode_error}")
|
||||
errors.append((module_name, prompt, "?", str(json_decode_error)))
|
||||
logger.error(f"JSON error: {json_decode_error}")
|
||||
fuzzer_state.add_error(module_name, prompt, "?", str(json_decode_error))
|
||||
return tokens, True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
return tokens, False
|
||||
|
||||
|
||||
async def process_prompt_batch(
|
||||
@@ -91,14 +148,29 @@ async def process_prompt_batch(
|
||||
prompts: list[str],
|
||||
tokens: int,
|
||||
module_name: str,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state: FuzzerState,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Processes a batch of prompts asynchronously and aggregates the results.
|
||||
|
||||
This function sends multiple prompts concurrently using `process_prompt`,
|
||||
collects the token count and failure status for each prompt, and returns
|
||||
the total number of tokens processed and the number of failed prompts.
|
||||
|
||||
Args:
|
||||
request_factory: An object with a `fn` method used to send the prompts.
|
||||
prompts (list[str]): A list of input prompts to be processed.
|
||||
tokens (int): The initial token count, which will be updated.
|
||||
module_name (str): The name of the module handling the request.
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
|
||||
Returns:
|
||||
tuple[int, int]:
|
||||
- Total number of tokens processed.
|
||||
- Number of failed prompts.
|
||||
"""
|
||||
tasks = [
|
||||
process_prompt(
|
||||
request_factory, p, tokens, module_name, refusals, errors, outputs
|
||||
)
|
||||
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
|
||||
for p in prompts
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -107,7 +179,154 @@ async def process_prompt_batch(
|
||||
return total_tokens, failures
|
||||
|
||||
|
||||
async def scan_module(
|
||||
request_factory,
|
||||
module,
|
||||
fuzzer_state: FuzzerState,
|
||||
processed_prompts: int = 0,
|
||||
total_prompts: int = 0,
|
||||
max_budget: int = 0,
|
||||
total_tokens: int = 0,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""
|
||||
Scan a single module.
|
||||
|
||||
Args:
|
||||
request_factory: The factory for creating requests
|
||||
module: The prompt module to scan
|
||||
fuzzer_state: State tracking object for the fuzzer
|
||||
processed_prompts: Number of prompts processed so far
|
||||
total_prompts: Total number of prompts to process
|
||||
max_budget: Maximum token budget
|
||||
total_tokens: Current token count
|
||||
optimize: Whether to use optimization
|
||||
stop_event: Event to stop scanning
|
||||
|
||||
Yields:
|
||||
ScanResult objects as the scan progresses
|
||||
"""
|
||||
tokens = 0
|
||||
module_failures = 0
|
||||
module_prompts = 0
|
||||
failure_rates = []
|
||||
should_stop = False
|
||||
|
||||
# Initialize optimizer if optimization is enabled
|
||||
optimizer = (
|
||||
Optimizer(
|
||||
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
|
||||
)
|
||||
if optimize
|
||||
else None
|
||||
)
|
||||
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=0,
|
||||
cost=0,
|
||||
progress=0,
|
||||
failureRate=0,
|
||||
prompt="",
|
||||
latency=0,
|
||||
model="",
|
||||
).model_dump_json()
|
||||
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
|
||||
processed_prompts += 1
|
||||
module_prompts += 1
|
||||
|
||||
# Calculate progress based on total processed prompts
|
||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
progress = progress % 100
|
||||
|
||||
total_tokens -= tokens
|
||||
start = time.time()
|
||||
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
total_tokens += tokens
|
||||
|
||||
if failed:
|
||||
module_failures += 1
|
||||
|
||||
failure_rate = module_failures / max(module_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
response_text = fuzzer_state.get_last_output(prompt) or ""
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
latency=end - start,
|
||||
model=response_text,
|
||||
).model_dump_json()
|
||||
|
||||
# Optimization logic
|
||||
if optimize and optimizer and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > FAILURE_RATE_THRESHOLD:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
# Budget check
|
||||
if total_tokens > max_budget:
|
||||
logger.info(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
yield ScanResult.status_msg(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def with_error_handling(agen):
|
||||
"""
|
||||
Wraps an asynchronous generator with error handling.
|
||||
|
||||
This function iterates over an asynchronous generator, yielding its values.
|
||||
If an exception occurs, it logs the error and yields a failure message.
|
||||
Finally, it ensures that a completion message is always yielded.
|
||||
|
||||
Args:
|
||||
agen: An asynchronous generator that produces scan results.
|
||||
|
||||
Yields:
|
||||
ScanResult: Either a successful result, an error message if an
|
||||
exception occurs, or a completion message at the end.
|
||||
"""
|
||||
try:
|
||||
async for t in agen:
|
||||
yield t
|
||||
@@ -123,14 +342,37 @@ async def perform_single_shot_scan(
|
||||
max_budget: int,
|
||||
datasets: list[dict[str, str]] = [],
|
||||
tools_inbox=None,
|
||||
optimize=False,
|
||||
stop_event: asyncio.Event = None,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
secrets: dict[str, str] = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a standard security scan."""
|
||||
"""
|
||||
Perform a standard security scan using a given request factory.
|
||||
|
||||
This function processes security scan prompts from selected datasets while
|
||||
respecting a predefined token budget. It supports optimization, failure tracking,
|
||||
and early stopping based on budget constraints or user intervention.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function that generates requests for processing prompts.
|
||||
max_budget (int): The maximum token budget for the scan.
|
||||
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
|
||||
tools_inbox: Optional additional tools for processing (default: None).
|
||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
||||
|
||||
Yields:
|
||||
str: JSON-encoded scan results or status messages.
|
||||
|
||||
The function iterates over prompts, processes them asynchronously, and updates
|
||||
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
|
||||
it stops execution. Results are saved to a CSV file upon completion.
|
||||
"""
|
||||
max_budget = max_budget * BUDGET_MULTIPLIER
|
||||
selected_datasets = [m for m in datasets if m["selected"]]
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
request_factory = get_modality_adapter(request_factory)
|
||||
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
dataset_names=[m["dataset_name"] for m in selected_datasets],
|
||||
@@ -140,108 +382,35 @@ async def perform_single_shot_scan(
|
||||
)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
optimizer = (
|
||||
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||
if optimize
|
||||
else None
|
||||
)
|
||||
failure_rates = []
|
||||
|
||||
total_tokens = 0
|
||||
tokens = 0
|
||||
should_stop = False
|
||||
for module in prompt_modules:
|
||||
if should_stop:
|
||||
break
|
||||
tokens = 0
|
||||
module_failures = 0
|
||||
module_gen = scan_module(
|
||||
request_factory=request_factory,
|
||||
module=module,
|
||||
fuzzer_state=fuzzer_state,
|
||||
processed_prompts=processed_prompts,
|
||||
total_prompts=total_prompts,
|
||||
max_budget=max_budget,
|
||||
total_tokens=total_tokens,
|
||||
optimize=optimize,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
try:
|
||||
async for result in module_gen:
|
||||
yield result
|
||||
except Exception:
|
||||
logger.error("Module exception")
|
||||
continue
|
||||
# Update processed_prompts count
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
module_prompts = 0 # Reset for each module
|
||||
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
|
||||
processed_prompts += 1
|
||||
module_prompts += 1 # Fixed increment syntax
|
||||
# Calculate progress based on total processed prompts
|
||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
progress = progress % 100
|
||||
|
||||
total_tokens -= tokens
|
||||
start = time.time()
|
||||
tokens, failed = await process_prompt(
|
||||
request_factory,
|
||||
prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
)
|
||||
end = time.time()
|
||||
total_tokens += tokens
|
||||
|
||||
if failed:
|
||||
module_failures += 1
|
||||
failure_rate = module_failures / max(module_prompts, 1)
|
||||
failure_rates.append(failure_rate)
|
||||
cost = calculate_cost(tokens)
|
||||
|
||||
last_output = outputs[-1] if outputs else None
|
||||
if last_output and last_output[1] == prompt:
|
||||
response_text = last_output[2]
|
||||
else:
|
||||
response_text = ""
|
||||
|
||||
yield ScanResult(
|
||||
module=module.dataset_name,
|
||||
tokens=round(tokens / 1000, 1),
|
||||
cost=cost,
|
||||
progress=round(progress, 2),
|
||||
failureRate=round(failure_rate * 100, 2),
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
latency=end - start,
|
||||
model=response_text,
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > 0.5:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
if total_tokens > max_budget:
|
||||
logger.info(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
yield ScanResult.status_msg(
|
||||
f"Scan ran out of budget and stopped. {total_tokens=} {max_budget=}"
|
||||
)
|
||||
should_stop = True
|
||||
break
|
||||
processed_prompts += module_size
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
failure_data = errors + refusals
|
||||
df = pd.DataFrame(
|
||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv("failures.csv", index=False)
|
||||
fuzzer_state.export_failures("failures.csv")
|
||||
|
||||
|
||||
async def perform_many_shot_scan(
|
||||
@@ -250,14 +419,39 @@ async def perform_many_shot_scan(
|
||||
datasets: list[dict[str, str]] = [],
|
||||
probe_datasets: list[dict[str, str]] = [],
|
||||
tools_inbox=None,
|
||||
optimize=False,
|
||||
stop_event: asyncio.Event = None,
|
||||
optimize: bool = False,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
probe_frequency: float = 0.2,
|
||||
max_ctx_length: int = 10_000,
|
||||
secrets: dict[str, str] = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a multi-step security scan with probe injection."""
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
"""
|
||||
Perform a multi-step security scan with probe injection.
|
||||
|
||||
This function executes a security scan while periodically injecting probe datasets
|
||||
to test system robustness. It tracks failures, optimizes scan efficiency,
|
||||
and ensures adherence to a predefined token budget.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function that generates requests for processing prompts.
|
||||
max_budget (int): The maximum token budget for the scan.
|
||||
datasets (list[dict[str, str]], optional): The main datasets for scanning.
|
||||
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
|
||||
tools_inbox: Optional tools for additional processing (default: None).
|
||||
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
|
||||
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
|
||||
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
|
||||
|
||||
Yields:
|
||||
str: JSON-encoded scan results or status messages.
|
||||
|
||||
This function iterates over prompts, injects probe prompts at random intervals,
|
||||
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
|
||||
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
|
||||
"""
|
||||
request_factory = get_modality_adapter(request_factory)
|
||||
# Load main and probe datasets
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
@@ -269,17 +463,10 @@ async def perform_many_shot_scan(
|
||||
msj_modules = msj_data.prepare_prompts(probe_datasets)
|
||||
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
||||
|
||||
errors = []
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = FuzzerState()
|
||||
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
|
||||
processed_prompts = 0
|
||||
|
||||
optimizer = (
|
||||
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=25)
|
||||
if optimize
|
||||
else None
|
||||
)
|
||||
failure_rates = []
|
||||
|
||||
for module in prompt_modules:
|
||||
@@ -293,6 +480,7 @@ async def perform_many_shot_scan(
|
||||
logger.info("Scan stopped by user.")
|
||||
yield ScanResult.status_msg("Scan stopped by user.")
|
||||
return
|
||||
|
||||
tokens = 0
|
||||
processed_prompts += 1
|
||||
progress = 100 * processed_prompts / total_prompts if total_prompts else 0
|
||||
@@ -320,9 +508,7 @@ async def perform_many_shot_scan(
|
||||
full_prompt,
|
||||
tokens,
|
||||
module.dataset_name,
|
||||
refusals,
|
||||
errors,
|
||||
outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
if failed:
|
||||
module_failures += 1
|
||||
@@ -343,30 +529,48 @@ async def perform_many_shot_scan(
|
||||
prompt=prompt[:MAX_PROMPT_LENGTH],
|
||||
).model_dump_json()
|
||||
|
||||
if optimize and len(failure_rates) >= 5:
|
||||
next_point = optimizer.ask()
|
||||
optimizer.tell(next_point, -failure_rate)
|
||||
best_failure_rate = -optimizer.get_result().fun
|
||||
if best_failure_rate > 0.5:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
break
|
||||
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
|
||||
yield ScanResult.status_msg(
|
||||
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
|
||||
)
|
||||
break
|
||||
|
||||
yield ScanResult.status_msg("Scan completed.")
|
||||
|
||||
df = pd.DataFrame(
|
||||
errors + refusals, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv("failures.csv", index=False)
|
||||
fuzzer_state.export_failures("failures.csv")
|
||||
|
||||
|
||||
def scan_router(
|
||||
request_factory,
|
||||
scan_parameters: Scan,
|
||||
tools_inbox=None,
|
||||
stop_event: asyncio.Event = None,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
):
|
||||
"""
|
||||
Route scan requests to the appropriate scanning function.
|
||||
|
||||
This function determines whether to perform a multi-step or single-shot
|
||||
security scan based on the provided scan parameters.
|
||||
|
||||
Args:
|
||||
request_factory: A factory function to generate requests for processing prompts.
|
||||
scan_parameters (Scan): An object containing the parameters for the scan, including:
|
||||
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
|
||||
- maxBudget (int): The maximum token budget for the scan.
|
||||
- datasets (list[dict[str, str]]): The datasets to scan.
|
||||
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
|
||||
- optimize (bool): Whether to enable optimization.
|
||||
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
|
||||
tools_inbox: Optional tools for additional processing (default: None).
|
||||
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
|
||||
|
||||
Returns:
|
||||
A function wrapped with `with_error_handling`, which executes either:
|
||||
- `perform_many_shot_scan` for multi-step scanning.
|
||||
- `perform_single_shot_scan` for single-shot scanning.
|
||||
|
||||
The function ensures that the appropriate scanning method is chosen based on
|
||||
the `enableMultiStepAttack` flag in `scan_parameters`.
|
||||
"""
|
||||
if scan_parameters.enableMultiStepAttack:
|
||||
return with_error_handling(
|
||||
perform_many_shot_scan(
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class FuzzerState:
|
||||
"""Container for tracking scan results"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.refusals = []
|
||||
self.outputs = []
|
||||
|
||||
def add_error(
|
||||
self,
|
||||
module_name: str,
|
||||
prompt: str,
|
||||
status_code: int | str,
|
||||
error_msg: str,
|
||||
):
|
||||
"""Add an error to the state"""
|
||||
self.errors.append((module_name, prompt, status_code, error_msg))
|
||||
|
||||
def add_refusal(
|
||||
self, module_name: str, prompt: str, status_code: int, response_text: str
|
||||
):
|
||||
"""Add a refusal to the state"""
|
||||
self.refusals.append((module_name, prompt, status_code, response_text))
|
||||
|
||||
def add_output(
|
||||
self, module_name: str, prompt: str, response_text: str, refused: bool
|
||||
):
|
||||
"""Add an output to the state"""
|
||||
self.outputs.append((module_name, prompt, response_text, refused))
|
||||
|
||||
def get_last_output(self, prompt: str) -> str | None:
|
||||
"""Get the last output for a given prompt"""
|
||||
for output in reversed(self.outputs):
|
||||
if output[1] == prompt:
|
||||
return output[2]
|
||||
return None
|
||||
|
||||
def export_failures(self, filename: str = "failures.csv"):
|
||||
"""Export failures to a CSV file"""
|
||||
failure_data = self.errors + self.refusals
|
||||
df = pd.DataFrame(
|
||||
failure_data, columns=["module", "prompt", "status_code", "content"]
|
||||
)
|
||||
df.to_csv(filename, index=False)
|
||||
@@ -1,4 +1,4 @@
|
||||
from .data import load_local_csv
|
||||
from .data import load_local_csv, load_local_csv_files
|
||||
|
||||
REGISTRY_V0 = [
|
||||
{
|
||||
@@ -484,3 +484,18 @@ REGISTRY = REGISTRY_V0 + [
|
||||
"modality": "text",
|
||||
},
|
||||
]
|
||||
|
||||
for ds in load_local_csv_files():
|
||||
REGISTRY.append(
|
||||
{
|
||||
"dataset_name": ds.dataset_name,
|
||||
"num_prompts": len(ds.prompts),
|
||||
"tokens": ds.prompts,
|
||||
"approx_cost": 0.0,
|
||||
"is_active": True,
|
||||
"source": f"Local file dataset: {ds.metadata['src']}",
|
||||
"selected": False,
|
||||
"url": "",
|
||||
"modality": "text",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import random
|
||||
from collections.abc import Callable, Iterator
|
||||
from functools import partial
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
@@ -21,15 +22,18 @@ from agentic_security.probe_data.modules import (
|
||||
)
|
||||
|
||||
# Type aliases for clarity
|
||||
T = TypeVar("T")
|
||||
FilterFn = Callable[[pd.Series], bool]
|
||||
ColumnMappings = dict[str, str]
|
||||
DatasetLoader = Callable[[], ProbeDataset]
|
||||
TransformFn = Callable[[str], str]
|
||||
|
||||
|
||||
# Core data loading utilities
|
||||
def fetch_csv_content(url: str) -> str:
|
||||
"""Fetch CSV content from a URL."""
|
||||
response = httpx.get(url)
|
||||
response.raise_for_status() # Raise exception for bad responses
|
||||
return response.content.decode("utf-8")
|
||||
|
||||
|
||||
@@ -57,7 +61,7 @@ def transform_df(
|
||||
|
||||
|
||||
def create_probe_dataset(
|
||||
name: str, prompts: list[str], metadata: dict = None
|
||||
name: str, prompts: list[str], metadata: dict[str, Any] | None = None
|
||||
) -> ProbeDataset:
|
||||
"""Create a ProbeDataset from prompts."""
|
||||
metadata = metadata or {}
|
||||
@@ -77,14 +81,46 @@ def load_dataset_generic(
|
||||
mappings: ColumnMappings | None = None,
|
||||
filter_fn: FilterFn | None = None,
|
||||
url: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ProbeDataset:
|
||||
"""Load and process a dataset with flexible configuration."""
|
||||
df = load_df_from_source(url or name, is_url=bool(url))
|
||||
transformed_df = transform_df(df, mappings, filter_fn)
|
||||
prompt_col = mappings.get("prompt", "prompt") if mappings else "prompt"
|
||||
prompts = transformed_df[prompt_col].tolist()
|
||||
return create_probe_dataset(name, prompts, metadata)
|
||||
try:
|
||||
df = load_df_from_source(url or name, is_url=bool(url))
|
||||
transformed_df = transform_df(df, mappings, filter_fn)
|
||||
|
||||
# Determine which column to use as the prompt source
|
||||
prompt_col = None
|
||||
if mappings and "prompt" in mappings:
|
||||
prompt_col = mappings["prompt"]
|
||||
elif "prompt" in transformed_df.columns:
|
||||
prompt_col = "prompt"
|
||||
else:
|
||||
# Try to find a suitable text column
|
||||
text_columns = [
|
||||
col
|
||||
for col in transformed_df.columns
|
||||
if any(
|
||||
keyword in col.lower()
|
||||
for keyword in ["prompt", "text", "query", "question"]
|
||||
)
|
||||
]
|
||||
if text_columns:
|
||||
prompt_col = text_columns[0]
|
||||
logger.info(f"Using column '{prompt_col}' as prompt source")
|
||||
else:
|
||||
logger.error(f"No suitable prompt column found in dataset {name}")
|
||||
return create_probe_dataset(name, [], metadata)
|
||||
|
||||
# Extract prompts and filter out empty ones
|
||||
prompts = [
|
||||
p
|
||||
for p in transformed_df[prompt_col].tolist()
|
||||
if p and isinstance(p, (str, int, float))
|
||||
]
|
||||
return create_probe_dataset(name, prompts, metadata)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset {name}: {e}")
|
||||
return create_probe_dataset(name, [], {"error": str(e)})
|
||||
|
||||
|
||||
# Dataset-specific configurations
|
||||
@@ -159,7 +195,7 @@ DATASET_CONFIGS_GENERICS = {
|
||||
|
||||
|
||||
# Dataset factory
|
||||
def create_dataset_loader(name: str, config: dict) -> DatasetLoader:
|
||||
def create_dataset_loader(name: str, config: dict[str, Any]) -> DatasetLoader:
|
||||
"""Create a dataset loader from configuration."""
|
||||
return partial(
|
||||
load_dataset_generic,
|
||||
@@ -167,6 +203,7 @@ def create_dataset_loader(name: str, config: dict) -> DatasetLoader:
|
||||
mappings=config.get("mappings"),
|
||||
filter_fn=config.get("filter_fn"),
|
||||
url=config.get("url"),
|
||||
metadata={"source": name, "config": str(config)},
|
||||
)
|
||||
|
||||
|
||||
@@ -176,39 +213,82 @@ def load_multi_dataset(name: str, sub_datasets: list[str]) -> ProbeDataset:
|
||||
"""Load and combine multiple sub-datasets."""
|
||||
prompts = []
|
||||
for sub in sub_datasets:
|
||||
dataset = load_dataset(name, sub)
|
||||
prompts.extend(dataset["train"]["query"])
|
||||
return create_probe_dataset(f"{name}_combined", prompts)
|
||||
try:
|
||||
dataset = load_dataset(name, sub)
|
||||
if "query" in dataset["train"].features:
|
||||
prompts.extend(dataset["train"]["query"])
|
||||
else:
|
||||
logger.warning(f"No 'query' column in {name}/{sub}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {name}/{sub}: {e}")
|
||||
|
||||
return create_probe_dataset(
|
||||
f"{name}_combined", prompts, {"source": name, "sub_datasets": sub_datasets}
|
||||
)
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_jailbreak_v28k() -> ProbeDataset:
|
||||
"""Load JailBreakV-28K dataset."""
|
||||
df = pd.read_csv("hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv")
|
||||
prompts = df["jailbreak_query"].tolist()
|
||||
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", prompts)
|
||||
try:
|
||||
df = pd.read_csv(
|
||||
"hf://datasets/JailbreakV-28K/JailBreakV-28k/JailBreakV_28K.csv"
|
||||
)
|
||||
prompts = df["jailbreak_query"].tolist()
|
||||
return create_probe_dataset(
|
||||
"JailbreakV-28K/JailBreakV-28k",
|
||||
prompts,
|
||||
{"source": "JailbreakV-28K/JailBreakV-28k"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading JailbreakV-28K: {e}")
|
||||
return create_probe_dataset("JailbreakV-28K/JailBreakV-28k", [])
|
||||
|
||||
|
||||
@cache_to_disk(1)
|
||||
def file_dataset(file) -> list[str]:
|
||||
prompts = []
|
||||
try:
|
||||
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
|
||||
if "prompt" in df.columns:
|
||||
prompts = df["prompt"].tolist()
|
||||
else:
|
||||
logger.warning(f"File {file} lacks a suitable prompt column")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {file}: {e}")
|
||||
return prompts
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
def load_local_csv() -> ProbeDataset:
|
||||
"""Load prompts from local CSV files."""
|
||||
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
|
||||
os.makedirs("./datasets", exist_ok=True)
|
||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
||||
|
||||
prompts = []
|
||||
for file in csv_files:
|
||||
try:
|
||||
df = pd.read_csv(file)
|
||||
if "prompt" in df.columns:
|
||||
prompts.extend(df["prompt"].tolist())
|
||||
else:
|
||||
logger.warning(f"File {file} lacks 'prompt' column")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading {file}: {e}")
|
||||
|
||||
prompts.extend(file_dataset(file))
|
||||
return create_probe_dataset("Local CSV", prompts, {"src": str(csv_files)})
|
||||
|
||||
|
||||
def load_csv(file: str) -> ProbeDataset:
|
||||
"""Load prompts from local CSV files."""
|
||||
prompts = file_dataset(file)
|
||||
return create_probe_dataset(f"fs://{file}", prompts, {"src": str(file)})
|
||||
|
||||
|
||||
def load_local_csv_files() -> list[ProbeDataset]:
|
||||
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
|
||||
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
|
||||
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
|
||||
|
||||
datasets = []
|
||||
|
||||
for file in csv_files:
|
||||
datasets.append(create_probe_dataset(file, file_dataset(file), {"src": file}))
|
||||
return datasets
|
||||
|
||||
|
||||
# Stenography transformer
|
||||
class StenographyTransformer:
|
||||
"""Apply stenography transformations to datasets."""
|
||||
@@ -231,36 +311,67 @@ class StenographyTransformer:
|
||||
|
||||
def transform(self) -> Iterator[ProbeDataset]:
|
||||
"""Generate transformed datasets."""
|
||||
if not self.datasets:
|
||||
logger.warning("No datasets provided to StenographyTransformer")
|
||||
return
|
||||
|
||||
for dataset in self.datasets:
|
||||
size = len(dataset.prompts)
|
||||
sample_size = int(size * 0.25)
|
||||
if not dataset.prompts:
|
||||
logger.warning(
|
||||
f"Dataset {dataset.dataset_name} has no prompts to transform"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert async generator (if any) to a list
|
||||
prompts = (
|
||||
list(dataset.prompts) if hasattr(dataset.prompts, "__iter__") else []
|
||||
)
|
||||
size = len(prompts)
|
||||
|
||||
if size == 0:
|
||||
continue
|
||||
|
||||
sample_size = max(1, int(size * 0.25)) # At least 1 sample
|
||||
|
||||
for name, fn in self.TRANSFORMATIONS.items():
|
||||
logger.info(f"Applying {name} to {dataset.dataset_name}")
|
||||
sampled = random.sample(dataset.prompts, sample_size)
|
||||
transformed = [
|
||||
f"Follow the instructions in {name}: {fn(p)}" for p in sampled
|
||||
]
|
||||
yield create_probe_dataset(
|
||||
f"stenography.{name}({dataset.dataset_name})", transformed
|
||||
)
|
||||
sampled = random.sample(prompts, min(sample_size, size))
|
||||
try:
|
||||
transformed = [
|
||||
f"Follow the instructions in {name}: {fn(str(p))}"
|
||||
for p in sampled
|
||||
]
|
||||
transformed_dataset = create_probe_dataset(
|
||||
f"stenography.{name}({dataset.dataset_name})",
|
||||
transformed,
|
||||
{"source": dataset.dataset_name, "transformation": name},
|
||||
)
|
||||
yield transformed_dataset
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error applying {name} to {dataset.dataset_name}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def dataset_from_iterator(
|
||||
name: str, iterator, lazy: bool = False
|
||||
name: str, iterator: Iterator[str], lazy: bool = False
|
||||
) -> list[ProbeDataset]:
|
||||
"""Convert an iterator into a list of ProbeDataset objects."""
|
||||
prompts = list(iterator) if not lazy else iterator
|
||||
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0
|
||||
dataset = ProbeDataset(
|
||||
dataset_name=name,
|
||||
metadata={},
|
||||
prompts=prompts,
|
||||
tokens=tokens,
|
||||
approx_cost=0.0,
|
||||
lazy=lazy,
|
||||
)
|
||||
return [dataset]
|
||||
try:
|
||||
prompts = list(iterator) if not lazy else iterator
|
||||
tokens = sum(len(str(s).split()) for s in prompts) if not lazy else 0
|
||||
dataset = ProbeDataset(
|
||||
dataset_name=name,
|
||||
metadata={"source": name, "lazy": lazy},
|
||||
prompts=prompts,
|
||||
tokens=tokens,
|
||||
approx_cost=0.0,
|
||||
lazy=lazy,
|
||||
)
|
||||
return [dataset]
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating dataset from iterator {name}: {e}")
|
||||
return [create_probe_dataset(name, [], {"error": str(e)})]
|
||||
|
||||
|
||||
# Main dataset preparation
|
||||
@@ -272,6 +383,7 @@ def prepare_prompts(
|
||||
) -> list[ProbeDataset]:
|
||||
"""Prepare datasets based on names and options."""
|
||||
# Base dataset loaders
|
||||
logger.info(f"Preparing datasets: {dataset_names}")
|
||||
dataset_loaders = {
|
||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS.items()},
|
||||
**{k: create_dataset_loader(k, v) for k, v in DATASET_CONFIGS_GENERICS.items()},
|
||||
@@ -288,28 +400,39 @@ def prepare_prompts(
|
||||
),
|
||||
"JailbreakV-28K/JailBreakV-28k": load_jailbreak_v28k,
|
||||
"Local CSV": load_local_csv,
|
||||
"Custom CSV": load_local_csv,
|
||||
}
|
||||
|
||||
# Dynamic dataset loaders
|
||||
dynamic_loaders = {
|
||||
"AgenticBackend": lambda opts: dataset_from_iterator(
|
||||
"AgenticBackend",
|
||||
fine_tuned.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
fine_tuned.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Steganography": lambda opts: list(StenographyTransformer([]).transform()),
|
||||
"Steganography": lambda opts: list(
|
||||
StenographyTransformer(opts["datasets"]).transform()
|
||||
),
|
||||
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
|
||||
"llm-adaptive-attacks",
|
||||
adaptive_attacks.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
adaptive_attacks.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
),
|
||||
"Garak": lambda opts: dataset_from_iterator(
|
||||
"Garak",
|
||||
garak_tool.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
garak_tool.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
||||
"Reinforcement Learning Optimization",
|
||||
rl_model.Module([], tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
rl_model.Module(
|
||||
opts["datasets"], tools_inbox=tools_inbox, opts=opts
|
||||
).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"InspectAI": lambda opts: dataset_from_iterator(
|
||||
@@ -320,28 +443,35 @@ def prepare_prompts(
|
||||
"GPT fuzzer": lambda opts: [],
|
||||
}
|
||||
|
||||
options = options or [{} for _ in dataset_names]
|
||||
datasets = []
|
||||
options = options or [dict(datasets=datasets) for _ in dataset_names]
|
||||
|
||||
# Load base datasets
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if name in dataset_loaders:
|
||||
logger.info(f"Loading base dataset {name}")
|
||||
try:
|
||||
datasets.append(dataset_loaders[name]())
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {name}: {e}")
|
||||
if name not in dataset_loaders:
|
||||
continue
|
||||
try:
|
||||
datasets.append(dataset_loaders[name]())
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {name}: {e}")
|
||||
|
||||
# Load dynamic datasets and apply transformations
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if name in dynamic_loaders:
|
||||
logger.info(f"Loading dynamic dataset {name}")
|
||||
try:
|
||||
dynamic_result = dynamic_loaders[name](opts)
|
||||
datasets.extend(dynamic_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dynamic {name}: {e}")
|
||||
elif name == "Steganography":
|
||||
datasets.extend(list(StenographyTransformer(datasets).transform()))
|
||||
if name not in dynamic_loaders:
|
||||
continue
|
||||
logger.info(f"Loading dynamic dataset {name} {opts}")
|
||||
opts["datasets"] = datasets
|
||||
try:
|
||||
dynamic_result = dynamic_loaders[name](opts)
|
||||
datasets.extend(dynamic_result)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error loading dynamic {name}: {e}")
|
||||
|
||||
# Load csv datasets and apply transformations
|
||||
for name, opts in zip(dataset_names, options):
|
||||
if not name.endswith(".csv"):
|
||||
continue
|
||||
logger.info(f"Loading csv dataset {name} {opts}")
|
||||
datasets.append(load_csv(name))
|
||||
|
||||
return datasets
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cache_to_disk import cache_to_disk
|
||||
from cache_to_disk import cache_to_disk # noqa
|
||||
|
||||
|
||||
# TODO: refactor this class to use from .data
|
||||
@@ -22,7 +22,7 @@ class ProbeDataset:
|
||||
}
|
||||
|
||||
|
||||
@cache_to_disk()
|
||||
# @cache_to_disk(n_days_to_cache=1)
|
||||
def load_dataset_generic(name, getter=lambda x: x["train"]["prompt"]):
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def plot_security_report(table: Table) -> io.BytesIO:
|
||||
try:
|
||||
return _plot_security_report(table=table)
|
||||
except (TypeError, ValueError, OverflowError, IndexError, Exception) as e:
|
||||
logger.error(f"Error in generating the security report: {e}")
|
||||
logger.error(f"Error in generating the security report: {e} {table}")
|
||||
return io.BytesIO()
|
||||
|
||||
|
||||
@@ -40,11 +40,7 @@ def generate_identifiers(data: pd.DataFrame) -> list[str]:
|
||||
Returns:
|
||||
list[str]: A list of generated identifiers. Returns a list with an empty string in case of an error.
|
||||
"""
|
||||
try:
|
||||
_generate_identifiers(data=data)
|
||||
except (TypeError, ValueError, Exception) as e:
|
||||
logger.error(f"Error in generate_identifiers: {e}")
|
||||
return [""]
|
||||
return _generate_identifiers(data=data)
|
||||
|
||||
|
||||
def _plot_security_report(table: Table) -> io.BytesIO:
|
||||
|
||||
@@ -0,0 +1,397 @@
|
||||
_SPECS = [
|
||||
"""POST ${SELF_URL}/v1/self-probe
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>"
|
||||
}
|
||||
|
||||
""",
|
||||
"""POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer $OPENAI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
""",
|
||||
"""
|
||||
POST https://api.deepseek.com/chat/completions
|
||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
""",
|
||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"input": {
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
||||
"temperature": 0.6,
|
||||
"max_new_tokens": 1024,
|
||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}
|
||||
""",
|
||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model_id": "codellama-34b",
|
||||
"system_prompt": "You are helpful and concise coding assistant",
|
||||
"user_prompt": "<<PROMPT>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.together.xyz/v1/chat/completions
|
||||
Authorization: Bearer $TOGETHER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert travel guide"},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
]
|
||||
}
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-image
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-file
|
||||
Authorization: Bearer $API_KEY
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
"file": "@./sample_audio.m4a",
|
||||
"model": "whisper-large-v3"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.gemini.com/v1/generate
|
||||
Authorization: Bearer $GEMINI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gemini-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.anthropic.com/v1/complete
|
||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens_to_sample": 256,
|
||||
"stop_sequences": ["\n\nHuman:"]
|
||||
}
|
||||
""",
|
||||
"""POST https://api.cohere.ai/generate
|
||||
Authorization: Bearer $COHERE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "command-xlarge-nightly",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.75,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
}
|
||||
""",
|
||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
||||
Authorization: Bearer $AZURE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.assemblyai.com/v2/transcript
|
||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
||||
Authorization: Bearer $OPENROUTER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "openrouter-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
LLM_SPECS = [
|
||||
"""POST ${SELF_URL}/v1/self-probe
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>"
|
||||
}
|
||||
|
||||
""",
|
||||
"""POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer $OPENAI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "<<PROMPT>>"}],
|
||||
"temperature": 0.7
|
||||
}
|
||||
""",
|
||||
"""
|
||||
POST https://api.deepseek.com/chat/completions
|
||||
Authorization: Bearer $DEEPSEEK_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
""",
|
||||
"""POST https://api.replicate.com/v1/models/mistralai/mixtral-8x7b-instruct-v0.1/predictions
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"input": {
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"prompt": "Write a bedtime story about neural networks I can read to my toddler",
|
||||
"temperature": 0.6,
|
||||
"max_new_tokens": 1024,
|
||||
"prompt_template": "<s>[INST] <<PROMPT>> [/INST] ",
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}
|
||||
""",
|
||||
"""POST https://api.groq.com/v1/request_manager/text_completion
|
||||
Authorization: Bearer $APIKEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model_id": "codellama-34b",
|
||||
"system_prompt": "You are helpful and concise coding assistant",
|
||||
"user_prompt": "<<PROMPT>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.together.xyz/v1/chat/completions
|
||||
Authorization: Bearer $TOGETHER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert travel guide"},
|
||||
{"role": "user", "content": "<<PROMPT>>"}
|
||||
]
|
||||
}
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-image
|
||||
Authorization: Bearer XXXXX
|
||||
Content-Type: application/json
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{<<BASE64_IMAGE>>}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
""",
|
||||
"""POST ${SELF_URL}/v1/self-probe-file
|
||||
Authorization: Bearer $API_KEY
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
{
|
||||
"file": "@./sample_audio.m4a",
|
||||
"model": "whisper-large-v3"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.gemini.com/v1/generate
|
||||
Authorization: Bearer $GEMINI_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gemini-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.anthropic.com/v1/complete
|
||||
Authorization: Bearer $ANTHROPIC_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens_to_sample": 256,
|
||||
"stop_sequences": ["\n\nHuman:"]
|
||||
}
|
||||
""",
|
||||
"""POST https://api.cohere.ai/generate
|
||||
Authorization: Bearer $COHERE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "command-xlarge-nightly",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.75,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
}
|
||||
""",
|
||||
"""POST https://<<RESOURCE_NAME>>.openai.azure.com/openai/deployments/<<DEPLOYMENT_NAME>>/completions?api-version=2023-06-01-preview
|
||||
Authorization: Bearer $AZURE_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"prompt": "<<PROMPT>>",
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
"""POST https://api.assemblyai.com/v2/transcript
|
||||
Authorization: Bearer $ASSEMBLY_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"audio_url": "<<AUDIO_FILE_URL>>"
|
||||
}
|
||||
""",
|
||||
"""POST https://api.openrouter.ai/v1/chat/completions
|
||||
Authorization: Bearer $OPENROUTER_API_KEY
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "openrouter-latest",
|
||||
"prompt": "<<PROMPT>>",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
LLM_CONFIGS = [
|
||||
{
|
||||
"name": "Custom API",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{"name": "Open AI", "prompts": 24000, "logo": "/icons/openai.png"},
|
||||
{"name": "Deepseek v1", "prompts": 24000, "logo": "/icons/deepseek.png"},
|
||||
{"name": "Replicate", "prompts": 40000, "logo": "/icons/replicate.png"},
|
||||
{"name": "Groq", "prompts": 40000, "logo": "/icons/groq.png"},
|
||||
{"name": "Together.ai", "prompts": 40000, "logo": "/icons/together.png"},
|
||||
{
|
||||
"name": "Custom API Image",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"modality": "Image",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{
|
||||
"name": "Custom API Files",
|
||||
"prompts": 40000,
|
||||
"customInstructions": "Requires api spec",
|
||||
"modality": "Files",
|
||||
"logo": "/icons/myshell.png",
|
||||
},
|
||||
{"name": "Gemini", "prompts": 40000, "logo": "/icons/gemini.png"},
|
||||
{"name": "Claude", "prompts": 40000, "logo": "/icons/claude.png"},
|
||||
{"name": "Cohere", "prompts": 40000, "logo": "/icons/cohere.png"},
|
||||
{"name": "Azure OpenAI", "prompts": 40000, "logo": "/icons/azureai.png"},
|
||||
{"name": "assemblyai", "prompts": 40000, "logo": "/icons/myshell.png"},
|
||||
{"name": "OpenRouter.ai", "prompts": 40000, "logo": "/icons/openrouter.png"},
|
||||
]
|
||||
|
||||
LLM_SPECS = [dict(spec=spec, **d) for spec, d in zip(_SPECS, LLM_CONFIGS)]
|
||||
@@ -6,6 +6,7 @@ from fastapi.responses import JSONResponse
|
||||
from ..primitives import FileProbeResponse, Probe
|
||||
from ..probe_actor.refusal import REFUSAL_MARKS
|
||||
from ..probe_data import REGISTRY
|
||||
from ._specs import LLM_SPECS
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -73,6 +74,12 @@ async def data_config():
|
||||
return [m for m in REGISTRY]
|
||||
|
||||
|
||||
@router.get("/v1/llm-specs", response_model=list)
|
||||
def get_llm_specs():
|
||||
"""Returns the LLM API specifications."""
|
||||
return LLM_SPECS
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
|
||||
@@ -17,7 +17,7 @@ from agentic_security.logutils import logger
|
||||
|
||||
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
|
||||
from ..dependencies import InMemorySecrets, get_in_memory_secrets
|
||||
from ..http_spec import LLMSpec
|
||||
from ..http_spec import InvalidHTTPSpecError, LLMSpec
|
||||
from ..primitives import LLMInfo, Scan
|
||||
from ..probe_actor import fuzzer
|
||||
|
||||
@@ -31,6 +31,8 @@ async def verify(
|
||||
spec = LLMSpec.from_string(info.spec)
|
||||
try:
|
||||
r = await spec.verify()
|
||||
except InvalidHTTPSpecError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -110,19 +110,21 @@ var app = new Vue({
|
||||
},
|
||||
focusTextarea() {
|
||||
this.isFocused = true;
|
||||
self = this.$refs;
|
||||
// Remove 'self' assignment if not used elsewhere
|
||||
this.$nextTick(() => {
|
||||
// Focus the textarea after rendering
|
||||
self.textarea.focus();
|
||||
this.adjustHeight({ target: self.textarea });
|
||||
this.$refs.textarea.focus();
|
||||
this.adjustHeight({ target: this.$refs.textarea });
|
||||
});
|
||||
document.addEventListener("mousedown", this.handleClickOutside);
|
||||
|
||||
// Correct the event listener to use handleOutsideClick
|
||||
document.addEventListener("mousedown", this.handleOutsideClick);
|
||||
},
|
||||
handleOutsideClick(event) {
|
||||
if (!this.$refs.container.contains(event.target)) {
|
||||
if (!this.$refs.textarea) {
|
||||
return
|
||||
}
|
||||
if (!this.$refs.textarea.contains(event.target)) {
|
||||
this.isFocused = false;
|
||||
document.removeEventListener("mousedown", this.handleClickOutside);
|
||||
document.removeEventListener("mousedown", this.handleOutsideClick);
|
||||
}
|
||||
},
|
||||
unfocusTextarea() {
|
||||
@@ -130,7 +132,12 @@ var app = new Vue({
|
||||
},
|
||||
acceptConsent() {
|
||||
this.showConsentModal = false; // Close the modal
|
||||
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
||||
|
||||
try {
|
||||
localStorage.setItem('consentGiven', 'true'); // Save consent to local storage
|
||||
} catch (e) {
|
||||
this.showToast('Failed to save consent', 'error'); // Show error if saving fails
|
||||
}
|
||||
},
|
||||
|
||||
saveStateToLocalStorage() {
|
||||
@@ -171,6 +178,7 @@ var app = new Vue({
|
||||
this.integrationVerified = false;
|
||||
this.showResetConfirmation = false;
|
||||
this.enableMultiStepAttack = false;
|
||||
this.showToast('All settings have been reset to default', 'info');
|
||||
},
|
||||
confirmResetState() {
|
||||
this.showResetConfirmation = true;
|
||||
@@ -209,33 +217,39 @@ var app = new Vue({
|
||||
spec: this.modelSpec,
|
||||
};
|
||||
let startTime = performance.now(); // Capture start time
|
||||
const response = await fetch(`${SELF_URL}/verify`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
console.log(response);
|
||||
let r = await response.json();
|
||||
let endTime = performance.now(); // Capture end time
|
||||
let latency = endTime - startTime; // Calculate latency in milliseconds
|
||||
latency = latency.toFixed(3) / 1000; // Round to 2 decimal places
|
||||
this.latency = latency;
|
||||
if (!response.ok) {
|
||||
this.updateStatusDot(false);
|
||||
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
||||
this.showToast('Integration verification failed', 'error');
|
||||
} else {
|
||||
this.errorMsg = '';
|
||||
this.updateStatusDot(true);
|
||||
this.okMsg = 'Integration verified';
|
||||
this.showToast('Integration verified successfully', 'success');
|
||||
this.integrationVerified = true;
|
||||
// console.log('Integration verified', this.integrationVerified);
|
||||
// this.$forceUpdate();
|
||||
|
||||
try {
|
||||
const response = await fetch(`${SELF_URL}/verify`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
let r = await response.json();
|
||||
|
||||
let endTime = performance.now(); // Capture end time
|
||||
let latency = ((endTime - startTime) / 1000).toFixed(3); // Calculate latency in milliseconds
|
||||
this.latency = latency;
|
||||
|
||||
if (!response.ok) {
|
||||
this.updateStatusDot(false);
|
||||
this.errorMsg = 'Integration verification failed:' + JSON.stringify(r);
|
||||
this.showToast('Integration verification failed', 'error');
|
||||
} else {
|
||||
this.errorMsg = '';
|
||||
this.updateStatusDot(true);
|
||||
this.okMsg = 'Integration verified';
|
||||
this.showToast('Integration verified successfully', 'success');
|
||||
this.integrationVerified = true;
|
||||
}
|
||||
} catch (error) {
|
||||
this.updateStatusDot(true);
|
||||
this.errorMsg = 'Server unreachable';
|
||||
this.showToast('Network error', 'error');
|
||||
}
|
||||
|
||||
this.saveStateToLocalStorage();
|
||||
},
|
||||
loadConfigs: async function () {
|
||||
@@ -257,6 +271,7 @@ var app = new Vue({
|
||||
this.errorMsg = '';
|
||||
this.okMsg = '';
|
||||
this.integrationVerified = false;
|
||||
this.showToast(`Config ${index + 1} selected`, 'info');
|
||||
},
|
||||
toggleModules() {
|
||||
this.showModules = !this.showModules;
|
||||
@@ -344,6 +359,7 @@ var app = new Vue({
|
||||
return
|
||||
}
|
||||
console.log('New row');
|
||||
this.showToast('New module', 'success');
|
||||
let payload = {
|
||||
table: this.mainTable,
|
||||
};
|
||||
@@ -454,6 +470,8 @@ var app = new Vue({
|
||||
}
|
||||
});
|
||||
}
|
||||
this.scanRunning = false;
|
||||
this.showToast('Scan finished successfully', 'success');
|
||||
this.saveStateToLocalStorage();
|
||||
|
||||
}
|
||||
|
||||
Generated
+217
-301
File diff suppressed because it is too large
Load Diff
+9
-2
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "agentic_security"
|
||||
version = "0.7.0"
|
||||
version = "0.7.3"
|
||||
description = "Agentic LLM vulnerability scanner"
|
||||
authors = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||
maintainers = ["Alexander Miasoiedov <msoedov@gmail.com>"]
|
||||
@@ -52,6 +52,7 @@ sentry_sdk = "^2.22.0"
|
||||
orjson = "^3.10"
|
||||
pyfiglet = "^1.0.2"
|
||||
termcolor = "^2.4.0"
|
||||
mcp = "^1.4.1"
|
||||
|
||||
# garak = { version = "*", optional = true }
|
||||
pytest-xdist = "3.6.1"
|
||||
@@ -69,7 +70,7 @@ pytest-mock = "^3.14.0"
|
||||
black = ">=24.10,<26.0"
|
||||
mypy = "^1.12.0"
|
||||
pre-commit = "^4.0.1"
|
||||
huggingface-hub = ">=0.25.1,<0.29.0"
|
||||
huggingface-hub = ">=0.25.1,<0.30.0"
|
||||
|
||||
# Docs
|
||||
mkdocs = ">=1.4.2"
|
||||
@@ -91,3 +92,9 @@ addopts = "--durations=5 -m 'not slow' -n 3"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
markers = "slow: marks tests as slow"
|
||||
|
||||
[project]
|
||||
# MCP requires the following fields to be present in the pyproject.toml file
|
||||
name = "agentic_security"
|
||||
version = "1.0.0"
|
||||
requires-python = ">=3.11"
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from cache_to_disk import delete_old_disk_caches
|
||||
|
||||
from agentic_security.logutils import logger
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "slow" in item.keywords and not os.getenv("RUN_SLOW_TESTS"):
|
||||
pytest.skip("Skipping slow test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def setup_delete_old_disk_caches():
|
||||
logger.info("delete_old_disk_caches")
|
||||
delete_old_disk_caches()
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from agentic_security.primitives import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
FuzzerState,
|
||||
generate_prompts,
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
@@ -207,9 +208,7 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=FuzzerState(),
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
|
||||
@@ -226,20 +225,17 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
outputs = []
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=outputs,
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
|
||||
self.assertFalse(refusal)
|
||||
# self.assertFalse(fuzzer_state.refusals)
|
||||
|
||||
async def test_http_error_response(self):
|
||||
mock_request_factory = Mock()
|
||||
@@ -252,15 +248,13 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
refusals = []
|
||||
fuzzer_state = FuzzerState()
|
||||
await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=refusals,
|
||||
errors=[],
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
async def test_request_error(self):
|
||||
@@ -269,18 +263,14 @@ class TestProcessPrompt(unittest.IsolatedAsyncioTestCase):
|
||||
side_effect=httpx.RequestError("Connection error")
|
||||
)
|
||||
|
||||
errors = []
|
||||
fuzzer_state = FuzzerState()
|
||||
tokens, refusal = await process_prompt(
|
||||
request_factory=mock_request_factory,
|
||||
prompt="test prompt",
|
||||
tokens=0,
|
||||
module_name="module_a",
|
||||
refusals=[],
|
||||
errors=errors,
|
||||
outputs=[],
|
||||
fuzzer_state=fuzzer_state,
|
||||
)
|
||||
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertTrue(refusal)
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertIn("Connection error", errors[0][3])
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from agentic_security.mcp.client import run
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_echo_tool():
|
||||
"""Test the echo tool functionality"""
|
||||
prompts, resources, tools = await run()
|
||||
assert prompts
|
||||
assert resources
|
||||
assert tools
|
||||
Generated
+3
-3
@@ -6891,9 +6891,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/http-proxy-middleware": {
|
||||
"version": "2.0.7",
|
||||
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.7.tgz",
|
||||
"integrity": "sha512-fgVY8AV7qU7z/MmXJ/rxwbrtQH4jBQ9m7kp3llF0liB7glmFeVZFBepQb32T3y8n8k2+AEYuMPCpinYW+/CuRA==",
|
||||
"version": "2.0.9",
|
||||
"resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz",
|
||||
"integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
Reference in New Issue
Block a user