mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Revert loss for the moment
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import numpy
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
@@ -45,9 +44,3 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
return source_embedding
|
||||
|
||||
|
||||
def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor:
|
||||
swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))
|
||||
structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean()
|
||||
return structural_similarity
|
||||
|
||||
@@ -2,9 +2,10 @@ import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss
|
||||
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -26,19 +27,20 @@ class FaceSwapperLoss:
|
||||
self.motion_extractor.eval()
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
|
||||
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
|
||||
source_tensor, target_tensor = batch
|
||||
is_same_person = torch.tensor(0) if source_tensor == target_tensor else torch.tensor(1)
|
||||
generator_loss_set =\
|
||||
{
|
||||
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
|
||||
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
|
||||
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor)
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
}
|
||||
|
||||
if weight_pose > 0:
|
||||
@@ -94,23 +96,12 @@ class FaceSwapperLoss:
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
|
||||
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
|
||||
loss_reconstructions = []
|
||||
|
||||
for index, face_similarity in enumerate(face_similarities):
|
||||
if face_similarity.item() > 0.9:
|
||||
loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
|
||||
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
|
||||
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
|
||||
loss_reconstructions.append(loss_reconstruction)
|
||||
else:
|
||||
loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype))
|
||||
|
||||
loss_reconstruction = torch.stack(loss_reconstructions).mean()
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
|
||||
Reference in New Issue
Block a user