From e9ea9cd9e57fc0239d56de52b828b40b668ff175 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 5 Mar 2025 19:19:09 +0100 Subject: [PATCH] Clean GazeLoss --- face_swapper/src/models/loss.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index bf6ba01..9c8e192 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -149,22 +149,17 @@ class GazeLoss(nn.Module): def __init__(self, gazer : GazerModule) -> None: super().__init__() self.gazer = gazer - self.mae_loss = nn.L1Loss() - self.transform = transforms.Compose( - [ - transforms.Resize(448), - transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]) - ]) + self.l1_loss = nn.L1Loss() def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') - output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor) - target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor) + output_pitch, output_yaw = self.detect_gaze(output_tensor) + target_pitch, target_yaw = self.detect_gaze(target_tensor) - pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor) - yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor) + pitch_loss = self.l1_loss(output_pitch, target_pitch) + yaw_loss = self.l1_loss(output_yaw, target_yaw) - gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5 + gaze_loss = (pitch_loss + yaw_loss) * 0.5 weighted_gaze_loss = gaze_loss * gaze_weight return gaze_loss, weighted_gaze_loss @@ -175,5 +170,5 @@ class GazeLoss(nn.Module): crop_tensor = (crop_tensor + 1) * 0.5 crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic') - pitch_tensor, yaw_tensor = self.gazer(crop_tensor) - return pitch_tensor, yaw_tensor + pitch, yaw = self.gazer(crop_tensor) + return pitch, yaw