This commit is contained in:
harisreedhar
2025-03-03 16:41:26 +05:30
committed by henryruhs
parent 2fb0b4289d
commit dfd9e99aed
4 changed files with 28 additions and 20 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ split_ratio =
[training.model]
embedder_path =
landmarker_path =
gazer_path =
motion_extractor_path =
[training.model.generator]
+22 -14
View File
@@ -4,9 +4,10 @@ from typing import List, Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -133,25 +134,32 @@ class PoseLoss(nn.Module):
class GazeLoss(nn.Module):
def __init__(self, landmarker : LandmarkerModule) -> None:
def __init__(self, gazer : GazerModule) -> None:
super().__init__()
self.landmarker = landmarker
self.mse_loss = nn.MSELoss()
self.gazer = gazer
self.mae_loss = nn.L1Loss()
self.transform = transforms.Compose(
[
transforms.Resize(448),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
])
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_face_landmark = self.detect_face_landmark(output_tensor)
target_face_landmark = self.detect_face_landmark(target_tensor)
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198])
right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197])
pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor)
yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor)
gaze_loss = left_gaze_loss + right_gaze_loss
gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5
weighted_gaze_loss = gaze_loss * gaze_weight
return gaze_loss, weighted_gaze_loss
def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203:
input_tensor = (input_tensor + 1) * 0.5
input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear')
face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2)
return face_landmarks_203
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
crop_tensor = input_tensor[:, :, 60: 224, 16: 205]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
return pitch_tensor, yaw_tensor
+3 -3
View File
@@ -30,11 +30,11 @@ class FaceSwapperTrainer(lightning.LightningModule):
def __init__(self) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
gazer_path = CONFIG.get('training.model', 'gazer_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.generator = Generator()
@@ -45,7 +45,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss(self.motion_extractor)
self.gaze_loss = GazeLoss(self.landmarker)
self.gaze_loss = GazeLoss(self.gazer)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
+2 -2
View File
@@ -7,13 +7,13 @@ Batch : TypeAlias = Tuple[Tensor, Tensor]
Attributes : TypeAlias = Tuple[Tensor, ...]
Embedding : TypeAlias = Tensor
FaceLandmark203 : TypeAlias = Tensor
Gaze : TypeAlias = Tuple[Tensor, Tensor]
Padding : TypeAlias = Tuple[int, int, int, int]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
LandmarkerModule : TypeAlias = Module
GazerModule : TypeAlias = Module
MotionExtractorModule : TypeAlias = Module
OptimizerConfig : TypeAlias = Any