mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Clean GazeLoss
This commit is contained in:
@@ -149,22 +149,17 @@ class GazeLoss(nn.Module):
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.gazer = gazer
|
||||
self.mae_loss = nn.L1Loss()
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(448),
|
||||
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
|
||||
])
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
|
||||
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
|
||||
output_pitch, output_yaw = self.detect_gaze(output_tensor)
|
||||
target_pitch, target_yaw = self.detect_gaze(target_tensor)
|
||||
|
||||
pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor)
|
||||
yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor)
|
||||
pitch_loss = self.l1_loss(output_pitch, target_pitch)
|
||||
yaw_loss = self.l1_loss(output_yaw, target_yaw)
|
||||
|
||||
gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5
|
||||
gaze_loss = (pitch_loss + yaw_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
@@ -175,5 +170,5 @@ class GazeLoss(nn.Module):
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
|
||||
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
|
||||
return pitch_tensor, yaw_tensor
|
||||
pitch, yaw = self.gazer(crop_tensor)
|
||||
return pitch, yaw
|
||||
|
||||
Reference in New Issue
Block a user