This commit is contained in:
harisreedhar
2025-02-26 14:01:41 +05:30
committed by henryruhs
parent 578b07a7f4
commit ab0a59fb74
+4 -6
View File
@@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer)
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6)
generator_config =\
{
@@ -58,7 +58,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
{
'scheduler': generator_scheduler,
'interval': 'step',
'frequency': 1000
}
}
discriminator_config =\
@@ -68,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
{
'scheduler': discriminator_scheduler,
'interval': 'step',
'frequency': 1000
}
}
return generator_config, discriminator_config
@@ -97,7 +95,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
generator_optimizer.step()
generator_scheduler = self.lr_schedulers()[0]
generator_scheduler.step(generator_loss)
generator_scheduler.step()
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
@@ -108,7 +106,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
discriminator_optimizer.step()
discriminator_scheduler = self.lr_schedulers()[1]
discriminator_scheduler.step(discriminator_loss)
discriminator_scheduler.step()
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)