diff --git a/agentic_security/http_spec.py b/agentic_security/http_spec.py index 5ac75b2..a41cf85 100644 --- a/agentic_security/http_spec.py +++ b/agentic_security/http_spec.py @@ -10,6 +10,7 @@ class Modality(Enum): IMAGE = 1 AUDIO = 2 FILES = 3 + MIXED = 4 def encode_image_base64_by_url(url: str = "https://github.com/fluidicon.png") -> str: @@ -107,6 +108,7 @@ class LLMSpec(BaseModel): case LLMSpec(has_audio=True): return await self.probe( "test", + # TODO: fix url for mp3 encoded_audio=encode_audio_base64_by_url( "https://www.example.com/audio.mp3" ), diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index a461759..cbbc697 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -18,7 +18,7 @@ from agentic_security.probe_data.data import prepare_prompts async def generate_prompts( - prompts: list[str] | AsyncGenerator, modality: Modality = Modality.TEXT + prompts: list[str] | AsyncGenerator, ) -> AsyncGenerator[str, None]: if isinstance(prompts, list): for prompt in prompts: @@ -28,6 +28,20 @@ async def generate_prompts( yield prompt +def multi_modality_spec(llm_spec): + match llm_spec.modality: + case Modality.IMAGE: + return image_generator.RequestAdapter(llm_spec) + case Modality.AUDIO: + return audio_generator.RequestAdapter(llm_spec) + case Modality.TEXT: + return llm_spec + case _: + return llm_spec + # case _: + # raise NotImplementedError(f"Modality {llm_spec.modality} not supported yet") + + async def process_prompt( request_factory, prompt, tokens, module_name, refusals, errors ): @@ -69,6 +83,7 @@ async def perform_single_shot_scan( """Perform a standard security scan.""" max_budget = max_budget * 100_000_000 selected_datasets = [m for m in datasets if m["selected"]] + request_factory = multi_modality_spec(request_factory) try: yield ScanResult.status_msg("Loading datasets...") prompt_modules = prepare_prompts( @@ -102,9 +117,7 @@ async def perform_single_shot_scan( module_size = 0 if module.lazy else len(module.prompts) logger.info(f"Scanning {module.dataset_name} {module_size}") - async for prompt in generate_prompts( - module.prompts, modality=request_factory.modality - ): + async for prompt in generate_prompts(module.prompts): if stop_event and stop_event.is_set(): stop_event.clear() logger.info("Scan stopped by user.") @@ -186,6 +199,7 @@ async def perform_many_shot_scan( max_ctx_length: int = 10_000, ) -> AsyncGenerator[str, None]: """Perform a multi-step security scan with probe injection.""" + request_factory = multi_modality_spec(request_factory) try: # Load main and probe datasets yield ScanResult.status_msg("Loading datasets...") @@ -215,9 +229,7 @@ async def perform_many_shot_scan( module_size = 0 if module.lazy else len(module.prompts) logger.info(f"Scanning {module.dataset_name} {module_size}") - async for prompt in generate_prompts( - module.prompts, modality=request_factory.modality - ): + async for prompt in generate_prompts(module.prompts): if stop_event and stop_event.is_set(): stop_event.clear() logger.info("Scan stopped by user.") diff --git a/agentic_security/probe_data/audio_generator.py b/agentic_security/probe_data/audio_generator.py index 64c883b..0278da9 100644 --- a/agentic_security/probe_data/audio_generator.py +++ b/agentic_security/probe_data/audio_generator.py @@ -1,11 +1,18 @@ +import base64 import os import platform import subprocess import uuid +import httpx from cache_to_disk import cache_to_disk +def encode(content: bytes) -> str: + encoded_content = base64.b64encode(content).decode("utf-8") + return "data:audio/mpeg;base64," + encoded_content + + def generate_audio_mac_wav(prompt: str) -> bytes: """ Generate an audio file from the provided prompt using macOS 'say' command @@ -64,3 +71,21 @@ def generate_audioform(prompt: str) -> bytes: raise NotImplementedError( "Audio generation is only supported on macOS for now." ) + + +class RequestAdapter: + # Adapter of http_spec.LLMSpec + + def __init__(self, llm_spec): + self.llm_spec = llm_spec + if not llm_spec.has_audio: + raise ValueError("LLMSpec must have an image") + + async def probe( + self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={} + ) -> httpx.Response: + encoded_audio = generate_audioform(prompt) + encoded_audio = encode(encoded_audio) + return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files) + + fn = probe diff --git a/agentic_security/probe_data/image_generator.py b/agentic_security/probe_data/image_generator.py index dd5ae00..f00ef46 100644 --- a/agentic_security/probe_data/image_generator.py +++ b/agentic_security/probe_data/image_generator.py @@ -1,5 +1,7 @@ +import base64 import io +import httpx import matplotlib.pyplot as plt from cache_to_disk import cache_to_disk from tqdm import tqdm @@ -75,3 +77,26 @@ def generate_image(prompt: str) -> bytes: # Return the image bytes return buffer.getvalue() + + +def encode(image: bytes) -> str: + encoded_content = base64.b64encode(image).decode("utf-8") + return "data:image/jpeg;base64," + encoded_content + + +class RequestAdapter: + # Adapter of http_spec.LLMSpec + + def __init__(self, llm_spec): + self.llm_spec = llm_spec + if not llm_spec.has_image: + raise ValueError("LLMSpec must have an image") + + async def probe( + self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={} + ) -> httpx.Response: + encoded_image = generate_image(prompt) + encoded_image = encode(encoded_image) + return await self.llm_spec.probe(prompt, encoded_image, encoded_audio, files) + + fn = probe diff --git a/agentic_security/probe_data/modules/fine_tuned.py b/agentic_security/probe_data/modules/fine_tuned.py index f098021..fa01951 100644 --- a/agentic_security/probe_data/modules/fine_tuned.py +++ b/agentic_security/probe_data/modules/fine_tuned.py @@ -20,7 +20,7 @@ class Module: self.batch_size = self.opts.get("batch_size", 500) async def apply(self): - for _ in range(self.max_prompts // self.batch_size): + for _ in range(max(self.max_prompts // self.batch_size, 1)): # Fetch prompts from the API prompts = await self.fetch_prompts() diff --git a/agentic_security/test_lib.py b/agentic_security/test_lib.py index 946ac57..e87de6d 100644 --- a/agentic_security/test_lib.py +++ b/agentic_security/test_lib.py @@ -124,7 +124,7 @@ class TestAS: print(result) assert len(result) in [0, 1] - def _test_image_modality(self): + def test_image_modality(self): llmSpec = test_spec_assets.IMAGE_SPEC maxBudget = 2 max_th = 0.3 @@ -139,8 +139,10 @@ class TestAS: "url": "", "dynamic": True, "opts": { + # "port": 8718, "port": 9094, "modules": ["encoding"], + "max_prompts": 2, }, "modality": "text", },