mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
More adjustments
This commit is contained in:
@@ -50,7 +50,7 @@ class EmbeddingConverterTrainer(LightningModule):
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
config =\
|
||||
optimizer_set =\
|
||||
{
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler':
|
||||
@@ -62,7 +62,7 @@ class EmbeddingConverterTrainer(LightningModule):
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
return optimizer_set
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
@@ -124,13 +124,16 @@ def train() -> None:
|
||||
config_dataset =\
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
config_trainer =\
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate'),
|
||||
}
|
||||
config_common =\
|
||||
{
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
@@ -141,7 +144,7 @@ def train() -> None:
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config_dataset.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path'))
|
||||
if os.path.exists(config_common.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path'))
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user