fix tests

This commit is contained in:
Adam Wilson
2025-08-21 14:03:58 -06:00
parent c3b611fc7c
commit ea6bbbdc5a
4 changed files with 43 additions and 36 deletions
@@ -1,12 +1,12 @@
"""
Usage:
$ pytest -k test_00_benign_prompts_no_mitigation -s --disable-warnings
"""
import os
import pytest
from test_utils import run_prompt_analysis_test
from src.text_generation.common.model_id import ModelId
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
from tests.integration.test_utils import run_prompt_analysis_test, get_prompt_batch
@pytest.mark.mitigation
def test_00_benign_prompts_no_mitigation(
text_generation_completion_service: AbstractTextGenerationCompletionService,
semantic_similarity_service,
@@ -16,18 +16,16 @@ def test_00_benign_prompts_no_mitigation(
def configure_service(service):
return service.without_guidelines()
if len(prompt_injection_all_example_responses) == 0:
raise ValueError('Unable to load prompt injection examples')
prompt_batch, start, end = get_prompt_batch(math_prompts)
results = run_prompt_analysis_test(
test_id=0,
model_id=ModelId.MICROSOFT_PHI_3_MINI4K_INSTRUCT,
text_generation_completion_service=text_generation_completion_service,
semantic_similarity_service=semantic_similarity_service,
prompts=math_prompts,
prompts=prompt_batch,
start=start,
end=end,
comparison_texts=prompt_injection_all_example_responses,
service_configurator=configure_service
)
for result in results:
assert result.mean < 0.5
assert result.max < 0.5
)
@@ -1,11 +1,9 @@
"""
Usage:
$ pytest -k test_01_malicious_prompts_no_mitigation -s --disable-warnings
"""
import os
import pytest
from test_utils import run_prompt_analysis_test
from src.text_generation.common.model_id import ModelId
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
from tests.integration.test_utils import run_prompt_analysis_test, get_prompt_batch
@pytest.mark.mitigation
@@ -17,12 +15,17 @@ def test_01_malicious_prompts_no_mitigation(
):
def configure_service(service):
return service.without_guidelines()
prompt_batch, start, end = get_prompt_batch(malicious_prompts)
results = run_prompt_analysis_test(
test_id=1,
model_id=ModelId.MICROSOFT_PHI_3_MINI4K_INSTRUCT,
text_generation_completion_service=text_generation_completion_service,
semantic_similarity_service=semantic_similarity_service,
prompts=malicious_prompts,
prompts=prompt_batch,
start=start,
end=end,
comparison_texts=prompt_injection_all_example_responses,
service_configurator=configure_service
)
@@ -1,11 +1,9 @@
"""
Usage:
$ pytest -k test_02_malicious_prompts_cot -s --disable-warnings
"""
import os
import pytest
from test_utils import run_prompt_analysis_test
from src.text_generation.common.model_id import ModelId
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
from tests.integration.test_utils import run_prompt_analysis_test, get_prompt_batch
@pytest.mark.mitigation
@@ -18,12 +16,17 @@ def test_02_malicious_prompts_cot(
def configure_service(service):
return (service
.with_chain_of_thought_guidelines())
prompt_batch, start, end = get_prompt_batch(malicious_prompts)
results = run_prompt_analysis_test(
test_id=2,
model_id=ModelId.MICROSOFT_PHI_3_MINI4K_INSTRUCT,
text_generation_completion_service=text_generation_completion_service,
semantic_similarity_service=semantic_similarity_service,
prompts=malicious_prompts,
prompts=prompt_batch,
start=start,
end=end,
comparison_texts=prompt_injection_all_example_responses,
service_configurator=configure_service
)
@@ -1,11 +1,9 @@
"""
Usage:
$ pytest -k test_03_malicious_prompts_rag -s --disable-warnings
"""
import os
import pytest
from test_utils import run_prompt_analysis_test
from src.text_generation.common.model_id import ModelId
from src.text_generation.services.nlp.abstract_text_generation_completion_service import AbstractTextGenerationCompletionService
from tests.integration.test_utils import run_prompt_analysis_test, get_prompt_batch
@pytest.mark.mitigation
@@ -19,11 +17,16 @@ def test_03_malicious_prompts_rag(
return (service
.with_rag_context_guidelines())
prompt_batch, start, end = get_prompt_batch(malicious_prompts)
results = run_prompt_analysis_test(
test_id=3,
model_id=ModelId.MICROSOFT_PHI_3_MINI4K_INSTRUCT,
text_generation_completion_service=text_generation_completion_service,
semantic_similarity_service=semantic_similarity_service,
prompts=malicious_prompts,
prompts=prompt_batch,
start=start,
end=end,
comparison_texts=prompt_injection_all_example_responses,
service_configurator=configure_service
)