refactor(generate_prompts):

This commit is contained in:
Alexander Myasoedov
2024-12-11 17:46:37 +02:00
parent 6df0ba5d52
commit b5ecc28ab6
2 changed files with 14 additions and 12 deletions
+9 -7
View File
@@ -14,13 +14,15 @@ from agentic_security.probe_data import msj_data
from agentic_security.probe_data.data import prepare_prompts
async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str, None]:
async def generate_prompts(
prompts: list[str] | AsyncGenerator,
) -> AsyncGenerator[str, None]:
if isinstance(prompts, list):
for p in prompts:
yield p
for prompt in prompts:
yield prompt
else:
async for p in prompts:
yield p
async for prompt in prompts:
yield prompt
async def perform_single_shot_scan(
@@ -60,7 +62,7 @@ async def perform_single_shot_scan(
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in prompt_iter(module.prompts):
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
@@ -175,7 +177,7 @@ async def perform_many_shot_scan(
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
async for prompt in prompt_iter(module.prompts):
async for prompt in generate_prompts(module.prompts):
if stop_event and stop_event.is_set():
stop_event.clear()
logger.info("Scan stopped by user.")
+5 -5
View File
@@ -7,25 +7,25 @@ from agentic_security.models.schemas import Scan
from agentic_security.probe_actor.fuzzer import (
perform_many_shot_scan,
perform_single_shot_scan,
prompt_iter,
generate_prompts,
scan_router,
)
@pytest.mark.asyncio
async def test_prompt_iter_with_list():
async def test_generate_prompts_with_list():
prompts = ["prompt1", "prompt2", "prompt3"]
results = [p async for p in prompt_iter(prompts)]
results = [p async for p in generate_prompts(prompts)]
assert results == prompts
@pytest.mark.asyncio
async def test_prompt_iter_with_async_generator():
async def test_generate_prompts_with_async_generator():
async def async_gen():
for i in range(3):
yield f"prompt{i}"
results = [p async for p in prompt_iter(async_gen())]
results = [p async for p in generate_prompts(async_gen())]
assert results == ["prompt0", "prompt1", "prompt2"]