feat(Update modules interface):

This commit is contained in:
Alexander Myasoedov
2024-12-24 23:29:20 +02:00
parent 79cbdf6c4d
commit 1335be9b0b
12 changed files with 69 additions and 53 deletions
+13
View File
@@ -4,6 +4,7 @@ from fastapi import FastAPI
tools_inbox: Queue = Queue()
stop_event: Event = Event()
current_run: str = {"spec": "", "id": ""}
def create_app() -> FastAPI:
@@ -20,3 +21,15 @@ def get_tools_inbox() -> Queue:
def get_stop_event() -> Event:
"""Get the global stop event."""
return stop_event
def get_current_run() -> str:
"""Get the current run id."""
return current_run
def set_current_run(spec):
"""Set the current run id."""
current_run["id"] = hash(id(spec))
current_run["spec"] = spec
return current_run
+3 -1
View File
@@ -62,12 +62,14 @@ async def perform_single_shot_scan(
) -> AsyncGenerator[str, None]:
"""Perform a standard security scan."""
selected_datasets = [m for m in datasets if m["selected"]]
try:
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
dataset_names=[m["dataset_name"] for m in selected_datasets],
budget=max_budget,
tools_inbox=tools_inbox,
options=[m.get("opts", {}) for m in selected_datasets],
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
+4
View File
@@ -180,6 +180,10 @@ REGISTRY = [
"selected": False,
"url": "https://github.com/leondz/garak2",
"dynamic": True,
"opts": {
"port": 8718,
"modules": ["encoding"],
},
},
{
"dataset_name": "InspectAI",
+13 -10
View File
@@ -213,7 +213,7 @@ def load_generic_csv(url, name, column="prompt", predicator=None):
)
def prepare_prompts(dataset_names, budget, tools_inbox=None):
def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
# ## Datasets used and cleaned:
# markush1/LLM-Jailbreak-Classifier
# 1. Open-Orca/OpenOrca
@@ -255,28 +255,31 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None):
logger.error(f"Error loading {dataset_name}: {e}")
dynamic_datasets = {
"Steganography": lambda: Stenography(group),
"llm-adaptive-attacks": lambda: dataset_from_iterator(
"llm-adaptive-attacks", adaptive_attacks.Module(group).apply()
"Steganography": lambda opts: Stenography(group),
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
"llm-adaptive-attacks",
adaptive_attacks.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
),
"Garak": lambda: dataset_from_iterator(
"Garak": lambda opts: dataset_from_iterator(
"Garak",
garak_tool.Module(group, tools_inbox=tools_inbox).apply(),
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"InspectAI": lambda: dataset_from_iterator(
"InspectAI": lambda opts: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
lazy=True,
),
"GPT fuzzer": lambda: [],
"GPT fuzzer": lambda opts: [],
}
dynamic_groups = []
for dataset_name in dataset_names:
options = options or [{} for _ in dataset_names]
for dataset_name, opts in zip(dataset_names, options):
if dataset_name in dynamic_datasets:
logger.info(f"Loading {dataset_name}")
ds = dynamic_datasets[dataset_name]()
ds = dynamic_datasets[dataset_name](opts)
for g in ds:
dynamic_groups.append(g)
@@ -3,12 +3,13 @@ import io
import httpx
import pandas as pd
from loguru import logger
import asyncio
url = "https://raw.githubusercontent.com/tml-epfl/llm-adaptive-attacks/main/harmful_behaviors/harmful_behaviors_pair.csv"
class Module:
def __init__(self, prompt_groups: []):
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue, opts: dict = {}):
r = httpx.get(url)
content = r.content
@@ -11,21 +11,22 @@ from loguru import logger
GARAK_CONFIG = "garak_rest.json"
def write_garak_config_json():
def write_garak_config_json(port):
with open(GARAK_CONFIG, "w") as f:
f.write(json.dumps(SPEC))
f.write(json.dumps(SPEC, indent=4).replace("$PORT", str(port)))
# TODO: add config params to data registry
class Module:
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue):
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue, opts: dict = {}):
self.tools_inbox = tools_inbox
if not self.is_garak_installed():
logger.error(
"Garak module is not installed. Please install it using 'pip install garak'"
)
self.opts = opts
def is_garak_installed(self) -> bool:
garak_spec = importlib.util.find_spec("garak")
@@ -46,7 +47,7 @@ class Module:
"encoding",
]
logger.info("Starting Garak tool. Writing config file.")
write_garak_config_json()
write_garak_config_json(port=self.opts.get("port", 8718))
logger.info(f"Executing command: {command}")
# Execute the command with the specific environment
process = subprocess.Popen(
@@ -63,7 +64,6 @@ class Module:
ready.set()
logger.info("Garak tool finished.")
logger.info(f"stdout: {out}")
logger.error(f"exit code: {process.returncode}")
if process.returncode != 0:
logger.error(f"Error executing command: {command}")
logger.error(f"err: {err}")
@@ -73,8 +73,8 @@ class Module:
SPEC = {
"rest": {
"RestGenerator": {
"name": "OpenAI GPT-4 Service",
"uri": "http://0.0.0.0:8718/proxy/chat/completions",
"name": "Agentic Security Proxy Service",
"uri": "http://0.0.0.0:$PORT/proxy/chat/completions",
"method": "POST",
"headers": {
"Authorization": "Bearer $OPENAI_API_KEY",
@@ -1,13 +0,0 @@
from inspect_ai import Task, eval, task
from inspect_ai.dataset import example_dataset
from inspect_ai.scorer import model_graded_fact
from inspect_ai.solver import chain_of_thought, generate, self_critique
@task
def theory_of_mind():
return Task(
dataset=example_dataset("theory_of_mind"),
plan=[chain_of_thought(), generate(), self_critique()],
scorer=model_graded_fact(),
)
@@ -14,7 +14,7 @@ inspect_ai_task = (
class Module:
name = "Inspect AI"
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue):
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue, opts: dict = {}):
self.tools_inbox = tools_inbox
if not self.is_tool_installed():
logger.error(
@@ -7,7 +7,7 @@ class TestModule:
# Module can be initialized with a list of prompt groups.
def test_initialize_with_prompt_groups(self):
prompt_groups = []
module = Module(prompt_groups)
module = Module(prompt_groups, None, {})
assert module is not None
assert isinstance(module, Module)
assert len(module.goals) == snapshot(50)
+22 -18
View File
@@ -2,8 +2,9 @@ import random
from asyncio import Event
from fastapi import APIRouter
from loguru import logger
from ..core.app import get_tools_inbox
from ..core.app import get_current_run, get_tools_inbox
from ..models.schemas import CompletionRequest, Settings
from ..probe_actor.refusal import REFUSAL_MARKS
@@ -18,6 +19,7 @@ async def proxy_completions(request: CompletionRequest):
[msg.content for msg in request.messages if msg.role == "user"]
)
# Todo: get current llm spec for proper proxing
request_factory = get_current_run()["spec"]
message = prompt_content + " " + message
ready = Event()
ref = dict(message=message, reply="", ready=ready)
@@ -29,20 +31,22 @@ async def proxy_completions(request: CompletionRequest):
await ready.wait()
reply = ref["reply"]
return reply
# Simulate a completion response
return {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1677858242,
"model": "gpt-3.5-turbo-0613",
"usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
"choices": [
{
"message": {"role": "assistant", "content": message},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
elif not request_factory:
logger.debug("No request factory found. Using mock response.")
return {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1677858242,
"model": "gpt-3.5-turbo-0613",
"usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
"choices": [
{
"message": {"role": "assistant", "content": message},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
else:
return await request_factory.fn(prompt_content)
+2 -1
View File
@@ -3,7 +3,7 @@ from datetime import datetime
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi.responses import StreamingResponse
from ..core.app import get_stop_event, get_tools_inbox
from ..core.app import get_stop_event, get_tools_inbox, set_current_run
from ..http_spec import LLMSpec
from ..models.schemas import LLMInfo, Scan
from ..probe_actor import fuzzer
@@ -27,6 +27,7 @@ async def verify(info: LLMInfo):
def streaming_response_generator(scan_parameters: Scan):
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
set_current_run(request_factory)
async def _gen():
async for scan_result in fuzzer.scan_router(
+1
View File
@@ -88,6 +88,7 @@ class TestAS:
"selected": True,
"url": "https://github.com/leondz/garak2",
"dynamic": True,
"opts": {"port": 9094},
},
]
result = AgenticSecurity.scan(llmSpec, maxBudget, datasets, max_th)