mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Rename loss_id to loss_identity
This commit is contained in:
@@ -37,7 +37,7 @@ class FaceSwapperLoss:
|
||||
generator_loss_set =\
|
||||
{
|
||||
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
|
||||
'loss_id': self.calc_id_loss(source_tensor, swap_tensor),
|
||||
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
|
||||
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
}
|
||||
@@ -101,11 +101,11 @@ class FaceSwapperLoss:
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_id
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_identity
|
||||
|
||||
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
|
||||
@@ -63,10 +63,10 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
|
||||
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_losses.get('loss_adversarial'), prog_bar = True)
|
||||
self.log('loss_attribute', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('loss_id', generator_losses.get('loss_id'), prog_bar = True)
|
||||
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_losses.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_losses.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_losses.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
|
||||
|
||||
Reference in New Issue
Block a user