feat(improve audio modality generation):

This commit is contained in:
Alexander Myasoedov
2025-02-14 11:15:11 +02:00
parent 3ae4f34bdf
commit 05021e59f1
3 changed files with 41 additions and 4 deletions
+30 -2
View File
@@ -52,11 +52,37 @@ def generate_audio_mac_wav(prompt: str) -> bytes:
return audio_bytes
def generate_audio_cross_platform(prompt: str) -> bytes:
"""
Generate an audio file from the provided prompt using gTTS for cross-platform support.
Parameters:
prompt (str): Text to convert into audio.
Returns:
bytes: The audio data in MP3 format.
"""
from gtts import gTTS # Import gTTS for cross-platform support
tts = gTTS(text=prompt, lang="en")
temp_mp3_path = f"temp_audio_{uuid.uuid4().hex}.mp3"
tts.save(temp_mp3_path)
try:
with open(temp_mp3_path, "rb") as f:
audio_bytes = f.read()
finally:
if os.path.exists(temp_mp3_path):
os.remove(temp_mp3_path)
return audio_bytes
@cache_to_disk()
def generate_audioform(prompt: str) -> bytes:
"""
Generate an audio file from the provided prompt in WAV format.
Uses macOS 'say' command if the operating system is macOS.
Uses macOS 'say' command if the operating system is macOS, otherwise uses gTTS.
Parameters:
prompt (str): Text to convert into audio.
@@ -67,9 +93,11 @@ def generate_audioform(prompt: str) -> bytes:
current_os = platform.system()
if current_os == "Darwin": # macOS
return generate_audio_mac_wav(prompt)
elif current_os in ["Windows", "Linux"]:
return generate_audio_cross_platform(prompt)
else:
raise NotImplementedError(
"Audio generation is only supported on macOS for now."
"Audio generation is only supported on macOS, Windows, and Linux for now."
)
@@ -3,6 +3,7 @@ import platform
import pytest
from agentic_security.probe_data.audio_generator import (
generate_audio_cross_platform,
generate_audio_mac_wav,
generate_audioform,
)
@@ -24,6 +25,13 @@ def test_generate_audioform_mac():
audio_bytes = generate_audioform(prompt)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
def test_generate_audio_cross_platform():
if platform.system() in ["Windows", "Linux"]:
prompt = "This is a cross-platform test."
audio_bytes = generate_audio_cross_platform(prompt)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
else:
with pytest.raises(NotImplementedError):
generate_audioform("This should raise an error on non-macOS systems.")
pytest.skip("Test is only applicable on Windows and Linux.")
@@ -1,4 +1,5 @@
from unittest.mock import patch
import pytest
from agentic_security.probe_data.image_generator import (