This commit is contained in:
Adam Wilson
2025-07-24 18:51:18 -06:00
parent ae279a512d
commit 4f2e539996
27 changed files with 10560 additions and 24 deletions
@@ -9,10 +9,13 @@ class GuidelinesResult(
self,
completion_text: str,
llm_config: dict,
cosine_similarity_score: float,
cosine_similarity_risk_threshold: float):
cosine_similarity_score: float = 0.0,
cosine_similarity_risk_threshold: float = 0.0):
self.completion_text = completion_text
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
def is_original_completion_malicious(self) -> bool:
return self.cosine_similarity_score >= self.cosine_similarity_risk_threshold
@@ -15,10 +15,12 @@ class TextGenerationCompletionResult(
def __init__(
self,
llm_config: dict,
original_completion: str,
guidelines_result: Optional[GuidelinesResult] = None,
guardrails_result: Optional[GuardrailsResult] = None):
self.llm_config = llm_config
self.original_completion = original_completion
self.guidelines_processed_completion = guidelines_result
self.guardrails_processed_completion = guardrails_result
@@ -52,13 +52,21 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
"timestamp": datetime.now().isoformat(),
"prompt": prompt,
"completion": completion,
"is_rag_few_shot_enabled": is_rag_few_shot_enabled,
"is_cot_enabled": is_cot_enabled,
"is_reflexion_enabled": is_reflexion_enabled,
"mitigations_enabled": {
"guidelines": {
"rag_with_few_shot_examples": is_rag_few_shot_enabled,
"chain_of_thought": is_cot_enabled
},
"guardrails": {
"reflexion": is_reflexion_enabled
}
},
"llm_config": llm_config,
"mean": mean,
"max": max,
"scores": scores
"cosine_similarity": {
"mean": mean,
"max": max,
"all_scores": scores
}
}
logs.append(log_entry)
self._write_logs(logs)
@@ -1,6 +1,7 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from langchain_huggingface import HuggingFacePipeline
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
@@ -40,6 +41,9 @@ class TextGenerationCompletionService(
self.example_prompt_injection_completions = prompt_injection_example_service.get_all_completions()
self.example_prompt_injection_prompts = prompt_injection_example_service.get_all_prompts()
self.semantic_similarity_service = semantic_similarity_service
self.semantic_similarity_service.use_comparison_texts(
self.example_prompt_injection_completions
)
# guidelines services
self.chain_of_thought_guidelines: AbstractSecurityGuidelinesService = chain_of_thought_guidelines
@@ -89,32 +93,54 @@ class TextGenerationCompletionService(
Returns:
TextGenerationCompletionResult with appropriate completion text
"""
self.semantic_similarity_service.use_comparison_texts(
self.example_prompt_injection_completions
)
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(
text=guidelines_result.completion_text
text = guidelines_result.completion_text
)
processed_guidelines_result = GuidelinesResult(
completion_text=guidelines_result.completion_text,
cosine_similarity_score=similarity_result.mean,
cosine_similarity_risk_threshold=self.COSINE_SIMILARITY_RISK_THRESHOLD
)
guidelines_result.cosine_similarity_risk_threshold = self.COSINE_SIMILARITY_RISK_THRESHOLD
guidelines_result.cosine_similarity_score = similarity_result.mean
completion_result = TextGenerationCompletionResult(
original_completion=guidelines_result.completion_text,
guidelines_result=processed_guidelines_result
llm_config = guidelines_result.llm_config,
original_completion = guidelines_result.completion_text,
guidelines_result = guidelines_result
)
if not processed_guidelines_result.is_original_completion_malicious:
if not guidelines_result.is_original_completion_malicious():
return completion_result
completion_result.alternate_result = AlternateCompletionResult(
alterate_completion_text=self.constants.ALT_COMPLETION_TEXT
alterate_completion_text = self.constants.ALT_COMPLETION_TEXT
)
return completion_result
def _get_active_model_configuration_params(self, hf_pipeline_component):
pipeline_obj = hf_pipeline_component.pipeline
# model defaults
active_params = {}
if hasattr(pipeline_obj.model, 'generation_config'):
gen_config = pipeline_obj.model.generation_config
print(gen_config)
active_params.update({
'temperature': getattr(gen_config, 'temperature', None),
'top_p': getattr(gen_config, 'top_p', None),
'top_k': getattr(gen_config, 'top_k', None),
'max_new_tokens': getattr(gen_config, 'max_new_tokens', None),
'max_length': getattr(gen_config, 'max_length', None),
'repetition_penalty': getattr(gen_config, 'repetition_penalty', None),
'do_sample': getattr(gen_config, 'do_sample', None),
})
# get pipeline-specific override parameters
if hasattr(pipeline_obj, '_forward_params'):
forward_params = pipeline_obj._forward_params
active_params.update(forward_params)
return active_params
# Handler methods for each guidelines combination
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=True"""
@@ -135,9 +161,10 @@ class TextGenerationCompletionService(
"""Handle: CoT=False, RAG=False"""
try:
chain = self._create_chain_without_guidelines()
llm_config = self._get_active_model_configuration_params(chain.steps[2])
result = GuidelinesResult(
completion_text=chain.invoke(user_prompt),
llm_config=chain.steps[1].model_dump()
completion_text = chain.invoke(user_prompt),
llm_config = llm_config
)
return self._process_guidelines_result(result)
except Exception as e: