mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -110,52 +110,40 @@ class MotionLoss(nn.Module):
|
||||
def __init__(self, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.motion_extractor = motion_extractor
|
||||
self.pose_loss = PoseLoss()
|
||||
self.expression_loss = ExpressionLoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
target_pose_features, target_expression = self.get_motion_features(target_tensor)
|
||||
output_pose_features, output_expression = self.get_motion_features(output_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features)
|
||||
expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression)
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
pose_features = translation, scale, rotation, motion_points
|
||||
return pose_features, expression
|
||||
|
||||
|
||||
class ExpressionLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
|
||||
class PoseLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]:
|
||||
target_poses, target_expression = self.get_motions(target_tensor)
|
||||
output_poses, output_expression = self.get_motions(output_tensor)
|
||||
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
|
||||
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
|
||||
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features):
|
||||
temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature)
|
||||
for target_pose, output_pose in zip(target_poses, output_poses):
|
||||
temp_tensor = self.mse_loss(target_pose, output_pose)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
pose = translation, scale, rotation, motion_points
|
||||
return pose, expression
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
@@ -168,7 +156,7 @@ class GazeLoss(nn.Module):
|
||||
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
|
||||
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
|
||||
|
||||
Reference in New Issue
Block a user