This commit is contained in:
harisreedhar
2025-01-17 16:14:57 +05:30
committed by henryruhs
parent 008a221f55
commit e45f46d355
2 changed files with 10 additions and 10 deletions
+7 -7
View File
@@ -17,7 +17,7 @@ from .data_loader import DataLoaderVGG
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import hinge_fake_loss, hinge_real_loss
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -88,7 +88,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
return generator_losses.get('loss_generator')
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss:
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
@@ -96,7 +96,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss:
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
swap_attributes = self.generator.get_attributes(swap_tensor)
@@ -105,19 +105,19 @@ class FaceSwapper(pytorch_lightning.LightningModule):
loss_attribute *= 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss:
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss:
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
return loss_id
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
@@ -126,7 +126,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_tsr
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
+3 -3
View File
@@ -14,6 +14,6 @@ StateDict = OrderedDict[str, Any]
Padding = Tuple[int, int, int, int]
FaceLandmark203 = Tensor
VisionTensor = Tensor
Loss = Tensor
GeneratorLossSet = Dict[str, Loss]
DiscriminatorLossSet = Dict[str, Loss]
LossTensor = Tensor
GeneratorLossSet = Dict[str, LossTensor]
DiscriminatorLossSet = Dict[str, LossTensor]