This commit is contained in:
harisreedhar
2025-03-04 20:07:41 +05:30
committed by henryruhs
parent 18a605e1a3
commit 176dced1f6
6 changed files with 19 additions and 5 deletions
+7 -1
View File
@@ -169,7 +169,13 @@ class GazeLoss(nn.Module):
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
crop_tensor = input_tensor[:, :, 60: 224, 16: 205]
resolution = CONFIG.getint('training.dataset', 'resolution')
scale_factor = resolution / 256
y_min = int(60 * scale_factor)
y_max = int(224 * scale_factor)
x_min = int(16 * scale_factor)
x_max = int(205 * scale_factor)
crop_tensor = input_tensor[:, :, y_min: y_max, x_min: x_max]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')