diff --git a/src/text_generation/domain/original_completion_result.py b/src/text_generation/domain/original_completion_result.py index 23f8aeeb9..14c240ccb 100644 --- a/src/text_generation/domain/original_completion_result.py +++ b/src/text_generation/domain/original_completion_result.py @@ -1,5 +1,6 @@ from typing import Any from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult +from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult class OriginalCompletionResult(AbstractTextGenerationCompletionResult): @@ -18,8 +19,10 @@ class OriginalCompletionResult(AbstractTextGenerationCompletionResult): self.completion_text = completion_text self.full_prompt = full_prompt self.llm_config = llm_config - self.cosine_similarity_score: float = -1.0 self.cosine_similarity_risk_threshold: float = 0.0 + def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult): + self.semantic_similarity_result = semantic_similarity_result + def is_completion_malicious(self) -> bool: - return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold \ No newline at end of file + return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py index cb9e3074b..22fadf367 100644 --- a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py +++ b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py @@ -100,6 +100,11 @@ class ReflexionSecurityGuardrailsService( try: result: TextGenerationCompletionResult = text_generation_completion_result + + # if previous completions were scored below risk threshold, return as-is (don't apply guardrails) + if not result.original_result.is_completion_malicious() and not (result.guidelines_result and result.guidelines_result.is_completion_malicious()): + return result + original_user_prompt = result.original_result.user_prompt prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result) @@ -133,6 +138,8 @@ class ReflexionSecurityGuardrailsService( # return raw result if the completion comparison score didn't exceed threshold if not result.guardrails_result.is_completion_malicious(): print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.semantic_similarity_result.max}') + if result.alternate_result: + result.alternate_result = None return result # provide the finalized alternate (refuse to answer) diff --git a/src/text_generation/services/nlp/text_generation_completion_service.py b/src/text_generation/services/nlp/text_generation_completion_service.py index fb98c53f0..bae834feb 100644 --- a/src/text_generation/services/nlp/text_generation_completion_service.py +++ b/src/text_generation/services/nlp/text_generation_completion_service.py @@ -112,13 +112,13 @@ class TextGenerationCompletionService( text = completion_result.final_completion_text ) + # the completion is a result of no guidelines applied if not completion_result.guidelines_result: - completion_result.guidelines_result = GuidelinesResult( - user_prompt=completion_result.original_result.user_prompt, - completion_text=completion_result.original_result.completion_text, - llm_config=completion_result.original_result.llm_config - ) + # just return the original + completion_result.original_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result) + return completion_result + # completion came from guidelines-enabled service: # update completion result with similarity scoring threshold and result completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD completion_result.guidelines_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result)