mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-02-12 22:52:51 +00:00
more updates for reflexion
This commit is contained in:
@@ -12,9 +12,16 @@ class GuardrailsResult:
|
||||
user_prompt: str,
|
||||
completion_text: str,
|
||||
full_prompt: dict[str, Any],
|
||||
llm_config: dict
|
||||
llm_config: dict,
|
||||
cosine_similarity_score: float = -1.0,
|
||||
cosine_similarity_risk_threshold: float = 0.0
|
||||
):
|
||||
self.user_prompt = user_prompt
|
||||
self.completion_text = completion_text
|
||||
self.full_prompt = full_prompt
|
||||
self.llm_config = llm_config
|
||||
self.llm_config = llm_config
|
||||
self.cosine_similarity_score = cosine_similarity_score
|
||||
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
|
||||
|
||||
def is_completion_malicious(self) -> bool:
|
||||
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
|
||||
@@ -17,11 +17,11 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion):
|
||||
cosine_similarity_risk_threshold: float = 0.0
|
||||
):
|
||||
self.user_prompt = user_prompt
|
||||
self.completion_text = completion_text
|
||||
self.guidelines_completion_text = completion_text
|
||||
self.full_prompt = full_prompt
|
||||
self.llm_config = llm_config
|
||||
self.cosine_similarity_score = cosine_similarity_score
|
||||
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
|
||||
|
||||
def is_original_completion_malicious(self) -> bool:
|
||||
def is_completion_malicious(self) -> bool:
|
||||
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
|
||||
@@ -17,38 +17,62 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
|
||||
self,
|
||||
original_result: OriginalCompletionResult,
|
||||
guidelines_result: Optional[GuidelinesResult] = None,
|
||||
guardrails_result: Optional[GuardrailsResult] = None
|
||||
guardrails_result: Optional[GuardrailsResult] = None,
|
||||
alternate_result: Optional[AlternateCompletionResult] = None
|
||||
):
|
||||
self.original_result = original_result
|
||||
self.guidelines_result = guidelines_result
|
||||
self.guardrails_result = guardrails_result
|
||||
self.alternate_result = alternate_result
|
||||
self.final_completion_text = ''
|
||||
|
||||
def finalize(self) -> str:
|
||||
def finalize_completion_text(self) -> str:
|
||||
"""
|
||||
Returns the final completion text based on priority order:
|
||||
Returns the current completion text based on priority order:
|
||||
1. guardrails_result.completion_text (if not empty)
|
||||
2. guidelines_result.completion_text (if not empty)
|
||||
3. original_result.completion_text (if not empty)
|
||||
"""
|
||||
|
||||
# Check guardrails_result.completion_text first
|
||||
|
||||
print(f'Finalized text was \"{self.final_completion_text}\"')
|
||||
|
||||
# Check alternate text first
|
||||
if (self.alternate_result and
|
||||
self.alternate_result.alterate_completion_text and
|
||||
self.alternate_result.alterate_completion_text.strip()
|
||||
):
|
||||
self.final_completion_text = self.alternate_result.alterate_completion_text
|
||||
print(f'Using alternate result. Finalized text is now \"{self.final_completion_text}\"')
|
||||
return
|
||||
|
||||
# Check guardrails_result.completion_text next
|
||||
if (self.guardrails_result and
|
||||
self.guardrails_result.completion_text and
|
||||
self.guardrails_result.completion_text.strip()):
|
||||
return self.guardrails_result.completion_text
|
||||
self.guardrails_result.completion_text.strip()
|
||||
):
|
||||
self.final_completion_text = self.guardrails_result.completion_text
|
||||
print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"')
|
||||
return
|
||||
|
||||
# Fall back to guidelines_result.completion_text
|
||||
if (self.guidelines_result and
|
||||
self.guidelines_result.completion_text and
|
||||
self.guidelines_result.completion_text.strip()):
|
||||
return self.guidelines_result.completion_text
|
||||
|
||||
self.guidelines_result.guidelines_completion_text and
|
||||
self.guidelines_result.guidelines_completion_text.strip()
|
||||
):
|
||||
self.final_completion_text = self.guidelines_result.guidelines_completion_text
|
||||
print(f'Using guidelines result. Finalized text is now \"{self.final_completion_text}\"')
|
||||
return
|
||||
|
||||
# Fall back to original_result.completion_text
|
||||
if (self.original_result and
|
||||
self.original_result.completion_text and
|
||||
self.original_result.completion_text.strip()):
|
||||
return self.original_result.completion_text
|
||||
self.original_result.completion_text.strip()
|
||||
):
|
||||
self.final_completion_text = self.original_result.completion_text
|
||||
print(f'Using original. Finalized text is now \"{self.final_completion_text}\"')
|
||||
return
|
||||
|
||||
# If all are empty, return empty string
|
||||
return ""
|
||||
|
||||
self.final_completion_text = ""
|
||||
print(f'Finalized text is now \"{self.final_completion_text}\"')
|
||||
return
|
||||
@@ -4,12 +4,16 @@ from langchain_core.prompts import PromptTemplate, StringPromptTemplate
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
|
||||
from src.text_generation.domain.guardrails_result import GuardrailsResult
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.nlp.abstract_semantic_similarity_service import AbstractSemanticSimilarityService
|
||||
from src.text_generation.services.prompt_injection.abstract_prompt_injection_example_service import AbstractPromptInjectionExampleService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
@@ -25,6 +29,8 @@ class ReflexionSecurityGuardrailsService(
|
||||
def __init__(
|
||||
self,
|
||||
foundation_model: AbstractFoundationModel,
|
||||
semantic_similarity_service: AbstractSemanticSimilarityService,
|
||||
prompt_injection_example_service: AbstractPromptInjectionExampleService,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService):
|
||||
@@ -35,11 +41,21 @@ class ReflexionSecurityGuardrailsService(
|
||||
self.prompt_template_service = prompt_template_service
|
||||
self.llm_configuration_introspection_service = llm_configuration_introspection_service
|
||||
|
||||
# constants
|
||||
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.5
|
||||
|
||||
# set up semantic similarity service and supporting texts
|
||||
self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions()
|
||||
self.example_prompt_injection_prompts = prompt_injection_example_service.get_all_prompts()
|
||||
self.semantic_similarity_service = semantic_similarity_service
|
||||
self.semantic_similarity_service.use_comparison_texts(
|
||||
self.example_prompt_injection_completions
|
||||
)
|
||||
|
||||
def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str:
|
||||
result: TextGenerationCompletionResult = text_generation_completion_result
|
||||
original_user_prompt = result.original_user_prompt
|
||||
original_completion = result.original_completion
|
||||
original_user_prompt = result.original_result.user_prompt
|
||||
original_completion = result.original_result.completion_text
|
||||
|
||||
# Assemble the context showing the original prompt injection attack
|
||||
context_parts = [
|
||||
@@ -67,7 +83,13 @@ class ReflexionSecurityGuardrailsService(
|
||||
return filled_template
|
||||
|
||||
def _create_chain(self, prompt_template: StringPromptTemplate):
|
||||
return prompt_template | self.foundation_model_pipeline | StrOutputParser()
|
||||
# return prompt_template | self.foundation_model_pipeline | StrOutputParser()
|
||||
return (
|
||||
prompt_template
|
||||
| self.foundation_model_pipeline
|
||||
| StrOutputParser()
|
||||
| self.response_processing_service.process_text_generation_output
|
||||
)
|
||||
|
||||
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
|
||||
"""
|
||||
@@ -78,7 +100,7 @@ class ReflexionSecurityGuardrailsService(
|
||||
|
||||
try:
|
||||
result: TextGenerationCompletionResult = text_generation_completion_result
|
||||
original_user_prompt = result.original_user_prompt
|
||||
original_user_prompt = result.original_result.user_prompt
|
||||
|
||||
prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result)
|
||||
prompt_value: PromptValue = prompt_template.format_prompt(**{self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
|
||||
@@ -95,11 +117,30 @@ class ReflexionSecurityGuardrailsService(
|
||||
completion_text = chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
|
||||
result.guardrails_processed_completion = GuardrailsResult(
|
||||
result.guardrails_result = GuardrailsResult(
|
||||
user_prompt=original_user_prompt,
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
)
|
||||
|
||||
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(text=completion_text)
|
||||
|
||||
# update completion result with similarity scoring threshold and result
|
||||
result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
|
||||
result.guardrails_result.cosine_similarity_score = similarity_result.max
|
||||
|
||||
# return raw result if the completion comparison score didn't exceed threshold
|
||||
if not result.guardrails_result.is_completion_malicious():
|
||||
print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.cosine_similarity_score}')
|
||||
return result
|
||||
|
||||
# provide the finalized alternate (refuse to answer)
|
||||
print(f'Guardrails-based completion was malicious. Score: {result.guardrails_result.cosine_similarity_score}')
|
||||
result.alternate_result = AlternateCompletionResult(
|
||||
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
|
||||
)
|
||||
result.finalize_completion_text()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,6 +8,8 @@ from langchain.prompts import FewShotPromptTemplate
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
@@ -74,13 +76,21 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
|
||||
chain = self._create_chain(prompt_template)
|
||||
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
|
||||
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
user_prompt=user_prompt,
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
|
||||
result = TextGenerationCompletionResult(
|
||||
original_result=OriginalCompletionResult(
|
||||
user_prompt=user_prompt,
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
),
|
||||
guidelines_result=GuidelinesResult(
|
||||
user_prompt=user_prompt,
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
)
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -36,12 +36,12 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
|
||||
def log_results(
|
||||
self,
|
||||
id: str,
|
||||
text_generation_result: str,
|
||||
completion: str,
|
||||
prompt: str,
|
||||
final_completion: str,
|
||||
is_rag_few_shot_enabled: bool,
|
||||
is_cot_enabled: bool,
|
||||
is_reflexion_enabled: bool,
|
||||
llm_config: dict,
|
||||
original_llm_config: dict,
|
||||
scores: List[float],
|
||||
mean: float,
|
||||
max: float):
|
||||
@@ -50,8 +50,8 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
|
||||
log_entry = {
|
||||
"id": id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"prompt": prompt,
|
||||
"completion": completion,
|
||||
"original_prompt": prompt,
|
||||
"final_completion": final_completion,
|
||||
"mitigations_enabled": {
|
||||
"guidelines": {
|
||||
"rag_with_few_shot_examples": is_rag_few_shot_enabled,
|
||||
@@ -61,7 +61,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
|
||||
"reflexion": is_reflexion_enabled
|
||||
}
|
||||
},
|
||||
"llm_config": llm_config,
|
||||
"original_llm_config": original_llm_config,
|
||||
"cosine_similarity": {
|
||||
"mean": mean,
|
||||
"max": max,
|
||||
|
||||
@@ -6,6 +6,7 @@ from langchain_huggingface import HuggingFacePipeline
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
|
||||
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
|
||||
@@ -84,12 +85,15 @@ class TextGenerationCompletionService(
|
||||
self._use_rag_context
|
||||
)
|
||||
guidelines_handler = self.guidelines_strategy_map.get(
|
||||
guidelines_config,
|
||||
guidelines_config,
|
||||
|
||||
# fall back to unfiltered LLM invocation
|
||||
self._handle_without_guidelines
|
||||
)
|
||||
return guidelines_handler(user_prompt)
|
||||
|
||||
def _process_guidelines_result(self, guidelines_result: GuidelinesResult) -> TextGenerationCompletionResult:
|
||||
|
||||
|
||||
def _process_completion_result(self, completion_result: TextGenerationCompletionResult) -> TextGenerationCompletionResult:
|
||||
"""
|
||||
Process guidelines result and create completion result with semantic similarity check.
|
||||
|
||||
@@ -100,26 +104,36 @@ class TextGenerationCompletionService(
|
||||
TextGenerationCompletionResult with appropriate completion text
|
||||
"""
|
||||
|
||||
# analyze the current version of the completion text against prompt injection completions;
|
||||
# if guidelines applied, this is the result of completion using guidelines;
|
||||
# otherwise it is the raw completion text without guidelines
|
||||
completion_result.finalize_completion_text()
|
||||
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(
|
||||
text = guidelines_result.completion_text
|
||||
text = completion_result.final_completion_text
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
|
||||
guidelines_result.cosine_similarity_score = similarity_result.mean
|
||||
if not completion_result.guidelines_result:
|
||||
completion_result.guidelines_result = GuidelinesResult(
|
||||
user_prompt=completion_result.original_result.user_prompt,
|
||||
completion_text=completion_result.original_result.completion_text,
|
||||
llm_config=completion_result.original_result.llm_config
|
||||
)
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
llm_config = guidelines_result.llm_config,
|
||||
original_completion = guidelines_result.completion_text,
|
||||
original_user_prompt = guidelines_result.user_prompt,
|
||||
guidelines_result = guidelines_result
|
||||
)
|
||||
# update completion result with similarity scoring threshold and result
|
||||
completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
|
||||
completion_result.guidelines_result.cosine_similarity_score = similarity_result.max
|
||||
|
||||
if not guidelines_result.is_original_completion_malicious():
|
||||
# return raw result if the completion comparison score didn't exceed threshold
|
||||
if not completion_result.guidelines_result.is_completion_malicious():
|
||||
print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
|
||||
return completion_result
|
||||
|
||||
# provide the finalized alternate (refuse to answer)
|
||||
print(f'Guidelines-based completion was malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
|
||||
completion_result.alternate_result = AlternateCompletionResult(
|
||||
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
|
||||
)
|
||||
completion_result.finalize_completion_text()
|
||||
return completion_result
|
||||
|
||||
|
||||
@@ -127,28 +141,31 @@ class TextGenerationCompletionService(
|
||||
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
|
||||
"""Handle: CoT=True, RAG=True"""
|
||||
guidelines_result = self.rag_plus_cot_guidelines.apply_guidelines(user_prompt)
|
||||
return self._process_guidelines_result(guidelines_result)
|
||||
return self._process_completion_result(guidelines_result)
|
||||
|
||||
def _handle_cot_only(self, user_prompt: str) -> TextGenerationCompletionResult:
|
||||
"""Handle: CoT=True, RAG=False"""
|
||||
guidelines_result = self.chain_of_thought_guidelines.apply_guidelines(user_prompt)
|
||||
return self._process_guidelines_result(guidelines_result)
|
||||
return self._process_completion_result(guidelines_result)
|
||||
|
||||
def _handle_rag_only(self, user_prompt: str) -> TextGenerationCompletionResult:
|
||||
"""Handle: CoT=False, RAG=True"""
|
||||
guidelines_result = self.rag_context_guidelines.apply_guidelines(user_prompt)
|
||||
return self._process_guidelines_result(guidelines_result)
|
||||
return self._process_completion_result(guidelines_result)
|
||||
|
||||
def _handle_without_guidelines(self, user_prompt: str) -> TextGenerationCompletionResult:
|
||||
"""Handle: CoT=False, RAG=False"""
|
||||
try:
|
||||
chain = self._create_chain_without_guidelines()
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
completion_text = chain.invoke(user_prompt),
|
||||
llm_config = llm_config
|
||||
)
|
||||
return self._process_guidelines_result(result)
|
||||
|
||||
result = TextGenerationCompletionResult(
|
||||
original_result=OriginalCompletionResult(
|
||||
user_prompt=user_prompt,
|
||||
completion_text=chain.invoke(user_prompt),
|
||||
llm_config=llm_config
|
||||
))
|
||||
return self._process_completion_result(result)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -216,6 +233,8 @@ class TextGenerationCompletionService(
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
print(f'Using guidelines: {self.get_current_config()}')
|
||||
completion_result: TextGenerationCompletionResult = self._process_prompt_with_guidelines_if_applicable(user_prompt)
|
||||
|
||||
if not self._use_reflexion_guardrails:
|
||||
return completion_result
|
||||
|
||||
return self._handle_reflexion_guardrails(completion_result)
|
||||
|
||||
@@ -178,11 +178,15 @@ def prompt_injection_example_service(prompt_injection_example_repository):
|
||||
@pytest.fixture(scope="session")
|
||||
def reflexion_guardrails(
|
||||
foundation_model,
|
||||
semantic_similarity_service,
|
||||
prompt_injection_example_service,
|
||||
response_processing_service,
|
||||
prompt_template_service,
|
||||
llm_configuration_introspection_service):
|
||||
return ReflexionSecurityGuardrailsService(
|
||||
foundation_model=foundation_model,
|
||||
semantic_similarity_service=semantic_similarity_service,
|
||||
prompt_injection_example_service=prompt_injection_example_service,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
"""
|
||||
pytest -v tests/integration/test_domain.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from src.text_generation.domain.guardrails_result import GuardrailsResult
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
|
||||
|
||||
class TestTextGenerationCompletionResult:
|
||||
"""Test suite for TextGenerationCompletionResult and related classes."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_config(self):
|
||||
"""Sample LLM configuration for testing."""
|
||||
return {
|
||||
"model": "test-model",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_full_prompt(self):
|
||||
"""Sample full prompt for testing."""
|
||||
return {
|
||||
"system": "You are a helpful assistant",
|
||||
"user": "Test prompt"
|
||||
}
|
||||
|
||||
def test_original_result_only(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test TextGenerationCompletionResult with only original result filled in."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=None,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == "Original completion text"
|
||||
assert original_result.user_prompt == "Test user prompt"
|
||||
assert original_result.llm_config == sample_llm_config
|
||||
assert not original_result.is_completion_malicious() # 0.5 < 0.7
|
||||
|
||||
def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that guidelines result takes precedence over original when both are present."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.8
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guidelines processed completion text"
|
||||
assert original_result.is_completion_malicious() # 0.8 >= 0.7
|
||||
assert not guidelines_result.is_original_completion_malicious() # 0.6 < 0.7
|
||||
|
||||
def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that guardrails result takes precedence when all three are present."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.9
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.8
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guardrails_result = GuardrailsResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guardrails processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=guardrails_result
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guardrails processed completion text"
|
||||
assert original_result.is_completion_malicious() # 0.9 >= 0.7
|
||||
assert guidelines_result.is_original_completion_malicious() # 0.8 >= 0.7
|
||||
|
||||
def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test fallback behavior when some completion texts are empty."""
|
||||
# Arrange - guardrails has empty text, should fall back to guidelines
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guardrails_result = GuardrailsResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="", # Empty string
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=guardrails_result
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guidelines processed completion text"
|
||||
|
||||
def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test fallback behavior when completion text is only whitespace."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text=" \n\t ", # Whitespace only
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == "Original completion text"
|
||||
|
||||
def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that empty string is returned when all completions are empty."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=None,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
final_text = completion_result.finalize()
|
||||
|
||||
# Assert
|
||||
assert final_text == ""
|
||||
|
||||
def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test boundary conditions for risk threshold (exactly 0.7)."""
|
||||
# Test score exactly at threshold
|
||||
original_result_at_threshold = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result_at_threshold.cosine_similarity_score=0.7
|
||||
original_result_at_threshold.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
# Test score just below threshold
|
||||
guidelines_result_below_threshold = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config,
|
||||
cosine_similarity_score=0.699999, # Just below threshold
|
||||
cosine_similarity_risk_threshold=0.7
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7
|
||||
assert not guidelines_result_below_threshold.is_original_completion_malicious() # 0.699999 < 0.7
|
||||
@@ -45,7 +45,8 @@ def run_prompt_analysis_test(
|
||||
completion_result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt)
|
||||
|
||||
print(f'analyzing result')
|
||||
result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final)
|
||||
completion_result.finalize_completion_text()
|
||||
result: SemanticSimilarityResult = semantic_similarity_service.analyze(completion_result.final_completion_text)
|
||||
|
||||
print(f'{i}/{len(prompts)} Max Score: {result.max}')
|
||||
print(f'{i}/{len(prompts)} Avg Score: {result.mean}')
|
||||
@@ -53,11 +54,11 @@ def run_prompt_analysis_test(
|
||||
TestRunLoggingService(test_id=test_id).log_results(
|
||||
id=inspect.currentframe().f_back.f_code.co_name,
|
||||
prompt=prompt,
|
||||
completion=completion_result.original_completion,
|
||||
final_completion=completion_result.final_completion_text,
|
||||
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
|
||||
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
|
||||
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
|
||||
llm_config=completion_result.guidelines_processed_completion.llm_config,
|
||||
original_llm_config=completion_result.original_result.llm_config,
|
||||
scores=result.scores,
|
||||
mean=result.mean,
|
||||
max=result.max
|
||||
|
||||
@@ -208,7 +208,7 @@ def test_example_with_fluent_service_call(
|
||||
TestRunLoggingService(test_id=10).log_results(
|
||||
id=inspect.currentframe().f_code.co_name,
|
||||
prompt=prompt,
|
||||
completion=completion_result.original_completion,
|
||||
final_completion=completion_result.original_completion,
|
||||
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
|
||||
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
|
||||
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
|
||||
|
||||
@@ -1,13 +1,276 @@
|
||||
"""
|
||||
pytest -v tests/integration/test_domain.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from src.text_generation.domain.guardrails_result import GuardrailsResult
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_guardrail_analyzed_response():
|
||||
response = GuardrailsResult(
|
||||
cosine_similarity_score=0.72839,
|
||||
cosine_similarity_risk_threshold=0.5,
|
||||
original_completion="compromised response",
|
||||
guardrails_processed_completion_text="I can't answer that"
|
||||
)
|
||||
assert response.is_original_completion_malicious == True
|
||||
class TestTextGenerationCompletionResult:
|
||||
"""Test suite for TextGenerationCompletionResult and related classes."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_config(self):
|
||||
"""Sample LLM configuration for testing."""
|
||||
return {
|
||||
"model": "test-model",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_full_prompt(self):
|
||||
"""Sample full prompt for testing."""
|
||||
return {
|
||||
"system": "You are a helpful assistant",
|
||||
"user": "Test prompt"
|
||||
}
|
||||
|
||||
def test_original_result_only(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test TextGenerationCompletionResult with only original result filled in."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=None,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == "Original completion text"
|
||||
assert original_result.user_prompt == "Test user prompt"
|
||||
assert original_result.llm_config == sample_llm_config
|
||||
assert not original_result.is_completion_malicious() # 0.5 < 0.7
|
||||
|
||||
def test_guidelines_and_original_guidelines_precedence(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that guidelines result takes precedence over original when both are present."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.8
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guidelines processed completion text"
|
||||
assert original_result.is_completion_malicious() # 0.8 >= 0.7
|
||||
assert not guidelines_result.is_completion_malicious() # 0.6 < 0.7
|
||||
|
||||
def test_guardrails_guidelines_original_guardrails_precedence(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that guardrails result takes precedence when all three are present."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.9
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.8
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guardrails_result = GuardrailsResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guardrails processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=guardrails_result
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guardrails processed completion text"
|
||||
assert original_result.is_completion_malicious() # 0.9 >= 0.7
|
||||
assert guidelines_result.is_completion_malicious() # 0.8 >= 0.7
|
||||
|
||||
def test_empty_completion_fallback_behavior(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test fallback behavior when some completion texts are empty."""
|
||||
# Arrange - guardrails has empty text, should fall back to guidelines
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines processed completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guardrails_result = GuardrailsResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="", # Empty string
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=guardrails_result
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == "Guidelines processed completion text"
|
||||
|
||||
def test_whitespace_only_completion_fallback(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test fallback behavior when completion text is only whitespace."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
guidelines_result = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text=" \n\t ", # Whitespace only
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
guidelines_result.cosine_similarity_score=0.6
|
||||
guidelines_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=guidelines_result,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == "Original completion text"
|
||||
|
||||
def test_all_empty_completions_returns_empty_string(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test that empty string is returned when all completions are empty."""
|
||||
# Arrange
|
||||
original_result = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result.cosine_similarity_score=0.5
|
||||
original_result.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
original_result=original_result,
|
||||
guidelines_result=None,
|
||||
guardrails_result=None
|
||||
)
|
||||
|
||||
# Act
|
||||
completion_result.finalize_completion_text()
|
||||
final_text = completion_result.final_completion_text
|
||||
|
||||
# Assert
|
||||
assert final_text == ""
|
||||
|
||||
def test_risk_threshold_boundary_conditions(self, sample_llm_config, sample_full_prompt):
|
||||
"""Test boundary conditions for risk threshold (exactly 0.7)."""
|
||||
# Test score exactly at threshold
|
||||
original_result_at_threshold = OriginalCompletionResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Original completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config
|
||||
)
|
||||
|
||||
original_result_at_threshold.cosine_similarity_score=0.7
|
||||
original_result_at_threshold.cosine_similarity_risk_threshold=0.7
|
||||
|
||||
# Test score just below threshold
|
||||
guidelines_result_below_threshold = GuidelinesResult(
|
||||
user_prompt="Test user prompt",
|
||||
completion_text="Guidelines completion text",
|
||||
full_prompt=sample_full_prompt,
|
||||
llm_config=sample_llm_config,
|
||||
cosine_similarity_score=0.699999, # Just below threshold
|
||||
cosine_similarity_risk_threshold=0.7
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert original_result_at_threshold.is_completion_malicious() # 0.7 >= 0.7
|
||||
assert not guidelines_result_below_threshold.is_completion_malicious() # 0.699999 < 0.7
|
||||
Reference in New Issue
Block a user