diff --git a/src/text_generation/domain/guidelines_result.py b/src/text_generation/domain/guidelines_result.py index 9c1a9d121..9beac9735 100644 --- a/src/text_generation/domain/guidelines_result.py +++ b/src/text_generation/domain/guidelines_result.py @@ -1,3 +1,4 @@ +from typing import List from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult @@ -6,12 +7,12 @@ class GuidelinesResult( AbstractGuidelinesProcessedCompletion): def __init__( self, + completion_text: str, + llm_config: dict, cosine_similarity_score: float, - cosine_similarity_risk_threshold: float, - original_completion: str): + cosine_similarity_risk_threshold: float): - self.score = cosine_similarity_score - self.original_completion = original_completion - self.is_original_completion_malicious = ( - cosine_similarity_score >= cosine_similarity_risk_threshold - ) \ No newline at end of file + self.completion_text = completion_text + self.llm_config = llm_config + self.cosine_similarity_score = cosine_similarity_score + self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold \ No newline at end of file diff --git a/src/text_generation/services/guidelines/base_security_guidelines_service.py b/src/text_generation/services/guidelines/base_security_guidelines_service.py index afd0bdf5a..217dc7fca 100644 --- a/src/text_generation/services/guidelines/base_security_guidelines_service.py +++ b/src/text_generation/services/guidelines/base_security_guidelines_service.py @@ -1,10 +1,11 @@ from typing import Optional from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate, StringPromptTemplate -from langchain_core.runnables import RunnablePassthrough +from langchain_core.runnables import RunnablePassthrough, RunnableConfig from src.text_generation.common.constants import Constants from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion +from src.text_generation.domain.guidelines_result import GuidelinesResult from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService @@ -51,6 +52,10 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService): try: prompt_template = self._get_template() chain = self._create_chain(prompt_template) - return chain.invoke(user_prompt) + result = GuidelinesResult( + completion_text=chain.invoke(user_prompt), + llm_config=chain.steps[1].model_dump() + ) + return result except Exception as e: raise e \ No newline at end of file diff --git a/src/text_generation/services/logging/test_run_logging_service.py b/src/text_generation/services/logging/test_run_logging_service.py index 8426a17f8..c153c3b8c 100644 --- a/src/text_generation/services/logging/test_run_logging_service.py +++ b/src/text_generation/services/logging/test_run_logging_service.py @@ -41,6 +41,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService): is_rag_few_shot_enabled: bool, is_cot_enabled: bool, is_reflexion_enabled: bool, + llm_config: dict, scores: List[float], mean: float, max: float): @@ -54,6 +55,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService): "is_rag_few_shot_enabled": is_rag_few_shot_enabled, "is_cot_enabled": is_cot_enabled, "is_reflexion_enabled": is_reflexion_enabled, + "llm_config": llm_config, "mean": mean, "max": max, "scores": scores diff --git a/src/text_generation/services/nlp/text_generation_completion_service.py b/src/text_generation/services/nlp/text_generation_completion_service.py index 833ab016f..95f1618d6 100644 --- a/src/text_generation/services/nlp/text_generation_completion_service.py +++ b/src/text_generation/services/nlp/text_generation_completion_service.py @@ -1,6 +1,6 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnablePassthrough +from langchain_core.runnables import RunnablePassthrough, RunnableConfig from src.text_generation.common.constants import Constants from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult @@ -93,17 +93,17 @@ class TextGenerationCompletionService( self.example_prompt_injection_completions ) similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze( - text=guidelines_result.original_completion + text=guidelines_result.completion_text ) processed_guidelines_result = GuidelinesResult( - original_completion=guidelines_result.original_completion, + completion_text=guidelines_result.completion_text, cosine_similarity_score=similarity_result.mean, cosine_similarity_risk_threshold=self.COSINE_SIMILARITY_RISK_THRESHOLD ) completion_result = TextGenerationCompletionResult( - original_completion=guidelines_result.original_completion, + original_completion=guidelines_result.completion_text, guidelines_result=processed_guidelines_result ) @@ -135,10 +135,11 @@ class TextGenerationCompletionService( """Handle: CoT=False, RAG=False""" try: chain = self._create_chain_without_guidelines() - completion = chain.invoke(user_prompt) - return TextGenerationCompletionResult( - original_completion=completion + result = GuidelinesResult( + completion_text=chain.invoke(user_prompt), + llm_config=chain.steps[1].model_dump() ) + return self._process_guidelines_result(result) except Exception as e: raise e diff --git a/tests/conftest.py b/tests/conftest.py index 619396b1b..e717e4b26 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ from src.text_generation.services.prompt_injection.prompt_injection_example_serv from src.text_generation.services.utilities.response_processing_service import ResponseProcessingService -MAX_REQUEST_SAMPLE_COUNT = 100 +MAX_REQUEST_SAMPLE_COUNT = 1 MAX_RESPONSE_SAMPLE_COUNT = 50 def pytest_deselected(items): diff --git a/tests/integration/test_01_malicious_prompts_no_mitigation.py b/tests/integration/test_01_malicious_prompts_no_mitigation.py index bf1be4d5e..ac2138166 100644 --- a/tests/integration/test_01_malicious_prompts_no_mitigation.py +++ b/tests/integration/test_01_malicious_prompts_no_mitigation.py @@ -1,6 +1,6 @@ """ Usage: - $ pytest -k test_00_benign_prompts_no_mitigation -s + $ pytest -k test_01_malicious_prompts_no_mitigation -s Data used: 1. Samples of compromised output (resulting from NVIDIA garak tests) diff --git a/tests/integration/test_02_malicious_prompts_cot.py b/tests/integration/test_02_malicious_prompts_cot.py index 1dac6d2b3..15ab37383 100644 --- a/tests/integration/test_02_malicious_prompts_cot.py +++ b/tests/integration/test_02_malicious_prompts_cot.py @@ -1,6 +1,6 @@ """ Usage: - $ pytest -k test_00_benign_prompts_no_mitigation -s + $ pytest -k test_02_malicious_prompts_cot -s Data used: 1. Samples of compromised output (resulting from NVIDIA garak tests) diff --git a/tests/integration/test_03_malicious_prompts_rag.py b/tests/integration/test_03_malicious_prompts_rag.py index a6c583bc1..18d9d4e6b 100644 --- a/tests/integration/test_03_malicious_prompts_rag.py +++ b/tests/integration/test_03_malicious_prompts_rag.py @@ -1,6 +1,6 @@ """ Usage: - $ pytest -k test_00_benign_prompts_no_mitigation -s + $ pytest -k test_03_malicious_prompts_rag -s Data used: 1. Samples of compromised output (resulting from NVIDIA garak tests) diff --git a/tests/integration/test_04_malicious_prompts_rag_and_cot.py b/tests/integration/test_04_malicious_prompts_rag_and_cot.py index 90a8f4916..60c9d118e 100644 --- a/tests/integration/test_04_malicious_prompts_rag_and_cot.py +++ b/tests/integration/test_04_malicious_prompts_rag_and_cot.py @@ -1,6 +1,6 @@ """ Usage: - $ pytest -k test_00_benign_prompts_no_mitigation -s + $ pytest -k test_04_malicious_prompts_rag_and_cot -s Data used: 1. Samples of compromised output (resulting from NVIDIA garak tests) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 7188cd99c..0ef31f8a5 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -50,6 +50,7 @@ def run_prompt_analysis_test( is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(), is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(), is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(), + llm_config=completion_result.guidelines_processed_completion.llm_config, scores=result.scores, mean=result.mean, max=result.max