diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index a492df9..91aea51 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -121,7 +121,7 @@ def create_trainer() -> Trainer: def train() -> None: - config : ConfigSet =\ + config_set : ConfigSet =\ { 'dataset': { @@ -133,7 +133,7 @@ def train() -> None: 'target_path': CONFIG.get('training.model', 'target_path'), 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') }, - 'common': + 'output': { 'resume_path': CONFIG.get('training.output', 'resume_path') } @@ -142,12 +142,12 @@ def train() -> None: if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = StaticDataset(config.get('dataset')) + dataset = StaticDataset(config_set.get('dataset')) training_loader, validation_loader = create_loaders(dataset) - embedding_converter_trainer = EmbeddingConverterTrainer(config.get('trainer')) + embedding_converter_trainer = EmbeddingConverterTrainer(config_set.get('trainer')) trainer = create_trainer() - if os.path.exists(config.get('common').get('resume_path')): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('common').get('resume_path')) + if os.path.exists(config_set.get('output').get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_set.get('output').get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader)