feat(Add backend):

This commit is contained in:
Alexander Myasoedov
2025-01-03 00:07:10 +02:00
parent 9de34e2835
commit ffd7d710f1
7 changed files with 145 additions and 17 deletions
+16 -12
View File
@@ -1,6 +1,21 @@
from .data import load_local_csv
REGISTRY = [
{
"dataset_name": "AgenticBackend",
"num_prompts": 2000,
"tokens": 0,
"approx_cost": 0.0,
"source": "Fine-tuned cloud hosted model",
"selected": True,
"url": "Cloud",
"dynamic": False,
"opts": {
"port": 8718,
"modules": ["encoding"],
},
"modality": "text",
},
{
"dataset_name": "ShawnMenz/DAN_jailbreak",
"num_prompts": 666,
@@ -73,7 +88,7 @@ REGISTRY = [
"tokens": 1975800,
"approx_cost": 0.0,
"source": "Hugging Face Datasets",
"selected": True,
"selected": False,
"dynamic": False,
"url": "https://huggingface.co/JailbreakV-28K/JailBreakV-28k",
"modality": "text",
@@ -111,17 +126,6 @@ REGISTRY = [
"url": "",
"modality": "text",
},
{
"dataset_name": "Agentic Security",
"num_prompts": 0,
"tokens": 0,
"approx_cost": 0.0,
"source": "Local dataset",
"selected": False,
"dynamic": True,
"url": "",
"modality": "text",
},
{
"dataset_name": "jailbreak_llms/2023_05_07",
"num_prompts": 0,
+7 -1
View File
@@ -12,6 +12,7 @@ from agentic_security.probe_data import stenography_fn
from agentic_security.probe_data.models import ProbeDataset
from agentic_security.probe_data.modules import (
adaptive_attacks,
fine_tuned,
garak_tool,
inspect_ai_tool,
)
@@ -23,7 +24,7 @@ def count_words_in_list(str_list):
:param str_list: List of strings
:return: Total number of words across all strings in the list
"""
total_words = sum(len(s.split()) for s in str_list)
total_words = sum(len(str(s).split()) for s in str_list)
return total_words
@@ -237,6 +238,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
logger.error(f"Error loading {dataset_name}: {e}")
dynamic_datasets = {
"AgenticBackend": lambda opts: dataset_from_iterator(
"AgenticBackend",
fine_tuned.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Steganography": lambda opts: Stenography(group),
"llm-adaptive-attacks": lambda opts: dataset_from_iterator(
"llm-adaptive-attacks",
@@ -0,0 +1,84 @@
import asyncio
import httpx
from typing import List
import uuid as U
from loguru import logger
class Module:
def __init__(
self, prompt_groups: List[str], tools_inbox: asyncio.Queue, opts: dict = {}
):
self.tools_inbox = tools_inbox
self.opts = opts
self.prompt_groups = prompt_groups
self.max_prompts = self.opts.get("max_prompts", 2000) # Default max M prompts
self.run_id = U.uuid4().hex
self.batch_size = self.opts.get("batch_size", 500)
async def apply(self):
for _ in range(self.max_prompts // self.batch_size):
# Fetch prompts from the API
prompts = await self.fetch_prompts()
if not prompts:
logger.error("No prompts retrieved from the API.")
return
logger.info(f"Retrieved {len(prompts)} prompts.")
for i, prompt in enumerate(
prompts[: self.max_prompts]
): # Limit to max_prompts
logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}")
# response = await self.post_prompt(prompt)
# logger.info(f"Response: {response}")
yield prompt
while not self.tools_inbox.empty():
ref = await self.tools_inbox.get()
message, _, ready = ref["message"], ref["reply"], ref["ready"]
yield message
ready.set()
async def post_prompt(self, prompt: str):
port = self.opts.get("port", 8718)
uri = f"http://0.0.0.0:{port}/proxy/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": "gpt-4",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1050,
"temperature": 0.7,
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(uri, headers=headers, json=data)
response.raise_for_status()
return response.json()
except httpx.RequestError as e:
logger.error(f"Failed to post prompt: {e}")
return {}
async def fetch_prompts(self) -> List[str]:
api_url = "https://msoedov--agesec-backend-fastapi-app.modal.run/infer"
headers = {
"Authorization": "Bearer gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(
api_url,
headers=headers,
json={"batch_size": self.batch_size, "run_id": self.run_id},
)
response.raise_for_status()
data = response.json()
return data.get("prompts", [])
except httpx.RequestError as e:
logger.error(f"Failed to fetch prompts: {e}")
return []