mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
refactor(generate_prompts):
This commit is contained in:
@@ -14,13 +14,15 @@ from agentic_security.probe_data import msj_data
|
||||
from agentic_security.probe_data.data import prepare_prompts
|
||||
|
||||
|
||||
async def prompt_iter(prompts: list[str] | AsyncGenerator) -> AsyncGenerator[str, None]:
|
||||
async def generate_prompts(
|
||||
prompts: list[str] | AsyncGenerator,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if isinstance(prompts, list):
|
||||
for p in prompts:
|
||||
yield p
|
||||
for prompt in prompts:
|
||||
yield prompt
|
||||
else:
|
||||
async for p in prompts:
|
||||
yield p
|
||||
async for prompt in prompts:
|
||||
yield prompt
|
||||
|
||||
|
||||
async def perform_single_shot_scan(
|
||||
@@ -60,7 +62,7 @@ async def perform_single_shot_scan(
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in prompt_iter(module.prompts):
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
@@ -175,7 +177,7 @@ async def perform_many_shot_scan(
|
||||
module_size = 0 if module.lazy else len(module.prompts)
|
||||
logger.info(f"Scanning {module.dataset_name} {module_size}")
|
||||
|
||||
async for prompt in prompt_iter(module.prompts):
|
||||
async for prompt in generate_prompts(module.prompts):
|
||||
if stop_event and stop_event.is_set():
|
||||
stop_event.clear()
|
||||
logger.info("Scan stopped by user.")
|
||||
|
||||
@@ -7,25 +7,25 @@ from agentic_security.models.schemas import Scan
|
||||
from agentic_security.probe_actor.fuzzer import (
|
||||
perform_many_shot_scan,
|
||||
perform_single_shot_scan,
|
||||
prompt_iter,
|
||||
generate_prompts,
|
||||
scan_router,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_iter_with_list():
|
||||
async def test_generate_prompts_with_list():
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
results = [p async for p in prompt_iter(prompts)]
|
||||
results = [p async for p in generate_prompts(prompts)]
|
||||
assert results == prompts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_iter_with_async_generator():
|
||||
async def test_generate_prompts_with_async_generator():
|
||||
async def async_gen():
|
||||
for i in range(3):
|
||||
yield f"prompt{i}"
|
||||
|
||||
results = [p async for p in prompt_iter(async_gen())]
|
||||
results = [p async for p in generate_prompts(async_gen())]
|
||||
assert results == ["prompt0", "prompt1", "prompt2"]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user