diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index f62cd34..6a62ef1 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -47,14 +47,16 @@ class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss): source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) swap_attributes = self.generator.get_attributes(swap_tensor) - real_discriminator_outputs = self.discriminator(source_tensor) - fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) + fake_discriminator_outputs = self.discriminator(swap_tensor) generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch) generator_optimizer.zero_grad() - self.manual_backward(generator_losses.get('loss_generator'), retain_graph = True) + self.manual_backward(generator_losses.get('loss_generator')) generator_optimizer.step() + real_discriminator_outputs = self.discriminator(source_tensor) + fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) + discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs) discriminator_optimizer.zero_grad() self.manual_backward(discriminator_losses.get('loss_discriminator'))