LLM config output

This commit is contained in:
Adam Wilson
2025-07-27 11:21:12 -06:00
parent a7a6873e73
commit eddacd87fa
7 changed files with 616 additions and 254 deletions
@@ -43,6 +43,7 @@ class TextGenerationFoundationModel(AbstractFoundationModel):
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# TODO: fix Both `max_new_tokens` (=512) and `max_length`(=1024) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
return HuggingFacePipeline(
pipeline=pipe,
@@ -11,6 +11,7 @@ from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
@@ -22,12 +23,14 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
super().__init__()
self.constants = Constants()
self.foundation_model_pipeline = foundation_model.create_pipeline()
self.response_processing_service = response_processing_service
self.prompt_template_service = prompt_template_service
self.llm_configuration_introspection_service = llm_configuration_introspection_service
self.config_builder = config_builder
def _create_chain(self, prompt_template: PromptTemplate):
@@ -60,6 +63,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
return None
def _extract_llm_config(self, llm_step):
if not llm_step:
return {}
@@ -110,6 +114,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
print(f'creating chain...')
chain = self._create_chain(prompt_template)
print(f'Chain type: {type(chain)}')
print(f'Number of steps: {len(chain.steps) if hasattr(chain, "steps") else "No steps attribute"}')
@@ -120,7 +125,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
print(f'generating completion...')
completion_text=chain.invoke({"input": user_prompt})
llm_step = self._find_llm_step(chain)
llm_config = self._extract_llm_config(llm_step)
llm_config = self.llm_configuration_introspection_service.get_config(chain)
result = GuidelinesResult(
completion_text=completion_text,
llm_config=llm_config,
@@ -5,6 +5,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
class ChainOfThoughtSecurityGuidelinesService(BaseSecurityGuidelinesService):
@@ -15,11 +16,13 @@ class ChainOfThoughtSecurityGuidelinesService(BaseSecurityGuidelinesService):
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
super().__init__(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=config_builder
)
@@ -4,6 +4,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
@@ -14,11 +15,13 @@ class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
config_builder: AbstractSecurityGuidelinesConfigurationBuilder):
super().__init__(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=config_builder
)
@@ -4,6 +4,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
@@ -17,11 +18,13 @@ class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService,
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
config_builder: AbstractSecurityGuidelinesConfigurationBuilder):
super().__init__(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=config_builder
)
@@ -6,278 +6,615 @@ from src.text_generation.services.utilities.abstract_llm_configuration_introspec
class LLMConfigurationIntrospectionService(
AbstractLLMConfigurationIntrospectionService):
# llm_configuration_introspection_service
def get_config(llm_step):
def get_config(self, lcel_chain, max_depth=10):
"""
Comprehensively extract all possible LLM configuration parameters
from a HuggingFace pipeline step, checking all known locations.
from a LangChain LCEL chain object, creating a multilayered dict structure
that preserves the chain hierarchy.
Args:
lcel_chain: A LangChain LCEL chain object (Runnable)
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
dict: All found configuration parameters that are JSON serializable
dict: Nested dictionary with full chain structure and all config parameters
"""
if not llm_step:
if not lcel_chain or max_depth <= 0:
return {}
config = {}
def is_serializable(value):
"""Check if a value is JSON serializable."""
return isinstance(value, (str, int, float, bool, type(None), list, tuple, dict))
def safe_add_to_config(source_dict, source_name="unknown"):
"""Safely add items from a dict to config if they're serializable."""
if not isinstance(source_dict, dict):
return
def safe_serialize(value):
"""Safely serialize a value, converting non-serializable objects to strings."""
if isinstance(value, (str, int, float, bool, type(None))):
return value
elif isinstance(value, (list, tuple)):
return [safe_serialize(item) for item in value]
elif isinstance(value, dict):
return {k: safe_serialize(v) for k, v in value.items() if k != '_type'}
else:
# Convert objects to string representation, but filter out some noise
str_repr = str(value)
if any(noise in str_repr for noise in ['<bound method', '<function', 'object at 0x']):
return f"<{type(value).__name__}>"
return str_repr
def extract_from_object(obj, path="root", visited=None, current_depth=0):
"""
Recursively extract configuration from any object, building a nested structure.
"""
if visited is None:
visited = set()
if current_depth >= max_depth or id(obj) in visited:
return {}
visited.add(id(obj))
result = {"_type": type(obj).__name__, "_path": path}
# === COMPREHENSIVE ATTRIBUTE EXTRACTION ===
# All possible LLM and chain configuration attributes
all_config_attrs = [
# Core generation parameters
'temperature', 'top_p', 'top_k', 'max_tokens', 'max_new_tokens', 'max_length',
'min_length', 'repetition_penalty', 'frequency_penalty', 'presence_penalty',
'length_penalty', 'do_sample', 'early_stopping', 'num_beams', 'num_beam_groups',
'diversity_penalty', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'seed',
'stop', 'stop_sequences', 'suffix', 'logit_bias', 'user', 'n', 'best_of',
'logprobs', 'echo', 'response_format', 'tool_choice', 'parallel_tool_calls',
for key, value in source_dict.items():
if isinstance(value, (str, int, float, bool, type(None))):
config[key] = value
elif isinstance(value, (list, tuple)) and all(isinstance(x, (str, int, float, bool, type(None))) for x in value):
config[key] = list(value)
# Skip non-serializable objects
# === LOCATION 1: Direct attributes on llm_step ===
direct_llm_attrs = [
# Generation parameters
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length', 'min_length',
'repetition_penalty', 'length_penalty', 'do_sample', 'early_stopping',
'num_beams', 'num_beam_groups', 'diversity_penalty', 'typical_p',
'epsilon_cutoff', 'eta_cutoff', 'exponential_decay_length_penalty',
# Token IDs
'pad_token_id', 'eos_token_id', 'bos_token_id', 'decoder_start_token_id',
'forced_bos_token_id', 'forced_eos_token_id',
# Model identifiers
'model_id', 'model_name', 'model_path', 'model_type',
# Task and device settings
'task', 'device', 'device_map', 'torch_dtype',
# Pipeline settings
'batch_size', 'max_batch_size', 'return_full_text', 'clean_up_tokenization_spaces',
'truncation', 'padding', 'add_special_tokens',
# Performance settings
'use_cache', 'cache_dir', 'revision', 'trust_remote_code',
'low_cpu_mem_usage', 'load_in_8bit', 'load_in_4bit',
# Quantization settings
'quantization_config', 'bnb_4bit_compute_dtype', 'bnb_4bit_quant_type',
'bnb_4bit_use_double_quant',
# Other generation settings
'seed', 'guidance_scale', 'negative_prompt', 'num_images_per_prompt',
'eta', 'generator', 'latents', 'prompt_embeds', 'negative_prompt_embeds',
'cross_attention_kwargs', 'guidance_rescale', 'clip_skip',
# Sampling parameters
'top_a', 'tfs', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta',
'penalty_alpha', 'use_mirostat_sampling',
# Stop conditions
'stop_sequences', 'stop_token_ids', 'stopping_criteria',
# Memory and efficiency
'offload_folder', 'cpu_offload', 'sequential_cpu_offload',
'model_cpu_offload', 'disk_offload',
# Framework specific
'framework', 'use_fast', 'use_auth_token', 'subfolder',
]
for attr in direct_llm_attrs:
if hasattr(llm_step, attr):
value = getattr(llm_step, attr)
if isinstance(value, (str, int, float, bool, type(None))):
config[attr] = value
elif isinstance(value, (list, tuple)) and all(isinstance(x, (str, int, float, bool, type(None))) for x in value):
config[attr] = list(value)
# === LOCATION 2: model_kwargs ===
if hasattr(llm_step, 'model_kwargs') and llm_step.model_kwargs:
safe_add_to_config(llm_step.model_kwargs, "model_kwargs")
# === LOCATION 3: pipeline_kwargs ===
if hasattr(llm_step, 'pipeline_kwargs') and llm_step.pipeline_kwargs:
safe_add_to_config(llm_step.pipeline_kwargs, "pipeline_kwargs")
# === LOCATION 4: Pipeline object and its attributes ===
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
pipeline = llm_step.pipeline
# Direct pipeline attributes
pipeline_attrs = [
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length',
'repetition_penalty', 'do_sample', 'pad_token_id', 'eos_token_id',
'return_full_text', 'clean_up_tokenization_spaces', 'prefix',
'handle_long_generation', 'batch_size'
# Model and API configuration
'model', 'model_name', 'model_id', 'model_path', 'model_type', 'engine',
'deployment_name', 'deployment_id', 'model_version', 'model_revision',
'api_key', 'api_base', 'api_version', 'api_type', 'organization', 'base_url',
'endpoint', 'region', 'project_id', 'project', 'location', 'credentials',
# Provider-specific keys
'openai_api_key', 'openai_organization', 'openai_api_base', 'openai_proxy',
'anthropic_api_key', 'anthropic_api_url', 'max_tokens_to_sample',
'cohere_api_key', 'huggingfacehub_api_token', 'repo_id', 'task',
'google_api_key', 'vertex_ai_model', 'azure_endpoint', 'azure_deployment',
'azure_api_version', 'azure_api_key', 'replicate_api_token',
'together_api_key', 'fireworks_api_key', 'groq_api_key', 'mistral_api_key',
# Request and performance settings
'max_retries', 'request_timeout', 'timeout', 'streaming', 'chunk_size',
'max_concurrent_requests', 'rate_limit', 'batch_size', 'max_batch_size',
'use_cache', 'cache_dir', 'cache_size', 'device', 'device_map', 'torch_dtype',
'load_in_8bit', 'load_in_4bit', 'trust_remote_code', 'revision',
# Token handling
'pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
'sep_token_id', 'cls_token_id', 'mask_token_id', 'decoder_start_token_id',
'forced_bos_token_id', 'forced_eos_token_id',
# Chain-specific attributes
'verbose', 'name', 'tags', 'metadata', 'callbacks', 'memory', 'memory_key',
'return_messages', 'input_key', 'output_key', 'prompt', 'llm_chain',
'combine_documents_chain', 'question_generator', 'retriever',
# Pipeline and processing
'return_full_text', 'clean_up_tokenization_spaces', 'truncation', 'padding',
'add_special_tokens', 'handle_long_generation', 'prefix',
# Advanced parameters
'penalty_alpha', 'use_mirostat_sampling', 'mirostat_mode', 'mirostat_tau',
'mirostat_eta', 'tfs', 'top_a', 'k', 'p', 'include_stop_str_in_output',
'ignore_eos', 'skip_special_tokens', 'spaces_between_special_tokens',
]
for attr in pipeline_attrs:
if hasattr(pipeline, attr):
value = getattr(pipeline, attr)
if isinstance(value, (str, int, float, bool, type(None))):
config[attr] = value
# === PRIORITY: Extract critical generation parameters first ===
critical_params = ['temperature', 'top_k', 'top_p', 'max_length', 'max_new_tokens',
'max_tokens', 'repetition_penalty', 'do_sample', 'num_beams']
# Check pipeline._preprocess_params
if hasattr(pipeline, '_preprocess_params'):
safe_add_to_config(pipeline._preprocess_params, "_preprocess_params")
# Check pipeline._forward_params
if hasattr(pipeline, '_forward_params'):
safe_add_to_config(pipeline._forward_params, "_forward_params")
# Check pipeline._postprocess_params
if hasattr(pipeline, '_postprocess_params'):
safe_add_to_config(pipeline._postprocess_params, "_postprocess_params")
# === LOCATION 5: Model's generation config ===
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
pipeline = llm_step.pipeline
# Try to access generation config through model
try:
if hasattr(pipeline, 'model') and hasattr(pipeline.model, 'generation_config'):
gen_config = pipeline.model.generation_config
if hasattr(gen_config, 'to_dict'):
gen_dict = gen_config.to_dict()
safe_add_to_config(gen_dict, "generation_config")
elif hasattr(gen_config, '__dict__'):
safe_add_to_config(gen_config.__dict__, "generation_config_dict")
except Exception as e:
# Silently continue if generation config access fails
pass
# Try to access config through model.config
try:
if hasattr(pipeline, 'model') and hasattr(pipeline.model, 'config'):
model_config = pipeline.model.config
if hasattr(model_config, 'to_dict'):
model_config_dict = model_config.to_dict()
# Only extract generation-related config items
generation_keys = [
'max_length', 'max_new_tokens', 'min_length', 'do_sample',
'temperature', 'top_k', 'top_p', 'repetition_penalty',
'length_penalty', 'num_beams', 'early_stopping',
'pad_token_id', 'eos_token_id', 'bos_token_id'
]
for key in generation_keys:
if key in model_config_dict:
value = model_config_dict[key]
if isinstance(value, (str, int, float, bool, type(None))):
config[key] = value
except Exception as e:
# Silently continue if model config access fails
pass
# === LOCATION 6: Tokenizer config ===
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
try:
if hasattr(llm_step.pipeline, 'tokenizer'):
tokenizer = llm_step.pipeline.tokenizer
tokenizer_attrs = [
'pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
'sep_token_id', 'cls_token_id', 'mask_token_id',
'padding_side', 'truncation_side', 'model_max_length'
]
for param in critical_params:
# Check multiple possible locations for each critical parameter
found_value = None
locations_to_check = [
# Direct attribute
(lambda: getattr(obj, param) if hasattr(obj, param) else None, f"direct.{param}"),
for attr in tokenizer_attrs:
if hasattr(tokenizer, attr):
value = getattr(tokenizer, attr)
if isinstance(value, (str, int, float, bool, type(None))):
config[f"tokenizer_{attr}"] = value
except Exception as e:
# Silently continue if tokenizer access fails
pass
# === LOCATION 7: Try model_dump with filtering ===
try:
full_dump = llm_step.model_dump()
if isinstance(full_dump, dict):
# List of keys we definitely want to try to extract
priority_keys = [
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length',
'repetition_penalty', 'do_sample', 'pad_token_id', 'eos_token_id',
'model_id', 'task', 'device', 'batch_size', 'return_full_text',
'model_kwargs', 'pipeline_kwargs'
# In various config containers
(lambda: getattr(obj, 'model_kwargs', {}).get(param) if hasattr(obj, 'model_kwargs') else None, f"model_kwargs.{param}"),
(lambda: getattr(obj, 'pipeline_kwargs', {}).get(param) if hasattr(obj, 'pipeline_kwargs') else None, f"pipeline_kwargs.{param}"),
(lambda: getattr(obj, 'generation_config', {}).get(param) if hasattr(obj, 'generation_config') else None, f"generation_config.{param}"),
(lambda: getattr(obj, 'kwargs', {}).get(param) if hasattr(obj, 'kwargs') else None, f"kwargs.{param}"),
(lambda: getattr(obj, '_config', {}).get(param) if hasattr(obj, '_config') else None, f"_config.{param}"),
# In nested pipeline object
(lambda: getattr(getattr(obj, 'pipeline', None), param, None) if hasattr(obj, 'pipeline') else None, f"pipeline.{param}"),
(lambda: getattr(getattr(obj, 'pipeline', None), '_preprocess_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._preprocess_params.{param}"),
(lambda: getattr(getattr(obj, 'pipeline', None), '_forward_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._forward_params.{param}"),
(lambda: getattr(getattr(obj, 'pipeline', None), '_postprocess_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._postprocess_params.{param}"),
# In model's generation config
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).__dict__.get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') else None, f"pipeline.model.generation_config.{param}"),
# Try generation_config.to_dict()
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).to_dict().get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') and hasattr(getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None), 'to_dict') else None, f"pipeline.model.generation_config.to_dict().{param}"),
# Check in model config
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'config', None).__dict__.get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'config') else None, f"pipeline.model.config.{param}"),
# Check bound parameters
(lambda: getattr(obj, 'bound', {}).get(param) if hasattr(obj, 'bound') else None, f"bound.{param}"),
# Check __dict__ directly
(lambda: obj.__dict__.get(param) if hasattr(obj, '__dict__') else None, f"__dict__.{param}"),
]
for key in priority_keys:
if key in full_dump:
value = full_dump[key]
if isinstance(value, (str, int, float, bool, type(None))):
config[key] = value
elif isinstance(value, dict):
# If it's a nested dict, try to extract from it
safe_add_to_config(value, f"model_dump_{key}")
except Exception as e:
# model_dump might fail due to non-serializable objects
pass
# === LOCATION 8: Check for any additional generation parameters ===
# Look for any attributes ending in common parameter suffixes
if hasattr(llm_step, '__dict__'):
for attr_name, attr_value in llm_step.__dict__.items():
if isinstance(attr_value, (str, int, float, bool, type(None))):
# Add if it looks like a generation parameter
if any(suffix in attr_name.lower() for suffix in [
'temperature', 'top_', 'max_', 'min_', 'penalty', 'token_id',
'sample', 'beam', 'length', 'config', 'param'
]):
config[attr_name] = attr_value
# === CLEANUP: Remove duplicates and None values (optional) ===
# Remove None values if desired
# config = {k: v for k, v in config.items() if v is not None}
return config
# Helper function to pretty print the config for debugging
def print_llm_config_debug(llm_step):
"""Debug helper to print all found configuration in organized format."""
config = extract_all_llm_config(llm_step)
if not config:
print("No LLM configuration found")
return config
print("=== EXTRACTED LLM CONFIGURATION ===")
# Group by category for better readability
categories = {
'Generation Parameters': [
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length', 'min_length',
'repetition_penalty', 'length_penalty', 'do_sample', 'num_beams', 'early_stopping'
],
'Token IDs': [
'pad_token_id', 'eos_token_id', 'bos_token_id', 'decoder_start_token_id'
],
'Model Info': [
'model_id', 'model_name', 'model_path', 'model_type', 'task'
],
'Device & Performance': [
'device', 'device_map', 'batch_size', 'use_cache', 'torch_dtype'
],
'Pipeline Settings': [
'return_full_text', 'clean_up_tokenization_spaces', 'truncation', 'padding'
for getter, location in locations_to_check:
try:
value = getter()
if value is not None:
found_value = value
result[param] = safe_serialize(value)
result[f"{param}_source"] = location # Track where we found it
break
except Exception:
continue
# If still not found, do a deeper search in __dict__
if found_value is None and hasattr(obj, '__dict__'):
for key, value in obj.__dict__.items():
if param in key.lower() and value is not None:
result[f"{param}_from_{key}"] = safe_serialize(value)
break
# Extract all other attributes
for attr in all_config_attrs:
if attr not in critical_params and hasattr(obj, attr):
try:
value = getattr(obj, attr)
if value is not None:
result[attr] = safe_serialize(value)
except Exception as e:
result[f"{attr}_error"] = str(e)
# === EXTRACT FROM COMMON CONFIG CONTAINERS ===
config_containers = [
'kwargs', 'model_kwargs', 'pipeline_kwargs', 'llm_kwargs', 'generation_config',
'config', '_config', 'params', '_params', 'bound', 'default_params',
'_preprocess_params', '_forward_params', '_postprocess_params'
]
for container_name in config_containers:
if hasattr(obj, container_name):
try:
container = getattr(obj, container_name)
if isinstance(container, dict) and container:
result[container_name] = safe_serialize(container)
except Exception:
pass
# === EXTRACT FROM __DICT__ ===
if hasattr(obj, '__dict__'):
obj_dict = {}
for key, value in obj.__dict__.items():
# Skip private/internal attributes and known non-config items
if (not key.startswith('_') or key in ['_config', '_params']) and \
key not in ['callbacks'] and \
not callable(value):
try:
if is_serializable(value) or isinstance(value, (dict, list)):
obj_dict[key] = safe_serialize(value)
elif hasattr(value, '__dict__') or hasattr(value, 'dict'):
# This might be a nested config object
nested_config = extract_from_object(
value, f"{path}.{key}", visited.copy(), current_depth + 1
)
if nested_config and len(nested_config) > 2: # More than just _type and _path
obj_dict[key] = nested_config
except Exception:
pass
if obj_dict:
result['_attributes'] = obj_dict
# === HANDLE SPECIFIC CHAIN STRUCTURES ===
# Sequential chains (RunnableSequence)
if hasattr(obj, 'steps') and obj.steps:
steps_config = {}
for i, step in enumerate(obj.steps):
step_config = extract_from_object(
step, f"{path}.steps[{i}]", visited.copy(), current_depth + 1
)
if step_config:
steps_config[f"step_{i}"] = step_config
if steps_config:
result['steps'] = steps_config
# Parallel chains (RunnableParallel)
if hasattr(obj, 'mapping') and isinstance(obj.mapping, dict):
mapping_config = {}
for key, component in obj.mapping.items():
comp_config = extract_from_object(
component, f"{path}.mapping[{key}]", visited.copy(), current_depth + 1
)
if comp_config:
mapping_config[key] = comp_config
if mapping_config:
result['mapping'] = mapping_config
# Conditional chains (RunnableBranch)
if hasattr(obj, 'branches') and obj.branches:
branches_config = {}
for i, (condition, branch) in enumerate(obj.branches):
branch_config = extract_from_object(
branch, f"{path}.branches[{i}]", visited.copy(), current_depth + 1
)
if branch_config:
branches_config[f"branch_{i}"] = branch_config
if branches_config:
result['branches'] = branches_config
if hasattr(obj, 'default') and obj.default:
default_config = extract_from_object(
obj.default, f"{path}.default", visited.copy(), current_depth + 1
)
if default_config:
result['default'] = default_config
# Chain components
component_attrs = [
'llm', 'model', 'language_model', 'chat_model', 'completion_model',
'first', 'last', 'middle', 'chain', 'inner_chain', 'base_chain',
'retrieval_chain', 'combine_documents_chain', 'question_generator',
'memory', 'retriever', 'prompt', 'output_parser', 'parser'
]
for comp_attr in component_attrs:
if hasattr(obj, comp_attr):
try:
component = getattr(obj, comp_attr)
if component and not callable(component):
if isinstance(component, list):
comp_configs = {}
for i, item in enumerate(component):
item_config = extract_from_object(
item, f"{path}.{comp_attr}[{i}]", visited.copy(), current_depth + 1
)
if item_config:
comp_configs[f"{comp_attr}_{i}"] = item_config
if comp_configs:
result[comp_attr] = comp_configs
else:
comp_config = extract_from_object(
component, f"{path}.{comp_attr}", visited.copy(), current_depth + 1
)
if comp_config and len(comp_config) > 2:
result[comp_attr] = comp_config
except Exception:
pass
# Try model.dict() or similar serialization methods
for method_name in ['dict', 'model_dump', 'to_dict', 'serialize']:
if hasattr(obj, method_name):
try:
method = getattr(obj, method_name)
if callable(method):
serialized = method()
if isinstance(serialized, dict) and serialized:
result[f'_{method_name}'] = safe_serialize(serialized)
break # Only use the first successful method
except Exception:
pass
return result
# Start extraction from the root chain
return extract_from_object(lcel_chain)
def print_nested_config(self, config, indent=0, max_items_per_level=50):
"""
Pretty print the nested configuration structure.
"""
if not isinstance(config, dict):
print(" " * indent + str(config))
return
items_shown = 0
for key, value in config.items():
if items_shown >= max_items_per_level:
print(" " * indent + f"... ({len(config) - items_shown} more items)")
break
if key.startswith('_') and key not in ['_type', '_path']:
continue # Skip most internal fields in main display
print(" " * indent + f"{key}:")
if isinstance(value, dict):
if key == '_attributes' and indent > 0:
# Flatten attributes for readability
attr_count = 0
for attr_key, attr_val in value.items():
if attr_count >= 10: # Limit attribute display
print(" " * (indent + 1) + f"... ({len(value) - attr_count} more attributes)")
break
if not isinstance(attr_val, dict):
print(" " * (indent + 1) + f"{attr_key}: {attr_val}")
attr_count += 1
else:
self.print_nested_config(value, indent + 1, max_items_per_level)
else:
print(" " * (indent + 1) + str(value))
items_shown += 1
def extract_flattened_config(self, lcel_chain):
"""
Extract and flatten all configuration into a single-level dictionary
with dotted paths showing the source hierarchy.
"""
nested = self.extract_all_llm_config(lcel_chain)
def flatten_dict(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
if k.startswith('_'):
continue # Skip metadata
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
return flatten_dict(nested)
def find_critical_generation_params(self, lcel_chain):
"""
Specifically hunt for the most critical generation parameters that are often missing.
Returns a focused dict with just the essential params and where they were found.
"""
critical_params = {
'temperature': None,
'top_k': None,
'top_p': None,
'max_length': None,
'max_new_tokens': None,
'max_tokens': None,
'repetition_penalty': None,
'do_sample': None
}
for category, keys in categories.items():
found_in_category = {k: v for k, v in config.items() if k in keys}
if found_in_category:
def deep_search_for_param(obj, param_name, visited=None, path=""):
if visited is None:
visited = set()
if id(obj) in visited:
return None
visited.add(id(obj))
# All possible locations to check
search_locations = [
# Direct attribute
lambda: getattr(obj, param_name, None),
# In common config dicts
lambda: getattr(obj, 'model_kwargs', {}).get(param_name),
lambda: getattr(obj, 'pipeline_kwargs', {}).get(param_name),
lambda: getattr(obj, 'kwargs', {}).get(param_name),
lambda: getattr(obj, 'generation_config', {}).get(param_name),
lambda: getattr(obj, '_config', {}).get(param_name),
lambda: getattr(obj, 'bound', {}).get(param_name),
# In pipeline
lambda: getattr(getattr(obj, 'pipeline', None), param_name, None),
# In pipeline config dicts
lambda: getattr(getattr(obj, 'pipeline', None), '_preprocess_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
lambda: getattr(getattr(obj, 'pipeline', None), '_forward_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
lambda: getattr(getattr(obj, 'pipeline', None), '_postprocess_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
# In model generation config
lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).__dict__.get(param_name) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model', None) and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') else None,
]
for search_func in search_locations:
try:
value = search_func()
if value is not None:
return {"value": value, "location": f"{path} -> {search_func.__name__}"}
except:
continue
# Recurse into sub-objects
if hasattr(obj, 'steps'):
for i, step in enumerate(obj.steps):
result = deep_search_for_param(step, param_name, visited.copy(), f"{path}.steps[{i}]")
if result:
return result
if hasattr(obj, 'mapping') and isinstance(obj.mapping, dict):
for key, component in obj.mapping.items():
result = deep_search_for_param(component, param_name, visited.copy(), f"{path}.mapping[{key}]")
if result:
return result
# Check common component attributes
for attr_name in ['llm', 'model', 'pipeline', 'chain']:
if hasattr(obj, attr_name):
component = getattr(obj, attr_name)
if component:
result = deep_search_for_param(component, param_name, visited.copy(), f"{path}.{attr_name}")
if result:
return result
return None
print("=== HUNTING FOR CRITICAL GENERATION PARAMETERS ===")
for param in critical_params:
result = deep_search_for_param(lcel_chain, param)
if result:
critical_params[param] = result
print(f"✓ Found {param}: {result['value']} (from: {result['location']})")
else:
print(f"✗ Missing {param}")
return critical_params
def print_llm_config_debug(self, lcel_chain):
"""
Debug helper that shows both nested and flattened views of the configuration.
"""
print("=== CRITICAL PARAMETERS SEARCH ===")
critical = self.find_critical_generation_params(lcel_chain)
print("\n" + "="*50)
print("=== NESTED LCEL CHAIN CONFIGURATION ===")
nested_config = self.extract_all_llm_config(lcel_chain)
self.print_nested_config(nested_config)
print("\n" + "="*50)
print("=== FLATTENED CONFIGURATION ===")
flattened = self.extract_flattened_config(lcel_chain)
if not flattened:
print("No configuration parameters found")
return nested_config
# Group by category with priority for generation params
categories = {
'🔥 CRITICAL Generation Parameters': [],
'Other Generation Parameters': [],
'Model Configuration': [],
'API Settings': [],
'Chain Structure': [],
'Other': []
}
critical_param_names = ['temperature', 'top_k', 'top_p', 'max_length', 'max_new_tokens', 'max_tokens']
for key, value in flattened.items():
categorized = False
# Check if it's a critical parameter
if any(param in key.lower() for param in critical_param_names):
categories['🔥 CRITICAL Generation Parameters'].append((key, value))
categorized = True
elif any(param in key.lower() for param in ['penalty', 'sample', 'beam', 'length']):
categories['Other Generation Parameters'].append((key, value))
categorized = True
elif any(param in key.lower() for param in ['model', 'engine', 'deployment']):
categories['Model Configuration'].append((key, value))
categorized = True
elif any(param in key.lower() for param in ['api', 'key', 'endpoint', 'url', 'timeout']):
categories['API Settings'].append((key, value))
categorized = True
elif any(param in key.lower() for param in ['step', 'chain', 'mapping', 'branch']):
categories['Chain Structure'].append((key, value))
categorized = True
if not categorized:
categories['Other'].append((key, value))
for category, items in categories.items():
if items:
print(f"\n{category}:")
for key, value in found_in_category.items():
for key, value in items:
print(f" {key}: {value}")
# Print any remaining parameters
categorized_keys = set()
for keys in categories.values():
categorized_keys.update(keys)
print(f"\nTotal parameters found: {len(flattened)}")
return nested_config
# Example usage with detailed iteration
def iterate_chain_components(self, lcel_chain):
"""
Example function showing how to iterate through all chain components
and extract configuration from each.
"""
print("=== ITERATING THROUGH CHAIN COMPONENTS ===")
remaining = {k: v for k, v in config.items() if k not in categorized_keys}
if remaining:
print(f"\nOther Parameters:")
for key, value in remaining.items():
print(f" {key}: {value}")
def visit_component(component, path="root", depth=0):
if depth > 5: # Prevent infinite recursion
return
print(" " * depth + f"Visiting: {path} ({type(component).__name__})")
# Extract config from this component
config = {}
# Check for common LLM attributes
llm_attrs = ['temperature', 'top_p', 'model', 'model_id', 'max_tokens', 'api_key']
for attr in llm_attrs:
if hasattr(component, attr):
value = getattr(component, attr)
if value is not None:
config[attr] = value
if config:
print(" " * depth + f" Config found: {config}")
# Recurse into sub-components
if hasattr(component, 'steps'):
for i, step in enumerate(component.steps):
visit_component(step, f"{path}.steps[{i}]", depth + 1)
if hasattr(component, 'mapping') and isinstance(component.mapping, dict):
for key, subcomp in component.mapping.items():
visit_component(subcomp, f"{path}.mapping[{key}]", depth + 1)
if hasattr(component, 'llm') and component.llm:
visit_component(component.llm, f"{path}.llm", depth + 1)
if hasattr(component, 'model') and component.model:
visit_component(component.model, f"{path}.model", depth + 1)
print(f"\nTotal parameters found: {len(config)}")
return config
visit_component(lcel_chain)
# Complete usage example
def example_usage(self):
"""
Complete example showing all extraction methods.
"""
print("=== LANGCHAIN LCEL CONFIG EXTRACTOR USAGE ===\n")
print("1. NESTED STRUCTURE EXTRACTION:")
print(" nested_config = extract_all_llm_config(chain)")
print(" # Returns: Full nested dict preserving chain hierarchy")
print("\n2. FLATTENED EXTRACTION:")
print(" flat_config = extract_flattened_config(chain)")
print(" # Returns: Single-level dict with dotted path keys")
print("\n3. DEBUG OUTPUT:")
print(" print_llm_config_debug(chain)")
print(" # Prints: Both nested and categorized flat views")
print("\n4. COMPONENT ITERATION:")
print(" iterate_chain_components(chain)")
print(" # Shows: Step-by-step traversal of all components")
print("\nExample output structure:")
example_structure = {
"_type": "RunnableSequence",
"steps": {
"step_0": {
"_type": "ChatPromptTemplate",
"template": "You are a helpful assistant"
},
"step_1": {
"_type": "ChatOpenAI",
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"max_tokens": 1000,
"openai_api_key": "sk-...",
"_attributes": {
"streaming": False,
"verbose": False
}
}
}
}
self.print_nested_config(example_structure)
+11 -1
View File
@@ -124,16 +124,22 @@ def rag_config_builder(
prompt_injection_example_repository=prompt_injection_example_repository
)
@pytest.fixture(scope="session")
def llm_configuration_introspection_service():
return LLMConfigurationIntrospectionService()
@pytest.fixture(scope="session")
def rag_context_guidelines(
foundation_model,
response_processing_service,
prompt_template_service,
llm_configuration_introspection_service,
rag_config_builder):
return RagContextSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=rag_config_builder
)
@@ -141,10 +147,12 @@ def rag_context_guidelines(
def chain_of_thought_guidelines(
foundation_model,
response_processing_service,
llm_configuration_introspection_service,
prompt_template_service):
return ChainOfThoughtSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
response_processing_service=response_processing_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
prompt_template_service=prompt_template_service
)
@@ -153,11 +161,13 @@ def rag_plus_cot_guidelines(
foundation_model,
response_processing_service,
prompt_template_service,
llm_configuration_introspection_service,
rag_config_builder):
return RagPlusCotSecurityGuidelinesService(
foundation_model=foundation_model,
response_processing_service=response_processing_service,
prompt_template_service=prompt_template_service,
llm_configuration_introspection_service=llm_configuration_introspection_service,
config_builder=rag_config_builder
)