Fix attribute loss

This commit is contained in:
henryruhs
2025-03-12 22:54:20 +01:00
parent 7c75b0d898
commit bdd7fd0d86
+2 -2
View File
@@ -104,8 +104,8 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
generator_target_attributes = self.generator.encode_attributes(target_tensor)
generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor)
generator_output_tensor, generator_target_attributes = self.generator(source_embedding, target_tensor)
generator_output_attributes = self.generator.encode_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes)