This commit is contained in:
harisreedhar
2025-03-05 22:21:26 +05:30
committed by henryruhs
parent 6388727262
commit 7cf5609c1f
3 changed files with 3 additions and 9 deletions
-1
View File
@@ -72,7 +72,6 @@ attribute_weight = 10.0
reconstruction_weight = 20.0
identity_weight = 20.0
gaze_weight = 0.0
gaze_scale_factor = 1.0
pose_weight = 0.0
expression_weight = 0.0
```
-1
View File
@@ -34,7 +34,6 @@ attribute_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
gaze_scale_factor =
pose_weight =
expression_weight =
+3 -7
View File
@@ -169,13 +169,9 @@ class GazeLoss(nn.Module):
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
scale_factor = CONFIG.getfloat('training.losses', 'gaze_scale_factor')
y_min = int(60 * scale_factor)
y_max = int(224 * scale_factor)
x_min = int(16 * scale_factor)
x_max = int(205 * scale_factor)
crop_tensor = input_tensor[:, :, y_min:y_max, x_min:x_max]
transform_size = CONFIG.getint('training.dataset', 'transform_size')
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * transform_size).int()
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')