This commit is contained in:
harisreedhar
2025-02-27 16:18:53 +05:30
committed by henryruhs
parent d87f6c0b15
commit b27b8663e5
+6 -8
View File
@@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
generator_config =\
{
@@ -82,6 +82,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
self.toggle_optimizer(generator_optimizer)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
@@ -93,10 +94,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
generator_optimizer.step()
self.untoggle_optimizer(generator_optimizer)
generator_scheduler = self.lr_schedulers()[0]
generator_scheduler.step()
self.toggle_optimizer(discriminator_optimizer)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
@@ -104,9 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss)
discriminator_optimizer.step()
discriminator_scheduler = self.lr_schedulers()[1]
discriminator_scheduler.step()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)