some fixes

This commit is contained in:
harisreedhar
2025-02-15 19:04:43 +05:30
committed by henryruhs
parent 0e148845af
commit 030d912c1b
6 changed files with 28 additions and 21 deletions
+2 -2
View File
@@ -3,8 +3,8 @@ from typing import List
from torch import nn
from face_swapper.src.networks.nld import NLD
from face_swapper.src.types import VisionTensor
from ..networks.nld import NLD
from ..types import VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
+3 -3
View File
@@ -3,9 +3,9 @@ from typing import Tuple
from torch import nn
from face_swapper.src.networks.attribute_modulator import AADGenerator
from face_swapper.src.networks.unet import UNet
from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor
from ..networks.attribute_modulator import AADGenerator
from ..networks.unet import UNet
from ..types import Embedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
+15 -13
View File
@@ -5,8 +5,8 @@ import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -62,35 +62,37 @@ class FaceSwapperLoss:
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fake = torch.Tensor(0)
loss_fakes = []
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
loss_true = torch.Tensor(0)
loss_trues = []
for true_discriminator_output in real_discriminator_outputs:
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
loss_fake = torch.stack(loss_fakes).mean()
loss_true = torch.stack(loss_trues).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
loss_adversarials = []
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0])
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
loss_adversarial = torch.mean(loss_adversarial)
loss_adversarial = torch.stack(loss_adversarials).mean()
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
loss_attributes = []
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
loss_attribute *= 0.5
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: