From 2659e6e43c729e7644da24df4e47b495b6dd5fdf Mon Sep 17 00:00:00 2001 From: Adam Wilson Date: Mon, 28 Jul 2025 10:31:55 -0600 Subject: [PATCH] more updates for reflexion --- .../domain/guardrails_result.py | 11 +- .../domain/guidelines_result.py | 4 +- .../text_generation_completion_result.py | 54 +++- .../reflexion_security_guidelines_service.py | 51 +++- .../base_security_guidelines_service.py | 22 +- .../logging/test_run_logging_service.py | 12 +- .../nlp/text_generation_completion_service.py | 61 ++-- tests/conftest.py | 4 + tests/integration/test_domain.py | 270 ----------------- tests/integration/test_utils.py | 7 +- tests/integration/test_violation_rate.py | 2 +- tests/unit/test_domain.py | 281 +++++++++++++++++- 12 files changed, 439 insertions(+), 340 deletions(-) delete mode 100644 tests/integration/test_domain.py diff --git a/src/text_generation/domain/guardrails_result.py b/src/text_generation/domain/guardrails_result.py index ee4ccead0..5233d0c08 100644 --- a/src/text_generation/domain/guardrails_result.py +++ b/src/text_generation/domain/guardrails_result.py @@ -12,9 +12,16 @@ class GuardrailsResult: user_prompt: str, completion_text: str, full_prompt: dict[str, Any], - llm_config: dict + llm_config: dict, + cosine_similarity_score: float = -1.0, + cosine_similarity_risk_threshold: float = 0.0 ): self.user_prompt = user_prompt self.completion_text = completion_text self.full_prompt = full_prompt - self.llm_config = llm_config \ No newline at end of file + self.llm_config = llm_config + self.cosine_similarity_score = cosine_similarity_score + self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold + + def is_completion_malicious(self) -> bool: + return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/domain/guidelines_result.py b/src/text_generation/domain/guidelines_result.py index 75d77e9b6..a4e473d7e 100644 --- a/src/text_generation/domain/guidelines_result.py +++ b/src/text_generation/domain/guidelines_result.py @@ -17,11 +17,11 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion): cosine_similarity_risk_threshold: float = 0.0 ): self.user_prompt = user_prompt - self.completion_text = completion_text + self.guidelines_completion_text = completion_text self.full_prompt = full_prompt self.llm_config = llm_config self.cosine_similarity_score = cosine_similarity_score self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold - def is_original_completion_malicious(self) -> bool: + def is_completion_malicious(self) -> bool: return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/domain/text_generation_completion_result.py b/src/text_generation/domain/text_generation_completion_result.py index e267b4628..d6d130b5f 100644 --- a/src/text_generation/domain/text_generation_completion_result.py +++ b/src/text_generation/domain/text_generation_completion_result.py @@ -17,38 +17,62 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult): self, original_result: OriginalCompletionResult, guidelines_result: Optional[GuidelinesResult] = None, - guardrails_result: Optional[GuardrailsResult] = None + guardrails_result: Optional[GuardrailsResult] = None, + alternate_result: Optional[AlternateCompletionResult] = None ): self.original_result = original_result self.guidelines_result = guidelines_result self.guardrails_result = guardrails_result + self.alternate_result = alternate_result + self.final_completion_text = '' - def finalize(self) -> str: + def finalize_completion_text(self) -> str: """ - Returns the final completion text based on priority order: + Returns the current completion text based on priority order: 1. guardrails_result.completion_text (if not empty) 2. guidelines_result.completion_text (if not empty) 3. original_result.completion_text (if not empty) """ - - # Check guardrails_result.completion_text first + + print(f'Finalized text was \"{self.final_completion_text}\"') + + # Check alternate text first + if (self.alternate_result and + self.alternate_result.alterate_completion_text and + self.alternate_result.alterate_completion_text.strip() + ): + self.final_completion_text = self.alternate_result.alterate_completion_text + print(f'Using alternate result. Finalized text is now \"{self.final_completion_text}\"') + return + + # Check guardrails_result.completion_text next if (self.guardrails_result and self.guardrails_result.completion_text and - self.guardrails_result.completion_text.strip()): - return self.guardrails_result.completion_text + self.guardrails_result.completion_text.strip() + ): + self.final_completion_text = self.guardrails_result.completion_text + print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"') + return # Fall back to guidelines_result.completion_text if (self.guidelines_result and - self.guidelines_result.completion_text and - self.guidelines_result.completion_text.strip()): - return self.guidelines_result.completion_text - + self.guidelines_result.guidelines_completion_text and + self.guidelines_result.guidelines_completion_text.strip() + ): + self.final_completion_text = self.guidelines_result.guidelines_completion_text + print(f'Using guidelines result. Finalized text is now \"{self.final_completion_text}\"') + return + # Fall back to original_result.completion_text if (self.original_result and self.original_result.completion_text and - self.original_result.completion_text.strip()): - return self.original_result.completion_text + self.original_result.completion_text.strip() + ): + self.final_completion_text = self.original_result.completion_text + print(f'Using original. Finalized text is now \"{self.final_completion_text}\"') + return # If all are empty, return empty string - return "" - + self.final_completion_text = "" + print(f'Finalized text is now \"{self.final_completion_text}\"') + return \ No newline at end of file diff --git a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py index ff79a1f39..488558343 100644 --- a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py +++ b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py @@ -4,12 +4,16 @@ from langchain_core.prompts import PromptTemplate, StringPromptTemplate from langchain_core.prompt_values import PromptValue from src.text_generation.common.constants import Constants +from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult from src.text_generation.domain.guardrails_result import GuardrailsResult from src.text_generation.domain.guidelines_result import GuidelinesResult +from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService +from src.text_generation.services.nlp.abstract_semantic_similarity_service import AbstractSemanticSimilarityService +from src.text_generation.services.prompt_injection.abstract_prompt_injection_example_service import AbstractPromptInjectionExampleService from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService @@ -25,6 +29,8 @@ class ReflexionSecurityGuardrailsService( def __init__( self, foundation_model: AbstractFoundationModel, + semantic_similarity_service: AbstractSemanticSimilarityService, + prompt_injection_example_service: AbstractPromptInjectionExampleService, response_processing_service: AbstractResponseProcessingService, prompt_template_service: AbstractPromptTemplateService, llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService): @@ -35,11 +41,21 @@ class ReflexionSecurityGuardrailsService( self.prompt_template_service = prompt_template_service self.llm_configuration_introspection_service = llm_configuration_introspection_service + # constants + self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.5 + + # set up semantic similarity service and supporting texts + self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions() + self.example_prompt_injection_prompts = prompt_injection_example_service.get_all_prompts() + self.semantic_similarity_service = semantic_similarity_service + self.semantic_similarity_service.use_comparison_texts( + self.example_prompt_injection_completions + ) def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str: result: TextGenerationCompletionResult = text_generation_completion_result - original_user_prompt = result.original_user_prompt - original_completion = result.original_completion + original_user_prompt = result.original_result.user_prompt + original_completion = result.original_result.completion_text # Assemble the context showing the original prompt injection attack context_parts = [ @@ -67,7 +83,13 @@ class ReflexionSecurityGuardrailsService( return filled_template def _create_chain(self, prompt_template: StringPromptTemplate): - return prompt_template | self.foundation_model_pipeline | StrOutputParser() + # return prompt_template | self.foundation_model_pipeline | StrOutputParser() + return ( + prompt_template + | self.foundation_model_pipeline + | StrOutputParser() + | self.response_processing_service.process_text_generation_output + ) def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult: """ @@ -78,7 +100,7 @@ class ReflexionSecurityGuardrailsService( try: result: TextGenerationCompletionResult = text_generation_completion_result - original_user_prompt = result.original_user_prompt + original_user_prompt = result.original_result.user_prompt prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result) prompt_value: PromptValue = prompt_template.format_prompt(**{self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt}) @@ -95,11 +117,30 @@ class ReflexionSecurityGuardrailsService( completion_text = chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt}) llm_config = self.llm_configuration_introspection_service.get_config(chain) - result.guardrails_processed_completion = GuardrailsResult( + result.guardrails_result = GuardrailsResult( + user_prompt=original_user_prompt, completion_text=completion_text, llm_config=llm_config, full_prompt=prompt_dict ) + + similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(text=completion_text) + + # update completion result with similarity scoring threshold and result + result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD + result.guardrails_result.cosine_similarity_score = similarity_result.max + + # return raw result if the completion comparison score didn't exceed threshold + if not result.guardrails_result.is_completion_malicious(): + print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.cosine_similarity_score}') + return result + + # provide the finalized alternate (refuse to answer) + print(f'Guardrails-based completion was malicious. Score: {result.guardrails_result.cosine_similarity_score}') + result.alternate_result = AlternateCompletionResult( + alterate_completion_text = self.constants.ALT_COMPLETION_TEXT + ) + result.finalize_completion_text() return result except Exception as e: diff --git a/src/text_generation/services/guidelines/base_security_guidelines_service.py b/src/text_generation/services/guidelines/base_security_guidelines_service.py index e97516d5e..a95ff9b45 100644 --- a/src/text_generation/services/guidelines/base_security_guidelines_service.py +++ b/src/text_generation/services/guidelines/base_security_guidelines_service.py @@ -8,6 +8,8 @@ from langchain.prompts import FewShotPromptTemplate from src.text_generation.common.constants import Constants from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion from src.text_generation.domain.guidelines_result import GuidelinesResult +from src.text_generation.domain.original_completion_result import OriginalCompletionResult +from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService @@ -74,13 +76,21 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService): chain = self._create_chain(prompt_template) completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt}) - llm_config = self.llm_configuration_introspection_service.get_config(chain) - result = GuidelinesResult( - user_prompt=user_prompt, - completion_text=completion_text, - llm_config=llm_config, - full_prompt=prompt_dict + + result = TextGenerationCompletionResult( + original_result=OriginalCompletionResult( + user_prompt=user_prompt, + completion_text=completion_text, + llm_config=llm_config, + full_prompt=prompt_dict + ), + guidelines_result=GuidelinesResult( + user_prompt=user_prompt, + completion_text=completion_text, + llm_config=llm_config, + full_prompt=prompt_dict + ) ) return result except Exception as e: diff --git a/src/text_generation/services/logging/test_run_logging_service.py b/src/text_generation/services/logging/test_run_logging_service.py index 5df7205a5..ced82ee4c 100644 --- a/src/text_generation/services/logging/test_run_logging_service.py +++ b/src/text_generation/services/logging/test_run_logging_service.py @@ -36,12 +36,12 @@ class TestRunLoggingService(AbstractTestRunLoggingService): def log_results( self, id: str, - text_generation_result: str, - completion: str, + prompt: str, + final_completion: str, is_rag_few_shot_enabled: bool, is_cot_enabled: bool, is_reflexion_enabled: bool, - llm_config: dict, + original_llm_config: dict, scores: List[float], mean: float, max: float): @@ -50,8 +50,8 @@ class TestRunLoggingService(AbstractTestRunLoggingService): log_entry = { "id": id, "timestamp": datetime.now().isoformat(), - "prompt": prompt, - "completion": completion, + "original_prompt": prompt, + "final_completion": final_completion, "mitigations_enabled": { "guidelines": { "rag_with_few_shot_examples": is_rag_few_shot_enabled, @@ -61,7 +61,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService): "reflexion": is_reflexion_enabled } }, - "llm_config": llm_config, + "original_llm_config": original_llm_config, "cosine_similarity": { "mean": mean, "max": max, diff --git a/src/text_generation/services/nlp/text_generation_completion_service.py b/src/text_generation/services/nlp/text_generation_completion_service.py index 1f044cded..9648c036e 100644 --- a/src/text_generation/services/nlp/text_generation_completion_service.py +++ b/src/text_generation/services/nlp/text_generation_completion_service.py @@ -6,6 +6,7 @@ from langchain_huggingface import HuggingFacePipeline from src.text_generation.common.constants import Constants from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult from src.text_generation.domain.guidelines_result import GuidelinesResult +from src.text_generation.domain.original_completion_result import OriginalCompletionResult from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService @@ -84,12 +85,15 @@ class TextGenerationCompletionService( self._use_rag_context ) guidelines_handler = self.guidelines_strategy_map.get( - guidelines_config, + guidelines_config, + + # fall back to unfiltered LLM invocation self._handle_without_guidelines ) return guidelines_handler(user_prompt) - - def _process_guidelines_result(self, guidelines_result: GuidelinesResult) -> TextGenerationCompletionResult: + + + def _process_completion_result(self, completion_result: TextGenerationCompletionResult) -> TextGenerationCompletionResult: """ Process guidelines result and create completion result with semantic similarity check. @@ -100,26 +104,36 @@ class TextGenerationCompletionService( TextGenerationCompletionResult with appropriate completion text """ + # analyze the current version of the completion text against prompt injection completions; + # if guidelines applied, this is the result of completion using guidelines; + # otherwise it is the raw completion text without guidelines + completion_result.finalize_completion_text() similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze( - text = guidelines_result.completion_text + text = completion_result.final_completion_text ) - guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD - guidelines_result.cosine_similarity_score = similarity_result.mean + if not completion_result.guidelines_result: + completion_result.guidelines_result = GuidelinesResult( + user_prompt=completion_result.original_result.user_prompt, + completion_text=completion_result.original_result.completion_text, + llm_config=completion_result.original_result.llm_config + ) - completion_result = TextGenerationCompletionResult( - llm_config = guidelines_result.llm_config, - original_completion = guidelines_result.completion_text, - original_user_prompt = guidelines_result.user_prompt, - guidelines_result = guidelines_result - ) + # update completion result with similarity scoring threshold and result + completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD + completion_result.guidelines_result.cosine_similarity_score = similarity_result.max - if not guidelines_result.is_original_completion_malicious(): + # return raw result if the completion comparison score didn't exceed threshold + if not completion_result.guidelines_result.is_completion_malicious(): + print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}') return completion_result + # provide the finalized alternate (refuse to answer) + print(f'Guidelines-based completion was malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}') completion_result.alternate_result = AlternateCompletionResult( alterate_completion_text = self.constants.ALT_COMPLETION_TEXT ) + completion_result.finalize_completion_text() return completion_result @@ -127,28 +141,31 @@ class TextGenerationCompletionService( def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult: """Handle: CoT=True, RAG=True""" guidelines_result = self.rag_plus_cot_guidelines.apply_guidelines(user_prompt) - return self._process_guidelines_result(guidelines_result) + return self._process_completion_result(guidelines_result) def _handle_cot_only(self, user_prompt: str) -> TextGenerationCompletionResult: """Handle: CoT=True, RAG=False""" guidelines_result = self.chain_of_thought_guidelines.apply_guidelines(user_prompt) - return self._process_guidelines_result(guidelines_result) + return self._process_completion_result(guidelines_result) def _handle_rag_only(self, user_prompt: str) -> TextGenerationCompletionResult: """Handle: CoT=False, RAG=True""" guidelines_result = self.rag_context_guidelines.apply_guidelines(user_prompt) - return self._process_guidelines_result(guidelines_result) + return self._process_completion_result(guidelines_result) def _handle_without_guidelines(self, user_prompt: str) -> TextGenerationCompletionResult: """Handle: CoT=False, RAG=False""" try: chain = self._create_chain_without_guidelines() llm_config = self.llm_configuration_introspection_service.get_config(chain) - result = GuidelinesResult( - completion_text = chain.invoke(user_prompt), - llm_config = llm_config - ) - return self._process_guidelines_result(result) + + result = TextGenerationCompletionResult( + original_result=OriginalCompletionResult( + user_prompt=user_prompt, + completion_text=chain.invoke(user_prompt), + llm_config=llm_config + )) + return self._process_completion_result(result) except Exception as e: raise e @@ -216,6 +233,8 @@ class TextGenerationCompletionService( raise ValueError(f"Parameter 'user_prompt' cannot be empty or None") print(f'Using guidelines: {self.get_current_config()}') completion_result: TextGenerationCompletionResult = self._process_prompt_with_guidelines_if_applicable(user_prompt) + if not self._use_reflexion_guardrails: return completion_result + return self._handle_reflexion_guardrails(completion_result) diff --git a/tests/conftest.py b/tests/conftest.py index 1c76afb3e..1f5b2b083 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,11 +178,15 @@ def prompt_injection_example_service(prompt_injection_example_repository): @pytest.fixture(scope="session") def reflexion_guardrails( foundation_model, + semantic_similarity_service, + prompt_injection_example_service, response_processing_service, prompt_template_service, llm_configuration_introspection_service): return ReflexionSecurityGuardrailsService( foundation_model=foundation_model, + semantic_similarity_service=semantic_similarity_service, + prompt_injection_example_service=prompt_injection_example_service, response_processing_service=response_processing_service, prompt_template_service=prompt_template_service, llm_configuration_introspection_service=llm_configuration_introspection_service diff --git a/tests/integration/test_domain.py b/tests/integration/test_domain.py deleted file mode 100644 index c2e5debe0..000000000 --- a/tests/integration/test_domain.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -pytest -v tests/integration/test_domain.py -""" - -import pytest -from unittest.mock import MagicMock - -from src.text_generation.domain.guardrails_result import GuardrailsResult -from src.text_generation.domain.guidelines_result import GuidelinesResult -from src.text_generation.domain.original_completion_result import OriginalCompletionResult -from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult - - -class TestTextGenerationCompletionResult: - """Test suite for TextGenerationCompletionResult and related classes.""" - - @pytest.fixture - def sample_llm_config(self): - """Sample LLM configuration for testing.""" - return { - "model": "test-model", - "temperature": 0.7, - "max_tokens": 100 - } - - @pytest.fixture - def sample_full_prompt(self): - """Sample full prompt for testing.""" - return { - "system": "You are a helpful assistant", - "user": "Test prompt" - } - - def test_original_result_only(self, sample_llm_config, sample_full_prompt): - """Test TextGenerationCompletionResult with only original result filled in.""" - # Arrange - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.5 - original_result.cosine_similarity_risk_threshold=0.7 - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=None, - guardrails_result=None - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "Original completion text" - assert original_result.user_prompt == "Test user prompt" - assert original_result.llm_config == sample_llm_config - assert not original_result.is_completion_malicious() # 0.5 < 0.7 - - def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt): - """Test that guidelines result takes precedence over original when both are present.""" - # Arrange - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.8 - original_result.cosine_similarity_risk_threshold=0.7 - - guidelines_result = GuidelinesResult( - user_prompt="Test user prompt", - completion_text="Guidelines processed completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - guidelines_result.cosine_similarity_score=0.6 - guidelines_result.cosine_similarity_risk_threshold=0.7 - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=guidelines_result, - guardrails_result=None - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "Guidelines processed completion text" - assert original_result.is_completion_malicious() # 0.8 >= 0.7 - assert not guidelines_result.is_original_completion_malicious() # 0.6 < 0.7 - - def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt): - """Test that guardrails result takes precedence when all three are present.""" - # Arrange - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.9 - original_result.cosine_similarity_risk_threshold=0.7 - - guidelines_result = GuidelinesResult( - user_prompt="Test user prompt", - completion_text="Guidelines processed completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - guidelines_result.cosine_similarity_score=0.8 - guidelines_result.cosine_similarity_risk_threshold=0.7 - - guardrails_result = GuardrailsResult( - user_prompt="Test user prompt", - completion_text="Guardrails processed completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=guidelines_result, - guardrails_result=guardrails_result - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "Guardrails processed completion text" - assert original_result.is_completion_malicious() # 0.9 >= 0.7 - assert guidelines_result.is_original_completion_malicious() # 0.8 >= 0.7 - - def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt): - """Test fallback behavior when some completion texts are empty.""" - # Arrange - guardrails has empty text, should fall back to guidelines - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.5 - original_result.cosine_similarity_risk_threshold=0.7 - - guidelines_result = GuidelinesResult( - user_prompt="Test user prompt", - completion_text="Guidelines processed completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - guidelines_result.cosine_similarity_score=0.6 - guidelines_result.cosine_similarity_risk_threshold=0.7 - - guardrails_result = GuardrailsResult( - user_prompt="Test user prompt", - completion_text="", # Empty string - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=guidelines_result, - guardrails_result=guardrails_result - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "Guidelines processed completion text" - - def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt): - """Test fallback behavior when completion text is only whitespace.""" - # Arrange - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.5 - original_result.cosine_similarity_risk_threshold=0.7 - - guidelines_result = GuidelinesResult( - user_prompt="Test user prompt", - completion_text=" \n\t ", # Whitespace only - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - guidelines_result.cosine_similarity_score=0.6 - guidelines_result.cosine_similarity_risk_threshold=0.7 - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=guidelines_result, - guardrails_result=None - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "Original completion text" - - def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt): - """Test that empty string is returned when all completions are empty.""" - # Arrange - original_result = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result.cosine_similarity_score=0.5 - original_result.cosine_similarity_risk_threshold=0.7 - - completion_result = TextGenerationCompletionResult( - original_result=original_result, - guidelines_result=None, - guardrails_result=None - ) - - # Act - final_text = completion_result.finalize() - - # Assert - assert final_text == "" - - def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt): - """Test boundary conditions for risk threshold (exactly 0.7).""" - # Test score exactly at threshold - original_result_at_threshold = OriginalCompletionResult( - user_prompt="Test user prompt", - completion_text="Original completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config - ) - - original_result_at_threshold.cosine_similarity_score=0.7 - original_result_at_threshold.cosine_similarity_risk_threshold=0.7 - - # Test score just below threshold - guidelines_result_below_threshold = GuidelinesResult( - user_prompt="Test user prompt", - completion_text="Guidelines completion text", - full_prompt=sample_full_prompt, - llm_config=sample_llm_config, - cosine_similarity_score=0.699999, # Just below threshold - cosine_similarity_risk_threshold=0.7 - ) - - # Assert - assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7 - assert not guidelines_result_below_threshold.is_original_completion_malicious() # 0.699999 < 0.7 \ No newline at end of file diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 945a2a91e..26dded558 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -45,7 +45,8 @@ def run_prompt_analysis_test( completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt) print(f'analyzing result') - result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final) + completion_result.finalize_completion_text() + result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final_completion_text) print(f'{i}/{len(prompts)} Max Score: {result.max}') print(f'{i}/{len(prompts)} Avg Score: {result.mean}') @@ -53,11 +54,11 @@ def run_prompt_analysis_test( TestRunLoggingService(test_id=test_id).log_results( id=inspect.currentframe().f_back.f_code.co_name, prompt=prompt, - completion=completion_result.original_completion, + final_completion=completion_result.final_completion_text, is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(), is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(), is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(), - llm_config=completion_result.guidelines_processed_completion.llm_config, + original_llm_config=completion_result.original_result.llm_config, scores=result.scores, mean=result.mean, max=result.max diff --git a/tests/integration/test_violation_rate.py b/tests/integration/test_violation_rate.py index bc7df7305..694fd1c00 100644 --- a/tests/integration/test_violation_rate.py +++ b/tests/integration/test_violation_rate.py @@ -208,7 +208,7 @@ def test_example_with_fluent_service_call( TestRunLoggingService(test_id=10).log_results( id=inspect.currentframe().f_code.co_name, prompt=prompt, - completion=completion_result.original_completion, + final_completion=completion_result.original_completion, is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(), is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(), is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(), diff --git a/tests/unit/test_domain.py b/tests/unit/test_domain.py index 531f57d36..72971e765 100644 --- a/tests/unit/test_domain.py +++ b/tests/unit/test_domain.py @@ -1,13 +1,276 @@ +""" +pytest -v tests/integration/test_domain.py +""" + import pytest +from unittest.mock import MagicMock + from src.text_generation.domain.guardrails_result import GuardrailsResult +from src.text_generation.domain.guidelines_result import GuidelinesResult +from src.text_generation.domain.original_completion_result import OriginalCompletionResult +from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult -@pytest.mark.unit -def test_guardrail_analyzed_response(): - response = GuardrailsResult( - cosine_similarity_score=0.72839, - cosine_similarity_risk_threshold=0.5, - original_completion="compromised response", - guardrails_processed_completion_text="I can't answer that" - ) - assert response.is_original_completion_malicious == True \ No newline at end of file +class TestTextGenerationCompletionResult: + """Test suite for TextGenerationCompletionResult and related classes.""" + + @pytest.fixture + def sample_llm_config(self): + """Sample LLM configuration for testing.""" + return { + "model": "test-model", + "temperature": 0.7, + "max_tokens": 100 + } + + @pytest.fixture + def sample_full_prompt(self): + """Sample full prompt for testing.""" + return { + "system": "You are a helpful assistant", + "user": "Test prompt" + } + + def test_original_result_only(self, sample_llm_config, sample_full_prompt): + """Test TextGenerationCompletionResult with only original result filled in.""" + # Arrange + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.5 + original_result.cosine_similarity_risk_threshold=0.7 + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=None, + guardrails_result=None + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "Original completion text" + assert original_result.user_prompt == "Test user prompt" + assert original_result.llm_config == sample_llm_config + assert not original_result.is_completion_malicious() # 0.5 < 0.7 + + def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt): + """Test that guidelines result takes precedence over original when both are present.""" + # Arrange + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.8 + original_result.cosine_similarity_risk_threshold=0.7 + + guidelines_result = GuidelinesResult( + user_prompt="Test user prompt", + completion_text="Guidelines processed completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + guidelines_result.cosine_similarity_score=0.6 + guidelines_result.cosine_similarity_risk_threshold=0.7 + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=guidelines_result, + guardrails_result=None + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "Guidelines processed completion text" + assert original_result.is_completion_malicious() # 0.8 >= 0.7 + assert not guidelines_result.is_completion_malicious() # 0.6 < 0.7 + + def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt): + """Test that guardrails result takes precedence when all three are present.""" + # Arrange + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.9 + original_result.cosine_similarity_risk_threshold=0.7 + + guidelines_result = GuidelinesResult( + user_prompt="Test user prompt", + completion_text="Guidelines processed completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + guidelines_result.cosine_similarity_score=0.8 + guidelines_result.cosine_similarity_risk_threshold=0.7 + + guardrails_result = GuardrailsResult( + user_prompt="Test user prompt", + completion_text="Guardrails processed completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=guidelines_result, + guardrails_result=guardrails_result + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "Guardrails processed completion text" + assert original_result.is_completion_malicious() # 0.9 >= 0.7 + assert guidelines_result.is_completion_malicious() # 0.8 >= 0.7 + + def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt): + """Test fallback behavior when some completion texts are empty.""" + # Arrange - guardrails has empty text, should fall back to guidelines + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.5 + original_result.cosine_similarity_risk_threshold=0.7 + + guidelines_result = GuidelinesResult( + user_prompt="Test user prompt", + completion_text="Guidelines processed completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + guidelines_result.cosine_similarity_score=0.6 + guidelines_result.cosine_similarity_risk_threshold=0.7 + + guardrails_result = GuardrailsResult( + user_prompt="Test user prompt", + completion_text="", # Empty string + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=guidelines_result, + guardrails_result=guardrails_result + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "Guidelines processed completion text" + + def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt): + """Test fallback behavior when completion text is only whitespace.""" + # Arrange + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.5 + original_result.cosine_similarity_risk_threshold=0.7 + + guidelines_result = GuidelinesResult( + user_prompt="Test user prompt", + completion_text=" \n\t ", # Whitespace only + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + guidelines_result.cosine_similarity_score=0.6 + guidelines_result.cosine_similarity_risk_threshold=0.7 + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=guidelines_result, + guardrails_result=None + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "Original completion text" + + def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt): + """Test that empty string is returned when all completions are empty.""" + # Arrange + original_result = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result.cosine_similarity_score=0.5 + original_result.cosine_similarity_risk_threshold=0.7 + + completion_result = TextGenerationCompletionResult( + original_result=original_result, + guidelines_result=None, + guardrails_result=None + ) + + # Act + completion_result.finalize_completion_text() + final_text = completion_result.final_completion_text + + # Assert + assert final_text == "" + + def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt): + """Test boundary conditions for risk threshold (exactly 0.7).""" + # Test score exactly at threshold + original_result_at_threshold = OriginalCompletionResult( + user_prompt="Test user prompt", + completion_text="Original completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config + ) + + original_result_at_threshold.cosine_similarity_score=0.7 + original_result_at_threshold.cosine_similarity_risk_threshold=0.7 + + # Test score just below threshold + guidelines_result_below_threshold = GuidelinesResult( + user_prompt="Test user prompt", + completion_text="Guidelines completion text", + full_prompt=sample_full_prompt, + llm_config=sample_llm_config, + cosine_similarity_score=0.699999, # Just below threshold + cosine_similarity_risk_threshold=0.7 + ) + + # Assert + assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7 + assert not guidelines_result_below_threshold.is_completion_malicious() # 0.699999 < 0.7 \ No newline at end of file