mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new AdversarialLoss class
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import configparser
|
||||
from typing import List, Tuple
|
||||
from warnings import deprecated
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
@@ -79,7 +78,6 @@ class FaceSwapperLoss:
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
@deprecated
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarials = []
|
||||
|
||||
@@ -98,7 +96,6 @@ class FaceSwapperLoss:
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
@deprecated
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
@@ -107,7 +104,6 @@ class FaceSwapperLoss:
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
@deprecated
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
|
||||
Reference in New Issue
Block a user