diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 3811118..cd9a41a 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,6 +1,5 @@ import numpy import torch -from pytorch_msssim import ssim from torch import Tensor, nn from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor @@ -45,9 +44,3 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor source_embedding = id_embedder(crop_vision_tensor) source_embedding = nn.functional.normalize(source_embedding, p = 2) return source_embedding - - -def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor: - swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor)) - structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean() - return structural_similarity diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 9ef34b9..73db988 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -2,9 +2,10 @@ import configparser from typing import Tuple import torch +from pytorch_msssim import ssim from torch import Tensor, nn -from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss +from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() @@ -26,19 +27,20 @@ class FaceSwapperLoss: self.motion_extractor.eval() def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - source_tensor, target_tensor = batch weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') + source_tensor, target_tensor = batch + is_same_person = torch.tensor(0) if source_tensor == target_tensor else torch.tensor(1) generator_loss_set =\ { 'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs), 'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor), 'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes), - 'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor) + 'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) } if weight_pose > 0: @@ -94,23 +96,12 @@ class FaceSwapperLoss: loss_attribute = torch.stack(loss_attributes).mean() * 0.5 return loss_attribute - def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - with torch.no_grad(): - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0)) - face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5 - loss_reconstructions = [] - - for index, face_similarity in enumerate(face_similarities): - if face_similarity.item() > 0.9: - loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0)) - loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0)) - loss_reconstruction = (loss_mse + loss_ssim) * 0.5 - loss_reconstructions.append(loss_reconstruction) - else: - loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)) - - loss_reconstruction = torch.stack(loss_reconstructions).mean() + def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: + loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) + loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 + loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) + loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() + loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 return loss_reconstruction def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: