feat(move optimizer to module lvl):

This commit is contained in:
Alexander Myasoedov
2025-03-12 19:45:27 +02:00
parent bd6d2f3db1
commit ac3f2f803c
+13 -26
View File
@@ -153,7 +153,6 @@ async def scan_module(
max_budget: int = 0,
total_tokens: int = 0,
optimize: bool = False,
optimizer=None,
stop_event: asyncio.Event | None = None,
) -> AsyncGenerator[dict[str, Any], None]:
"""
@@ -168,7 +167,6 @@ async def scan_module(
max_budget: Maximum token budget
total_tokens: Current token count
optimize: Whether to use optimization
optimizer: The optimizer to use
stop_event: Event to stop scanning
Yields:
@@ -180,6 +178,15 @@ async def scan_module(
failure_rates = []
should_stop = False
# Initialize optimizer if optimization is enabled
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
)
module_size = 0 if module.lazy else len(module.prompts)
logger.info(f"Scanning {module.dataset_name} {module_size}")
@@ -313,14 +320,6 @@ async def perform_single_shot_scan(
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
)
total_tokens = 0
for module in prompt_modules:
module_gen = scan_module(
@@ -332,7 +331,6 @@ async def perform_single_shot_scan(
max_budget=max_budget,
total_tokens=total_tokens,
optimize=optimize,
optimizer=optimizer,
stop_event=stop_event,
)
try:
@@ -396,13 +394,6 @@ async def perform_many_shot_scan(
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
processed_prompts = 0
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
)
failure_rates = []
for module in prompt_modules:
@@ -466,14 +457,10 @@ async def perform_many_shot_scan(
).model_dump_json()
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
next_point = optimizer.ask()
optimizer.tell(next_point, -failure_rate)
best_failure_rate = -optimizer.get_result().fun
if best_failure_rate > FAILURE_RATE_THRESHOLD:
yield ScanResult.status_msg(
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
)
break
yield ScanResult.status_msg(
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
)
break
yield ScanResult.status_msg("Scan completed.")
fuzzer_state.export_failures("failures.csv")