WIP reflexion service

This commit is contained in:
Adam Wilson
2025-07-27 13:59:01 -06:00
parent eddacd87fa
commit 99ec0ddf98
8 changed files with 106 additions and 94 deletions
@@ -7,12 +7,14 @@ class GuidelinesResult(
AbstractGuidelinesProcessedCompletion):
def __init__(
self,
user_prompt: str,
completion_text: str,
full_prompt: dict[str, Any],
llm_config: dict,
cosine_similarity_score: float = 0.0,
cosine_similarity_risk_threshold: float = 0.0):
self.user_prompt = user_prompt
self.completion_text = completion_text
self.full_prompt = full_prompt
self.llm_config = llm_config
@@ -16,11 +16,13 @@ class TextGenerationCompletionResult(
def __init__(
self,
llm_config: dict,
original_user_prompt: str,
original_completion: str,
guidelines_result: Optional[GuidelinesResult] = None,
guardrails_result: Optional[GuardrailsResult] = None):
self.llm_config = llm_config
self.original_user_prompt = original_user_prompt
self.original_completion = original_completion
self.guidelines_processed_completion = guidelines_result
self.guardrails_processed_completion = guardrails_result
@@ -1,9 +1,8 @@
import abc
from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
class AbstractGeneratedTextGuardrailService(abc.ABC):
@abc.abstractmethod
def process_generated_text(self, model_generated_text: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
raise NotImplementedError
@@ -17,7 +17,7 @@ class GeneratedTextGuardrailService(
def use_comparison_texts(self, comparison_texts: list[str]):
self.semantic_similarity_service.use_comparison_texts(comparison_texts)
def process_generated_text(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
def apply_guardrails(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion:
score: float = self.semantic_similarity_service.analyze(text=completion)
response = GuardrailsResult(
cosine_similarity_score=score,
@@ -1,15 +1,101 @@
from typing import Optional
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, StringPromptTemplate
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import FewShotPromptTemplate
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion
from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
class ReflexionSecurityGuardrailsService(
AbstractGeneratedTextGuardrailService):
"""Basic implementation of reflexion security guidelines service."""
def process_generated_text(self, model_generated_text: str) -> AbstractGuardrailsProcessedCompletion:
"""
Apply basic reflexion security guidelines
"""
return ""
def __init__(
self,
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
super().__init__()
self.constants = Constants()
self.foundation_model_pipeline = foundation_model.create_pipeline()
self.response_processing_service = response_processing_service
self.prompt_template_service = prompt_template_service
self.llm_configuration_introspection_service = llm_configuration_introspection_service
self.config_builder = config_builder
def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str:
if not self.vectorstore:
raise ValueError("Vector store not initialized")
result: TextGenerationCompletionResult = text_generation_completion_result
original_user_prompt = result.original_user_prompt
original_completion = result.original_completion
# context_docs = self.vectorstore.as_retriever().invoke(?)
# context_parts = []
# for i, doc in enumerate(context_docs, 1):
# context_parts.append(f"Example {i}:\n{doc.page_content}")
# return "\n\n".join(context_parts)
def _get_template(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> StringPromptTemplate:
prompt_template: StringPromptTemplate = self.prompt_template_service.get(id=self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__05_REFLEXION)
context = self._create_context_from_rag(text_generation_completion_result)
# Create a new template with the context filled in
filled_template = PromptTemplate(
input_variables=[self.constants.INPUT_VARIABLE_TOKEN],
template=prompt_template.template.replace("{context}", context)
)
return filled_template
def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult:
if not text_generation_completion_result:
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
try:
prompt_template: StringPromptTemplate = self._get_template(user_prompt=)
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
prompt_dict = {
"messages": [
{"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs}
for msg in prompt_value.to_messages()
],
"string_representation": prompt_value.to_string(),
}
chain = self._create_chain(prompt_template)
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
)
return result
except Exception as e:
raise e
@@ -55,55 +55,15 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
"""
raise NotImplementedError("Subclasses must implement _get_template()")
def _find_llm_step(self, chain):
if hasattr(chain, 'steps'):
for i, step in enumerate(chain.steps):
if step.__class__.__name__ == 'HuggingFacePipeline':
return step
return None
def _extract_llm_config(self, llm_step):
if not llm_step:
return {}
full_config = llm_step.model_dump()
serializable_keys = [
'batch_size',
'device',
'do_sample',
'temperature',
'top_p',
'top_k',
'max_new_tokens',
'max_length',
'repetition_penalty',
'pad_token_id',
'eos_token_id',
'model_id',
'task',
'return_full_text'
]
config = {}
for key, value in full_config.items():
if key in serializable_keys and isinstance(value, (str, int, float, bool, type(None))):
config[key] = value
return config
def apply_guidelines(self, user_prompt: str) -> AbstractGuidelinesProcessedCompletion:
print(f'applying guidelines (if any set)')
if not user_prompt:
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
try:
prompt_template: StringPromptTemplate = self._get_template(user_prompt=user_prompt)
print(f'got prompt template')
prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt)
# Create a comprehensive dict
prompt_dict = {
"messages": [
{"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs}
@@ -112,21 +72,12 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
"string_representation": prompt_value.to_string(),
}
print(f'creating chain...')
chain = self._create_chain(prompt_template)
print(f'Chain type: {type(chain)}')
print(f'Number of steps: {len(chain.steps) if hasattr(chain, "steps") else "No steps attribute"}')
completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt})
# Print each step to see what's at each position
if hasattr(chain, 'steps'):
for i, step in enumerate(chain.steps):
print(f'Step {i}: {type(step)} - {step.__class__.__name__}')
print(f'generating completion...')
completion_text=chain.invoke({"input": user_prompt})
llm_step = self._find_llm_step(chain)
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
user_prompt=user_prompt,
completion_text=completion_text,
llm_config=llm_config,
full_prompt=prompt_dict
@@ -77,7 +77,6 @@ class TextGenerationCompletionService(
# introspection for logging
self.llm_configuration_introspection_service = llm_configuration_introspection_service
def _process_prompt_with_guidelines_if_applicable(self, user_prompt: str):
guidelines_config = (
@@ -111,6 +110,7 @@ class TextGenerationCompletionService(
completion_result = TextGenerationCompletionResult(
llm_config = guidelines_result.llm_config,
original_completion = guidelines_result.completion_text,
original_user_prompt = guidelines_result.user_prompt,
guidelines_result = guidelines_result
)
@@ -123,30 +123,6 @@ class TextGenerationCompletionService(
return completion_result
def _get_active_model_configuration_params(self, hf_pipeline_component):
pipeline_obj = hf_pipeline_component.pipeline
# model defaults
active_params = {}
if hasattr(pipeline_obj.model, 'generation_config'):
gen_config = pipeline_obj.model.generation_config
active_params.update({
'temperature': getattr(gen_config, 'temperature', None),
'top_p': getattr(gen_config, 'top_p', None),
'top_k': getattr(gen_config, 'top_k', None),
'max_new_tokens': getattr(gen_config, 'max_new_tokens', None),
'max_length': getattr(gen_config, 'max_length', None),
'repetition_penalty': getattr(gen_config, 'repetition_penalty', None),
'do_sample': getattr(gen_config, 'do_sample', None),
})
# get pipeline-specific override parameters
if hasattr(pipeline_obj, '_forward_params'):
forward_params = pipeline_obj._forward_params
active_params.update(forward_params)
return active_params
# Handler methods for each guidelines combination
def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult:
"""Handle: CoT=True, RAG=True"""
@@ -167,7 +143,7 @@ class TextGenerationCompletionService(
"""Handle: CoT=False, RAG=False"""
try:
chain = self._create_chain_without_guidelines()
llm_config = self._get_active_model_configuration_params(chain.steps[2])
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
completion_text = chain.invoke(user_prompt),
llm_config = llm_config
@@ -1,12 +1,8 @@
import abc
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
class LLMConfigurationIntrospectionService(
AbstractLLMConfigurationIntrospectionService):
# llm_configuration_introspection_service
def get_config(self, lcel_chain, max_depth=10):
"""
@@ -482,7 +478,7 @@ class LLMConfigurationIntrospectionService(
# Group by category with priority for generation params
categories = {
'🔥 CRITICAL Generation Parameters': [],
'CRITICAL Generation Parameters': [],
'Other Generation Parameters': [],
'Model Configuration': [],
'API Settings': [],
@@ -497,7 +493,7 @@ class LLMConfigurationIntrospectionService(
# Check if it's a critical parameter
if any(param in key.lower() for param in critical_param_names):
categories['🔥 CRITICAL Generation Parameters'].append((key, value))
categories['CRITICAL Generation Parameters'].append((key, value))
categorized = True
elif any(param in key.lower() for param in ['penalty', 'sample', 'beam', 'length']):
categories['Other Generation Parameters'].append((key, value))