mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
@@ -82,6 +82,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
@@ -93,10 +94,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_loss)
|
||||
generator_optimizer.step()
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
generator_scheduler = self.lr_schedulers()[0]
|
||||
generator_scheduler.step()
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
@@ -104,9 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss)
|
||||
discriminator_optimizer.step()
|
||||
|
||||
discriminator_scheduler = self.lr_schedulers()[1]
|
||||
discriminator_scheduler.step()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
Reference in New Issue
Block a user