This commit is contained in:
harisreedhar
2025-03-04 14:23:20 +05:30
committed by henryruhs
parent ceb3c0cfdf
commit f2d3f8a19f
5 changed files with 45 additions and 19 deletions
+37 -13
View File
@@ -106,32 +106,56 @@ class IdentityLoss(nn.Module):
return identity_loss, weighted_identity_loss
class PoseLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
class MotionLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule):
super().__init__()
self.motion_extractor = motion_extractor
self.pose_loss = PoseLoss()
self.expression_loss = ExpressionLoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
target_pose_features, target_expression = self.get_motion_features(target_tensor)
output_pose_features, output_expression = self.get_motion_features(output_tensor)
pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features)
expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression)
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
pose_features = translation, scale, rotation, motion_points
return pose_features, expression
class ExpressionLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]:
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * expression_weight
return expression_loss, weighted_expression_loss
class PoseLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
output_motion_features = self.get_motion_features(output_tensor)
target_motion_features = self.get_motion_features(target_tensor)
temp_tensors = []
for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features):
temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature)
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * pose_weight
return pose_loss, weighted_pose_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class GazeLoss(nn.Module):
def __init__(self, gazer : GazerModule) -> None: