Revert loss for the moment

This commit is contained in:
henryruhs
2025-02-21 16:53:32 +01:00
parent c17378f3c7
commit 5a6e3393e2
2 changed files with 11 additions and 27 deletions
-7
View File
@@ -1,6 +1,5 @@
import numpy
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
@@ -45,9 +44,3 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
return source_embedding
def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor:
swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))
structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean()
return structural_similarity
+11 -20
View File
@@ -2,9 +2,10 @@ import configparser
from typing import Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -26,19 +27,20 @@ class FaceSwapperLoss:
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
source_tensor, target_tensor = batch
is_same_person = torch.tensor(0) if source_tensor == target_tensor else torch.tensor(1)
generator_loss_set =\
{
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor)
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
}
if weight_pose > 0:
@@ -94,23 +96,12 @@ class FaceSwapperLoss:
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
with torch.no_grad():
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
loss_reconstructions = []
for index, face_similarity in enumerate(face_similarities):
if face_similarity.item() > 0.9:
loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
loss_reconstructions.append(loss_reconstruction)
else:
loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype))
loss_reconstruction = torch.stack(loss_reconstructions).mean()
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: