From 847579f92556e69d2244b24f449be2cbf04b59a5 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 6 Mar 2025 13:08:29 +0100 Subject: [PATCH] More adjustments --- embedding_converter/src/training.py | 42 ++++++++++++++++------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 5ee8dde..4bf5b78 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -95,9 +95,9 @@ def create_trainer() -> Trainer: config =\ { 'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'), + 'precision': CONFIG.get('training.trainer', 'precision'), 'directory_path': CONFIG.get('training.output', 'directory_path'), - 'file_pattern': CONFIG.get('training.output', 'file_pattern'), - 'precision': CONFIG.get('training.trainer', 'precision') + 'file_pattern': CONFIG.get('training.output', 'file_pattern') } logger = TensorBoardLogger('.logs', name = 'embedding_converter') @@ -105,7 +105,7 @@ def create_trainer() -> Trainer: logger = logger, log_every_n_steps = 10, max_epochs = config.get('max_epochs'), - precision = config.get('precision'), # type:ignore[arg-type] + precision = config.get('precision'), callbacks = [ ModelCheckpoint( @@ -121,30 +121,34 @@ def create_trainer() -> Trainer: def train() -> None: - config_dataset =\ + config =\ { - 'file_pattern': CONFIG.get('training.dataset', 'file_pattern'), - } - config_trainer =\ - { - 'source_path': CONFIG.get('training.model', 'source_path'), - 'target_path': CONFIG.get('training.model', 'target_path'), - 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') - } - config_common =\ - { - 'resume_path': CONFIG.get('training.output', 'resume_path') + 'dataset': + { + 'file_pattern': CONFIG.get('training.dataset', 'file_pattern') + }, + 'trainer': + { + 'source_path': CONFIG.get('training.model', 'source_path'), + 'target_path': CONFIG.get('training.model', 'target_path'), + 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') + }, + 'common': + { + 'resume_path': CONFIG.get('training.output', 'resume_path') + } } + if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = StaticDataset(config_dataset) + dataset = StaticDataset(config.get('dataset')) training_loader, validation_loader = create_loaders(dataset) - embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer) + embedding_converter_trainer = EmbeddingConverterTrainer(config.get('trainer')) trainer = create_trainer() - if os.path.exists(config_common.get('resume_path')): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path')) + if os.path.exists(config.get('common').get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('common').get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader)