Rename loss_id to loss_identity

This commit is contained in:
henryruhs
2025-02-12 12:38:07 +01:00
parent 67ad9badac
commit 4d2038d4ce
2 changed files with 8 additions and 8 deletions
+4 -4
View File
@@ -37,7 +37,7 @@ class FaceSwapperLoss:
generator_loss_set =\
{
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
'loss_id': self.calc_id_loss(source_tensor, swap_tensor),
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
}
@@ -101,11 +101,11 @@ class FaceSwapperLoss:
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_id
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_identity
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
+4 -4
View File
@@ -63,10 +63,10 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True)
self.log('loss_adversarial', generator_losses.get('loss_adversarial'), prog_bar = True)
self.log('loss_attribute', generator_losses.get('loss_attribute'), prog_bar = True)
self.log('loss_id', generator_losses.get('loss_id'), prog_bar = True)
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'), prog_bar = True)
self.log('loss_adversarial', generator_losses.get('loss_adversarial'))
self.log('loss_attribute', generator_losses.get('loss_attribute'))
self.log('loss_identity', generator_losses.get('loss_identity'))
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
return generator_losses.get('loss_generator')
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: