diff --git a/src/text_generation/domain/guardrails_result.py b/src/text_generation/domain/guardrails_result.py index 5233d0c08..1ba920158 100644 --- a/src/text_generation/domain/guardrails_result.py +++ b/src/text_generation/domain/guardrails_result.py @@ -1,27 +1,26 @@ from typing import Any from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion +from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult class GuardrailsResult: - """ - Represents the result of guardrails processing without cosine similarity scoring. - """ - + def __init__( self, user_prompt: str, completion_text: str, full_prompt: dict[str, Any], llm_config: dict, - cosine_similarity_score: float = -1.0, cosine_similarity_risk_threshold: float = 0.0 ): self.user_prompt = user_prompt - self.completion_text = completion_text + self.guardrails_completion_text = completion_text self.full_prompt = full_prompt self.llm_config = llm_config - self.cosine_similarity_score = cosine_similarity_score self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold + def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult): + self.semantic_similarity_result = semantic_similarity_result + def is_completion_malicious(self) -> bool: - return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold \ No newline at end of file + return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/domain/guidelines_result.py b/src/text_generation/domain/guidelines_result.py index a4e473d7e..84f7e2d65 100644 --- a/src/text_generation/domain/guidelines_result.py +++ b/src/text_generation/domain/guidelines_result.py @@ -1,11 +1,9 @@ from typing import Any, List from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion +from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult class GuidelinesResult(AbstractGuidelinesProcessedCompletion): - """ - Represents the result of guidelines processing with cosine similarity scoring. - """ def __init__( self, @@ -13,15 +11,16 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion): completion_text: str, full_prompt: dict[str, Any], llm_config: dict, - cosine_similarity_score: float = -1.0, cosine_similarity_risk_threshold: float = 0.0 ): self.user_prompt = user_prompt self.guidelines_completion_text = completion_text self.full_prompt = full_prompt self.llm_config = llm_config - self.cosine_similarity_score = cosine_similarity_score self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold + def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult): + self.semantic_similarity_result = semantic_similarity_result + def is_completion_malicious(self) -> bool: - return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold \ No newline at end of file + return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/domain/text_generation_completion_result.py b/src/text_generation/domain/text_generation_completion_result.py index d6d130b5f..dcb801317 100644 --- a/src/text_generation/domain/text_generation_completion_result.py +++ b/src/text_generation/domain/text_generation_completion_result.py @@ -47,10 +47,10 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult): # Check guardrails_result.completion_text next if (self.guardrails_result and - self.guardrails_result.completion_text and - self.guardrails_result.completion_text.strip() + self.guardrails_result.guardrails_completion_text and + self.guardrails_result.guardrails_completion_text.strip() ): - self.final_completion_text = self.guardrails_result.completion_text + self.final_completion_text = self.guardrails_result.guardrails_completion_text print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"') return diff --git a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py index 9cd58ec69..2dbd8d853 100644 --- a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py +++ b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py @@ -42,7 +42,7 @@ class ReflexionSecurityGuardrailsService( self.llm_configuration_introspection_service = llm_configuration_introspection_service # constants - self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.75 + self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.8 # set up semantic similarity service and supporting texts self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions() @@ -128,7 +128,7 @@ class ReflexionSecurityGuardrailsService( # update completion result with similarity scoring threshold and result result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD - result.guardrails_result.cosine_similarity_score = similarity_result.max + result.guardrails_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result) # return raw result if the completion comparison score didn't exceed threshold if not result.guardrails_result.is_completion_malicious(): diff --git a/src/text_generation/services/logging/test_run_logging_service.py b/src/text_generation/services/logging/test_run_logging_service.py index ced82ee4c..6ef82a126 100644 --- a/src/text_generation/services/logging/test_run_logging_service.py +++ b/src/text_generation/services/logging/test_run_logging_service.py @@ -7,6 +7,7 @@ import time from datetime import datetime from typing import Any, Dict, List +from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult from src.text_generation.services.logging.abstract_test_run_logging_service import AbstractTestRunLoggingService @@ -36,22 +37,18 @@ class TestRunLoggingService(AbstractTestRunLoggingService): def log_results( self, id: str, - prompt: str, - final_completion: str, + text_generation_completion_result: TextGenerationCompletionResult, + final_completion_text_score: float, is_rag_few_shot_enabled: bool, is_cot_enabled: bool, is_reflexion_enabled: bool, - original_llm_config: dict, - scores: List[float], - mean: float, - max: float): + original_llm_config: dict): with self._lock: logs = self._read_logs() log_entry = { "id": id, "timestamp": datetime.now().isoformat(), - "original_prompt": prompt, - "final_completion": final_completion, + "final_completion_text_score": final_completion_text_score, "mitigations_enabled": { "guidelines": { "rag_with_few_shot_examples": is_rag_few_shot_enabled, @@ -62,11 +59,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService): } }, "original_llm_config": original_llm_config, - "cosine_similarity": { - "mean": mean, - "max": max, - "all_scores": scores - } + "text_generation_completion_result": text_generation_completion_result.__dict__ } logs.append(log_entry) self._write_logs(logs) diff --git a/src/text_generation/services/nlp/text_generation_completion_service.py b/src/text_generation/services/nlp/text_generation_completion_service.py index 8e561b356..cf7708a6c 100644 --- a/src/text_generation/services/nlp/text_generation_completion_service.py +++ b/src/text_generation/services/nlp/text_generation_completion_service.py @@ -57,7 +57,7 @@ class TextGenerationCompletionService( self.reflexion_guardrails: AbstractGeneratedTextGuardrailService = reflexion_guardrails # constants - self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.75 + self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.8 # default guidelines settings self._use_guidelines = False @@ -121,8 +121,8 @@ class TextGenerationCompletionService( # update completion result with similarity scoring threshold and result completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD - completion_result.guidelines_result.cosine_similarity_score = similarity_result.max - + completion_result.guidelines_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result) + # return raw result if the completion comparison score didn't exceed threshold if not completion_result.guidelines_result.is_completion_malicious(): print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}') diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 26dded558..0ee2f1fda 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -53,15 +53,12 @@ def run_prompt_analysis_test( TestRunLoggingService(test_id=test_id).log_results( id=inspect.currentframe().f_back.f_code.co_name, - prompt=prompt, - final_completion=completion_result.final_completion_text, + text_generation_completion_result=completion_result, + final_completion_text_score=result.max, is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(), is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(), is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(), - original_llm_config=completion_result.original_result.llm_config, - scores=result.scores, - mean=result.mean, - max=result.max + original_llm_config=completion_result.original_result.llm_config ) results.append(result)