Train MaskNet based on the output

This commit is contained in:
henryruhs
2025-03-12 21:42:08 +01:00
parent cf0bd93814
commit 72591fbed1
+1 -1
View File
@@ -120,7 +120,7 @@ class FaceSwapperTrainer(LightningModule):
generator_output_attribute = generator_output_attributes[-1]
mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
mask_loss = self.mask_loss(generator_output_tensor.detach(), mask_tensor)
mask_loss = self.mask_loss(target_tensor, mask_tensor)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss, retain_graph = True)