From 99ec0ddf9833e1a5ded99d8f8c862870f84ffbf0 Mon Sep 17 00:00:00 2001 From: Adam Wilson Date: Sun, 27 Jul 2025 13:59:01 -0600 Subject: [PATCH] WIP reflexion service --- .../domain/guidelines_result.py | 2 + .../text_generation_completion_result.py | 2 + ...stract_generated_text_guardrail_service.py | 3 +- .../generated_text_guardrail_service.py | 2 +- .../reflexion_security_guidelines_service.py | 100 ++++++++++++++++-- .../base_security_guidelines_service.py | 55 +--------- .../nlp/text_generation_completion_service.py | 28 +---- ...llm_configuration_introspection_service.py | 8 +- 8 files changed, 106 insertions(+), 94 deletions(-) diff --git a/src/text_generation/domain/guidelines_result.py b/src/text_generation/domain/guidelines_result.py index da7034ce5..6a7f1e0c0 100644 --- a/src/text_generation/domain/guidelines_result.py +++ b/src/text_generation/domain/guidelines_result.py @@ -7,12 +7,14 @@ class GuidelinesResult( AbstractGuidelinesProcessedCompletion): def __init__( self, + user_prompt: str, completion_text: str, full_prompt: dict[str, Any], llm_config: dict, cosine_similarity_score: float = 0.0, cosine_similarity_risk_threshold: float = 0.0): + self.user_prompt = user_prompt self.completion_text = completion_text self.full_prompt = full_prompt self.llm_config = llm_config diff --git a/src/text_generation/domain/text_generation_completion_result.py b/src/text_generation/domain/text_generation_completion_result.py index 08081dadd..9ca94be0c 100644 --- a/src/text_generation/domain/text_generation_completion_result.py +++ b/src/text_generation/domain/text_generation_completion_result.py @@ -16,11 +16,13 @@ class TextGenerationCompletionResult( def __init__( self, llm_config: dict, + original_user_prompt: str, original_completion: str, guidelines_result: Optional[GuidelinesResult] = None, guardrails_result: Optional[GuardrailsResult] = None): self.llm_config = llm_config + self.original_user_prompt = original_user_prompt self.original_completion = original_completion self.guidelines_processed_completion = guidelines_result self.guardrails_processed_completion = guardrails_result diff --git a/src/text_generation/services/guardrails/abstract_generated_text_guardrail_service.py b/src/text_generation/services/guardrails/abstract_generated_text_guardrail_service.py index 9fcdea20b..6ddd66d86 100644 --- a/src/text_generation/services/guardrails/abstract_generated_text_guardrail_service.py +++ b/src/text_generation/services/guardrails/abstract_generated_text_guardrail_service.py @@ -1,9 +1,8 @@ import abc -from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult class AbstractGeneratedTextGuardrailService(abc.ABC): @abc.abstractmethod - def process_generated_text(self, model_generated_text: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion: + def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult: raise NotImplementedError \ No newline at end of file diff --git a/src/text_generation/services/guardrails/generated_text_guardrail_service.py b/src/text_generation/services/guardrails/generated_text_guardrail_service.py index c40b74cd9..66a8645d0 100644 --- a/src/text_generation/services/guardrails/generated_text_guardrail_service.py +++ b/src/text_generation/services/guardrails/generated_text_guardrail_service.py @@ -17,7 +17,7 @@ class GeneratedTextGuardrailService( def use_comparison_texts(self, comparison_texts: list[str]): self.semantic_similarity_service.use_comparison_texts(comparison_texts) - def process_generated_text(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion: + def apply_guardrails(self, completion: AbstractTextGenerationCompletionResult) -> AbstractGuardrailsProcessedCompletion: score: float = self.semantic_similarity_service.analyze(text=completion) response = GuardrailsResult( cosine_similarity_score=score, diff --git a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py index aeec1aa72..21f9cc1d6 100644 --- a/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py +++ b/src/text_generation/services/guardrails/reflexion_security_guidelines_service.py @@ -1,15 +1,101 @@ +from typing import Optional +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate, StringPromptTemplate +from langchain_core.prompt_values import PromptValue +from langchain_core.runnables import RunnablePassthrough +from langchain.prompts import FewShotPromptTemplate + +from src.text_generation.common.constants import Constants +from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion +from src.text_generation.domain.guidelines_result import GuidelinesResult +from src.text_generation.domain.text_generation_completion_result import TextGenerationCompletionResult +from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel +from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService +from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService +from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService +from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService + + from src.text_generation.domain.abstract_guardrails_processed_completion import AbstractGuardrailsProcessedCompletion +from src.text_generation.domain.abstract_text_generation_completion_result import AbstractTextGenerationCompletionResult from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService class ReflexionSecurityGuardrailsService( AbstractGeneratedTextGuardrailService): """Basic implementation of reflexion security guidelines service.""" - - def process_generated_text(self, model_generated_text: str) -> AbstractGuardrailsProcessedCompletion: - """ - Apply basic reflexion security guidelines - """ - - return "" \ No newline at end of file + + def __init__( + self, + foundation_model: AbstractFoundationModel, + response_processing_service: AbstractResponseProcessingService, + prompt_template_service: AbstractPromptTemplateService, + llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService, + config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None): + super().__init__() + self.constants = Constants() + self.foundation_model_pipeline = foundation_model.create_pipeline() + self.response_processing_service = response_processing_service + self.prompt_template_service = prompt_template_service + self.llm_configuration_introspection_service = llm_configuration_introspection_service + self.config_builder = config_builder + + + def _create_context_from_rag(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> str: + if not self.vectorstore: + raise ValueError("Vector store not initialized") + + result: TextGenerationCompletionResult = text_generation_completion_result + original_user_prompt = result.original_user_prompt + original_completion = result.original_completion + + # context_docs = self.vectorstore.as_retriever().invoke(?) + # context_parts = [] + # for i, doc in enumerate(context_docs, 1): + # context_parts.append(f"Example {i}:\n{doc.page_content}") + + # return "\n\n".join(context_parts) + + + def _get_template(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> StringPromptTemplate: + prompt_template: StringPromptTemplate = self.prompt_template_service.get(id=self.constants.PromptTemplateIds.PHI_3_MINI_4K_INSTRUCT__05_REFLEXION) + context = self._create_context_from_rag(text_generation_completion_result) + + # Create a new template with the context filled in + filled_template = PromptTemplate( + input_variables=[self.constants.INPUT_VARIABLE_TOKEN], + template=prompt_template.template.replace("{context}", context) + ) + + return filled_template + + + def apply_guardrails(self, text_generation_completion_result: AbstractTextGenerationCompletionResult) -> AbstractTextGenerationCompletionResult: + + if not text_generation_completion_result: + raise ValueError(f"Parameter 'user_prompt' cannot be empty or None") + + try: + prompt_template: StringPromptTemplate = self._get_template(user_prompt=) + prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt) + prompt_dict = { + "messages": [ + {"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs} + for msg in prompt_value.to_messages() + ], + "string_representation": prompt_value.to_string(), + } + + chain = self._create_chain(prompt_template) + completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt}) + + llm_config = self.llm_configuration_introspection_service.get_config(chain) + result = GuidelinesResult( + completion_text=completion_text, + llm_config=llm_config, + full_prompt=prompt_dict + ) + return result + except Exception as e: + raise e \ No newline at end of file diff --git a/src/text_generation/services/guidelines/base_security_guidelines_service.py b/src/text_generation/services/guidelines/base_security_guidelines_service.py index bba7b5039..e97516d5e 100644 --- a/src/text_generation/services/guidelines/base_security_guidelines_service.py +++ b/src/text_generation/services/guidelines/base_security_guidelines_service.py @@ -55,55 +55,15 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService): """ raise NotImplementedError("Subclasses must implement _get_template()") - def _find_llm_step(self, chain): - if hasattr(chain, 'steps'): - for i, step in enumerate(chain.steps): - if step.__class__.__name__ == 'HuggingFacePipeline': - return step - return None - - def _extract_llm_config(self, llm_step): - - if not llm_step: - return {} - - full_config = llm_step.model_dump() - - serializable_keys = [ - 'batch_size', - 'device', - 'do_sample', - 'temperature', - 'top_p', - 'top_k', - 'max_new_tokens', - 'max_length', - 'repetition_penalty', - 'pad_token_id', - 'eos_token_id', - 'model_id', - 'task', - 'return_full_text' - ] - - config = {} - for key, value in full_config.items(): - if key in serializable_keys and isinstance(value, (str, int, float, bool, type(None))): - config[key] = value - return config - def apply_guidelines(self, user_prompt: str) -> AbstractGuidelinesProcessedCompletion: - print(f'applying guidelines (if any set)') + if not user_prompt: raise ValueError(f"Parameter 'user_prompt' cannot be empty or None") try: prompt_template: StringPromptTemplate = self._get_template(user_prompt=user_prompt) - print(f'got prompt template') prompt_value: PromptValue = prompt_template.format_prompt(input=user_prompt) - - # Create a comprehensive dict prompt_dict = { "messages": [ {"role": msg.type, "content": msg.content, "additional_kwargs": msg.additional_kwargs} @@ -112,21 +72,12 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService): "string_representation": prompt_value.to_string(), } - print(f'creating chain...') chain = self._create_chain(prompt_template) - - print(f'Chain type: {type(chain)}') - print(f'Number of steps: {len(chain.steps) if hasattr(chain, "steps") else "No steps attribute"}') + completion_text=chain.invoke({self.constants.INPUT_VARIABLE_TOKEN: user_prompt}) - # Print each step to see what's at each position - if hasattr(chain, 'steps'): - for i, step in enumerate(chain.steps): - print(f'Step {i}: {type(step)} - {step.__class__.__name__}') - print(f'generating completion...') - completion_text=chain.invoke({"input": user_prompt}) - llm_step = self._find_llm_step(chain) llm_config = self.llm_configuration_introspection_service.get_config(chain) result = GuidelinesResult( + user_prompt=user_prompt, completion_text=completion_text, llm_config=llm_config, full_prompt=prompt_dict diff --git a/src/text_generation/services/nlp/text_generation_completion_service.py b/src/text_generation/services/nlp/text_generation_completion_service.py index 780b4ac91..59d1d0966 100644 --- a/src/text_generation/services/nlp/text_generation_completion_service.py +++ b/src/text_generation/services/nlp/text_generation_completion_service.py @@ -77,7 +77,6 @@ class TextGenerationCompletionService( # introspection for logging self.llm_configuration_introspection_service = llm_configuration_introspection_service - def _process_prompt_with_guidelines_if_applicable(self, user_prompt: str): guidelines_config = ( @@ -111,6 +110,7 @@ class TextGenerationCompletionService( completion_result = TextGenerationCompletionResult( llm_config = guidelines_result.llm_config, original_completion = guidelines_result.completion_text, + original_user_prompt = guidelines_result.user_prompt, guidelines_result = guidelines_result ) @@ -123,30 +123,6 @@ class TextGenerationCompletionService( return completion_result - def _get_active_model_configuration_params(self, hf_pipeline_component): - - pipeline_obj = hf_pipeline_component.pipeline - - # model defaults - active_params = {} - if hasattr(pipeline_obj.model, 'generation_config'): - gen_config = pipeline_obj.model.generation_config - active_params.update({ - 'temperature': getattr(gen_config, 'temperature', None), - 'top_p': getattr(gen_config, 'top_p', None), - 'top_k': getattr(gen_config, 'top_k', None), - 'max_new_tokens': getattr(gen_config, 'max_new_tokens', None), - 'max_length': getattr(gen_config, 'max_length', None), - 'repetition_penalty': getattr(gen_config, 'repetition_penalty', None), - 'do_sample': getattr(gen_config, 'do_sample', None), - }) - - # get pipeline-specific override parameters - if hasattr(pipeline_obj, '_forward_params'): - forward_params = pipeline_obj._forward_params - active_params.update(forward_params) - return active_params - # Handler methods for each guidelines combination def _handle_cot_and_rag(self, user_prompt: str) -> TextGenerationCompletionResult: """Handle: CoT=True, RAG=True""" @@ -167,7 +143,7 @@ class TextGenerationCompletionService( """Handle: CoT=False, RAG=False""" try: chain = self._create_chain_without_guidelines() - llm_config = self._get_active_model_configuration_params(chain.steps[2]) + llm_config = self.llm_configuration_introspection_service.get_config(chain) result = GuidelinesResult( completion_text = chain.invoke(user_prompt), llm_config = llm_config diff --git a/src/text_generation/services/utilities/llm_configuration_introspection_service.py b/src/text_generation/services/utilities/llm_configuration_introspection_service.py index 31285229c..a1185a314 100644 --- a/src/text_generation/services/utilities/llm_configuration_introspection_service.py +++ b/src/text_generation/services/utilities/llm_configuration_introspection_service.py @@ -1,12 +1,8 @@ -import abc - from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService class LLMConfigurationIntrospectionService( AbstractLLMConfigurationIntrospectionService): - # llm_configuration_introspection_service - def get_config(self, lcel_chain, max_depth=10): """ @@ -482,7 +478,7 @@ class LLMConfigurationIntrospectionService( # Group by category with priority for generation params categories = { - '🔥 CRITICAL Generation Parameters': [], + 'CRITICAL Generation Parameters': [], 'Other Generation Parameters': [], 'Model Configuration': [], 'API Settings': [], @@ -497,7 +493,7 @@ class LLMConfigurationIntrospectionService( # Check if it's a critical parameter if any(param in key.lower() for param in critical_param_names): - categories['🔥 CRITICAL Generation Parameters'].append((key, value)) + categories['CRITICAL Generation Parameters'].append((key, value)) categorized = True elif any(param in key.lower() for param in ['penalty', 'sample', 'beam', 'length']): categories['Other Generation Parameters'].append((key, value))