diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a36b58d..0eac0e2 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -85,7 +85,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) - generator_loss = weighted_adversarial_loss+ weighted_reconstruction_loss + weighted_identity_loss + generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss self.log('generator_loss_new', generator_loss, prog_bar = True) self.log('adversarial_loss_new', adversarial_loss, prog_bar = True)