diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 093e260..98f1941 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -110,52 +110,40 @@ class MotionLoss(nn.Module): def __init__(self, motion_extractor : MotionExtractorModule): super().__init__() self.motion_extractor = motion_extractor - self.pose_loss = PoseLoss() - self.expression_loss = ExpressionLoss() - - def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - target_pose_features, target_expression = self.get_motion_features(target_tensor) - output_pose_features, output_expression = self.get_motion_features(output_tensor) - pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features) - expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression) - return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss - - def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]: - input_tensor = (input_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - pose_features = translation, scale, rotation, motion_points - return pose_features, expression - - -class ExpressionLoss(nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]: - expression_weight = CONFIG.getfloat('training.losses', 'expression_weight') - expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() - weighted_expression_loss = expression_loss * expression_weight - return expression_loss, weighted_expression_loss - - -class PoseLoss(nn.Module): - def __init__(self) -> None: - super().__init__() self.mse_loss = nn.MSELoss() - def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]: + target_poses, target_expression = self.get_motions(target_tensor) + output_poses, output_expression = self.get_motions(output_tensor) + pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses) + expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression) + return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss + + def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]: pose_weight = CONFIG.getfloat('training.losses', 'pose_weight') temp_tensors = [] - for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features): - temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature) + for target_pose, output_pose in zip(target_poses, output_poses): + temp_tensor = self.mse_loss(target_pose, output_pose) temp_tensors.append(temp_tensor) pose_loss = torch.stack(temp_tensors).mean() weighted_pose_loss = pose_loss * pose_weight return pose_loss, weighted_pose_loss + def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]: + expression_weight = CONFIG.getfloat('training.losses', 'expression_weight') + expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() + weighted_expression_loss = expression_loss * expression_weight + return expression_loss, weighted_expression_loss + + def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: + input_tensor = (input_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) + pose = translation, scale, rotation, motion_points + return pose, expression + class GazeLoss(nn.Module): def __init__(self, gazer : GazerModule) -> None: @@ -168,7 +156,7 @@ class GazeLoss(nn.Module): transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]) ]) - def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor) target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)