mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-05-26 01:47:55 +02:00
updates for RAG + CoT tests
This commit is contained in:
@@ -15,12 +15,13 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
|
||||
def _examples(self) -> List[Dict[str, str]]:
|
||||
"""Get the list of prompt injection examples."""
|
||||
|
||||
base_path = os.environ.get('BASE_PATH', '.')
|
||||
json_filename = 'process_random_requests_and_responses.output.50_count.json'
|
||||
base_path = os.environ.get('INJECTION_DATA_DIR', '.')
|
||||
json_filename = 'process_random_requests_and_responses.output.545_count.json'
|
||||
file_path = os.path.join(base_path, json_filename)
|
||||
result = []
|
||||
|
||||
try:
|
||||
print(f'opening {file_path}')
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
@@ -33,6 +34,8 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
|
||||
print("Error: JSON file should contain an object or array of objects")
|
||||
return []
|
||||
|
||||
print(f'data length: {len(data)}')
|
||||
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
# Extract prompt and completion keys, ignore others
|
||||
@@ -50,8 +53,7 @@ class PromptInjectionExampleRepository(AbstractPromptInjectionExampleRepository)
|
||||
extracted['prompt_injection_completion'] = item[key]
|
||||
break
|
||||
|
||||
# Only add if we found both prompt and completion
|
||||
if 'prompt' in extracted and 'completion' in extracted:
|
||||
if 'prompt_injection_prompt' in extracted and 'prompt_injection_completion' in extracted:
|
||||
result.append(extracted)
|
||||
|
||||
except FileNotFoundError:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
|
||||
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
|
||||
|
||||
@@ -8,11 +8,13 @@ class GuidelinesResult(
|
||||
def __init__(
|
||||
self,
|
||||
completion_text: str,
|
||||
full_prompt: dict[str, Any],
|
||||
llm_config: dict,
|
||||
cosine_similarity_score: float = 0.0,
|
||||
cosine_similarity_risk_threshold: float = 0.0):
|
||||
|
||||
self.completion_text = completion_text
|
||||
self.full_prompt = full_prompt
|
||||
self.llm_config = llm_config
|
||||
self.cosine_similarity_score = cosine_similarity_score
|
||||
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Optional
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate, StringPromptTemplate
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain.prompts import FewShotPromptTemplate
|
||||
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
|
||||
@@ -11,6 +13,7 @@ from src.text_generation.services.guidelines.abstract_security_guidelines_servic
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
|
||||
class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
"""Base service for security guidelines implementations."""
|
||||
|
||||
@@ -54,9 +57,8 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
|
||||
try:
|
||||
prompt_template = self._get_template(user_prompt=user_prompt)
|
||||
|
||||
prompt_value = prompt_template.format_prompt()
|
||||
prompt_template: FewShotPromptTemplate = self._get_template(user_prompt=user_prompt)
|
||||
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
|
||||
|
||||
# Create a comprehensive dict
|
||||
prompt_dict = {
|
||||
@@ -67,12 +69,11 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
"string_representation": prompt_value.to_string(),
|
||||
}
|
||||
|
||||
print(prompt_dict)
|
||||
|
||||
chain = self._create_chain(prompt_template)
|
||||
result = GuidelinesResult(
|
||||
completion_text=chain.invoke(user_prompt),
|
||||
llm_config=chain.steps[1].model_dump()
|
||||
completion_text=chain.invoke({"input": user_prompt}),
|
||||
llm_config=chain.steps[1].model_dump(),
|
||||
full_prompt=prompt_dict
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
+35
-24
@@ -1,9 +1,9 @@
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.prompts import StringPromptTemplate
|
||||
from langchain.prompts import FewShotPromptTemplate
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from src.text_generation.adapters.embedding_model import EmbeddingModel
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.ports.abstract_prompt_injection_example_repository import AbstractPromptInjectionExampleRepository
|
||||
@@ -11,25 +11,23 @@ from src.text_generation.ports.abstract_embedding_model import AbstractEmbedding
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
|
||||
|
||||
class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
|
||||
AbstractSecurityGuidelinesConfigurationBuilder):
|
||||
|
||||
AbstractSecurityGuidelinesConfigurationBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: AbstractEmbeddingModel,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository):
|
||||
self,
|
||||
embedding_model: AbstractEmbeddingModel,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository):
|
||||
|
||||
self.constants = Constants()
|
||||
self.embedding_model: EmbeddingModel = embedding_model
|
||||
self.prompt_template_service = prompt_template_service
|
||||
self.prompt_injection_example_repository = prompt_injection_example_repository
|
||||
self.prompt_template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__03_FEW_SHOT_EXAMPLES
|
||||
self.vectorstore = self._setup_vectorstore()
|
||||
|
||||
|
||||
def _setup_vectorstore(self):
|
||||
documents = self._load_examples()
|
||||
|
||||
# Split documents into chunks
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
@@ -37,14 +35,12 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
|
||||
separators=["\n\n", "\n", ".", ",", " ", ""]
|
||||
)
|
||||
split_docs = text_splitter.split_documents(documents)
|
||||
|
||||
# Create FAISS vector store from chunks
|
||||
return FAISS.from_documents(split_docs, self.embedding_model.embeddings)
|
||||
|
||||
|
||||
def _load_examples(self) -> list[Document]:
|
||||
data = self.prompt_injection_example_repository.get_all()
|
||||
print(f'got {len(data)} prompt injection examples')
|
||||
|
||||
documents = []
|
||||
for item in data:
|
||||
# Create document content combining both fields for better retrieval
|
||||
@@ -57,16 +53,31 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
|
||||
}
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def _create_few_shot_prompt(self, user_prompt: str) -> FewShotPromptTemplate:
|
||||
prompt_template: FewShotPromptTemplate = self.prompt_template_service.get(id=self.prompt_template_id)
|
||||
prompt_template.examples = self._load_examples()
|
||||
def _get_relevant_examples(self, user_prompt: str, k: int = 3):
|
||||
"""Retrieve the most relevant examples based on the user prompt using RAG"""
|
||||
# Use similarity search to find relevant examples
|
||||
relevant_docs = self.vectorstore.similarity_search(user_prompt, k=k)
|
||||
|
||||
# Convert documents back to the format expected by the few-shot template
|
||||
examples = []
|
||||
for doc in relevant_docs:
|
||||
examples.append({
|
||||
'prompt_injection_prompt': doc.metadata['prompt_injection_prompt'],
|
||||
'prompt_injection_completion': doc.metadata['prompt_injection_completion']
|
||||
})
|
||||
|
||||
return examples
|
||||
|
||||
def _create_few_shot_prompt(self, template_id: str, user_prompt: str) -> FewShotPromptTemplate:
|
||||
prompt_template: FewShotPromptTemplate = self.prompt_template_service.get(id=template_id)
|
||||
prompt_template.examples = self._get_relevant_examples(user_prompt)
|
||||
return prompt_template
|
||||
|
||||
|
||||
def get_prompt_template(self, user_prompt: str):
|
||||
return self._create_few_shot_prompt(user_prompt)
|
||||
|
||||
|
||||
def _get_template(self, template_id: str, user_prompt: str):
|
||||
prompt_template = self._create_few_shot_prompt(template_id=template_id, user_prompt=user_prompt)
|
||||
return prompt_template
|
||||
|
||||
def get_prompt_template(self, template_id: str, user_prompt: str) -> StringPromptTemplate:
|
||||
return self._get_template(template_id=template_id, user_prompt=user_prompt)
|
||||
@@ -23,12 +23,6 @@ class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
)
|
||||
|
||||
def _get_template(self, user_prompt: str) -> StringPromptTemplate:
|
||||
"""
|
||||
Get RAG context security guidelines template.
|
||||
|
||||
Returns:
|
||||
StringPromptTemplate: Template configured for RAG processing
|
||||
"""
|
||||
template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__03_FEW_SHOT_EXAMPLES
|
||||
return self.config_builder.get_prompt_template(
|
||||
template_id=template_id,
|
||||
|
||||
+5
-18
@@ -26,21 +26,8 @@ class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
)
|
||||
|
||||
def _get_template(self, user_prompt: str) -> StringPromptTemplate:
|
||||
"""
|
||||
Get RAG context security guidelines template.
|
||||
|
||||
Returns:
|
||||
StringPromptTemplate: Template configured for RAG processing
|
||||
"""
|
||||
return self.prompt_template_service.get(
|
||||
id=self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
|
||||
)
|
||||
|
||||
def _get_template_id(self) -> str:
|
||||
"""
|
||||
Get template ID for combined RAG + CoT processing.
|
||||
|
||||
Returns:
|
||||
str: Template ID for RAG + CoT security guidelines
|
||||
"""
|
||||
return self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
|
||||
template_id = self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__04_FEW_SHOT_RAG_PLUS_COT
|
||||
return self.config_builder.get_prompt_template(
|
||||
template_id=template_id,
|
||||
user_prompt=user_prompt
|
||||
)
|
||||
+1
-1
@@ -37,7 +37,7 @@ MAX_REQUEST_SAMPLE_COUNT = 1
|
||||
MAX_RESPONSE_SAMPLE_COUNT = 50
|
||||
|
||||
# prompt_injection_sample_file_path = './tests/integration/samples.prompt_injection_true_positive.json'
|
||||
INJECTION_PROMPT_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/process_random_requests.output.150_count.json'
|
||||
INJECTION_PROMPT_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/process_random_requests.output.100_count.json'
|
||||
PROMPT_INJECTION_PROMPTS_AND_RESPONSES_SAMPLE_FILE_PATH = './tests/security/tests/results/01_garak_no_guidelines/failed_entries.json'
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ def run_prompt_analysis_test(
|
||||
"""
|
||||
semantic_similarity_service.use_comparison_texts(comparison_texts)
|
||||
results = []
|
||||
print(f'using {len(prompts)} prompts for testing...')
|
||||
|
||||
for i, prompt in enumerate(prompts[:max_prompts], 1):
|
||||
# Configure the service using the provided configurator function
|
||||
|
||||
+3272
File diff suppressed because one or more lines are too long
+3272
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user