more updates for reflexion

This commit is contained in:
Adam Wilson
2025-07-28 10:31:55 -06:00
parent 5bc9f480f9
commit 2659e6e43c
12 changed files with 439 additions and 340 deletions
@@ -12,9 +12,16 @@ class GuardrailsResult:
user_prompt: str,
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict
llm_config: dict,
cosine_similarity_score: float = -1.0,
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
@@ -17,11 +17,11 @@ class GuidelinesResult(AbstractGuidelinesProcessedCompletion):
cosine_similarity_risk_threshold: float = 0.0
):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.guidelines_completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def is_original_completion_malicious(self) -> bool:
def is_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
@@ -17,38 +17,62 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
self,
original_result: OriginalCompletionResult,
guidelines_result: Optional[GuidelinesResult] = None,
guardrails_result: Optional[GuardrailsResult] = None
guardrails_result: Optional[GuardrailsResult] = None,
alternate_result: Optional[AlternateCompletionResult] = None
):
self.original_result = original_result
self.guidelines_result = guidelines_result
self.guardrails_result = guardrails_result
self.alternate_result = alternate_result
self.final_completion_text = ''
def finalize(self) -> str:
def finalize_completion_text(self) -> str:
"""
Returns the final completion text based on priority order:
Returns the current completion text based on priority order:
1. guardrails_result.completion_text (if not empty)
2. guidelines_result.completion_text (if not empty)
3. original_result.completion_text (if not empty)
"""
# Check guardrails_result.completion_text first
print(f'Finalized text was \"{self.final_completion_text}\"')
# Check alternate text first
if (self.alternate_result and
self.alternate_result.alterate_completion_text and
self.alternate_result.alterate_completion_text.strip()
):
self.final_completion_text = self.alternate_result.alterate_completion_text
print(f'Using alternate result. Finalized text is now \"{self.final_completion_text}\"')
return
# Check guardrails_result.completion_text next
if (self.guardrails_result and
self.guardrails_result.completion_text and
self.guardrails_result.completion_text.strip()):
return self.guardrails_result.completion_text
self.guardrails_result.completion_text.strip()
):
self.final_completion_text = self.guardrails_result.completion_text
print(f'Using guardrails result. Finalized text is now \"{self.final_completion_text}\"')
return
# Fall back to guidelines_result.completion_text
if (self.guidelines_result and
self.guidelines_result.completion_text and
self.guidelines_result.completion_text.strip()):
return self.guidelines_result.completion_text
self.guidelines_result.guidelines_completion_text and
self.guidelines_result.guidelines_completion_text.strip()
):
self.final_completion_text = self.guidelines_result.guidelines_completion_text
print(f'Using guidelines result. Finalized text is now \"{self.final_completion_text}\"')
return
# Fall back to original_result.completion_text
if (self.original_result and
self.original_result.completion_text and
self.original_result.completion_text.strip()):
return self.original_result.completion_text
self.original_result.completion_text.strip()
):
self.final_completion_text = self.original_result.completion_text
print(f'Using original. Finalized text is now \"{self.final_completion_text}\"')
return
# If all are empty, return empty string
return ""
self.final_completion_text = ""
print(f'Finalized text is now \"{self.final_completion_text}\"')
return
@@ -4,12 +4,16 @@ from langchain_core.prompts import PromptTemplate, StringPromptTemplate
from langchain_core.prompt_values import PromptValue
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
from src.text_generation.domain.guardrails_result import GuardrailsResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.nlp.abstract_semantic_similarity_service import AbstractSemanticSimilarityService
from src.text_generation.services.prompt_injection.abstract_prompt_injection_example_service import AbstractPromptInjectionExampleService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
@@ -25,6 +29,8 @@ class ReflexionSecurityGuardrailsService(
def __init__(
self,
foundation_model: AbstractFoundationModel,
semantic_similarity_service: AbstractSemanticSimilarityService,
prompt_injection_example_service: AbstractPromptInjectionExampleService,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService):
@@ -35,11 +41,21 @@ class ReflexionSecurityGuardrailsService(
self.prompt_template_service = prompt_template_service
self.llm_configuration_introspection_service = llm_configuration_introspection_service
# constants
self.COSINE_SIMILARITY_RISK_THRESHOLD = 0.5
# set up semantic similarity service and supporting texts
self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions()
self.example_prompt_injection_prompts = prompt_injection_example_service.get_all_prompts()
self.semantic_similarity_service = semantic_similarity_service
self.semantic_similarity_service.use_comparison_texts(
self.example_prompt_injection_completions
)
def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str:
result: TextGenerationCompletionResult = text_generation_completion_result
original_user_prompt = result.original_user_prompt
original_completion = result.original_completion
original_user_prompt = result.original_result.user_prompt
original_completion = result.original_result.completion_text
# Assemble the context showing the original prompt injection attack
context_parts = [
@@ -67,7 +83,13 @@ class ReflexionSecurityGuardrailsService(
return filled_template
def _create_chain(self, prompt_template: StringPromptTemplate):
return prompt_template | self.foundation_model_pipeline | StrOutputParser()
# return prompt_template | self.foundation_model_pipeline | StrOutputParser()
return (
prompt_template
| self.foundation_model_pipeline
| StrOutputParser()
| self.response_processing_service.process_text_generation_output
)
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
"""
@@ -78,7 +100,7 @@ class ReflexionSecurityGuardrailsService(
try:
result: TextGenerationCompletionResult = text_generation_completion_result
original_user_prompt = result.original_user_prompt
original_user_prompt = result.original_result.user_prompt
prompt_template: StringPromptTemplate = self._get_template(text_generation_completion_result)
prompt_value: PromptValue = prompt_template.format_prompt(**{self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
@@ -95,11 +117,30 @@ class ReflexionSecurityGuardrailsService(
completion_text = chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: original_user_prompt})
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result.guardrails_processed_completion = GuardrailsResult(
result.guardrails_result = GuardrailsResult(
user_prompt=original_user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
)
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(text=completion_text)
# update completion result with similarity scoring threshold and result
result.guardrails_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
result.guardrails_result.cosine_similarity_score = similarity_result.max
# return raw result if the completion comparison score didn't exceed threshold
if not result.guardrails_result.is_completion_malicious():
print(f'Guardrails-based completion was NOT malicious. Score: {result.guardrails_result.cosine_similarity_score}')
return result
# provide the finalized alternate (refuse to answer)
print(f'Guardrails-based completion was malicious. Score: {result.guardrails_result.cosine_similarity_score}')
result.alternate_result = AlternateCompletionResult(
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
)
result.finalize_completion_text()
return result
except Exception as e:
@@ -8,6 +8,8 @@ from langchain.prompts import FewShotPromptTemplate
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
@@ -74,13 +76,21 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
chain = self._create_chain(prompt_template)
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
result = TextGenerationCompletionResult(
original_result=OriginalCompletionResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
),
guidelines_result=GuidelinesResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
)
)
return result
except Exception as e:
@@ -36,12 +36,12 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
def log_results(
self,
id: str,
text_generation_result: str,
completion: str,
prompt: str,
final_completion: str,
is_rag_few_shot_enabled: bool,
is_cot_enabled: bool,
is_reflexion_enabled: bool,
llm_config: dict,
original_llm_config: dict,
scores: List[float],
mean: float,
max: float):
@@ -50,8 +50,8 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
log_entry = {
"id": id,
"timestamp": datetime.now().isoformat(),
"prompt": prompt,
"completion": completion,
"original_prompt": prompt,
"final_completion": final_completion,
"mitigations_enabled": {
"guidelines": {
"rag_with_few_shot_examples": is_rag_few_shot_enabled,
@@ -61,7 +61,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
"reflexion": is_reflexion_enabled
}
},
"llm_config": llm_config,
"original_llm_config": original_llm_config,
"cosine_similarity": {
"mean": mean,
"max": max,
@@ -6,6 +6,7 @@ from langchain_huggingface import HuggingFacePipeline
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.original_completion_result import OriginalCompletionResult
from src.text_generation.domain.semantic_similarity_result import SemanticSimilarityResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
@@ -84,12 +85,15 @@ class TextGenerationCompletionService(
self._use_rag_context
)
guidelines_handler = self.guidelines_strategy_map.get(
guidelines_config,
guidelines_config,
# fall back to unfiltered LLM invocation
self._handle_without_guidelines
)
return guidelines_handler(user_prompt)
def _process_guidelines_result(self, guidelines_result: GuidelinesResult) -> TextGenerationCompletionResult:
def _process_completion_result(self, completion_result: TextGenerationCompletionResult) -> TextGenerationCompletionResult:
"""
Process guidelines result and create completion result with semantic similarity check.
@@ -100,26 +104,36 @@ class TextGenerationCompletionService(
TextGenerationCompletionResult with appropriate completion text
"""
# analyze the current version of the completion text against prompt injection completions;
# if guidelines applied, this is the result of completion using guidelines;
# otherwise it is the raw completion text without guidelines
completion_result.finalize_completion_text()
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(
text = guidelines_result.completion_text
text = completion_result.final_completion_text
)
guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
guidelines_result.cosine_similarity_score = similarity_result.mean
if not completion_result.guidelines_result:
completion_result.guidelines_result = GuidelinesResult(
user_prompt=completion_result.original_result.user_prompt,
completion_text=completion_result.original_result.completion_text,
llm_config=completion_result.original_result.llm_config
)
completion_result = TextGenerationCompletionResult(
llm_config = guidelines_result.llm_config,
original_completion = guidelines_result.completion_text,
original_user_prompt = guidelines_result.user_prompt,
guidelines_result = guidelines_result
)
# update completion result with similarity scoring threshold and result
completion_result.guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
completion_result.guidelines_result.cosine_similarity_score = similarity_result.max
if not guidelines_result.is_original_completion_malicious():
# return raw result if the completion comparison score didn't exceed threshold
if not completion_result.guidelines_result.is_completion_malicious():
print(f'Guidelines-based completion was NOT malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
return completion_result
# provide the finalized alternate (refuse to answer)
print(f'Guidelines-based completion was malicious. Score: {completion_result.guidelines_result.cosine_similarity_score}')
completion_result.alternate_result = AlternateCompletionResult(
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
)
completion_result.finalize_completion_text()
return completion_result
@@ -127,28 +141,31 @@ class TextGenerationCompletionService(
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=True"""
guidelines_result = self.rag_plus_cot_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_cot_only(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=False"""
guidelines_result = self.chain_of_thought_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_rag_only(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=False, RAG=True"""
guidelines_result = self.rag_context_guidelines.apply_guidelines(user_prompt)
return self._process_guidelines_result(guidelines_result)
return self._process_completion_result(guidelines_result)
def _handle_without_guidelines(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=False, RAG=False"""
try:
chain = self._create_chain_without_guidelines()
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
completion_text = chain.invoke(user_prompt),
llm_config = llm_config
)
return self._process_guidelines_result(result)
result = TextGenerationCompletionResult(
original_result=OriginalCompletionResult(
user_prompt=user_prompt,
completion_text=chain.invoke(user_prompt),
llm_config=llm_config
))
return self._process_completion_result(result)
except Exception as e:
raise e
@@ -216,6 +233,8 @@ class TextGenerationCompletionService(
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
print(f'Using guidelines: {self.get_current_config()}')
completion_result: TextGenerationCompletionResult = self._process_prompt_with_guidelines_if_applicable(user_prompt)
if not self._use_reflexion_guardrails:
return completion_result
return self._handle_reflexion_guardrails(completion_result)