Clean GazeLoss

This commit is contained in:
henryruhs
2025-03-05 19:19:09 +01:00
parent abdc770892
commit e9ea9cd9e5
+8 -13
View File
@@ -149,22 +149,17 @@ class GazeLoss(nn.Module):
def __init__(self, gazer : GazerModule) -> None:
super().__init__()
self.gazer = gazer
self.mae_loss = nn.L1Loss()
self.transform = transforms.Compose(
[
transforms.Resize(448),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
])
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
output_pitch, output_yaw = self.detect_gaze(output_tensor)
target_pitch, target_yaw = self.detect_gaze(target_tensor)
pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor)
yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor)
pitch_loss = self.l1_loss(output_pitch, target_pitch)
yaw_loss = self.l1_loss(output_yaw, target_yaw)
gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5
gaze_loss = (pitch_loss + yaw_loss) * 0.5
weighted_gaze_loss = gaze_loss * gaze_weight
return gaze_loss, weighted_gaze_loss
@@ -175,5 +170,5 @@ class GazeLoss(nn.Module):
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
return pitch_tensor, yaw_tensor
pitch, yaw = self.gazer(crop_tensor)
return pitch, yaw