more updates for reflexion

This commit is contained in:
Adam Wilson
2025-07-28 10:31:55 -06:00
parent 5bc9f480f9
commit 2659e6e43c
12 changed files with 439 additions and 340 deletions

View File

@@ -12,9 +12,16 @@ class GuardrailsResult:
user_prompt: str,
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict
llm_config: dict,
cosine_similarity_score: float = -1.0,
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold

View File

@@ -17,11 +17,11 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion):
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.guidelines_completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def is_original_completion_malicious(self) -> bool:
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold

View File

@@ -17,38 +17,62 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
self,
original_result: OriginalCompletionResult,
guidelines_result: Optional[GuidelinesResult] = None,
guardrails_result: Optional[GuardrailsResult] = None
guardrails_result: Optional[GuardrailsResult] = None,
alternate_result: Optional[AlternateCompletionResult] = None
):
self.original_result = original_result
self.guidelines_result = guidelines_result
self.guardrails_result = guardrails_result
self.alternate_result = alternate_result
self.final_completion_text = ''
def finalize(self) -> str:
def finalize_completion_text(self) -> str:
"""
Returns the final completion text based on priority order:
Returns the current completion text based on priority order:
1. guardrails_result.completion_text (if not empty)
2. guidelines_result.completion_text (if not empty)
3. original_result.completion_text (if not empty)
"""
# Check guardrails_result.completion_text first
print(f'Finalized text was \"{self.final_completion_text}\"')
# Check alternate text first
if (self.alternate_result and
self.alternate_result.alterate_completion_text and
self.alternate_result.alterate_completion_text.strip()
):
self.final_completion_text = self.alternate_result.alterate_completion_text
print(f'Using alternate result. Finalized text is now \"{self.final_completion_text}\"')
return
# Check guardrails_result.completion_text next
if (self.guardrails_result and
self.guardrails_result.completion_text and
self.guardrails_result.completion_text.strip()):
return self.guardrails_result.completion_text
self.guardrails_result.completion_text.strip()
):
self.final_completion_text = self.guardrails_result.completion_text
print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"')
return
# Fall back to guidelines_result.completion_text
if (self.guidelines_result and
self.guidelines_result.completion_text and
self.guidelines_result.completion_text.strip()):
return self.guidelines_result.completion_text
self.guidelines_result.guidelines_completion_text and
self.guidelines_result.guidelines_completion_text.strip()
):
self.final_completion_text = self.guidelines_result.guidelines_completion_text
print(f'Using guidelines result. Finalized text is now \"{self.final_completion_text}\"')
return
# Fall back to original_result.completion_text
if (self.original_result and
self.original_result.completion_text and
self.original_result.completion_text.strip()):
return self.original_result.completion_text
self.original_result.completion_text.strip()
):
self.final_completion_text = self.original_result.completion_text
print(f'Using original. Finalized text is now \"{self.final_completion_text}\"')
return
# If all are empty, return empty string
return ""
self.final_completion_text = ""
print(f'Finalized text is now \"{self.final_completion_text}\"')
return

View File

@@ -4,12 +4,16 @@ from langchain_core.prompts import PromptTemplate, StringPromptTemplate
from langchain_core.prompt_values import PromptValue
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
from src.text_generation.domain.guardrails_result import GuardrailsResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.nlp.abstract_semantic_similarity_service import AbstractSemanticSimilarityService
from src.text_generation.services.prompt_injection.abstract_prompt_injection_example_service import AbstractPromptInjectionExampleService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
@@ -25,6 +29,8 @@ class ReflexionSecurityGuardrailsService(
def __init__(
self,
foundation_model: AbstractFoundationModel,
semantic_similarity_service: AbstractSemanticSimilarityService,
prompt_injection_example_service: AbstractPromptInjectionExampleService,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService):
@@ -35,11 +41,21 @@ class ReflexionSecurityGuardrailsService(
self.prompt_template_service = prompt_template_service
self.llm_configuration_introspection_service = llm_configuration_introspection_service
# constants
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.5
# set up semantic similarity service and supporting texts
self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions()
self.example_prompt_injection_prompts = prompt_injection_example_service.get_all_prompts()
self.semantic_similarity_service = semantic_similarity_service
self.semantic_similarity_service.use_comparison_texts(
self.example_prompt_injection_completions
)
def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str:
result: TextGenerationCompletionResult = text_generation_completion_result
original_user_prompt = result.original_user_prompt
original_completion = result.original_completion
original_user_prompt = result.original_result.user_prompt
original_completion = result.original_result.completion_text
# Assemble the context showing the original prompt injection attack
context_parts = [
@@ -67,7 +83,13 @@ class ReflexionSecurityGuardrailsService(
return filled_template
def _create_chain(self, prompt_template: StringPromptTemplate):
return prompt_template | self.foundation_model_pipeline | StrOutputParser()
# return prompt_template | self.foundation_model_pipeline | StrOutputParser()
return (
prompt_template
| self.foundation_model_pipeline
| StrOutputParser()
| self.response_processing_service.process_text_generation_output
)
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
"""
@@ -78,7 +100,7 @@ class ReflexionSecurityGuardrailsService(
try:
result: TextGenerationCompletionResult = text_generation_completion_result
original_user_prompt = result.original_user_prompt
original_user_prompt = result.original_result.user_prompt
prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result)
prompt_value: PromptValue = prompt_template.format_prompt(**{self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
@@ -95,11 +117,30 @@ class ReflexionSecurityGuardrailsService(
completion_text = chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result.guardrails_processed_completion = GuardrailsResult(
result.guardrails_result = GuardrailsResult(
user_prompt=original_user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
)
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(text=completion_text)
# update completion result with similarity scoring threshold and result
result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
result.guardrails_result.cosine_similarity_score = similarity_result.max
# return raw result if the completion comparison score didn't exceed threshold
if not result.guardrails_result.is_completion_malicious():
print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.cosine_similarity_score}')
return result
# provide the finalized alternate (refuse to answer)
print(f'Guardrails-based completion was malicious. Score: {result.guardrails_result.cosine_similarity_score}')
result.alternate_result = AlternateCompletionResult(
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
)
result.finalize_completion_text()
return result
except Exception as e:

View File

@@ -8,6 +8,8 @@ from langchain.prompts import FewShotPromptTemplate
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
@@ -74,13 +76,21 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
chain = self._create_chain(prompt_template)
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
result = TextGenerationCompletionResult(
original_result=OriginalCompletionResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
),
guidelines_result=GuidelinesResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
)
)
return result
except Exception as e:

View File

@@ -36,12 +36,12 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
def log_results(
self,
id: str,
text_generation_result: str,
completion: str,
prompt: str,
final_completion: str,
is_rag_few_shot_enabled: bool,
is_cot_enabled: bool,
is_reflexion_enabled: bool,
llm_config: dict,
original_llm_config: dict,
scores: List[float],
mean: float,
max: float):
@@ -50,8 +50,8 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
log_entry = {
"id": id,
"timestamp": datetime.now().isoformat(),
"prompt": prompt,
"completion": completion,
"original_prompt": prompt,
"final_completion": final_completion,
"mitigations_enabled": {
"guidelines": {
"rag_with_few_shot_examples": is_rag_few_shot_enabled,
@@ -61,7 +61,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
"reflexion": is_reflexion_enabled
}
},
"llm_config": llm_config,
"original_llm_config": original_llm_config,
"cosine_similarity": {
"mean": mean,
"max": max,

View File

@@ -6,6 +6,7 @@ from langchain_huggingface import HuggingFacePipeline
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
@@ -84,12 +85,15 @@ class TextGenerationCompletionService(
self._use_rag_context
)
guidelines_handler = self.guidelines_strategy_map.get(
guidelines_config,
guidelines_config,
# fall back to unfiltered LLM invocation
self._handle_without_guidelines
)
return guidelines_handler(user_prompt)
def _process_guidelines_result(self, guidelines_result: GuidelinesResult) -> TextGenerationCompletionResult:
def _process_completion_result(self, completion_result: TextGenerationCompletionResult) -> TextGenerationCompletionResult:
"""
Process guidelines result and create completion result with semantic similarity check.
@@ -100,26 +104,36 @@ class TextGenerationCompletionService(
TextGenerationCompletionResult with appropriate completion text
"""
# analyze the current version of the completion text against prompt injection completions;
# if guidelines applied, this is the result of completion using guidelines;
# otherwise it is the raw completion text without guidelines
completion_result.finalize_completion_text()
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(
text = guidelines_result.completion_text
text = completion_result.final_completion_text
)
guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
guidelines_result.cosine_similarity_score = similarity_result.mean
if not completion_result.guidelines_result:
completion_result.guidelines_result = GuidelinesResult(
user_prompt=completion_result.original_result.user_prompt,
completion_text=completion_result.original_result.completion_text,
llm_config=completion_result.original_result.llm_config
)
completion_result = TextGenerationCompletionResult(
llm_config = guidelines_result.llm_config,
original_completion = guidelines_result.completion_text,
original_user_prompt = guidelines_result.user_prompt,
guidelines_result = guidelines_result
)
# update completion result with similarity scoring threshold and result
completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
completion_result.guidelines_result.cosine_similarity_score = similarity_result.max
if not guidelines_result.is_original_completion_malicious():
# return raw result if the completion comparison score didn't exceed threshold
if not completion_result.guidelines_result.is_completion_malicious():
print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
return completion_result
# provide the finalized alternate (refuse to answer)
print(f'Guidelines-based completion was malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
completion_result.alternate_result = AlternateCompletionResult(
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
)
completion_result.finalize_completion_text()
return completion_result
@@ -127,28 +141,31 @@ class TextGenerationCompletionService(
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=True"""
guidelines_result = self.rag_plus_cot_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_cot_only(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=False"""
guidelines_result = self.chain_of_thought_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_rag_only(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=False, RAG=True"""
guidelines_result = self.rag_context_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_without_guidelines(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=False, RAG=False"""
try:
chain = self._create_chain_without_guidelines()
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
completion_text = chain.invoke(user_prompt),
llm_config = llm_config
)
return self._process_guidelines_result(result)
result = TextGenerationCompletionResult(
original_result=OriginalCompletionResult(
user_prompt=user_prompt,
completion_text=chain.invoke(user_prompt),
llm_config=llm_config
))
return self._process_completion_result(result)
except Exception as e:
raise e
@@ -216,6 +233,8 @@ class TextGenerationCompletionService(
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
print(f'Using guidelines: {self.get_current_config()}')
completion_result: TextGenerationCompletionResult = self._process_prompt_with_guidelines_if_applicable(user_prompt)
if not self._use_reflexion_guardrails:
return completion_result
return self._handle_reflexion_guardrails(completion_result)

View File

@@ -178,11 +178,15 @@ def prompt_injection_example_service(prompt_injection_example_repository):
@pytest.fixture(scope="session")
def reflexion_guardrails(
foundation_model,
semantic_similarity_service,
prompt_injection_example_service,
response_processing_service,
prompt_template_service,
llm_configuration_introspection_service):
return ReflexionSecurityGuardrailsService(
foundation_model=foundation_model,
semantic_similarity_service=semantic_similarity_service,
prompt_injection_example_service=prompt_injection_example_service,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service

View File

@@ -1,270 +0,0 @@
"""
pytest -v tests/integration/test_domain.py
"""
import pytest
from unittest.mock import MagicMock
from src.text_generation.domain.guardrails_result import GuardrailsResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
class TestTextGenerationCompletionResult:
"""Test suite for TextGenerationCompletionResult and related classes."""
@pytest.fixture
def sample_llm_config(self):
"""Sample LLM configuration for testing."""
return {
"model": "test-model",
"temperature": 0.7,
"max_tokens": 100
}
@pytest.fixture
def sample_full_prompt(self):
"""Sample full prompt for testing."""
return {
"system": "You are a helpful assistant",
"user": "Test prompt"
}
def test_original_result_only(self, sample_llm_config, sample_full_prompt):
"""Test TextGenerationCompletionResult with only original result filled in."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=None,
guardrails_result=None
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == "Original completion text"
assert original_result.user_prompt == "Test user prompt"
assert original_result.llm_config == sample_llm_config
assert not original_result.is_completion_malicious() # 0.5 < 0.7
def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt):
"""Test that guidelines result takes precedence over original when both are present."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.8
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=None
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == "Guidelines processed completion text"
assert original_result.is_completion_malicious() # 0.8 >= 0.7
assert not guidelines_result.is_original_completion_malicious() # 0.6 < 0.7
def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt):
"""Test that guardrails result takes precedence when all three are present."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.9
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.8
guidelines_result.cosine_similarity_risk_threshold=0.7
guardrails_result = GuardrailsResult(
user_prompt="Test user prompt",
completion_text="Guardrails processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=guardrails_result
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == "Guardrails processed completion text"
assert original_result.is_completion_malicious() # 0.9 >= 0.7
assert guidelines_result.is_original_completion_malicious() # 0.8 >= 0.7
def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt):
"""Test fallback behavior when some completion texts are empty."""
# Arrange - guardrails has empty text, should fall back to guidelines
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
guardrails_result = GuardrailsResult(
user_prompt="Test user prompt",
completion_text="", # Empty string
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=guardrails_result
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == "Guidelines processed completion text"
def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt):
"""Test fallback behavior when completion text is only whitespace."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text=" \n\t ", # Whitespace only
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=None
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == "Original completion text"
def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt):
"""Test that empty string is returned when all completions are empty."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=None,
guardrails_result=None
)
# Act
final_text = completion_result.finalize()
# Assert
assert final_text == ""
def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt):
"""Test boundary conditions for risk threshold (exactly 0.7)."""
# Test score exactly at threshold
original_result_at_threshold = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result_at_threshold.cosine_similarity_score=0.7
original_result_at_threshold.cosine_similarity_risk_threshold=0.7
# Test score just below threshold
guidelines_result_below_threshold = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config,
cosine_similarity_score=0.699999, # Just below threshold
cosine_similarity_risk_threshold=0.7
)
# Assert
assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7
assert not guidelines_result_below_threshold.is_original_completion_malicious() # 0.699999 < 0.7

View File

@@ -45,7 +45,8 @@ def run_prompt_analysis_test(
completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt)
print(f'analyzing result')
result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final)
completion_result.finalize_completion_text()
result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final_completion_text)
print(f'{i}/{len(prompts)} Max Score: {result.max}')
print(f'{i}/{len(prompts)} Avg Score: {result.mean}')
@@ -53,11 +54,11 @@ def run_prompt_analysis_test(
TestRunLoggingService(test_id=test_id).log_results(
id=inspect.currentframe().f_back.f_code.co_name,
prompt=prompt,
completion=completion_result.original_completion,
final_completion=completion_result.final_completion_text,
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
llm_config=completion_result.guidelines_processed_completion.llm_config,
original_llm_config=completion_result.original_result.llm_config,
scores=result.scores,
mean=result.mean,
max=result.max

View File

@@ -208,7 +208,7 @@ def test_example_with_fluent_service_call(
TestRunLoggingService(test_id=10).log_results(
id=inspect.currentframe().f_code.co_name,
prompt=prompt,
completion=completion_result.original_completion,
final_completion=completion_result.original_completion,
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),

View File

@@ -1,13 +1,276 @@
"""
pytest -v tests/integration/test_domain.py
"""
import pytest
from unittest.mock import MagicMock
from src.text_generation.domain.guardrails_result import GuardrailsResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
@pytest.mark.unit
def test_guardrail_analyzed_response():
response = GuardrailsResult(
cosine_similarity_score=0.72839,
cosine_similarity_risk_threshold=0.5,
original_completion="compromised response",
guardrails_processed_completion_text="I can't answer that"
)
assert response.is_original_completion_malicious == True
class TestTextGenerationCompletionResult:
"""Test suite for TextGenerationCompletionResult and related classes."""
@pytest.fixture
def sample_llm_config(self):
"""Sample LLM configuration for testing."""
return {
"model": "test-model",
"temperature": 0.7,
"max_tokens": 100
}
@pytest.fixture
def sample_full_prompt(self):
"""Sample full prompt for testing."""
return {
"system": "You are a helpful assistant",
"user": "Test prompt"
}
def test_original_result_only(self, sample_llm_config, sample_full_prompt):
"""Test TextGenerationCompletionResult with only original result filled in."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=None,
guardrails_result=None
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == "Original completion text"
assert original_result.user_prompt == "Test user prompt"
assert original_result.llm_config == sample_llm_config
assert not original_result.is_completion_malicious() # 0.5 < 0.7
def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt):
"""Test that guidelines result takes precedence over original when both are present."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.8
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=None
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == "Guidelines processed completion text"
assert original_result.is_completion_malicious() # 0.8 >= 0.7
assert not guidelines_result.is_completion_malicious() # 0.6 < 0.7
def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt):
"""Test that guardrails result takes precedence when all three are present."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.9
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.8
guidelines_result.cosine_similarity_risk_threshold=0.7
guardrails_result = GuardrailsResult(
user_prompt="Test user prompt",
completion_text="Guardrails processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=guardrails_result
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == "Guardrails processed completion text"
assert original_result.is_completion_malicious() # 0.9 >= 0.7
assert guidelines_result.is_completion_malicious() # 0.8 >= 0.7
def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt):
"""Test fallback behavior when some completion texts are empty."""
# Arrange - guardrails has empty text, should fall back to guidelines
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines processed completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
guardrails_result = GuardrailsResult(
user_prompt="Test user prompt",
completion_text="", # Empty string
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=guardrails_result
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == "Guidelines processed completion text"
def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt):
"""Test fallback behavior when completion text is only whitespace."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
guidelines_result = GuidelinesResult(
user_prompt="Test user prompt",
completion_text=" \n\t ", # Whitespace only
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
guidelines_result.cosine_similarity_score=0.6
guidelines_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=guidelines_result,
guardrails_result=None
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == "Original completion text"
def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt):
"""Test that empty string is returned when all completions are empty."""
# Arrange
original_result = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result.cosine_similarity_score=0.5
original_result.cosine_similarity_risk_threshold=0.7
completion_result = TextGenerationCompletionResult(
original_result=original_result,
guidelines_result=None,
guardrails_result=None
)
# Act
completion_result.finalize_completion_text()
final_text = completion_result.final_completion_text
# Assert
assert final_text == ""
def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt):
"""Test boundary conditions for risk threshold (exactly 0.7)."""
# Test score exactly at threshold
original_result_at_threshold = OriginalCompletionResult(
user_prompt="Test user prompt",
completion_text="Original completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config
)
original_result_at_threshold.cosine_similarity_score=0.7
original_result_at_threshold.cosine_similarity_risk_threshold=0.7
# Test score just below threshold
guidelines_result_below_threshold = GuidelinesResult(
user_prompt="Test user prompt",
completion_text="Guidelines completion text",
full_prompt=sample_full_prompt,
llm_config=sample_llm_config,
cosine_similarity_score=0.699999, # Just below threshold
cosine_similarity_risk_threshold=0.7
)
# Assert
assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7
assert not guidelines_result_below_threshold.is_completion_malicious() # 0.699999 < 0.7