remove grad for inference only models

This commit is contained in:
harisreedhar
2025-03-06 15:24:53 +05:30
committed by henryruhs
parent 61f48d9246
commit a89e51c2f8
2 changed files with 9 additions and 3 deletions
+3 -1
View File
@@ -32,6 +32,8 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
embedding = embedder(crop_tensor)
with torch.no_grad():
embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
+6 -2
View File
@@ -139,7 +139,9 @@ class MotionLoss(nn.Module):
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
with torch.no_grad():
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
pose = translation, scale, rotation, motion_points
return pose, expression
@@ -170,5 +172,7 @@ class GazeLoss(nn.Module):
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
pitch, yaw = self.gazer(crop_tensor)
with torch.no_grad():
pitch, yaw = self.gazer(crop_tensor)
return pitch, yaw