log LLM config

This commit is contained in:
Adam Wilson
2025-07-23 20:21:42 -06:00
parent cb92890bb9
commit ae279a512d
10 changed files with 31 additions and 21 deletions
@@ -1,3 +1,4 @@
from typing import List
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
@@ -6,12 +7,12 @@ class GuidelinesResult(
AbstractGuidelinesProcessedCompletion):
def __init__(
self,
completion_text: str,
llm_config: dict,
cosine_similarity_score: float,
cosine_similarity_risk_threshold: float,
original_completion: str):
cosine_similarity_risk_threshold: float):
self.score = cosine_similarity_score
self.original_completion = original_completion
self.is_original_completion_malicious = (
cosine_similarity_score >= cosine_similarity_risk_threshold
)
self.completion_text = completion_text
self.llm_config = llm_config
self.cosine_similarity_score = cosine_similarity_score
self.cosine_similarity_risk_threshold = cosine_similarity_risk_threshold
@@ -1,10 +1,11 @@
from typing import Optional
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, StringPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
@@ -51,6 +52,10 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
try:
prompt_template = self._get_template()
chain = self._create_chain(prompt_template)
return chain.invoke(user_prompt)
result = GuidelinesResult(
completion_text=chain.invoke(user_prompt),
llm_config=chain.steps[1].model_dump()
)
return result
except Exception as e:
raise e
@@ -41,6 +41,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
is_rag_few_shot_enabled: bool,
is_cot_enabled: bool,
is_reflexion_enabled: bool,
llm_config: dict,
scores: List[float],
mean: float,
max: float):
@@ -54,6 +55,7 @@ class TestRunLoggingService(AbstractTestRunLoggingService):
"is_rag_few_shot_enabled": is_rag_few_shot_enabled,
"is_cot_enabled": is_cot_enabled,
"is_reflexion_enabled": is_reflexion_enabled,
"llm_config": llm_config,
"mean": mean,
"max": max,
"scores": scores
@@ -1,6 +1,6 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough, RunnableConfig
from src.text_generation.common.constants import Constants
from src.text_generation.domain.alternate_completion_result import AlternateCompletionResult
@@ -93,17 +93,17 @@ class TextGenerationCompletionService(
self.example_prompt_injection_completions
)
similarity_result: SemanticSimilarityResult = self.semantic_similarity_service.analyze(
text=guidelines_result.original_completion
text=guidelines_result.completion_text
)
processed_guidelines_result = GuidelinesResult(
original_completion=guidelines_result.original_completion,
completion_text=guidelines_result.completion_text,
cosine_similarity_score=similarity_result.mean,
cosine_similarity_risk_threshold=self.COSINE_SIMILARITY_RISK_THRESHOLD
)
completion_result = TextGenerationCompletionResult(
original_completion=guidelines_result.original_completion,
original_completion=guidelines_result.completion_text,
guidelines_result=processed_guidelines_result
)
@@ -135,10 +135,11 @@ class TextGenerationCompletionService(
"""Handle: CoT=False, RAG=False"""
try:
chain = self._create_chain_without_guidelines()
completion = chain.invoke(user_prompt)
return TextGenerationCompletionResult(
original_completion=completion
result = GuidelinesResult(
completion_text=chain.invoke(user_prompt),
llm_config=chain.steps[1].model_dump()
)
return self._process_guidelines_result(result)
except Exception as e:
raise e
+1 -1
View File
@@ -33,7 +33,7 @@ from src.text_generation.services.prompt_injection.prompt_injection_example_serv
from src.text_generation.services.utilities.response_processing_service import ResponseProcessingService
MAX_REQUEST_SAMPLE_COUNT = 100
MAX_REQUEST_SAMPLE_COUNT = 1
MAX_RESPONSE_SAMPLE_COUNT = 50
def pytest_deselected(items):
@@ -1,6 +1,6 @@
"""
Usage:
$ pytest -k test_00_benign_prompts_no_mitigation -s
$ pytest -k test_01_malicious_prompts_no_mitigation -s
Data used:
1. Samples of compromised output (resulting from NVIDIA garak tests)
@@ -1,6 +1,6 @@
"""
Usage:
$ pytest -k test_00_benign_prompts_no_mitigation -s
$ pytest -k test_02_malicious_prompts_cot -s
Data used:
1. Samples of compromised output (resulting from NVIDIA garak tests)
@@ -1,6 +1,6 @@
"""
Usage:
$ pytest -k test_00_benign_prompts_no_mitigation -s
$ pytest -k test_03_malicious_prompts_rag -s
Data used:
1. Samples of compromised output (resulting from NVIDIA garak tests)
@@ -1,6 +1,6 @@
"""
Usage:
$ pytest -k test_00_benign_prompts_no_mitigation -s
$ pytest -k test_04_malicious_prompts_rag_and_cot -s
Data used:
1. Samples of compromised output (resulting from NVIDIA garak tests)
+1
View File
@@ -50,6 +50,7 @@ def run_prompt_analysis_test(
is_rag_few_shot_enabled=text_generation_completion_service.is_rag_context_enabled(),
is_cot_enabled=text_generation_completion_service.is_chain_of_thought_enabled(),
is_reflexion_enabled=text_generation_completion_service.is_reflexion_enabled(),
llm_config=completion_result.guidelines_processed_completion.llm_config,
scores=result.scores,
mean=result.mean,
max=result.max