More adjustments

This commit is contained in:
henryruhs
2025-03-06 13:08:29 +01:00
parent 57aad5204e
commit 847579f925
+23 -19
View File
@@ -95,9 +95,9 @@ def create_trainer() -> Trainer:
config =\
{
'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'),
'precision': CONFIG.get('training.trainer', 'precision'),
'directory_path': CONFIG.get('training.output', 'directory_path'),
'file_pattern': CONFIG.get('training.output', 'file_pattern'),
'precision': CONFIG.get('training.trainer', 'precision')
'file_pattern': CONFIG.get('training.output', 'file_pattern')
}
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
@@ -105,7 +105,7 @@ def create_trainer() -> Trainer:
logger = logger,
log_every_n_steps = 10,
max_epochs = config.get('max_epochs'),
precision = config.get('precision'), # type:ignore[arg-type]
precision = config.get('precision'),
callbacks =
[
ModelCheckpoint(
@@ -121,30 +121,34 @@ def create_trainer() -> Trainer:
def train() -> None:
config_dataset =\
config =\
{
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
}
config_trainer =\
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
}
config_common =\
{
'resume_path': CONFIG.get('training.output', 'resume_path')
'dataset':
{
'file_pattern': CONFIG.get('training.dataset', 'file_pattern')
},
'trainer':
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
},
'common':
{
'resume_path': CONFIG.get('training.output', 'resume_path')
}
}
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(config_dataset)
dataset = StaticDataset(config.get('dataset'))
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
embedding_converter_trainer = EmbeddingConverterTrainer(config.get('trainer'))
trainer = create_trainer()
if os.path.exists(config_common.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path'))
if os.path.exists(config.get('common').get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('common').get('resume_path'))
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)