More adjustments

This commit is contained in:
henryruhs
2025-03-06 13:20:04 +01:00
parent 7ab7efbbf4
commit f4a1e18ca9
+6 -6
View File
@@ -121,7 +121,7 @@ def create_trainer() -> Trainer:
def train() -> None:
config : ConfigSet =\
config_set : ConfigSet =\
{
'dataset':
{
@@ -133,7 +133,7 @@ def train() -> None:
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
},
'common':
'output':
{
'resume_path': CONFIG.get('training.output', 'resume_path')
}
@@ -142,12 +142,12 @@ def train() -> None:
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(config.get('dataset'))
dataset = StaticDataset(config_set.get('dataset'))
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer(config.get('trainer'))
embedding_converter_trainer = EmbeddingConverterTrainer(config_set.get('trainer'))
trainer = create_trainer()
if os.path.exists(config.get('common').get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('common').get('resume_path'))
if os.path.exists(config_set.get('output').get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_set.get('output').get('resume_path'))
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)