feat(add more image generation variants):

This commit is contained in:
Alexander Myasoedov
2025-02-14 11:10:37 +02:00
parent 1ba6c588d7
commit 3ae4f34bdf
3 changed files with 57 additions and 15 deletions
+1
View File
@@ -15,3 +15,4 @@ garak_rest.json
2025.*.json
inv/
scripts/
docx/
+52 -13
View File
@@ -38,12 +38,13 @@ def generate_image_dataset(
@cache_to_disk()
def generate_image(prompt: str) -> bytes:
def generate_image(prompt: str, variant: int = 0) -> bytes:
"""
Generate an image based on the provided prompt and return it as bytes.
Parameters:
prompt (str): Text to display on the generated image.
variant (int): The variant style of the image.
Returns:
bytes: The image data in JPG format.
@@ -51,18 +52,56 @@ def generate_image(prompt: str) -> bytes:
# Create a matplotlib figure
fig, ax = plt.subplots(figsize=(6, 4))
# Customize the plot (background color, text, etc.)
ax.set_facecolor("lightblue")
ax.text(
0.5,
0.5,
prompt,
fontsize=16,
ha="center",
va="center",
wrap=True,
color="darkblue",
)
# Customize the plot based on the variant
if variant == 1:
# Dark Theme
ax.set_facecolor("darkgray")
text_color = "white"
fontsize = 18
elif variant == 2:
# Artistic Theme
ax.set_facecolor("lightpink")
text_color = "black"
fontsize = 20
# Add a border around the text
ax.text(
0.5,
0.5,
prompt,
fontsize=fontsize,
ha="center",
va="center",
wrap=True,
color=text_color,
bbox=dict(
facecolor="lightyellow", edgecolor="black", boxstyle="round,pad=0.5"
),
)
elif variant == 3:
# Minimalist Theme
ax.set_facecolor("white")
text_color = "black"
fontsize = 14
# Add a simple geometric shape (circle) behind the text
circle = plt.Circle((0.5, 0.5), 0.3, color="lightblue", fill=True)
ax.add_artist(circle)
else:
# Default Theme
ax.set_facecolor("lightblue")
text_color = "darkblue"
fontsize = 16
if variant != 2:
ax.text(
0.5,
0.5,
prompt,
fontsize=fontsize,
ha="center",
va="center",
wrap=True,
color=text_color,
)
# Remove axes for a cleaner look
ax.axis("off")
@@ -1,4 +1,5 @@
from unittest.mock import patch
import pytest
from agentic_security.probe_data.image_generator import (
generate_image,
@@ -7,9 +8,10 @@ from agentic_security.probe_data.image_generator import (
from agentic_security.probe_data.models import ImageProbeDataset, ProbeDataset
def test_generate_image():
@pytest.mark.parametrize("variant", [0, 1, 2, 3])
def test_generate_image(variant):
prompt = "Test prompt"
image_bytes = generate_image(prompt)
image_bytes = generate_image(prompt, variant)
assert isinstance(image_bytes, bytes)
assert len(image_bytes) > 0