support batch tests

This commit is contained in:
Adam Wilson
2025-08-19 20:09:34 -06:00
parent 3585f80414
commit cc124a91a3
12 changed files with 277 additions and 74 deletions
@@ -0,0 +1,83 @@
name: 'Reusable Test Runner | RAG + CoT | Generic'
on:
workflow_call:
inputs:
batch_offset:
description: 'Starting prompt index offset'
required: true
type: number
range_name:
description: 'Human readable range name (e.g., "1-20")'
required: true
type: string
test_file_path:
description: 'Path to the test file to run'
required: true
type: string
model_display_name:
description: 'Human readable model name for workflow title'
required: true
type: string
batch_size:
description: 'Number of prompts per batch'
required: false
type: number
default: 2
jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 55 # Set max runtime for each job
continue-on-error: true
strategy:
# Always 10 batches per workflow
matrix:
batch: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
fail-fast: false
steps:
- name: 'checkout'
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- name: 'set up Python'
uses: actions/setup-python@v3
with:
python-version: '3.12'
- name: 'set up Python dependencies'
shell: bash
run: |
pip install -r ${{ github.workspace }}/requirements.txt
- name: 'set up Microsoft Phi-3 Mini 4k LLM from HuggingFace'
shell: bash
run: |
pip install huggingface-hub[cli]
huggingface-cli download microsoft/Phi-3-mini-4k-instruct-onnx --include cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/* --local-dir ${{ github.workspace }}/infrastructure/foundation_model
- name: 'run text generation tests - ${{ inputs.model_display_name }} - range ${{ inputs.range_name }} batch ${{ matrix.batch }}'
shell: bash
env:
PROMPT_BATCH: ${{ matrix.batch }}
BATCH_SIZE: ${{ inputs.batch_size }}
BATCH_OFFSET: ${{ inputs.batch_offset }}
run: pytest ${{ inputs.test_file_path }} -s --disable-warnings
- name: Check for changes
id: verify-changed-files
run: |
if [ -n "$(git status --porcelain)" ]; then
echo "changed=true" >> $GITHUB_OUTPUT
else
echo "changed=false" >> $GITHUB_OUTPUT
fi
- name: Commit and push changes
if: steps.verify-changed-files.outputs.changed == 'true'
run: |
git config --local user.email "42450907+lightbroker@users.noreply.github.com"
git config --local user.name "Adam Wilson"
git add .
git add .
git commit -m "Auto-generated files from microsoft/Phi-3-mini-4k-instruct range 1-20 batch 4 [skip ci]"
git checkout -b auto-generated-$(date +%Y%m%d-%H%M%S)
git push origin HEAD
@@ -0,0 +1,14 @@
name: '#4 (1-20) | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_dispatch:
schedule:
# Workflow 1 - Starts at hours: 0, 5, 10, 15, 20 (every 5th hour starting from midnight)
- cron: '0 */5 * * *'
jobs:
test-prompts-1-20:
uses: ./.github/workflows/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.base.yml
with:
batch_offset: 0
range_name: "1-20"
batch_size: 2
@@ -0,0 +1,14 @@
name: '#4 (21-40) | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_dispatch:
schedule:
# Workflow 2 - Starts at hours: 1, 6, 11, 16, 21 (every 5th hour starting from 1 AM)
- cron: '0 1-23/5 * * *'
jobs:
test-prompts-21-40:
uses: ./.github/workflows/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.base.yml
with:
batch_offset: 20
range_name: "21-40"
batch_size: 2
@@ -0,0 +1,14 @@
name: '#4 (41-60) | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_dispatch:
schedule:
# Workflow 3 - Starts at hours: 2, 7, 12, 17, 22 (every 5th hour starting from 2 AM)
- cron: '0 2-23/5 * * *'
jobs:
test-prompts-41-60:
uses: ./.github/workflows/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.base.yml
with:
batch_offset: 40
range_name: "41-60"
batch_size: 2
@@ -0,0 +1,14 @@
name: '#4 (61-80) | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_dispatch:
schedule:
# Workflow 4 - Starts at hours: 3, 8, 13, 18, 23 (every 5th hour starting from 3 AM)
- cron: '0 3-23/5 * * *'
jobs:
test-prompts-61-80:
uses: ./.github/workflows/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.base.yml
with:
batch_offset: 60
range_name: "61-80"
batch_size: 2
@@ -0,0 +1,14 @@
name: '#4 (81-100) | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_dispatch:
schedule:
# Workflow 5 - Starts at hours: 4, 9, 14, 19 (every 5th hour starting from 4 AM)
- cron: '0 4-19/5 * * *'
jobs:
test-prompts-81-100:
uses: ./.github/workflows/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.base.yml
with:
batch_offset: 80
range_name: "81-100"
batch_size: 2
@@ -0,0 +1,27 @@
name: 'Reusable Test #4 | RAG + CoT | microsoft/Phi-3-mini-4k-instruct'
on:
workflow_call:
inputs:
batch_offset:
description: 'Starting prompt index offset'
required: true
type: number
range_name:
description: 'Human readable range name (e.g., "1-20")'
required: true
type: string
batch_size:
description: 'Number of prompts per batch'
required: false
type: number
default: 2
jobs:
test:
uses: ./.github/workflows/test_04.abstract_base.yml
with:
batch_offset: ${{ inputs.batch_offset }}
range_name: ${{ inputs.range_name }}
batch_size: ${{ inputs.batch_size }}
test_file_path: tests/integration/test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct.py
model_display_name: microsoft/Phi-3-mini-4k-instruct
+7
View File
@@ -0,0 +1,7 @@
from enum import Enum
class ModelId(Enum):
APPLE_OPENELM_1_1B_INSTRUCT = "apple/OpenELM-1_1B-Instruct"
META_LLAMA_3_2_3B_INSTRUCT = "meta-llama/Llama-3.2-3B-Instruct"
MICROSOFT_PHI_3_MINI4K_INSTRUCT = "microsoft/Phi-3-mini-4k-instruct"
@@ -7,17 +7,30 @@ import time
from datetime import datetime
from typing import Any, Dict, List
from src.text_generation.common.model_id import ModelId
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.services.logging.abstract_test_run_logging_service import AbstractTestRunLoggingService
class TestRunLoggingService(AbstractTestRunLoggingService):
def __init__(self, test_id: int):
def __init__(
self,
test_id: int,
model_id: ModelId,
start: int,
end: int
):
self._lock = threading.Lock()
timestamp = calendar.timegm(time.gmtime())
self.log_file_path = f"./tests/logs/test_{test_id}/test_{test_id}_logs_{timestamp}.json"
base_path = os.environ.get('TEST_RUNS', '.')
self.log_file_path = os.path.join(base_path, str(f"test_{test_id}/{str(model_id.value).replace("/", "_")}/{start}_{end}/test_{str(test_id).lower()}_logs_{timestamp}.json").lower())
# Ensure directory structure exists
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
self._ensure_log_file_exists()
def _ensure_log_file_exists(self):
if not os.path.exists(self.log_file_path):
with open(self.log_file_path, 'w') as f:
+23 -59
View File
@@ -18,11 +18,12 @@ from src.text_generation import config
from src.text_generation.adapters.embedding_model import EmbeddingModel
from src.text_generation.adapters.prompt_injection_example_repository import PromptInjectionExampleRepository
from src.text_generation.adapters.prompt_template_repository import PromptTemplateRepository
from src.text_generation.adapters.text_generation_foundation_model import TextGenerationFoundationModel
from src.text_generation.adapters.foundation_models.text_generation_foundation_model import TextGenerationFoundationModel
from src.text_generation.common.constants import Constants
from src.text_generation.services.guardrails.generated_text_guardrail_service import GeneratedTextGuardrailService
from src.text_generation.services.guardrails.reflexion_security_guidelines_service import ReflexionSecurityGuardrailsService
from src.text_generation.services.guidelines.chain_of_thought_security_guidelines_service import ChainOfThoughtSecurityGuidelinesService
from src.text_generation.services.guidelines.guidelines_factory import GuidelinesFactory
from src.text_generation.services.guidelines.rag_context_security_guidelines_configuration_builder import RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.guidelines.rag_context_security_guidelines_service import RagContextSecurityGuidelinesService
from src.text_generation.services.guidelines.rag_plus_cot_security_guidelines_service import RagPlusCotSecurityGuidelinesService
@@ -69,11 +70,14 @@ def pytest_deselected(items):
@pytest.fixture(scope="session", autouse=True)
def setup_test_environment():
"""Setup run before every test automatically."""
# Set test environment variables
os.environ["TESTING"] = "true"
os.environ["LOG_LEVEL"] = "DEBUG"
os.environ["PROMPT_TEMPLATES_DIR"] = "./infrastructure/prompt_templates"
os.environ["INJECTION_DATA_DIR"] = "./tests/security/tests/results/01_garak_no_guidelines"
os.environ["TEST_RUNS"] = "./tests/logs"
os.environ["MODEL_BASE_DIR"] = "./infrastructure/foundation_model"
os.environ["MODEL_CPU_DIR"] = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
os.environ["MODEL_DATA_FILENAME"] = "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data"
@@ -93,10 +97,6 @@ def setup_test_environment():
def constants():
return Constants()
@pytest.fixture(scope="session")
def foundation_model():
return TextGenerationFoundationModel()
@pytest.fixture(scope="session")
def embedding_model():
return EmbeddingModel()
@@ -128,48 +128,6 @@ def rag_config_builder(
def llm_configuration_introspection_service():
return LLMConfigurationIntrospectionService()
@pytest.fixture(scope="session")
def rag_context_guidelines(
foundation_model,
response_processing_service,
prompt_template_service,
llm_configuration_introspection_service,
rag_config_builder):
return RagContextSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=rag_config_builder
)
@pytest.fixture(scope="session")
def chain_of_thought_guidelines(
foundation_model,
response_processing_service,
llm_configuration_introspection_service,
prompt_template_service):
return ChainOfThoughtSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
prompt_template_service=prompt_template_service
)
@pytest.fixture(scope="session")
def rag_plus_cot_guidelines(
foundation_model,
response_processing_service,
prompt_template_service,
llm_configuration_introspection_service,
rag_config_builder):
return RagPlusCotSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=rag_config_builder
)
@pytest.fixture(scope="session")
def prompt_injection_example_service(prompt_injection_example_repository):
@@ -196,30 +154,36 @@ def reflexion_guardrails(
def response_processing_service():
return ResponseProcessingService()
@pytest.fixture(scope="session")
def llm_configuration_introspection_service():
return LLMConfigurationIntrospectionService()
def guidelines_factory():
return GuidelinesFactory()
@pytest.fixture(scope="session")
def guidelines_config_builder(
embedding_model,
prompt_template_service,
prompt_injection_example_repository):
return RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
embedding_model=embedding_model,
prompt_template_service=prompt_template_service,
prompt_injection_example_repository=prompt_injection_example_repository
)
@pytest.fixture(scope="session")
def text_generation_completion_service(
foundation_model,
response_processing_service,
prompt_template_service,
chain_of_thought_guidelines,
rag_context_guidelines,
rag_plus_cot_guidelines,
reflexion_guardrails,
guidelines_factory,
guidelines_config_builder,
semantic_similarity_service,
prompt_injection_example_service,
llm_configuration_introspection_service):
return TextGenerationCompletionService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
chain_of_thought_guidelines=chain_of_thought_guidelines,
rag_context_guidelines=rag_context_guidelines,
rag_plus_cot_guidelines=rag_plus_cot_guidelines,
reflexion_guardrails=reflexion_guardrails,
guidelines_factory=guidelines_factory,
guidelines_config_builder=guidelines_config_builder,
semantic_similarity_service=semantic_similarity_service,
prompt_injection_example_service=prompt_injection_example_service,
llm_configuration_introspection_service=llm_configuration_introspection_service
@@ -1,15 +1,13 @@
"""
Usage:
$ pytest -k test_04_malicious_prompts_rag_and_cot -s --disable-warnings
"""
import os
import pytest
from test_utils import run_prompt_analysis_test
from src.text_generation.common.model_id import ModelId
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
from tests.integration.test_utils import run_prompt_analysis_test, get_prompt_batch
@pytest.mark.mitigation
def test_04_malicious_prompts_rag_and_cot(
def test_04_malicious_prompts_rag_and_cot_microsoft_phi_3_mini4k_instruct(
text_generation_completion_service: AbstractTextGenerationCompletionService,
semantic_similarity_service,
malicious_prompts,
@@ -17,14 +15,19 @@ def test_04_malicious_prompts_rag_and_cot(
):
def configure_service(service):
return (service
.with_chain_of_thought_guidelines()
.with_rag_context_guidelines())
.with_chain_of_thought_guidelines()
.with_rag_context_guidelines())
prompt_batch, start, end = get_prompt_batch(malicious_prompts)
results = run_prompt_analysis_test(
test_id=4,
model_id=ModelId.MICROSOFT_PHI_3_MINI4K_INSTRUCT,
text_generation_completion_service=text_generation_completion_service,
semantic_similarity_service=semantic_similarity_service,
prompts=malicious_prompts,
prompts=prompt_batch,
start=start,
end=end,
comparison_texts=prompt_injection_all_example_responses,
service_configurator=configure_service
)
+39 -3
View File
@@ -1,5 +1,7 @@
import inspect
import os
from typing import List, Callable
from src.text_generation.common.model_id import ModelId
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.services.logging.test_run_logging_service import TestRunLoggingService
@@ -8,11 +10,40 @@ from src.text_generation.services.nlp.abstract_text_generation_completion_servic
from src.text_generation.services.nlp.text_generation_completion_service import TextGenerationCompletionService
def get_prompt_batch(prompts: List[str], batch_size=10, env_var='PROMPT_BATCH'):
batch_size = int(os.getenv('BATCH_SIZE', '2'))
batch_num = int(os.getenv('PROMPT_BATCH', '1'))
if 'BATCH_OFFSET' in os.environ:
# Option 1: Fixed offset per workflow
offset = int(os.getenv('BATCH_OFFSET', '0'))
else:
# Option 2: Configurable range
prompt_range = int(os.getenv('PROMPT_RANGE', '1'))
offset = (prompt_range - 1) * 20
# Calculate start and end indices
start_idx = offset + (batch_num - 1) * batch_size
end_idx = min(start_idx + batch_size, len(prompts))
# Get the subset of prompts for this batch
prompt_subset: List[str] = prompts[start_idx:end_idx]
print(f"Running batch {batch_num} (range offset {offset}): prompts {start_idx+1}-{end_idx} ({len(prompt_subset)} prompts)")
return prompt_subset, (start_idx+1), end_idx
def run_prompt_analysis_test(
test_id: int,
model_id: ModelId,
text_generation_completion_service: AbstractTextGenerationCompletionService,
semantic_similarity_service: AbstractSemanticSimilarityService,
prompts: List,
start: int,
end: int,
comparison_texts: List,
service_configurator: Callable,
max_prompts: int = 100
@@ -42,7 +73,7 @@ def run_prompt_analysis_test(
configured_service: TextGenerationCompletionService = service_configurator(text_generation_completion_service)
print(f'sending prompt {i} to LLM')
completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt)
completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt, model_id=model_id)
print(f'analyzing result')
completion_result.finalize_completion_text()
@@ -51,13 +82,18 @@ def run_prompt_analysis_test(
print(f'{i}/{len(prompts)} Max Score: {result.max}')
print(f'{i}/{len(prompts)} Avg Score: {result.mean}')
TestRunLoggingService(test_id=test_id).log_results(
TestRunLoggingService(
test_id=test_id,
model_id=model_id,
start=start,
end=end
).log_results(
id=inspect.currentframe().f_back.f_code.co_name,
text_generation_completion_result=completion_result,
final_completion_text_score=result.max,
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
is_reflexion_enabled=False,
original_llm_config=completion_result.original_result.llm_config
)