diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index f5a62b0..4b9a261 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -124,8 +124,8 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(generator_optimizer) self.toggle_optimizer(masker_optimizer) - target_attribute = generator_output_attributes[-1].detach() - mask_tensor = self.masker(target_tensor, target_attribute) + target_attribute = generator_output_attributes[-1] + mask_tensor = self.masker(target_tensor, target_attribute.detach()) mask_loss = self.mask_loss(target_tensor, mask_tensor) self.manual_backward(mask_loss)