mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-05-14 21:28:04 +02:00
LLM config output
This commit is contained in:
@@ -43,6 +43,7 @@ class TextGenerationFoundationModel(AbstractFoundationModel):
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
# TODO: fix Both `max_new_tokens` (=512) and `max_length`(=1024) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
|
||||
|
||||
return HuggingFacePipeline(
|
||||
pipeline=pipe,
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.text_generation.domain.guidelines_result import GuidelinesResult
|
||||
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder, AbstractSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
|
||||
@@ -22,12 +23,14 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
foundation_model: AbstractFoundationModel,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
|
||||
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
|
||||
super().__init__()
|
||||
self.constants = Constants()
|
||||
self.foundation_model_pipeline = foundation_model.create_pipeline()
|
||||
self.response_processing_service = response_processing_service
|
||||
self.prompt_template_service = prompt_template_service
|
||||
self.llm_configuration_introspection_service = llm_configuration_introspection_service
|
||||
self.config_builder = config_builder
|
||||
|
||||
def _create_chain(self, prompt_template: PromptTemplate):
|
||||
@@ -60,6 +63,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
return None
|
||||
|
||||
def _extract_llm_config(self, llm_step):
|
||||
|
||||
if not llm_step:
|
||||
return {}
|
||||
|
||||
@@ -110,6 +114,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
|
||||
print(f'creating chain...')
|
||||
chain = self._create_chain(prompt_template)
|
||||
|
||||
print(f'Chain type: {type(chain)}')
|
||||
print(f'Number of steps: {len(chain.steps) if hasattr(chain, "steps") else "No steps attribute"}')
|
||||
|
||||
@@ -120,7 +125,7 @@ class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
||||
print(f'generating completion...')
|
||||
completion_text=chain.invoke({"input": user_prompt})
|
||||
llm_step = self._find_llm_step(chain)
|
||||
llm_config = self._extract_llm_config(llm_step)
|
||||
llm_config = self.llm_configuration_introspection_service.get_config(chain)
|
||||
result = GuidelinesResult(
|
||||
completion_text=completion_text,
|
||||
llm_config=llm_config,
|
||||
|
||||
+3
@@ -5,6 +5,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
|
||||
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
class ChainOfThoughtSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
@@ -15,11 +16,13 @@ class ChainOfThoughtSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
foundation_model: AbstractFoundationModel,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
|
||||
config_builder: Optional[AbstractSecurityGuidelinesConfigurationBuilder] = None):
|
||||
super().__init__(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
config_builder=config_builder
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
|
||||
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
@@ -14,11 +15,13 @@ class RagContextSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
foundation_model: AbstractFoundationModel,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
|
||||
config_builder: AbstractSecurityGuidelinesConfigurationBuilder):
|
||||
super().__init__(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
config_builder=config_builder
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from src.text_generation.ports.abstract_foundation_model import AbstractFoundati
|
||||
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesConfigurationBuilder
|
||||
from src.text_generation.services.guidelines.base_security_guidelines_service import BaseSecurityGuidelinesService
|
||||
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
||||
from src.text_generation.services.utilities.abstract_llm_configuration_introspection_service import AbstractLLMConfigurationIntrospectionService
|
||||
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
||||
|
||||
class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
@@ -17,11 +18,13 @@ class RagPlusCotSecurityGuidelinesService(BaseSecurityGuidelinesService):
|
||||
foundation_model: AbstractFoundationModel,
|
||||
response_processing_service: AbstractResponseProcessingService,
|
||||
prompt_template_service: AbstractPromptTemplateService,
|
||||
llm_configuration_introspection_service: AbstractLLMConfigurationIntrospectionService,
|
||||
config_builder: AbstractSecurityGuidelinesConfigurationBuilder):
|
||||
super().__init__(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
config_builder=config_builder
|
||||
)
|
||||
|
||||
|
||||
+589
-252
@@ -6,278 +6,615 @@ from src.text_generation.services.utilities.abstract_llm_configuration_introspec
|
||||
class LLMConfigurationIntrospectionService(
|
||||
AbstractLLMConfigurationIntrospectionService):
|
||||
# llm_configuration_introspection_service
|
||||
|
||||
def get_config(llm_step):
|
||||
|
||||
|
||||
def get_config(self, lcel_chain, max_depth=10):
|
||||
"""
|
||||
Comprehensively extract all possible LLM configuration parameters
|
||||
from a HuggingFace pipeline step, checking all known locations.
|
||||
from a LangChain LCEL chain object, creating a multilayered dict structure
|
||||
that preserves the chain hierarchy.
|
||||
|
||||
Args:
|
||||
lcel_chain: A LangChain LCEL chain object (Runnable)
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
dict: All found configuration parameters that are JSON serializable
|
||||
dict: Nested dictionary with full chain structure and all config parameters
|
||||
"""
|
||||
if not llm_step:
|
||||
if not lcel_chain or max_depth <= 0:
|
||||
return {}
|
||||
|
||||
config = {}
|
||||
def is_serializable(value):
|
||||
"""Check if a value is JSON serializable."""
|
||||
return isinstance(value, (str, int, float, bool, type(None), list, tuple, dict))
|
||||
|
||||
def safe_add_to_config(source_dict, source_name="unknown"):
|
||||
"""Safely add items from a dict to config if they're serializable."""
|
||||
if not isinstance(source_dict, dict):
|
||||
return
|
||||
def safe_serialize(value):
|
||||
"""Safely serialize a value, converting non-serializable objects to strings."""
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [safe_serialize(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: safe_serialize(v) for k, v in value.items() if k != '_type'}
|
||||
else:
|
||||
# Convert objects to string representation, but filter out some noise
|
||||
str_repr = str(value)
|
||||
if any(noise in str_repr for noise in ['<bound method', '<function', 'object at 0x']):
|
||||
return f"<{type(value).__name__}>"
|
||||
return str_repr
|
||||
|
||||
def extract_from_object(obj, path="root", visited=None, current_depth=0):
|
||||
"""
|
||||
Recursively extract configuration from any object, building a nested structure.
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if current_depth >= max_depth or id(obj) in visited:
|
||||
return {}
|
||||
|
||||
visited.add(id(obj))
|
||||
result = {"_type": type(obj).__name__, "_path": path}
|
||||
|
||||
# === COMPREHENSIVE ATTRIBUTE EXTRACTION ===
|
||||
|
||||
# All possible LLM and chain configuration attributes
|
||||
all_config_attrs = [
|
||||
# Core generation parameters
|
||||
'temperature', 'top_p', 'top_k', 'max_tokens', 'max_new_tokens', 'max_length',
|
||||
'min_length', 'repetition_penalty', 'frequency_penalty', 'presence_penalty',
|
||||
'length_penalty', 'do_sample', 'early_stopping', 'num_beams', 'num_beam_groups',
|
||||
'diversity_penalty', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'seed',
|
||||
'stop', 'stop_sequences', 'suffix', 'logit_bias', 'user', 'n', 'best_of',
|
||||
'logprobs', 'echo', 'response_format', 'tool_choice', 'parallel_tool_calls',
|
||||
|
||||
for key, value in source_dict.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[key] = value
|
||||
elif isinstance(value, (list, tuple)) and all(isinstance(x, (str, int, float, bool, type(None))) for x in value):
|
||||
config[key] = list(value)
|
||||
# Skip non-serializable objects
|
||||
|
||||
# === LOCATION 1: Direct attributes on llm_step ===
|
||||
direct_llm_attrs = [
|
||||
# Generation parameters
|
||||
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length', 'min_length',
|
||||
'repetition_penalty', 'length_penalty', 'do_sample', 'early_stopping',
|
||||
'num_beams', 'num_beam_groups', 'diversity_penalty', 'typical_p',
|
||||
'epsilon_cutoff', 'eta_cutoff', 'exponential_decay_length_penalty',
|
||||
|
||||
# Token IDs
|
||||
'pad_token_id', 'eos_token_id', 'bos_token_id', 'decoder_start_token_id',
|
||||
'forced_bos_token_id', 'forced_eos_token_id',
|
||||
|
||||
# Model identifiers
|
||||
'model_id', 'model_name', 'model_path', 'model_type',
|
||||
|
||||
# Task and device settings
|
||||
'task', 'device', 'device_map', 'torch_dtype',
|
||||
|
||||
# Pipeline settings
|
||||
'batch_size', 'max_batch_size', 'return_full_text', 'clean_up_tokenization_spaces',
|
||||
'truncation', 'padding', 'add_special_tokens',
|
||||
|
||||
# Performance settings
|
||||
'use_cache', 'cache_dir', 'revision', 'trust_remote_code',
|
||||
'low_cpu_mem_usage', 'load_in_8bit', 'load_in_4bit',
|
||||
|
||||
# Quantization settings
|
||||
'quantization_config', 'bnb_4bit_compute_dtype', 'bnb_4bit_quant_type',
|
||||
'bnb_4bit_use_double_quant',
|
||||
|
||||
# Other generation settings
|
||||
'seed', 'guidance_scale', 'negative_prompt', 'num_images_per_prompt',
|
||||
'eta', 'generator', 'latents', 'prompt_embeds', 'negative_prompt_embeds',
|
||||
'cross_attention_kwargs', 'guidance_rescale', 'clip_skip',
|
||||
|
||||
# Sampling parameters
|
||||
'top_a', 'tfs', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta',
|
||||
'penalty_alpha', 'use_mirostat_sampling',
|
||||
|
||||
# Stop conditions
|
||||
'stop_sequences', 'stop_token_ids', 'stopping_criteria',
|
||||
|
||||
# Memory and efficiency
|
||||
'offload_folder', 'cpu_offload', 'sequential_cpu_offload',
|
||||
'model_cpu_offload', 'disk_offload',
|
||||
|
||||
# Framework specific
|
||||
'framework', 'use_fast', 'use_auth_token', 'subfolder',
|
||||
]
|
||||
|
||||
for attr in direct_llm_attrs:
|
||||
if hasattr(llm_step, attr):
|
||||
value = getattr(llm_step, attr)
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[attr] = value
|
||||
elif isinstance(value, (list, tuple)) and all(isinstance(x, (str, int, float, bool, type(None))) for x in value):
|
||||
config[attr] = list(value)
|
||||
|
||||
# === LOCATION 2: model_kwargs ===
|
||||
if hasattr(llm_step, 'model_kwargs') and llm_step.model_kwargs:
|
||||
safe_add_to_config(llm_step.model_kwargs, "model_kwargs")
|
||||
|
||||
# === LOCATION 3: pipeline_kwargs ===
|
||||
if hasattr(llm_step, 'pipeline_kwargs') and llm_step.pipeline_kwargs:
|
||||
safe_add_to_config(llm_step.pipeline_kwargs, "pipeline_kwargs")
|
||||
|
||||
# === LOCATION 4: Pipeline object and its attributes ===
|
||||
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
|
||||
pipeline = llm_step.pipeline
|
||||
|
||||
# Direct pipeline attributes
|
||||
pipeline_attrs = [
|
||||
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length',
|
||||
'repetition_penalty', 'do_sample', 'pad_token_id', 'eos_token_id',
|
||||
'return_full_text', 'clean_up_tokenization_spaces', 'prefix',
|
||||
'handle_long_generation', 'batch_size'
|
||||
# Model and API configuration
|
||||
'model', 'model_name', 'model_id', 'model_path', 'model_type', 'engine',
|
||||
'deployment_name', 'deployment_id', 'model_version', 'model_revision',
|
||||
'api_key', 'api_base', 'api_version', 'api_type', 'organization', 'base_url',
|
||||
'endpoint', 'region', 'project_id', 'project', 'location', 'credentials',
|
||||
|
||||
# Provider-specific keys
|
||||
'openai_api_key', 'openai_organization', 'openai_api_base', 'openai_proxy',
|
||||
'anthropic_api_key', 'anthropic_api_url', 'max_tokens_to_sample',
|
||||
'cohere_api_key', 'huggingfacehub_api_token', 'repo_id', 'task',
|
||||
'google_api_key', 'vertex_ai_model', 'azure_endpoint', 'azure_deployment',
|
||||
'azure_api_version', 'azure_api_key', 'replicate_api_token',
|
||||
'together_api_key', 'fireworks_api_key', 'groq_api_key', 'mistral_api_key',
|
||||
|
||||
# Request and performance settings
|
||||
'max_retries', 'request_timeout', 'timeout', 'streaming', 'chunk_size',
|
||||
'max_concurrent_requests', 'rate_limit', 'batch_size', 'max_batch_size',
|
||||
'use_cache', 'cache_dir', 'cache_size', 'device', 'device_map', 'torch_dtype',
|
||||
'load_in_8bit', 'load_in_4bit', 'trust_remote_code', 'revision',
|
||||
|
||||
# Token handling
|
||||
'pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
|
||||
'sep_token_id', 'cls_token_id', 'mask_token_id', 'decoder_start_token_id',
|
||||
'forced_bos_token_id', 'forced_eos_token_id',
|
||||
|
||||
# Chain-specific attributes
|
||||
'verbose', 'name', 'tags', 'metadata', 'callbacks', 'memory', 'memory_key',
|
||||
'return_messages', 'input_key', 'output_key', 'prompt', 'llm_chain',
|
||||
'combine_documents_chain', 'question_generator', 'retriever',
|
||||
|
||||
# Pipeline and processing
|
||||
'return_full_text', 'clean_up_tokenization_spaces', 'truncation', 'padding',
|
||||
'add_special_tokens', 'handle_long_generation', 'prefix',
|
||||
|
||||
# Advanced parameters
|
||||
'penalty_alpha', 'use_mirostat_sampling', 'mirostat_mode', 'mirostat_tau',
|
||||
'mirostat_eta', 'tfs', 'top_a', 'k', 'p', 'include_stop_str_in_output',
|
||||
'ignore_eos', 'skip_special_tokens', 'spaces_between_special_tokens',
|
||||
]
|
||||
|
||||
for attr in pipeline_attrs:
|
||||
if hasattr(pipeline, attr):
|
||||
value = getattr(pipeline, attr)
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[attr] = value
|
||||
# === PRIORITY: Extract critical generation parameters first ===
|
||||
critical_params = ['temperature', 'top_k', 'top_p', 'max_length', 'max_new_tokens',
|
||||
'max_tokens', 'repetition_penalty', 'do_sample', 'num_beams']
|
||||
|
||||
# Check pipeline._preprocess_params
|
||||
if hasattr(pipeline, '_preprocess_params'):
|
||||
safe_add_to_config(pipeline._preprocess_params, "_preprocess_params")
|
||||
|
||||
# Check pipeline._forward_params
|
||||
if hasattr(pipeline, '_forward_params'):
|
||||
safe_add_to_config(pipeline._forward_params, "_forward_params")
|
||||
|
||||
# Check pipeline._postprocess_params
|
||||
if hasattr(pipeline, '_postprocess_params'):
|
||||
safe_add_to_config(pipeline._postprocess_params, "_postprocess_params")
|
||||
|
||||
# === LOCATION 5: Model's generation config ===
|
||||
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
|
||||
pipeline = llm_step.pipeline
|
||||
|
||||
# Try to access generation config through model
|
||||
try:
|
||||
if hasattr(pipeline, 'model') and hasattr(pipeline.model, 'generation_config'):
|
||||
gen_config = pipeline.model.generation_config
|
||||
if hasattr(gen_config, 'to_dict'):
|
||||
gen_dict = gen_config.to_dict()
|
||||
safe_add_to_config(gen_dict, "generation_config")
|
||||
elif hasattr(gen_config, '__dict__'):
|
||||
safe_add_to_config(gen_config.__dict__, "generation_config_dict")
|
||||
except Exception as e:
|
||||
# Silently continue if generation config access fails
|
||||
pass
|
||||
|
||||
# Try to access config through model.config
|
||||
try:
|
||||
if hasattr(pipeline, 'model') and hasattr(pipeline.model, 'config'):
|
||||
model_config = pipeline.model.config
|
||||
if hasattr(model_config, 'to_dict'):
|
||||
model_config_dict = model_config.to_dict()
|
||||
# Only extract generation-related config items
|
||||
generation_keys = [
|
||||
'max_length', 'max_new_tokens', 'min_length', 'do_sample',
|
||||
'temperature', 'top_k', 'top_p', 'repetition_penalty',
|
||||
'length_penalty', 'num_beams', 'early_stopping',
|
||||
'pad_token_id', 'eos_token_id', 'bos_token_id'
|
||||
]
|
||||
for key in generation_keys:
|
||||
if key in model_config_dict:
|
||||
value = model_config_dict[key]
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[key] = value
|
||||
except Exception as e:
|
||||
# Silently continue if model config access fails
|
||||
pass
|
||||
|
||||
# === LOCATION 6: Tokenizer config ===
|
||||
if hasattr(llm_step, 'pipeline') and llm_step.pipeline:
|
||||
try:
|
||||
if hasattr(llm_step.pipeline, 'tokenizer'):
|
||||
tokenizer = llm_step.pipeline.tokenizer
|
||||
tokenizer_attrs = [
|
||||
'pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id',
|
||||
'sep_token_id', 'cls_token_id', 'mask_token_id',
|
||||
'padding_side', 'truncation_side', 'model_max_length'
|
||||
]
|
||||
for param in critical_params:
|
||||
# Check multiple possible locations for each critical parameter
|
||||
found_value = None
|
||||
locations_to_check = [
|
||||
# Direct attribute
|
||||
(lambda: getattr(obj, param) if hasattr(obj, param) else None, f"direct.{param}"),
|
||||
|
||||
for attr in tokenizer_attrs:
|
||||
if hasattr(tokenizer, attr):
|
||||
value = getattr(tokenizer, attr)
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[f"tokenizer_{attr}"] = value
|
||||
except Exception as e:
|
||||
# Silently continue if tokenizer access fails
|
||||
pass
|
||||
|
||||
# === LOCATION 7: Try model_dump with filtering ===
|
||||
try:
|
||||
full_dump = llm_step.model_dump()
|
||||
if isinstance(full_dump, dict):
|
||||
# List of keys we definitely want to try to extract
|
||||
priority_keys = [
|
||||
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length',
|
||||
'repetition_penalty', 'do_sample', 'pad_token_id', 'eos_token_id',
|
||||
'model_id', 'task', 'device', 'batch_size', 'return_full_text',
|
||||
'model_kwargs', 'pipeline_kwargs'
|
||||
# In various config containers
|
||||
(lambda: getattr(obj, 'model_kwargs', {}).get(param) if hasattr(obj, 'model_kwargs') else None, f"model_kwargs.{param}"),
|
||||
(lambda: getattr(obj, 'pipeline_kwargs', {}).get(param) if hasattr(obj, 'pipeline_kwargs') else None, f"pipeline_kwargs.{param}"),
|
||||
(lambda: getattr(obj, 'generation_config', {}).get(param) if hasattr(obj, 'generation_config') else None, f"generation_config.{param}"),
|
||||
(lambda: getattr(obj, 'kwargs', {}).get(param) if hasattr(obj, 'kwargs') else None, f"kwargs.{param}"),
|
||||
(lambda: getattr(obj, '_config', {}).get(param) if hasattr(obj, '_config') else None, f"_config.{param}"),
|
||||
|
||||
# In nested pipeline object
|
||||
(lambda: getattr(getattr(obj, 'pipeline', None), param, None) if hasattr(obj, 'pipeline') else None, f"pipeline.{param}"),
|
||||
(lambda: getattr(getattr(obj, 'pipeline', None), '_preprocess_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._preprocess_params.{param}"),
|
||||
(lambda: getattr(getattr(obj, 'pipeline', None), '_forward_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._forward_params.{param}"),
|
||||
(lambda: getattr(getattr(obj, 'pipeline', None), '_postprocess_params', {}).get(param) if hasattr(obj, 'pipeline') else None, f"pipeline._postprocess_params.{param}"),
|
||||
|
||||
# In model's generation config
|
||||
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).__dict__.get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') else None, f"pipeline.model.generation_config.{param}"),
|
||||
|
||||
# Try generation_config.to_dict()
|
||||
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).to_dict().get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') and hasattr(getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None), 'to_dict') else None, f"pipeline.model.generation_config.to_dict().{param}"),
|
||||
|
||||
# Check in model config
|
||||
(lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'config', None).__dict__.get(param) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model') and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'config') else None, f"pipeline.model.config.{param}"),
|
||||
|
||||
# Check bound parameters
|
||||
(lambda: getattr(obj, 'bound', {}).get(param) if hasattr(obj, 'bound') else None, f"bound.{param}"),
|
||||
|
||||
# Check __dict__ directly
|
||||
(lambda: obj.__dict__.get(param) if hasattr(obj, '__dict__') else None, f"__dict__.{param}"),
|
||||
]
|
||||
|
||||
for key in priority_keys:
|
||||
if key in full_dump:
|
||||
value = full_dump[key]
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config[key] = value
|
||||
elif isinstance(value, dict):
|
||||
# If it's a nested dict, try to extract from it
|
||||
safe_add_to_config(value, f"model_dump_{key}")
|
||||
except Exception as e:
|
||||
# model_dump might fail due to non-serializable objects
|
||||
pass
|
||||
|
||||
# === LOCATION 8: Check for any additional generation parameters ===
|
||||
# Look for any attributes ending in common parameter suffixes
|
||||
if hasattr(llm_step, '__dict__'):
|
||||
for attr_name, attr_value in llm_step.__dict__.items():
|
||||
if isinstance(attr_value, (str, int, float, bool, type(None))):
|
||||
# Add if it looks like a generation parameter
|
||||
if any(suffix in attr_name.lower() for suffix in [
|
||||
'temperature', 'top_', 'max_', 'min_', 'penalty', 'token_id',
|
||||
'sample', 'beam', 'length', 'config', 'param'
|
||||
]):
|
||||
config[attr_name] = attr_value
|
||||
|
||||
# === CLEANUP: Remove duplicates and None values (optional) ===
|
||||
# Remove None values if desired
|
||||
# config = {k: v for k, v in config.items() if v is not None}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# Helper function to pretty print the config for debugging
|
||||
def print_llm_config_debug(llm_step):
|
||||
"""Debug helper to print all found configuration in organized format."""
|
||||
config = extract_all_llm_config(llm_step)
|
||||
|
||||
if not config:
|
||||
print("No LLM configuration found")
|
||||
return config
|
||||
|
||||
print("=== EXTRACTED LLM CONFIGURATION ===")
|
||||
|
||||
# Group by category for better readability
|
||||
categories = {
|
||||
'Generation Parameters': [
|
||||
'temperature', 'top_p', 'top_k', 'max_new_tokens', 'max_length', 'min_length',
|
||||
'repetition_penalty', 'length_penalty', 'do_sample', 'num_beams', 'early_stopping'
|
||||
],
|
||||
'Token IDs': [
|
||||
'pad_token_id', 'eos_token_id', 'bos_token_id', 'decoder_start_token_id'
|
||||
],
|
||||
'Model Info': [
|
||||
'model_id', 'model_name', 'model_path', 'model_type', 'task'
|
||||
],
|
||||
'Device & Performance': [
|
||||
'device', 'device_map', 'batch_size', 'use_cache', 'torch_dtype'
|
||||
],
|
||||
'Pipeline Settings': [
|
||||
'return_full_text', 'clean_up_tokenization_spaces', 'truncation', 'padding'
|
||||
for getter, location in locations_to_check:
|
||||
try:
|
||||
value = getter()
|
||||
if value is not None:
|
||||
found_value = value
|
||||
result[param] = safe_serialize(value)
|
||||
result[f"{param}_source"] = location # Track where we found it
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# If still not found, do a deeper search in __dict__
|
||||
if found_value is None and hasattr(obj, '__dict__'):
|
||||
for key, value in obj.__dict__.items():
|
||||
if param in key.lower() and value is not None:
|
||||
result[f"{param}_from_{key}"] = safe_serialize(value)
|
||||
break
|
||||
|
||||
# Extract all other attributes
|
||||
for attr in all_config_attrs:
|
||||
if attr not in critical_params and hasattr(obj, attr):
|
||||
try:
|
||||
value = getattr(obj, attr)
|
||||
if value is not None:
|
||||
result[attr] = safe_serialize(value)
|
||||
except Exception as e:
|
||||
result[f"{attr}_error"] = str(e)
|
||||
|
||||
# === EXTRACT FROM COMMON CONFIG CONTAINERS ===
|
||||
config_containers = [
|
||||
'kwargs', 'model_kwargs', 'pipeline_kwargs', 'llm_kwargs', 'generation_config',
|
||||
'config', '_config', 'params', '_params', 'bound', 'default_params',
|
||||
'_preprocess_params', '_forward_params', '_postprocess_params'
|
||||
]
|
||||
|
||||
for container_name in config_containers:
|
||||
if hasattr(obj, container_name):
|
||||
try:
|
||||
container = getattr(obj, container_name)
|
||||
if isinstance(container, dict) and container:
|
||||
result[container_name] = safe_serialize(container)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# === EXTRACT FROM __DICT__ ===
|
||||
if hasattr(obj, '__dict__'):
|
||||
obj_dict = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
# Skip private/internal attributes and known non-config items
|
||||
if (not key.startswith('_') or key in ['_config', '_params']) and \
|
||||
key not in ['callbacks'] and \
|
||||
not callable(value):
|
||||
try:
|
||||
if is_serializable(value) or isinstance(value, (dict, list)):
|
||||
obj_dict[key] = safe_serialize(value)
|
||||
elif hasattr(value, '__dict__') or hasattr(value, 'dict'):
|
||||
# This might be a nested config object
|
||||
nested_config = extract_from_object(
|
||||
value, f"{path}.{key}", visited.copy(), current_depth + 1
|
||||
)
|
||||
if nested_config and len(nested_config) > 2: # More than just _type and _path
|
||||
obj_dict[key] = nested_config
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if obj_dict:
|
||||
result['_attributes'] = obj_dict
|
||||
|
||||
# === HANDLE SPECIFIC CHAIN STRUCTURES ===
|
||||
|
||||
# Sequential chains (RunnableSequence)
|
||||
if hasattr(obj, 'steps') and obj.steps:
|
||||
steps_config = {}
|
||||
for i, step in enumerate(obj.steps):
|
||||
step_config = extract_from_object(
|
||||
step, f"{path}.steps[{i}]", visited.copy(), current_depth + 1
|
||||
)
|
||||
if step_config:
|
||||
steps_config[f"step_{i}"] = step_config
|
||||
if steps_config:
|
||||
result['steps'] = steps_config
|
||||
|
||||
# Parallel chains (RunnableParallel)
|
||||
if hasattr(obj, 'mapping') and isinstance(obj.mapping, dict):
|
||||
mapping_config = {}
|
||||
for key, component in obj.mapping.items():
|
||||
comp_config = extract_from_object(
|
||||
component, f"{path}.mapping[{key}]", visited.copy(), current_depth + 1
|
||||
)
|
||||
if comp_config:
|
||||
mapping_config[key] = comp_config
|
||||
if mapping_config:
|
||||
result['mapping'] = mapping_config
|
||||
|
||||
# Conditional chains (RunnableBranch)
|
||||
if hasattr(obj, 'branches') and obj.branches:
|
||||
branches_config = {}
|
||||
for i, (condition, branch) in enumerate(obj.branches):
|
||||
branch_config = extract_from_object(
|
||||
branch, f"{path}.branches[{i}]", visited.copy(), current_depth + 1
|
||||
)
|
||||
if branch_config:
|
||||
branches_config[f"branch_{i}"] = branch_config
|
||||
if branches_config:
|
||||
result['branches'] = branches_config
|
||||
|
||||
if hasattr(obj, 'default') and obj.default:
|
||||
default_config = extract_from_object(
|
||||
obj.default, f"{path}.default", visited.copy(), current_depth + 1
|
||||
)
|
||||
if default_config:
|
||||
result['default'] = default_config
|
||||
|
||||
# Chain components
|
||||
component_attrs = [
|
||||
'llm', 'model', 'language_model', 'chat_model', 'completion_model',
|
||||
'first', 'last', 'middle', 'chain', 'inner_chain', 'base_chain',
|
||||
'retrieval_chain', 'combine_documents_chain', 'question_generator',
|
||||
'memory', 'retriever', 'prompt', 'output_parser', 'parser'
|
||||
]
|
||||
|
||||
for comp_attr in component_attrs:
|
||||
if hasattr(obj, comp_attr):
|
||||
try:
|
||||
component = getattr(obj, comp_attr)
|
||||
if component and not callable(component):
|
||||
if isinstance(component, list):
|
||||
comp_configs = {}
|
||||
for i, item in enumerate(component):
|
||||
item_config = extract_from_object(
|
||||
item, f"{path}.{comp_attr}[{i}]", visited.copy(), current_depth + 1
|
||||
)
|
||||
if item_config:
|
||||
comp_configs[f"{comp_attr}_{i}"] = item_config
|
||||
if comp_configs:
|
||||
result[comp_attr] = comp_configs
|
||||
else:
|
||||
comp_config = extract_from_object(
|
||||
component, f"{path}.{comp_attr}", visited.copy(), current_depth + 1
|
||||
)
|
||||
if comp_config and len(comp_config) > 2:
|
||||
result[comp_attr] = comp_config
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try model.dict() or similar serialization methods
|
||||
for method_name in ['dict', 'model_dump', 'to_dict', 'serialize']:
|
||||
if hasattr(obj, method_name):
|
||||
try:
|
||||
method = getattr(obj, method_name)
|
||||
if callable(method):
|
||||
serialized = method()
|
||||
if isinstance(serialized, dict) and serialized:
|
||||
result[f'_{method_name}'] = safe_serialize(serialized)
|
||||
break # Only use the first successful method
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
# Start extraction from the root chain
|
||||
return extract_from_object(lcel_chain)
|
||||
|
||||
|
||||
def print_nested_config(self, config, indent=0, max_items_per_level=50):
|
||||
"""
|
||||
Pretty print the nested configuration structure.
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
print(" " * indent + str(config))
|
||||
return
|
||||
|
||||
items_shown = 0
|
||||
for key, value in config.items():
|
||||
if items_shown >= max_items_per_level:
|
||||
print(" " * indent + f"... ({len(config) - items_shown} more items)")
|
||||
break
|
||||
|
||||
if key.startswith('_') and key not in ['_type', '_path']:
|
||||
continue # Skip most internal fields in main display
|
||||
|
||||
print(" " * indent + f"{key}:")
|
||||
|
||||
if isinstance(value, dict):
|
||||
if key == '_attributes' and indent > 0:
|
||||
# Flatten attributes for readability
|
||||
attr_count = 0
|
||||
for attr_key, attr_val in value.items():
|
||||
if attr_count >= 10: # Limit attribute display
|
||||
print(" " * (indent + 1) + f"... ({len(value) - attr_count} more attributes)")
|
||||
break
|
||||
if not isinstance(attr_val, dict):
|
||||
print(" " * (indent + 1) + f"{attr_key}: {attr_val}")
|
||||
attr_count += 1
|
||||
else:
|
||||
self.print_nested_config(value, indent + 1, max_items_per_level)
|
||||
else:
|
||||
print(" " * (indent + 1) + str(value))
|
||||
|
||||
items_shown += 1
|
||||
|
||||
|
||||
def extract_flattened_config(self, lcel_chain):
|
||||
"""
|
||||
Extract and flatten all configuration into a single-level dictionary
|
||||
with dotted paths showing the source hierarchy.
|
||||
"""
|
||||
nested = self.extract_all_llm_config(lcel_chain)
|
||||
|
||||
def flatten_dict(d, parent_key='', sep='.'):
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
if k.startswith('_'):
|
||||
continue # Skip metadata
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
return flatten_dict(nested)
|
||||
|
||||
|
||||
def find_critical_generation_params(self, lcel_chain):
|
||||
"""
|
||||
Specifically hunt for the most critical generation parameters that are often missing.
|
||||
Returns a focused dict with just the essential params and where they were found.
|
||||
"""
|
||||
critical_params = {
|
||||
'temperature': None,
|
||||
'top_k': None,
|
||||
'top_p': None,
|
||||
'max_length': None,
|
||||
'max_new_tokens': None,
|
||||
'max_tokens': None,
|
||||
'repetition_penalty': None,
|
||||
'do_sample': None
|
||||
}
|
||||
|
||||
for category, keys in categories.items():
|
||||
found_in_category = {k: v for k, v in config.items() if k in keys}
|
||||
if found_in_category:
|
||||
def deep_search_for_param(obj, param_name, visited=None, path=""):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if id(obj) in visited:
|
||||
return None
|
||||
visited.add(id(obj))
|
||||
|
||||
# All possible locations to check
|
||||
search_locations = [
|
||||
# Direct attribute
|
||||
lambda: getattr(obj, param_name, None),
|
||||
# In common config dicts
|
||||
lambda: getattr(obj, 'model_kwargs', {}).get(param_name),
|
||||
lambda: getattr(obj, 'pipeline_kwargs', {}).get(param_name),
|
||||
lambda: getattr(obj, 'kwargs', {}).get(param_name),
|
||||
lambda: getattr(obj, 'generation_config', {}).get(param_name),
|
||||
lambda: getattr(obj, '_config', {}).get(param_name),
|
||||
lambda: getattr(obj, 'bound', {}).get(param_name),
|
||||
# In pipeline
|
||||
lambda: getattr(getattr(obj, 'pipeline', None), param_name, None),
|
||||
# In pipeline config dicts
|
||||
lambda: getattr(getattr(obj, 'pipeline', None), '_preprocess_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
|
||||
lambda: getattr(getattr(obj, 'pipeline', None), '_forward_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
|
||||
lambda: getattr(getattr(obj, 'pipeline', None), '_postprocess_params', {}).get(param_name) if hasattr(obj, 'pipeline') else None,
|
||||
# In model generation config
|
||||
lambda: getattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config', None).__dict__.get(param_name) if hasattr(obj, 'pipeline') and hasattr(getattr(obj, 'pipeline', None), 'model', None) and hasattr(getattr(getattr(obj, 'pipeline', None), 'model', None), 'generation_config') else None,
|
||||
]
|
||||
|
||||
for search_func in search_locations:
|
||||
try:
|
||||
value = search_func()
|
||||
if value is not None:
|
||||
return {"value": value, "location": f"{path} -> {search_func.__name__}"}
|
||||
except:
|
||||
continue
|
||||
|
||||
# Recurse into sub-objects
|
||||
if hasattr(obj, 'steps'):
|
||||
for i, step in enumerate(obj.steps):
|
||||
result = deep_search_for_param(step, param_name, visited.copy(), f"{path}.steps[{i}]")
|
||||
if result:
|
||||
return result
|
||||
|
||||
if hasattr(obj, 'mapping') and isinstance(obj.mapping, dict):
|
||||
for key, component in obj.mapping.items():
|
||||
result = deep_search_for_param(component, param_name, visited.copy(), f"{path}.mapping[{key}]")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Check common component attributes
|
||||
for attr_name in ['llm', 'model', 'pipeline', 'chain']:
|
||||
if hasattr(obj, attr_name):
|
||||
component = getattr(obj, attr_name)
|
||||
if component:
|
||||
result = deep_search_for_param(component, param_name, visited.copy(), f"{path}.{attr_name}")
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
print("=== HUNTING FOR CRITICAL GENERATION PARAMETERS ===")
|
||||
for param in critical_params:
|
||||
result = deep_search_for_param(lcel_chain, param)
|
||||
if result:
|
||||
critical_params[param] = result
|
||||
print(f"✓ Found {param}: {result['value']} (from: {result['location']})")
|
||||
else:
|
||||
print(f"✗ Missing {param}")
|
||||
|
||||
return critical_params
|
||||
|
||||
|
||||
def print_llm_config_debug(self, lcel_chain):
|
||||
"""
|
||||
Debug helper that shows both nested and flattened views of the configuration.
|
||||
"""
|
||||
print("=== CRITICAL PARAMETERS SEARCH ===")
|
||||
critical = self.find_critical_generation_params(lcel_chain)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("=== NESTED LCEL CHAIN CONFIGURATION ===")
|
||||
nested_config = self.extract_all_llm_config(lcel_chain)
|
||||
self.print_nested_config(nested_config)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("=== FLATTENED CONFIGURATION ===")
|
||||
flattened = self.extract_flattened_config(lcel_chain)
|
||||
|
||||
if not flattened:
|
||||
print("No configuration parameters found")
|
||||
return nested_config
|
||||
|
||||
# Group by category with priority for generation params
|
||||
categories = {
|
||||
'🔥 CRITICAL Generation Parameters': [],
|
||||
'Other Generation Parameters': [],
|
||||
'Model Configuration': [],
|
||||
'API Settings': [],
|
||||
'Chain Structure': [],
|
||||
'Other': []
|
||||
}
|
||||
|
||||
critical_param_names = ['temperature', 'top_k', 'top_p', 'max_length', 'max_new_tokens', 'max_tokens']
|
||||
|
||||
for key, value in flattened.items():
|
||||
categorized = False
|
||||
|
||||
# Check if it's a critical parameter
|
||||
if any(param in key.lower() for param in critical_param_names):
|
||||
categories['🔥 CRITICAL Generation Parameters'].append((key, value))
|
||||
categorized = True
|
||||
elif any(param in key.lower() for param in ['penalty', 'sample', 'beam', 'length']):
|
||||
categories['Other Generation Parameters'].append((key, value))
|
||||
categorized = True
|
||||
elif any(param in key.lower() for param in ['model', 'engine', 'deployment']):
|
||||
categories['Model Configuration'].append((key, value))
|
||||
categorized = True
|
||||
elif any(param in key.lower() for param in ['api', 'key', 'endpoint', 'url', 'timeout']):
|
||||
categories['API Settings'].append((key, value))
|
||||
categorized = True
|
||||
elif any(param in key.lower() for param in ['step', 'chain', 'mapping', 'branch']):
|
||||
categories['Chain Structure'].append((key, value))
|
||||
categorized = True
|
||||
|
||||
if not categorized:
|
||||
categories['Other'].append((key, value))
|
||||
|
||||
for category, items in categories.items():
|
||||
if items:
|
||||
print(f"\n{category}:")
|
||||
for key, value in found_in_category.items():
|
||||
for key, value in items:
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Print any remaining parameters
|
||||
categorized_keys = set()
|
||||
for keys in categories.values():
|
||||
categorized_keys.update(keys)
|
||||
print(f"\nTotal parameters found: {len(flattened)}")
|
||||
return nested_config
|
||||
|
||||
|
||||
# Example usage with detailed iteration
|
||||
def iterate_chain_components(self, lcel_chain):
|
||||
"""
|
||||
Example function showing how to iterate through all chain components
|
||||
and extract configuration from each.
|
||||
"""
|
||||
print("=== ITERATING THROUGH CHAIN COMPONENTS ===")
|
||||
|
||||
remaining = {k: v for k, v in config.items() if k not in categorized_keys}
|
||||
if remaining:
|
||||
print(f"\nOther Parameters:")
|
||||
for key, value in remaining.items():
|
||||
print(f" {key}: {value}")
|
||||
def visit_component(component, path="root", depth=0):
|
||||
if depth > 5: # Prevent infinite recursion
|
||||
return
|
||||
|
||||
print(" " * depth + f"Visiting: {path} ({type(component).__name__})")
|
||||
|
||||
# Extract config from this component
|
||||
config = {}
|
||||
|
||||
# Check for common LLM attributes
|
||||
llm_attrs = ['temperature', 'top_p', 'model', 'model_id', 'max_tokens', 'api_key']
|
||||
for attr in llm_attrs:
|
||||
if hasattr(component, attr):
|
||||
value = getattr(component, attr)
|
||||
if value is not None:
|
||||
config[attr] = value
|
||||
|
||||
if config:
|
||||
print(" " * depth + f" Config found: {config}")
|
||||
|
||||
# Recurse into sub-components
|
||||
if hasattr(component, 'steps'):
|
||||
for i, step in enumerate(component.steps):
|
||||
visit_component(step, f"{path}.steps[{i}]", depth + 1)
|
||||
|
||||
if hasattr(component, 'mapping') and isinstance(component.mapping, dict):
|
||||
for key, subcomp in component.mapping.items():
|
||||
visit_component(subcomp, f"{path}.mapping[{key}]", depth + 1)
|
||||
|
||||
if hasattr(component, 'llm') and component.llm:
|
||||
visit_component(component.llm, f"{path}.llm", depth + 1)
|
||||
|
||||
if hasattr(component, 'model') and component.model:
|
||||
visit_component(component.model, f"{path}.model", depth + 1)
|
||||
|
||||
print(f"\nTotal parameters found: {len(config)}")
|
||||
return config
|
||||
visit_component(lcel_chain)
|
||||
|
||||
|
||||
# Complete usage example
|
||||
def example_usage(self):
|
||||
"""
|
||||
Complete example showing all extraction methods.
|
||||
"""
|
||||
print("=== LANGCHAIN LCEL CONFIG EXTRACTOR USAGE ===\n")
|
||||
|
||||
print("1. NESTED STRUCTURE EXTRACTION:")
|
||||
print(" nested_config = extract_all_llm_config(chain)")
|
||||
print(" # Returns: Full nested dict preserving chain hierarchy")
|
||||
|
||||
print("\n2. FLATTENED EXTRACTION:")
|
||||
print(" flat_config = extract_flattened_config(chain)")
|
||||
print(" # Returns: Single-level dict with dotted path keys")
|
||||
|
||||
print("\n3. DEBUG OUTPUT:")
|
||||
print(" print_llm_config_debug(chain)")
|
||||
print(" # Prints: Both nested and categorized flat views")
|
||||
|
||||
print("\n4. COMPONENT ITERATION:")
|
||||
print(" iterate_chain_components(chain)")
|
||||
print(" # Shows: Step-by-step traversal of all components")
|
||||
|
||||
print("\nExample output structure:")
|
||||
example_structure = {
|
||||
"_type": "RunnableSequence",
|
||||
"steps": {
|
||||
"step_0": {
|
||||
"_type": "ChatPromptTemplate",
|
||||
"template": "You are a helpful assistant"
|
||||
},
|
||||
"step_1": {
|
||||
"_type": "ChatOpenAI",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"openai_api_key": "sk-...",
|
||||
"_attributes": {
|
||||
"streaming": False,
|
||||
"verbose": False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.print_nested_config(example_structure)
|
||||
|
||||
|
||||
+11
-1
@@ -124,16 +124,22 @@ def rag_config_builder(
|
||||
prompt_injection_example_repository=prompt_injection_example_repository
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llm_configuration_introspection_service():
|
||||
return LLMConfigurationIntrospectionService()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def rag_context_guidelines(
|
||||
foundation_model,
|
||||
response_processing_service,
|
||||
prompt_template_service,
|
||||
llm_configuration_introspection_service,
|
||||
rag_config_builder):
|
||||
return RagContextSecurityGuidelinesService(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
config_builder=rag_config_builder
|
||||
)
|
||||
|
||||
@@ -141,10 +147,12 @@ def rag_context_guidelines(
|
||||
def chain_of_thought_guidelines(
|
||||
foundation_model,
|
||||
response_processing_service,
|
||||
llm_configuration_introspection_service,
|
||||
prompt_template_service):
|
||||
return ChainOfThoughtSecurityGuidelinesService(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
response_processing_service=response_processing_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
prompt_template_service=prompt_template_service
|
||||
)
|
||||
|
||||
@@ -153,11 +161,13 @@ def rag_plus_cot_guidelines(
|
||||
foundation_model,
|
||||
response_processing_service,
|
||||
prompt_template_service,
|
||||
llm_configuration_introspection_service,
|
||||
rag_config_builder):
|
||||
return RagPlusCotSecurityGuidelinesService(
|
||||
foundation_model=foundation_model,
|
||||
response_processing_service=response_processing_service,
|
||||
prompt_template_service=prompt_template_service,
|
||||
llm_configuration_introspection_service=llm_configuration_introspection_service,
|
||||
config_builder=rag_config_builder
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user