updates for RAG + CoT tests

This commit is contained in:
Adam Wilson
2025-07-25 18:11:49 -06:00
parent 72785c6420
commit 741629908c
10 changed files with 6603 additions and 61 deletions
@@ -15,12 +15,13 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
def _examples(self) -> List[Dict[str, str]]:
"""Get the list of prompt injection examples."""
base_path = os.environ.get('BASE_PATH', '.')
json_filename = 'process_random_requests_and_responses.output.50_count.json'
base_path = os.environ.get('INJECTION_DATA_DIR', '.')
json_filename = 'process_random_requests_and_responses.output.545_count.json'
file_path = os.path.join(base_path, json_filename)
result = []
try:
print(f'opening {file_path}')
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
@@ -33,6 +34,8 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
print("Error: JSON file should contain an object or array of objects")
return []
print(f'data length: {len(data)}')
for item in data:
if isinstance(item, dict):
# Extract prompt and completion keys, ignore others
@@ -50,8 +53,7 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
extracted['prompt_injection_completion'] = item[key]
break
# Only add if we found both prompt and completion
if 'prompt' in extracted and 'completion' in extracted:
if 'prompt_injection_prompt' in extracted and 'prompt_injection_completion' in extracted:
result.append(extracted)
except FileNotFoundError:
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
@@ -8,11 +8,13 @@ class GuidelinesResult(
def __init__(
self,
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict,
cosine_similarity_score: float = 0.0,
cosine_similarity_risk_threshold: float = 0.0):
self.completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
@@ -1,7 +1,9 @@
from typing import Optional
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, StringPromptTemplate
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import FewShotPromptTemplate
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
@@ -11,6 +13,7 @@ from src.text_generation.services.guidelines.abstract_security_guidelines_servic
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
"""Base service for security guidelines implementations."""
@@ -54,9 +57,8 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
try:
prompt_template = self._get_template(user_prompt=user_prompt)
prompt_value = prompt_template.format_prompt()
prompt_template: FewShotPromptTemplate = self._get_template(user_prompt=user_prompt)
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
# Create a comprehensive dict
prompt_dict = {
@@ -67,12 +69,11 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
"string_representation": prompt_value.to_string(),
}
print(prompt_dict)
chain = self._create_chain(prompt_template)
result = GuidelinesResult(
completion_text=chain.invoke(user_prompt),
llm_config=chain.steps[1].model_dump()
completion_text=chain.invoke({"input": user_prompt}),
llm_config=chain.steps[1].model_dump(),
full_prompt=prompt_dict
)
return result
except Exception as e:
@@ -1,9 +1,9 @@
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import StringPromptTemplate
from langchain.prompts import FewShotPromptTemplate
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from src.text_generation.adapters.embedding_model import EmbeddingModel
from src.text_generation.common.constants import Constants
from src.text_generation.ports.abstract_prompt_injection_example_repository import AbstractPromptInjectionExampleRepository
@@ -11,25 +11,23 @@ from src.text_generation.ports.abstract_embedding_model import AbstractEmbedding
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
AbstractSecurityGuidelinesConfigurationBuilder):
AbstractSecurityGuidelinesConfigurationBuilder):
def __init__(
self,
embedding_model: AbstractEmbeddingModel,
prompt_template_service: AbstractPromptTemplateService,
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository):
self,
embedding_model: AbstractEmbeddingModel,
prompt_template_service: AbstractPromptTemplateService,
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository):
self.constants = Constants()
self.embedding_model: EmbeddingModel = embedding_model
self.prompt_template_service = prompt_template_service
self.prompt_injection_example_repository = prompt_injection_example_repository
self.prompt_template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__03_FEW_SHOT_EXAMPLES
self.vectorstore = self._setup_vectorstore()
def _setup_vectorstore(self):
documents = self._load_examples()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
@@ -37,14 +35,12 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
separators=["\n\n", "\n", ".", ",", " ", ""]
)
split_docs = text_splitter.split_documents(documents)
# Create FAISS vector store from chunks
return FAISS.from_documents(split_docs, self.embedding_model.embeddings)
def _load_examples(self) -> list[Document]:
data = self.prompt_injection_example_repository.get_all()
print(f'got {len(data)} prompt injection examples')
documents = []
for item in data:
# Create document content combining both fields for better retrieval
@@ -57,16 +53,31 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
}
)
documents.append(doc)
return documents
def _create_few_shot_prompt(self, user_prompt: str) -> FewShotPromptTemplate:
prompt_template: FewShotPromptTemplate = self.prompt_template_service.get(id=self.prompt_template_id)
prompt_template.examples = self._load_examples()
def _get_relevant_examples(self, user_prompt: str, k: int = 3):
"""Retrieve the most relevant examples based on the user prompt using RAG"""
# Use similarity search to find relevant examples
relevant_docs = self.vectorstore.similarity_search(user_prompt, k=k)
# Convert documents back to the format expected by the few-shot template
examples = []
for doc in relevant_docs:
examples.append({
'prompt_injection_prompt': doc.metadata['prompt_injection_prompt'],
'prompt_injection_completion': doc.metadata['prompt_injection_completion']
})
return examples
def _create_few_shot_prompt(self, template_id: str, user_prompt: str) -> FewShotPromptTemplate:
prompt_template: FewShotPromptTemplate = self.prompt_template_service.get(id=template_id)
prompt_template.examples = self._get_relevant_examples(user_prompt)
return prompt_template
def get_prompt_template(self, user_prompt: str):
return self._create_few_shot_prompt(user_prompt)
def _get_template(self, template_id: str, user_prompt: str):
prompt_template = self._create_few_shot_prompt(template_id=template_id, user_prompt=user_prompt)
return prompt_template
def get_prompt_template(self, template_id: str, user_prompt: str) -> StringPromptTemplate:
return self._get_template(template_id=template_id, user_prompt=user_prompt)
@@ -23,12 +23,6 @@ class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
)
def _get_template(self, user_prompt: str) -> StringPromptTemplate:
"""
Get RAG context security guidelines template.
Returns:
StringPromptTemplate: Template configured for RAG processing
"""
template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__03_FEW_SHOT_EXAMPLES
return self.config_builder.get_prompt_template(
template_id=template_id,
@@ -26,21 +26,8 @@ class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
)
def _get_template(self, user_prompt: str) -> StringPromptTemplate:
"""
Get RAG context security guidelines template.
Returns:
StringPromptTemplate: Template configured for RAG processing
"""
return self.prompt_template_service.get(
id=self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
)
def _get_template_id(self) -> str:
"""
Get template ID for combined RAG + CoT processing.
Returns:
str: Template ID for RAG + CoT security guidelines
"""
return self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
return self.config_builder.get_prompt_template(
template_id=template_id,
user_prompt=user_prompt
)
+1 -1
View File
@@ -37,7 +37,7 @@ MAX_REQUEST_SAMPLE_COUNT = 1
MAX_RESPONSE_SAMPLE_COUNT = 50
# prompt_injection_sample_file_path = './tests/integration/samples.prompt_injection_true_positive.json'
INJECTION_PROMPT_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/process_random_requests.output.150_count.json'
INJECTION_PROMPT_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/process_random_requests.output.100_count.json'
PROMPT_INJECTION_PROMPTS_AND_RESPONSES_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/failed_entries.json'
+1
View File
@@ -33,6 +33,7 @@ def run_prompt_analysis_test(
"""
semantic_similarity_service.use_comparison_texts(comparison_texts)
results = []
print(f'using {len(prompts)} prompts for testing...')
for i, prompt in enumerate(prompts[:max_prompts], 1):
# Configure the service using the provided configurator function