feat(add modality adapter):

This commit is contained in:
Alexander Myasoedov
2025-01-14 11:54:51 +02:00
parent 7c0d6f7eae
commit 386ff2aa15
6 changed files with 75 additions and 9 deletions
+2
View File
@@ -10,6 +10,7 @@ class Modality(Enum):
IMAGE = 1
AUDIO = 2
FILES = 3
MIXED = 4
def encode_image_base64_by_url(url: str = "https://github.com/fluidicon.png") -> str:
@@ -107,6 +108,7 @@ class LLMSpec(BaseModel):
case LLMSpec(has_audio=True):
return await self.probe(
"test",
# TODO: fix url for mp3
encoded_audio=encode_audio_base64_by_url(
"https://www.example.com/audio.mp3"
),
+19 -7
View File
@@ -18,7 +18,7 @@ from agentic_security.probe_data.data import prepare_prompts
async def generate_prompts(
prompts: list[str] | AsyncGenerator, modality: Modality = Modality.TEXT
prompts: list[str] | AsyncGenerator,
) -> AsyncGenerator[str, None]:
if isinstance(prompts, list):
for prompt in prompts:
@@ -28,6 +28,20 @@ async def generate_prompts(
yield prompt
def multi_modality_spec(llm_spec):
match llm_spec.modality:
case Modality.IMAGE:
return image_generator.RequestAdapter(llm_spec)
case Modality.AUDIO:
return audio_generator.RequestAdapter(llm_spec)
case Modality.TEXT:
return llm_spec
case _:
return llm_spec
# case _:
# raise NotImplementedError(f"Modality {llm_spec.modality} not supported yet")
async def process_prompt(
request_factory, prompt, tokens, module_name, refusals, errors
):
@@ -69,6 +83,7 @@ async def perform_single_shot_scan(
"""Perform a standard security scan."""
max_budget = max_budget * 100_000_000
selected_datasets = [m for m in datasets if m["selected"]]
request_factory = multi_modality_spec(request_factory)
try:
yield ScanResult.status_msg("Loading datasets...")
prompt_modules = prepare_prompts(
@@ -102,9 +117,7 @@ async def perform_single_shot_scan(
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in generate_prompts(
module.prompts, modality=request_factory.modality
):
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
@@ -186,6 +199,7 @@ async def perform_many_shot_scan(
max_ctx_length: int = 10_000,
) -> AsyncGenerator[str, None]:
"""Perform a multi-step security scan with probe injection."""
request_factory = multi_modality_spec(request_factory)
try:
# Load main and probe datasets
yield ScanResult.status_msg("Loading datasets...")
@@ -215,9 +229,7 @@ async def perform_many_shot_scan(
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in generate_prompts(
module.prompts, modality=request_factory.modality
):
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
@@ -1,11 +1,18 @@
import base64
import os
import platform
import subprocess
import uuid
import httpx
from cache_to_disk import cache_to_disk
def encode(content: bytes) -> str:
encoded_content = base64.b64encode(content).decode("utf-8")
return "data:audio/mpeg;base64," + encoded_content
def generate_audio_mac_wav(prompt: str) -> bytes:
"""
Generate an audio file from the provided prompt using macOS 'say' command
@@ -64,3 +71,21 @@ def generate_audioform(prompt: str) -> bytes:
raise NotImplementedError(
"Audio generation is only supported on macOS for now."
)
class RequestAdapter:
# Adapter of http_spec.LLMSpec
def __init__(self, llm_spec):
self.llm_spec = llm_spec
if not llm_spec.has_audio:
raise ValueError("LLMSpec must have an image")
async def probe(
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
) -> httpx.Response:
encoded_audio = generate_audioform(prompt)
encoded_audio = encode(encoded_audio)
return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files)
fn = probe
@@ -1,5 +1,7 @@
import base64
import io
import httpx
import matplotlib.pyplot as plt
from cache_to_disk import cache_to_disk
from tqdm import tqdm
@@ -75,3 +77,26 @@ def generate_image(prompt: str) -> bytes:
# Return the image bytes
return buffer.getvalue()
def encode(image: bytes) -> str:
encoded_content = base64.b64encode(image).decode("utf-8")
return "data:image/jpeg;base64," + encoded_content
class RequestAdapter:
# Adapter of http_spec.LLMSpec
def __init__(self, llm_spec):
self.llm_spec = llm_spec
if not llm_spec.has_image:
raise ValueError("LLMSpec must have an image")
async def probe(
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
) -> httpx.Response:
encoded_image = generate_image(prompt)
encoded_image = encode(encoded_image)
return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files)
fn = probe
@@ -20,7 +20,7 @@ class Module:
self.batch_size = self.opts.get("batch_size", 500)
async def apply(self):
for _ in range(self.max_prompts // self.batch_size):
for _ in range(max(self.max_prompts // self.batch_size, 1)):
# Fetch prompts from the API
prompts = await self.fetch_prompts()
+3 -1
View File
@@ -124,7 +124,7 @@ class TestAS:
print(result)
assert len(result) in [0, 1]
def _test_image_modality(self):
def test_image_modality(self):
llmSpec = test_spec_assets.IMAGE_SPEC
maxBudget = 2
max_th = 0.3
@@ -139,8 +139,10 @@ class TestAS:
"url": "",
"dynamic": True,
"opts": {
# "port": 8718,
"port": 9094,
"modules": ["encoding"],
"max_prompts": 2,
},
"modality": "text",
},