Introduce new AdversarialLoss class

This commit is contained in:
henryruhs
2025-02-22 16:28:00 +01:00
parent 38211f0340
commit 30e787129a
-4
View File
@@ -1,6 +1,5 @@
import configparser
from typing import List, Tuple
from warnings import deprecated
import torch
from pytorch_msssim import ssim
@@ -79,7 +78,6 @@ class FaceSwapperLoss:
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
return discriminator_loss_set
@deprecated
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarials = []
@@ -98,7 +96,6 @@ class FaceSwapperLoss:
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
@deprecated
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
@@ -107,7 +104,6 @@ class FaceSwapperLoss:
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
@deprecated
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))