More adjustments

This commit is contained in:
henryruhs
2025-03-06 13:00:06 +01:00
parent d944d95bfd
commit d3b0051912
+9 -6
View File
@@ -50,7 +50,7 @@ class EmbeddingConverterTrainer(LightningModule):
def configure_optimizers(self) -> OptimizerSet:
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
config =\
optimizer_set =\
{
'optimizer': optimizer,
'lr_scheduler':
@@ -62,7 +62,7 @@ class EmbeddingConverterTrainer(LightningModule):
}
}
return config
return optimizer_set
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
@@ -124,13 +124,16 @@ def train() -> None:
config_dataset =\
{
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
'resume_path': CONFIG.get('training.output', 'resume_path')
}
config_trainer =\
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate'),
}
config_common =\
{
'resume_path': CONFIG.get('training.output', 'resume_path')
}
if torch.cuda.is_available():
@@ -141,7 +144,7 @@ def train() -> None:
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
trainer = create_trainer()
if os.path.exists(config_dataset.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path'))
if os.path.exists(config_common.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path'))
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)