Fix attribute loss

This commit is contained in:
henryruhs
2025-03-12 22:37:06 +01:00
parent 8f0ee4935b
commit 0d73bcf918
2 changed files with 8 additions and 4 deletions
+6 -3
View File
@@ -16,10 +16,13 @@ class Generator(nn.Module):
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
target_attributes = self.encoder(target_tensor)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.encode_attributes(target_tensor)
output_tensor = self.generator(source_embedding, target_attributes)
return output_tensor, target_attributes
return output_tensor
def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
return self.encoder(input_tensor)
def init_weight(module : nn.Module) -> None:
+2 -1
View File
@@ -104,10 +104,11 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
generator_target_attributes = self.generator.encode_attributes(target_tensor)
generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_output_attributes, generator_output_attributes)
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)