From a797548329fe2bb1c0bb7783eba3dfe11e453ac9 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 22 Feb 2025 23:15:39 +0100 Subject: [PATCH] Introduce new PoseLoss class (switched to mean) --- face_swapper/src/models/loss.py | 40 ++++++++++++++++++++++++--------- face_swapper/src/training.py | 20 +++++++++++------ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 4f5c621..c47068a 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -3,7 +3,6 @@ from typing import List, Tuple import torch from pytorch_msssim import ssim -from sqlalchemy.dialects.mssql.information_schema import identity_columns from torch import Tensor, nn from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss @@ -44,15 +43,8 @@ class FaceSwapperLoss: 'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) } - if weight_pose > 0: - generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - if weight_gaze > 0: - generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) + generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity @@ -214,3 +206,31 @@ class IdentityLoss(torch.nn.Module): identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() weighted_identity_loss = identity_loss * identity_weight return identity_loss, weighted_identity_loss + + +class PoseLoss(torch.nn.Module): + def __init__(self) -> None: + super(PoseLoss, self).__init__() + motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.mse_loss = nn.MSELoss() + + def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + pose_weight = CONFIG.getfloat('training.losses', 'pose_weight') + output_motion_features = self.get_motion_features(output_tensor) + target_motion_features = self.get_motion_features(target_tensor) + temp_tensors = [] + + for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features): + temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature) + temp_tensors.append(temp_tensor) + + pose_loss = torch.stack(temp_tensors).mean() + weighted_pose_loss = pose_loss * pose_weight + return pose_loss, weighted_pose_loss + + def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: + vision_tensor_norm = (input_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) + return translation, scale, rotation diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e52c6cd..1879549 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -35,6 +35,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.attribute_loss = AttributeLoss() self.reconstruction_loss = ReconstructionLoss() self.identity_loss = IdentityLoss() + self.pose_loss = PoseLoss() self.automatic_optimization = automatic_optimization def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor: @@ -77,9 +78,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator')) self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True) - self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True) + self.log('loss_attribute', generator_loss_set.get('loss_attribute')) self.log('loss_identity', generator_loss_set.get('loss_identity')) self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) + self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True) ############################################### @@ -87,13 +89,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) - generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor) + generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss self.log('generator_loss_new', generator_loss, prog_bar = True) self.log('adversarial_loss_new', adversarial_loss) - self.log('attribute_loss_new', attribute_loss, prog_bar = True) + self.log('attribute_loss_new', attribute_loss) self.log('reconstruction_loss_new', reconstruction_loss) self.log('identity_loss_new', identity_loss) + self.log('pose_loss_new', pose_loss, prog_bar = True) return generator_loss_set.get('loss_generator') def validation_step(self, batch : Batch, batch_index : int) -> Tensor: @@ -107,12 +111,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None: preview_limit = 8 - preview_items = [] + preview_cells = [] for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]): - preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)) + preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2) + preview_cells.append(preview_cell) - preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True) + preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0) + preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True) self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]