From d3b0051912c8bbf57b72be917ccc41249739b35c Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 6 Mar 2025 13:00:06 +0100 Subject: [PATCH] More adjustments --- embedding_converter/src/training.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 4198673..9dc28c8 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -50,7 +50,7 @@ class EmbeddingConverterTrainer(LightningModule): def configure_optimizers(self) -> OptimizerSet: optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate')) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) - config =\ + optimizer_set =\ { 'optimizer': optimizer, 'lr_scheduler': @@ -62,7 +62,7 @@ class EmbeddingConverterTrainer(LightningModule): } } - return config + return optimizer_set def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: @@ -124,13 +124,16 @@ def train() -> None: config_dataset =\ { 'file_pattern': CONFIG.get('training.dataset', 'file_pattern'), - 'resume_path': CONFIG.get('training.output', 'resume_path') } config_trainer =\ { 'source_path': CONFIG.get('training.model', 'source_path'), 'target_path': CONFIG.get('training.model', 'target_path'), - 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') + 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate'), + } + config_common =\ + { + 'resume_path': CONFIG.get('training.output', 'resume_path') } if torch.cuda.is_available(): @@ -141,7 +144,7 @@ def train() -> None: embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer) trainer = create_trainer() - if os.path.exists(config_dataset.get('resume_path')): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path')) + if os.path.exists(config_common.get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader)