From 564cc7b127df17d5120fb280b8561c003f649c79 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 08:38:49 +0100 Subject: [PATCH] Introduce Loss type, Remove Gaze type --- face_swapper/src/models/loss.py | 26 +++++++++++++------------- face_swapper/src/types.py | 4 ++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 4584535..8520a61 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule +from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule class DiscriminatorLoss(nn.Module): @@ -55,7 +55,7 @@ class AttributeLoss(nn.Module): self.config_batch_size = config_parser.getint('training.loader', 'batch_size') self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight') - def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]: + def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]: temp_tensors = [] for target_attribute, output_attribute in zip(target_attributes, output_attributes): @@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module): self.embedder = embedder self.mse_loss = nn.MSELoss() - def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: with torch.no_grad(): source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) @@ -97,7 +97,7 @@ class IdentityLoss(nn.Module): self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight') self.embedder = embedder - def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() @@ -113,14 +113,14 @@ class MotionLoss(nn.Module): self.motion_extractor = motion_extractor self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]: - target_poses, target_expression = self.get_motions(target_tensor) - output_poses, output_expression = self.get_motions(output_tensor) + def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]: + target_poses, target_expression = self.detect_motions(target_tensor) + output_poses, output_expression = self.detect_motions(output_tensor) pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses) expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression) return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss - def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]: + def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]: temp_tensors = [] for target_pose, output_pose in zip(target_poses, output_poses): @@ -131,12 +131,12 @@ class MotionLoss(nn.Module): weighted_pose_loss = pose_loss * self.config_pose_weight return pose_loss, weighted_pose_loss - def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]: + def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]: expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() weighted_expression_loss = expression_loss * self.config_expression_weight return expression_loss, weighted_expression_loss - def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: + def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: input_tensor = (input_tensor + 1) * 0.5 with torch.no_grad(): @@ -155,7 +155,7 @@ class GazeLoss(nn.Module): self.gazer = gazer self.l1_loss = nn.L1Loss() - def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]: output_pitch, output_yaw = self.detect_gaze(output_tensor) target_pitch, target_yaw = self.detect_gaze(target_tensor) @@ -166,7 +166,7 @@ class GazeLoss(nn.Module): weighted_gaze_loss = gaze_loss * self.config_gaze_weight return gaze_loss, weighted_gaze_loss - def detect_gaze(self, input_tensor : Tensor) -> Gaze: + def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]: crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int() crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]] crop_tensor = (crop_tensor + 1) * 0.5 @@ -186,7 +186,7 @@ class MaskLoss(nn.Module): self.face_parser = face_parser self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor: + def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss: target_mask = self.calc_mask(target_tensor) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 2b39c79..a794ee1 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -4,11 +4,11 @@ from torch import Tensor from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] -BatchMode = Literal['equal', 'same'] +BatchMode = Literal['equal', 'same', 'different'] Attribute : TypeAlias = Tensor Embedding : TypeAlias = Tensor -Gaze : TypeAlias = Tuple[Tensor, Tensor] +Loss : TypeAlias = Tensor Padding : TypeAlias = Tuple[int, int, int, int]