mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-02-27 06:14:12 +00:00
don't overprocess results
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
|
||||
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
|
||||
|
||||
|
||||
class OriginalCompletionResult(AbstractTextGenerationCompletionResult):
|
||||
@@ -18,8 +19,10 @@ class OriginalCompletionResult(AbstractTextGenerationCompletionResult):
|
||||
self.completion_text = completion_text
|
||||
self.full_prompt = full_prompt
|
||||
self.llm_config = llm_config
|
||||
self.cosine_similarity_score: float = -1.0
|
||||
self.cosine_similarity_risk_threshold: float = 0.0
|
||||
|
||||
def append_semantic_similarity_result(self, semantic_similarity_result: SemanticSimilarityResult):
|
||||
self.semantic_similarity_result = semantic_similarity_result
|
||||
|
||||
def is_completion_malicious(self) -> bool:
|
||||
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
|
||||
return self.semantic_similarity_result.max >= self.cosine_similarity_risk_threshold
|
||||
@@ -100,6 +100,11 @@ class ReflexionSecurityGuardrailsService(
|
||||
|
||||
try:
|
||||
result: TextGenerationCompletionResult = text_generation_completion_result
|
||||
|
||||
# if previous completions were scored below risk threshold, return as-is (don't apply guardrails)
|
||||
if not result.original_result.is_completion_malicious() and not (result.guidelines_result and result.guidelines_result.is_completion_malicious()):
|
||||
return result
|
||||
|
||||
original_user_prompt = result.original_result.user_prompt
|
||||
|
||||
prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result)
|
||||
@@ -133,6 +138,8 @@ class ReflexionSecurityGuardrailsService(
|
||||
# return raw result if the completion comparison score didn't exceed threshold
|
||||
if not result.guardrails_result.is_completion_malicious():
|
||||
print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.semantic_similarity_result.max}')
|
||||
if result.alternate_result:
|
||||
result.alternate_result = None
|
||||
return result
|
||||
|
||||
# provide the finalized alternate (refuse to answer)
|
||||
|
||||
@@ -112,13 +112,13 @@ class TextGenerationCompletionService(
|
||||
text = completion_result.final_completion_text
|
||||
)
|
||||
|
||||
# the completion is a result of no guidelines applied
|
||||
if not completion_result.guidelines_result:
|
||||
completion_result.guidelines_result = GuidelinesResult(
|
||||
user_prompt=completion_result.original_result.user_prompt,
|
||||
completion_text=completion_result.original_result.completion_text,
|
||||
llm_config=completion_result.original_result.llm_config
|
||||
)
|
||||
# just return the original
|
||||
completion_result.original_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result)
|
||||
return completion_result
|
||||
|
||||
# completion came from guidelines-enabled service:
|
||||
# update completion result with similarity scoring threshold and result
|
||||
completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
|
||||
completion_result.guidelines_result.append_semantic_similarity_result(semantic_similarity_result=similarity_result)
|
||||
|
||||
Reference in New Issue
Block a user