diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index fca3105..ab880fb 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -120,7 +120,7 @@ class FaceSwapperTrainer(LightningModule): generator_output_attribute = generator_output_attributes[-1] mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach()) - mask_loss = self.mask_loss(generator_output_tensor.detach(), mask_tensor) + mask_loss = self.mask_loss(target_tensor, mask_tensor) self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss, retain_graph = True)