From 0d73bcf918d85fee3df1c072f126a8c6fb8acc10 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 22:37:06 +0100 Subject: [PATCH] Fix attribute loss --- face_swapper/src/models/generator.py | 9 ++++++--- face_swapper/src/training.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 6fd0097..2ddf748 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -16,10 +16,13 @@ class Generator(nn.Module): self.encoder.apply(init_weight) self.generator.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]: - target_attributes = self.encoder(target_tensor) + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: + target_attributes = self.encode_attributes(target_tensor) output_tensor = self.generator(source_embedding, target_attributes) - return output_tensor, target_attributes + return output_tensor + + def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: + return self.encoder(input_tensor) def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 5ebde71..eaacfa3 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -104,10 +104,11 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + generator_target_attributes = self.generator.encode_attributes(target_tensor) generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) - attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_output_attributes, generator_output_attributes) + attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)