mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(add modality adapter):
This commit is contained in:
@@ -10,6 +10,7 @@ class Modality(Enum):
|
||||
IMAGE = 1
|
||||
AUDIO = 2
|
||||
FILES = 3
|
||||
MIXED = 4
|
||||
|
||||
|
||||
def encode_image_base64_by_url(url: str = "https://github.com/fluidicon.png") -> str:
|
||||
@@ -107,6 +108,7 @@ class LLMSpec(BaseModel):
|
||||
case LLMSpec(has_audio=True):
|
||||
return await self.probe(
|
||||
"test",
|
||||
# TODO: fix url for mp3
|
||||
encoded_audio=encode_audio_base64_by_url(
|
||||
"https://www.example.com/audio.mp3"
|
||||
),
|
||||
|
||||
@@ -18,7 +18,7 @@ from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator, modality: Modality = Modality.TEXT
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if isinstance(prompts, list):
|
||||
for prompt in prompts:
|
||||
@@ -28,6 +28,20 @@ async def generate_prompts(
|
||||
yield prompt
|
||||
|
||||
|
||||
def multi_modality_spec(llm_spec):
|
||||
match llm_spec.modality:
|
||||
case Modality.IMAGE:
|
||||
return image_generator.RequestAdapter(llm_spec)
|
||||
case Modality.AUDIO:
|
||||
return audio_generator.RequestAdapter(llm_spec)
|
||||
case Modality.TEXT:
|
||||
return llm_spec
|
||||
case _:
|
||||
return llm_spec
|
||||
# case _:
|
||||
# raise NotImplementedError(f"Modality {llm_spec.modality} not supported yet")
|
||||
|
||||
|
||||
async def process_prompt(
|
||||
request_factory, prompt, tokens, module_name, refusals, errors
|
||||
):
|
||||
@@ -69,6 +83,7 @@ async def perform_single_shot_scan(
|
||||
"""Perform a standard security scan."""
|
||||
max_budget = max_budget * 100_000_000
|
||||
selected_datasets = [m for m in datasets if m["selected"]]
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
try:
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
prompt_modules = prepare_prompts(
|
||||
@@ -102,9 +117,7 @@ async def perform_single_shot_scan(
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in generate_prompts(
|
||||
module.prompts, modality=request_factory.modality
|
||||
):
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
@@ -186,6 +199,7 @@ async def perform_many_shot_scan(
|
||||
max_ctx_length: int = 10_000,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Perform a multi-step security scan with probe injection."""
|
||||
request_factory = multi_modality_spec(request_factory)
|
||||
try:
|
||||
# Load main and probe datasets
|
||||
yield ScanResult.status_msg("Loading datasets...")
|
||||
@@ -215,9 +229,7 @@ async def perform_many_shot_scan(
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in generate_prompts(
|
||||
module.prompts, modality=request_factory.modality
|
||||
):
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import base64
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from cache_to_disk import cache_to_disk
|
||||
|
||||
|
||||
def encode(content: bytes) -> str:
|
||||
encoded_content = base64.b64encode(content).decode("utf-8")
|
||||
return "data:audio/mpeg;base64," + encoded_content
|
||||
|
||||
|
||||
def generate_audio_mac_wav(prompt: str) -> bytes:
|
||||
"""
|
||||
Generate an audio file from the provided prompt using macOS 'say' command
|
||||
@@ -64,3 +71,21 @@ def generate_audioform(prompt: str) -> bytes:
|
||||
raise NotImplementedError(
|
||||
"Audio generation is only supported on macOS for now."
|
||||
)
|
||||
|
||||
|
||||
class RequestAdapter:
|
||||
# Adapter of http_spec.LLMSpec
|
||||
|
||||
def __init__(self, llm_spec):
|
||||
self.llm_spec = llm_spec
|
||||
if not llm_spec.has_audio:
|
||||
raise ValueError("LLMSpec must have an image")
|
||||
|
||||
async def probe(
|
||||
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
|
||||
) -> httpx.Response:
|
||||
encoded_audio = generate_audioform(prompt)
|
||||
encoded_audio = encode(encoded_audio)
|
||||
return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files)
|
||||
|
||||
fn = probe
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import matplotlib.pyplot as plt
|
||||
from cache_to_disk import cache_to_disk
|
||||
from tqdm import tqdm
|
||||
@@ -75,3 +77,26 @@ def generate_image(prompt: str) -> bytes:
|
||||
|
||||
# Return the image bytes
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def encode(image: bytes) -> str:
|
||||
encoded_content = base64.b64encode(image).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + encoded_content
|
||||
|
||||
|
||||
class RequestAdapter:
|
||||
# Adapter of http_spec.LLMSpec
|
||||
|
||||
def __init__(self, llm_spec):
|
||||
self.llm_spec = llm_spec
|
||||
if not llm_spec.has_image:
|
||||
raise ValueError("LLMSpec must have an image")
|
||||
|
||||
async def probe(
|
||||
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
|
||||
) -> httpx.Response:
|
||||
encoded_image = generate_image(prompt)
|
||||
encoded_image = encode(encoded_image)
|
||||
return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files)
|
||||
|
||||
fn = probe
|
||||
|
||||
@@ -20,7 +20,7 @@ class Module:
|
||||
self.batch_size = self.opts.get("batch_size", 500)
|
||||
|
||||
async def apply(self):
|
||||
for _ in range(self.max_prompts // self.batch_size):
|
||||
for _ in range(max(self.max_prompts // self.batch_size, 1)):
|
||||
# Fetch prompts from the API
|
||||
prompts = await self.fetch_prompts()
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class TestAS:
|
||||
print(result)
|
||||
assert len(result) in [0, 1]
|
||||
|
||||
def _test_image_modality(self):
|
||||
def test_image_modality(self):
|
||||
llmSpec = test_spec_assets.IMAGE_SPEC
|
||||
maxBudget = 2
|
||||
max_th = 0.3
|
||||
@@ -139,8 +139,10 @@ class TestAS:
|
||||
"url": "",
|
||||
"dynamic": True,
|
||||
"opts": {
|
||||
# "port": 8718,
|
||||
"port": 9094,
|
||||
"modules": ["encoding"],
|
||||
"max_prompts": 2,
|
||||
},
|
||||
"modality": "text",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user