diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 517b88b..01a9d12 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -37,7 +37,7 @@ class FaceSwapperLoss: generator_loss_set =\ { 'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs), - 'loss_id': self.calc_id_loss(source_tensor, swap_tensor), + 'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor), 'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes), 'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) } @@ -101,11 +101,11 @@ class FaceSwapperLoss: loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 return loss_reconstruction - def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: + def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) - loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() - return loss_id + loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() + return loss_identity def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: swap_motion_features = self.get_pose_features(swap_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index b267d93..fc20ead 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -63,10 +63,10 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True) self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True) - self.log('loss_adversarial', generator_losses.get('loss_adversarial'), prog_bar = True) - self.log('loss_attribute', generator_losses.get('loss_attribute'), prog_bar = True) - self.log('loss_id', generator_losses.get('loss_id'), prog_bar = True) - self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'), prog_bar = True) + self.log('loss_adversarial', generator_losses.get('loss_adversarial')) + self.log('loss_attribute', generator_losses.get('loss_attribute')) + self.log('loss_identity', generator_losses.get('loss_identity')) + self.log('loss_reconstruction', generator_losses.get('loss_reconstruction')) return generator_losses.get('loss_generator') def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: