mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
fix discriminator training
This commit is contained in:
@@ -47,14 +47,16 @@ class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
real_discriminator_outputs = self.discriminator(source_tensor)
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor)
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'), retain_graph = True)
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
real_discriminator_outputs = self.discriminator(source_tensor)
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
|
||||
Reference in New Issue
Block a user