mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
some fixes
This commit is contained in:
@@ -3,8 +3,8 @@ from typing import List
|
||||
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.nld import NLD
|
||||
from face_swapper.src.types import VisionTensor
|
||||
from ..networks.nld import NLD
|
||||
from ..types import VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.attribute_modulator import AADGenerator
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor
|
||||
from ..networks.attribute_modulator import AADGenerator
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Embedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -62,35 +62,37 @@ class FaceSwapperLoss:
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fake = torch.Tensor(0)
|
||||
loss_fakes = []
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
|
||||
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
|
||||
|
||||
loss_true = torch.Tensor(0)
|
||||
loss_trues = []
|
||||
|
||||
for true_discriminator_output in real_discriminator_outputs:
|
||||
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
|
||||
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
|
||||
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
loss_fake = torch.stack(loss_fakes).mean()
|
||||
loss_true = torch.stack(loss_trues).mean()
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
loss_adversarials = []
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_real_loss(discriminator_output[0])
|
||||
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
|
||||
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
loss_adversarial = torch.stack(loss_adversarials).mean()
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
loss_attributes = []
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
|
||||
|
||||
loss_attribute *= 0.5
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
|
||||
Reference in New Issue
Block a user