mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-05-22 00:06:58 +02:00
WIP reflexion service
This commit is contained in:
@@ -7,12 +7,14 @@ class GuidelinesResult(
|
||||
AbstractGuidelinesProcessedCompletion):
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt: str,
|
||||
completion_text: str,
|
||||
full_prompt: dict[str, Any],
|
||||
llm_config: dict,
|
||||
cosine_similarity_score: float = 0.0,
|
||||
cosine_similarity_risk_threshold: float = 0.0):
|
||||
|
||||
self.user_prompt = user_prompt
|
||||
self.completion_text = completion_text
|
||||
self.full_prompt = full_prompt
|
||||
self.llm_config = llm_config
|
||||
|
||||
@@ -16,11 +16,13 @@ class TextGenerationCompletionResult(
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: dict,
|
||||
original_user_prompt: str,
|
||||
original_completion: str,
|
||||
guidelines_result: Optional[GuidelinesResult] = None,
|
||||
guardrails_result: Optional[GuardrailsResult] = None):
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.original_user_prompt = original_user_prompt
|
||||
self.original_completion = original_completion
|
||||
self.guidelines_processed_completion = guidelines_result
|
||||
self.guardrails_processed_completion = guardrails_result
|
||||
|
||||
+1
-2
@@ -1,9 +1,8 @@
|
||||
import abc
|
||||
from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion
|
||||
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
|
||||
|
||||
|
||||
class AbstractGeneratedTextGuardrailService(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def process_generated_text(self, model_generated_text: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
|
||||
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
|
||||
raise NotImplementedError
|
||||
@@ -17,7 +17,7 @@ class GeneratedTextGuardrailService(
|
||||
def use_comparison_texts(self, comparison_texts: list[str]):
|
||||
self.semantic_similarity_service.use_comparison_texts(comparison_texts)
|
||||
|
||||
def process_generated_text(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
|
||||
def apply_guardrails(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
|
||||
score: float = self.semantic_similarity_service.analyze(text=completion)
|
||||
response = GuardrailsResult(
|
||||
cosine_similarity_score=score,
|
||||
|
||||
@@ -1,15 +1,101 @@
|
||||
from typing import Optional
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate, StringPromptTemplate
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain.prompts import FewShotPromptTemplate
|
||||
|
||||
from src.text_generation.common.constants import Constants
|
||||
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
|
||||
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
|
||||
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
|
||||
|
||||
from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion
|
||||
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
|
||||
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
|
||||
|
||||
|
||||
class ReflexionSecurityGuardrailsService(
|
||||
AbstractGeneratedTextGuardrailService):
|
||||
"""Basic implementation of reflexion security guidelines service."""
|
||||
|
||||
def process_generated_text(self, model_generated_text: str) -> AbstractGuardrailsProcessedCompletion:
|
||||
"""
|
||||
Apply basic reflexion security guidelines
|
||||
"""
|
||||
|
||||
return ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
foundation_model: AbstractFoundationModel,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
|
||||
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
|
||||
super().__init__()
|
||||
self.constants = Constants()
|
||||
self.foundation_model_pipeline = foundation_model.create_pipeline()
|
||||
self.response_processing_service = response_processing_service
|
||||
self.prompt_template_service = prompt_template_service
|
||||
self.llm_configuration_introspection_service = llm_configuration_introspection_service
|
||||
self.config_builder = config_builder
|
||||
|
||||
|
||||
def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str:
|
||||
if not self.vectorstore:
|
||||
raise ValueError("Vector store not initialized")
|
||||
|
||||
result: TextGenerationCompletionResult = text_generation_completion_result
|
||||
original_user_prompt = result.original_user_prompt
|
||||
original_completion = result.original_completion
|
||||
|
||||
# context_docs = self.vectorstore.as_retriever().invoke(?)
|
||||
# context_parts = []
|
||||
# for i, doc in enumerate(context_docs, 1):
|
||||
# context_parts.append(f"Example {i}:\n{doc.page_content}")
|
||||
|
||||
# return "\n\n".join(context_parts)
|
||||
|
||||
|
||||
def _get_template(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> StringPromptTemplate:
|
||||
prompt_template: StringPromptTemplate = self.prompt_template_service.get(id=self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__05_REFLEXION)
|
||||
context = self._create_context_from_rag(text_generation_completion_result)
|
||||
|
||||
# Create a new template with the context filled in
|
||||
filled_template = PromptTemplate(
|
||||
input_variables=[self.constants.INPUT_VARIABLE_TOKEN],
|
||||
template=prompt_template.template.replace("{context}", context)
|
||||
)
|
||||
|
||||
return filled_template
|
||||
|
||||
|
||||
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
|
||||
|
||||
if not text_generation_completion_result:
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
|
||||
try:
|
||||
prompt_template: StringPromptTemplate = self._get_template(user_prompt=)
|
||||
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
|
||||
prompt_dict = {
|
||||
"messages": [
|
||||
{"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs}
|
||||
for msg in prompt_value.to_messages()
|
||||
],
|
||||
"string_representation": prompt_value.to_string(),
|
||||
}
|
||||
|
||||
chain = self._create_chain(prompt_template)
|
||||
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
|
||||
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -55,55 +55,15 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _get_template()")
|
||||
|
||||
def _find_llm_step(self, chain):
|
||||
if hasattr(chain, 'steps'):
|
||||
for i, step in enumerate(chain.steps):
|
||||
if step.__class__.__name__ == 'HuggingFacePipeline':
|
||||
return step
|
||||
return None
|
||||
|
||||
def _extract_llm_config(self, llm_step):
|
||||
|
||||
if not llm_step:
|
||||
return {}
|
||||
|
||||
full_config = llm_step.model_dump()
|
||||
|
||||
serializable_keys = [
|
||||
'batch_size',
|
||||
'device',
|
||||
'do_sample',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'max_new_tokens',
|
||||
'max_length',
|
||||
'repetition_penalty',
|
||||
'pad_token_id',
|
||||
'eos_token_id',
|
||||
'model_id',
|
||||
'task',
|
||||
'return_full_text'
|
||||
]
|
||||
|
||||
config = {}
|
||||
for key, value in full_config.items():
|
||||
if key in serializable_keys and isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[key] = value
|
||||
return config
|
||||
|
||||
|
||||
def apply_guidelines(self, user_prompt: str) -> AbstractGuidelinesProcessedCompletion:
|
||||
print(f'applying guidelines (if any set)')
|
||||
|
||||
if not user_prompt:
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
|
||||
try:
|
||||
prompt_template: StringPromptTemplate = self._get_template(user_prompt=user_prompt)
|
||||
print(f'got prompt template')
|
||||
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
|
||||
|
||||
# Create a comprehensive dict
|
||||
prompt_dict = {
|
||||
"messages": [
|
||||
{"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs}
|
||||
@@ -112,21 +72,12 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
"string_representation": prompt_value.to_string(),
|
||||
}
|
||||
|
||||
print(f'creating chain...')
|
||||
chain = self._create_chain(prompt_template)
|
||||
|
||||
print(f'Chain type: {type(chain)}')
|
||||
print(f'Number of steps: {len(chain.steps) if hasattr(chain, "steps") else "No steps attribute"}')
|
||||
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
|
||||
|
||||
# Print each step to see what's at each position
|
||||
if hasattr(chain, 'steps'):
|
||||
for i, step in enumerate(chain.steps):
|
||||
print(f'Step {i}: {type(step)} - {step.__class__.__name__}')
|
||||
print(f'generating completion...')
|
||||
completion_text=chain.invoke({"input": user_prompt})
|
||||
llm_step = self._find_llm_step(chain)
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
user_prompt=user_prompt,
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
full_prompt=prompt_dict
|
||||
|
||||
@@ -77,7 +77,6 @@ class TextGenerationCompletionService(
|
||||
# introspection for logging
|
||||
self.llm_configuration_introspection_service = llm_configuration_introspection_service
|
||||
|
||||
|
||||
|
||||
def _process_prompt_with_guidelines_if_applicable(self, user_prompt: str):
|
||||
guidelines_config = (
|
||||
@@ -111,6 +110,7 @@ class TextGenerationCompletionService(
|
||||
completion_result = TextGenerationCompletionResult(
|
||||
llm_config = guidelines_result.llm_config,
|
||||
original_completion = guidelines_result.completion_text,
|
||||
original_user_prompt = guidelines_result.user_prompt,
|
||||
guidelines_result = guidelines_result
|
||||
)
|
||||
|
||||
@@ -123,30 +123,6 @@ class TextGenerationCompletionService(
|
||||
return completion_result
|
||||
|
||||
|
||||
def _get_active_model_configuration_params(self, hf_pipeline_component):
|
||||
|
||||
pipeline_obj = hf_pipeline_component.pipeline
|
||||
|
||||
# model defaults
|
||||
active_params = {}
|
||||
if hasattr(pipeline_obj.model, 'generation_config'):
|
||||
gen_config = pipeline_obj.model.generation_config
|
||||
active_params.update({
|
||||
'temperature': getattr(gen_config, 'temperature', None),
|
||||
'top_p': getattr(gen_config, 'top_p', None),
|
||||
'top_k': getattr(gen_config, 'top_k', None),
|
||||
'max_new_tokens': getattr(gen_config, 'max_new_tokens', None),
|
||||
'max_length': getattr(gen_config, 'max_length', None),
|
||||
'repetition_penalty': getattr(gen_config, 'repetition_penalty', None),
|
||||
'do_sample': getattr(gen_config, 'do_sample', None),
|
||||
})
|
||||
|
||||
# get pipeline-specific override parameters
|
||||
if hasattr(pipeline_obj, '_forward_params'):
|
||||
forward_params = pipeline_obj._forward_params
|
||||
active_params.update(forward_params)
|
||||
return active_params
|
||||
|
||||
# Handler methods for each guidelines combination
|
||||
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
|
||||
"""Handle: CoT=True, RAG=True"""
|
||||
@@ -167,7 +143,7 @@ class TextGenerationCompletionService(
|
||||
"""Handle: CoT=False, RAG=False"""
|
||||
try:
|
||||
chain = self._create_chain_without_guidelines()
|
||||
llm_config = self._get_active_model_configuration_params(chain.steps[2])
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
completion_text = chain.invoke(user_prompt),
|
||||
llm_config = llm_config
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import abc
|
||||
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
|
||||
|
||||
class LLMConfigurationIntrospectionService(
|
||||
AbstractLLMConfigurationIntrospectionService):
|
||||
# llm_configuration_introspection_service
|
||||
|
||||
|
||||
def get_config(self, lcel_chain, max_depth=10):
|
||||
"""
|
||||
@@ -482,7 +478,7 @@ class LLMConfigurationIntrospectionService(
|
||||
|
||||
# Group by category with priority for generation params
|
||||
categories = {
|
||||
'🔥 CRITICAL Generation Parameters': [],
|
||||
'CRITICAL Generation Parameters': [],
|
||||
'Other Generation Parameters': [],
|
||||
'Model Configuration': [],
|
||||
'API Settings': [],
|
||||
@@ -497,7 +493,7 @@ class LLMConfigurationIntrospectionService(
|
||||
|
||||
# Check if it's a critical parameter
|
||||
if any(param in key.lower() for param in critical_param_names):
|
||||
categories['🔥 CRITICAL Generation Parameters'].append((key, value))
|
||||
categories['CRITICAL Generation Parameters'].append((key, value))
|
||||
categorized = True
|
||||
elif any(param in key.lower() for param in ['penalty', 'sample', 'beam', 'length']):
|
||||
categories['Other Generation Parameters'].append((key, value))
|
||||
|
||||
Reference in New Issue
Block a user