From dfd9e99aed722f6dd7fdb1c090403670e6f54cfd Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 3 Mar 2025 16:41:26 +0530 Subject: [PATCH] changes --- face_swapper/config.ini | 2 +- face_swapper/src/models/loss.py | 36 ++++++++++++++++++++------------- face_swapper/src/training.py | 6 +++--- face_swapper/src/types.py | 4 ++-- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 72233e8..4cf9f80 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -10,7 +10,7 @@ split_ratio = [training.model] embedder_path = -landmarker_path = +gazer_path = motion_extractor_path = [training.model.generator] diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index c2be15e..8ab6533 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -4,9 +4,10 @@ from typing import List, Tuple import torch from pytorch_msssim import ssim from torch import Tensor, nn +from torchvision import transforms from ..helper import calc_embedding -from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule +from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -133,25 +134,32 @@ class PoseLoss(nn.Module): class GazeLoss(nn.Module): - def __init__(self, landmarker : LandmarkerModule) -> None: + def __init__(self, gazer : GazerModule) -> None: super().__init__() - self.landmarker = landmarker - self.mse_loss = nn.MSELoss() + self.gazer = gazer + self.mae_loss = nn.L1Loss() + self.transform = transforms.Compose( + [ + transforms.Resize(448), + transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]) + ]) def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') - output_face_landmark = self.detect_face_landmark(output_tensor) - target_face_landmark = self.detect_face_landmark(target_tensor) + output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor) + target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor) - left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198]) - right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197]) + pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor) + yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor) - gaze_loss = left_gaze_loss + right_gaze_loss + gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5 weighted_gaze_loss = gaze_loss * gaze_weight return gaze_loss, weighted_gaze_loss - def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203: - input_tensor = (input_tensor + 1) * 0.5 - input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear') - face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2) - return face_landmarks_203 + def detect_gaze(self, input_tensor : Tensor) -> Gaze: + crop_tensor = input_tensor[:, :, 60: 224, 16: 205] + crop_tensor = (crop_tensor + 1) * 0.5 + crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) + crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic') + pitch_tensor, yaw_tensor = self.gazer(crop_tensor) + return pitch_tensor, yaw_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index df97894..91f20bf 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -30,11 +30,11 @@ class FaceSwapperTrainer(lightning.LightningModule): def __init__(self) -> None: super().__init__() embedder_path = CONFIG.get('training.model', 'embedder_path') - landmarker_path = CONFIG.get('training.model', 'landmarker_path') + gazer_path = CONFIG.get('training.model', 'gazer_path') motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] + self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] self.generator = Generator() @@ -45,7 +45,7 @@ class FaceSwapperTrainer(lightning.LightningModule): self.reconstruction_loss = ReconstructionLoss(self.embedder) self.identity_loss = IdentityLoss(self.embedder) self.pose_loss = PoseLoss(self.motion_extractor) - self.gaze_loss = GazeLoss(self.landmarker) + self.gaze_loss = GazeLoss(self.gazer) self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 7742764..44a1e8a 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -7,13 +7,13 @@ Batch : TypeAlias = Tuple[Tensor, Tensor] Attributes : TypeAlias = Tuple[Tensor, ...] Embedding : TypeAlias = Tensor -FaceLandmark203 : TypeAlias = Tensor +Gaze : TypeAlias = Tuple[Tensor, Tensor] Padding : TypeAlias = Tuple[int, int, int, int] GeneratorModule : TypeAlias = Module EmbedderModule : TypeAlias = Module -LandmarkerModule : TypeAlias = Module +GazerModule : TypeAlias = Module MotionExtractorModule : TypeAlias = Module OptimizerConfig : TypeAlias = Any