mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-02-13 15:12:51 +00:00
102 lines
4.2 KiB
Python
102 lines
4.2 KiB
Python
import inspect
|
|
import os
|
|
from typing import List, Callable
|
|
from src.text_generation.common.model_id import ModelId
|
|
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
|
|
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
|
from src.text_generation.services.logging.test_run_logging_service import TestRunLoggingService
|
|
from src.text_generation.services.nlp.abstract_semantic_similarity_service import AbstractSemanticSimilarityService
|
|
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
|
|
from src.text_generation.services.nlp.text_generation_completion_service import TextGenerationCompletionService
|
|
|
|
|
|
|
|
def get_prompt_batch(prompts: List[str], batch_size=10, env_var='PROMPT_BATCH'):
|
|
|
|
batch_size = int(os.getenv('BATCH_SIZE', '2'))
|
|
batch_num = int(os.getenv('PROMPT_BATCH', '1'))
|
|
|
|
if 'BATCH_OFFSET' in os.environ:
|
|
# Option 1: Fixed offset per workflow
|
|
offset = int(os.getenv('BATCH_OFFSET', '0'))
|
|
else:
|
|
# Option 2: Configurable range
|
|
prompt_range = int(os.getenv('PROMPT_RANGE', '1'))
|
|
offset = (prompt_range - 1) * 20
|
|
|
|
# Calculate start and end indices
|
|
start_idx = offset + (batch_num - 1) * batch_size
|
|
end_idx = min(start_idx + batch_size, len(prompts))
|
|
|
|
# Get the subset of prompts for this batch
|
|
prompt_subset: List[str] = prompts[start_idx:end_idx]
|
|
|
|
print(f"Running batch {batch_num} (range offset {offset}): prompts {start_idx+1}-{end_idx} ({len(prompt_subset)} prompts)")
|
|
|
|
return prompt_subset, (start_idx+1), end_idx
|
|
|
|
|
|
def run_prompt_analysis_test(
|
|
test_id: int,
|
|
model_id: ModelId,
|
|
text_generation_completion_service: AbstractTextGenerationCompletionService,
|
|
semantic_similarity_service: AbstractSemanticSimilarityService,
|
|
prompts: List,
|
|
start: int,
|
|
end: int,
|
|
comparison_texts: List,
|
|
service_configurator: Callable,
|
|
max_prompts: int = 100
|
|
) -> List[SemanticSimilarityResult]:
|
|
"""
|
|
Common utility for running prompt analysis tests.
|
|
|
|
Args:
|
|
test_id: Unique identifier for the test
|
|
text_generation_completion_service: Service for text generation
|
|
semantic_similarity_service: Service for semantic analysis
|
|
prompts: List of prompts to test
|
|
comparison_texts: Reference texts for similarity comparison
|
|
service_configurator: Function that configures the text generation service
|
|
max_prompts: Maximum number of prompts to process (default: 100)
|
|
|
|
Returns:
|
|
List of SemanticSimilarityResult objects
|
|
"""
|
|
semantic_similarity_service.use_comparison_texts(comparison_texts)
|
|
print(f'using {len(prompts)} prompts for testing...')
|
|
|
|
results = []
|
|
|
|
for i, prompt in enumerate(prompts[:max_prompts], 1):
|
|
# Configure the service using the provided configurator function
|
|
configured_service: TextGenerationCompletionService = service_configurator(text_generation_completion_service)
|
|
|
|
print(f'sending prompt {i} to LLM')
|
|
completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt, model_id=model_id)
|
|
|
|
print(f'analyzing result')
|
|
completion_result.finalize_completion_text()
|
|
result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final_completion_text)
|
|
|
|
print(f'{i}/{len(prompts)} Max Score: {result.max}')
|
|
print(f'{i}/{len(prompts)} Avg Score: {result.mean}')
|
|
|
|
TestRunLoggingService(
|
|
test_id=test_id,
|
|
model_id=model_id,
|
|
start=start,
|
|
end=end
|
|
).log_results(
|
|
id=inspect.currentframe().f_back.f_code.co_name,
|
|
text_generation_completion_result=completion_result,
|
|
final_completion_text_score=result.max,
|
|
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
|
|
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
|
|
is_reflexion_enabled=False,
|
|
original_llm_config=completion_result.original_result.llm_config
|
|
)
|
|
|
|
results.append(result)
|
|
|
|
return results |