Introduce Loss type, Remove Gaze type

This commit is contained in:
henryruhs
2025-03-12 08:38:49 +01:00
parent 944096befc
commit 564cc7b127
2 changed files with 15 additions and 15 deletions
+13 -13
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule
class DiscriminatorLoss(nn.Module):
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]:
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
@@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module):
self.embedder = embedder
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
with torch.no_grad():
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
@@ -97,7 +97,7 @@ class IdentityLoss(nn.Module):
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
@@ -113,14 +113,14 @@ class MotionLoss(nn.Module):
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]:
target_poses, target_expression = self.get_motions(target_tensor)
output_poses, output_expression = self.get_motions(output_tensor)
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]:
target_poses, target_expression = self.detect_motions(target_tensor)
output_poses, output_expression = self.detect_motions(output_tensor)
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_pose, output_pose in zip(target_poses, output_poses):
@@ -131,12 +131,12 @@ class MotionLoss(nn.Module):
weighted_pose_loss = pose_loss * self.config_pose_weight
return pose_loss, weighted_pose_loss
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]:
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * self.config_expression_weight
return expression_loss, weighted_expression_loss
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
with torch.no_grad():
@@ -155,7 +155,7 @@ class GazeLoss(nn.Module):
self.gazer = gazer
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_pitch, output_yaw = self.detect_gaze(output_tensor)
target_pitch, target_yaw = self.detect_gaze(target_tensor)
@@ -166,7 +166,7 @@ class GazeLoss(nn.Module):
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
crop_tensor = (crop_tensor + 1) * 0.5
@@ -186,7 +186,7 @@ class MaskLoss(nn.Module):
self.face_parser = face_parser
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss:
target_mask = self.calc_mask(target_tensor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size)
+2 -2
View File
@@ -4,11 +4,11 @@ from torch import Tensor
from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same']
BatchMode = Literal['equal', 'same', 'different']
Attribute : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Gaze : TypeAlias = Tuple[Tensor, Tensor]
Loss : TypeAlias = Tensor
Padding : TypeAlias = Tuple[int, int, int, int]