From 579d3ef51cc33c2ac780b3288215cbfe43d070e6 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 22 Feb 2025 23:48:50 +0100 Subject: [PATCH] Introduce new GazeLoss class (switched to mean) --- embedding_converter/src/training.py | 2 -- face_swapper/src/models/loss.py | 30 +++++++++++++++++++++++++---- face_swapper/src/training.py | 12 ++++++++---- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 09ebe54..93e3ac2 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -28,8 +28,6 @@ class EmbeddingConverterTrainer(lightning.LightningModule): self.embedding_converter = EmbeddingConverter() self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.source_embedder.eval() - self.target_embedder.eval() self.mse_loss = nn.MSELoss() def forward(self, source_embedding : Embedding) -> Embedding: diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index c47068a..801257e 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -22,9 +22,6 @@ class FaceSwapperLoss: self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.embedder.eval() - self.landmarker.eval() - self.motion_extractor.eval() def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') @@ -197,7 +194,6 @@ class IdentityLoss(torch.nn.Module): super(IdentityLoss, self).__init__() embedder_path = CONFIG.get('training.model', 'embedder_path') self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.embedder.eval() def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') @@ -234,3 +230,29 @@ class PoseLoss(torch.nn.Module): pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) rotation = torch.cat([ pitch, yaw, roll ], dim = 1) return translation, scale, rotation + + +class GazeLoss(torch.nn.Module): + def __init__(self) -> None: + super(GazeLoss, self).__init__() + landmarker_path = CONFIG.get('training.model', 'landmarker_path') + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.mse_loss = nn.MSELoss() + + def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') + output_face_landmark = self.detect_face_landmark(output_tensor) + target_face_landmark = self.detect_face_landmark(target_tensor) + + left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198]) + right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197]) + + gaze_loss = left_gaze_loss + right_gaze_loss + weighted_gaze_loss = gaze_loss * gaze_weight + return gaze_loss, weighted_gaze_loss + + def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203: + input_tensor = (input_tensor + 1) * 0.5 + input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear') + face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2) + return face_landmarks_203 diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 1879549..0c2ab5e 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -36,6 +36,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.reconstruction_loss = ReconstructionLoss() self.identity_loss = IdentityLoss() self.pose_loss = PoseLoss() + self.gaze_loss = GazeLoss() self.automatic_optimization = automatic_optimization def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor: @@ -77,11 +78,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator')) - self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True) + self.log('loss_adversarial', generator_loss_set.get('loss_adversarial')) self.log('loss_attribute', generator_loss_set.get('loss_attribute')) self.log('loss_identity', generator_loss_set.get('loss_identity')) self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) - self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True) + self.log('loss_pose', generator_loss_set.get('loss_pose')) + self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True) ############################################### @@ -90,6 +92,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor) + gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor) generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss self.log('generator_loss_new', generator_loss, prog_bar = True) @@ -97,7 +100,8 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('attribute_loss_new', attribute_loss) self.log('reconstruction_loss_new', reconstruction_loss) self.log('identity_loss_new', identity_loss) - self.log('pose_loss_new', pose_loss, prog_bar = True) + self.log('pose_loss_new', pose_loss) + self.log('gaze_loss_new', gaze_loss, prog_bar = True) return generator_loss_set.get('loss_generator') def validation_step(self, batch : Batch, batch_index : int) -> Tensor: