mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
@@ -58,7 +58,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'interval': 'step',
|
||||
'frequency': 1000
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
@@ -68,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'interval': 'step',
|
||||
'frequency': 1000
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
@@ -97,7 +95,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
generator_optimizer.step()
|
||||
|
||||
generator_scheduler = self.lr_schedulers()[0]
|
||||
generator_scheduler.step(generator_loss)
|
||||
generator_scheduler.step()
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
@@ -108,7 +106,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
discriminator_optimizer.step()
|
||||
|
||||
discriminator_scheduler = self.lr_schedulers()[1]
|
||||
discriminator_scheduler.step(discriminator_loss)
|
||||
discriminator_scheduler.step()
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
Reference in New Issue
Block a user