log full completion result with semantic similarity comparison results

This commit is contained in:
Adam Wilson
2025-07-28 11:49:07 -06:00
parent b971df0a7a
commit df14a01fe9
7 changed files with 29 additions and 41 deletions
@@ -1,27 +1,26 @@
from typing import Any
from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
class GuardrailsResult:
"""
Represents the result of guardrails processing without cosine similarity scoring.
"""
def __init__(
self,
user_prompt: str,
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict,
cosine_similarity_score: float = -1.0,
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.guardrails_completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult):
self.semantic_similarity_result = semantic_similarity_result
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold
@@ -1,11 +1,9 @@
from typing import Any, List
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
class GuidelinesResult(AbstractGuidelinesProcessedCompletion):
"""
Represents the result of guidelines processing with cosine similarity scoring.
"""
def __init__(
self,
@@ -13,15 +11,16 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion):
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict,
cosine_similarity_score: float = -1.0,
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.guidelines_completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult):
self.semantic_similarity_result = semantic_similarity_result
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold
@@ -47,10 +47,10 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
# Check guardrails_result.completion_text next
if (self.guardrails_result and
self.guardrails_result.completion_text and
self.guardrails_result.completion_text.strip()
self.guardrails_result.guardrails_completion_text and
self.guardrails_result.guardrails_completion_text.strip()
):
self.final_completion_text = self.guardrails_result.completion_text
self.final_completion_text = self.guardrails_result.guardrails_completion_text
print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"')
return
@@ -42,7 +42,7 @@ class ReflexionSecurityGuardrailsService(
self.llm_configuration_introspection_service = llm_configuration_introspection_service
# constants
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.75
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.8
# set up semantic similarity service and supporting texts
self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions()
@@ -128,7 +128,7 @@ class ReflexionSecurityGuardrailsService(
# update completion result with similarity scoring threshold and result
result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
result.guardrails_result.cosine_similarity_score = similarity_result.max
result.guardrails_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result)
# return raw result if the completion comparison score didn't exceed threshold
if not result.guardrails_result.is_completion_malicious():
@@ -7,6 +7,7 @@ import time
from datetime import datetime
from typing import Any, Dict, List
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.services.logging.abstract_test_run_logging_service import AbstractTestRunLoggingService
@@ -36,22 +37,18 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
def log_results(
self,
id: str,
prompt: str,
final_completion: str,
text_generation_completion_result: TextGenerationCompletionResult,
final_completion_text_score: float,
is_rag_few_shot_enabled: bool,
is_cot_enabled: bool,
is_reflexion_enabled: bool,
original_llm_config: dict,
scores: List[float],
mean: float,
max: float):
original_llm_config: dict):
with self._lock:
logs = self._read_logs()
log_entry = {
"id": id,
"timestamp": datetime.now().isoformat(),
"original_prompt": prompt,
"final_completion": final_completion,
"final_completion_text_score": final_completion_text_score,
"mitigations_enabled": {
"guidelines": {
"rag_with_few_shot_examples": is_rag_few_shot_enabled,
@@ -62,11 +59,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
}
},
"original_llm_config": original_llm_config,
"cosine_similarity": {
"mean": mean,
"max": max,
"all_scores": scores
}
"text_generation_completion_result": text_generation_completion_result.__dict__
}
logs.append(log_entry)
self._write_logs(logs)
@@ -57,7 +57,7 @@ class TextGenerationCompletionService(
self.reflexion_guardrails: AbstractGeneratedTextGuardrailService = reflexion_guardrails
# constants
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.75
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.8
# default guidelines settings
self._use_guidelines = False
@@ -121,8 +121,8 @@ class TextGenerationCompletionService(
# update completion result with similarity scoring threshold and result
completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
completion_result.guidelines_result.cosine_similarity_score = similarity_result.max
completion_result.guidelines_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result)
# return raw result if the completion comparison score didn't exceed threshold
if not completion_result.guidelines_result.is_completion_malicious():
print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
+3 -6
View File
@@ -53,15 +53,12 @@ def run_prompt_analysis_test(
TestRunLoggingService(test_id=test_id).log_results(
id=inspect.currentframe().f_back.f_code.co_name,
prompt=prompt,
final_completion=completion_result.final_completion_text,
text_generation_completion_result=completion_result,
final_completion_text_score=result.max,
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
original_llm_config=completion_result.original_result.llm_config,
scores=result.scores,
mean=result.mean,
max=result.max
original_llm_config=completion_result.original_result.llm_config
)
results.append(result)