diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 97d4cfc..a51fada 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -32,6 +32,8 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P crop_tensor[:, :, 112 - padding[1]:, :] = 0 crop_tensor[:, :, :, :padding[2]] = 0 crop_tensor[:, :, :, 112 - padding[3]:] = 0 - embedding = embedder(crop_tensor) + + with torch.no_grad(): + embedding = embedder(crop_tensor) embedding = nn.functional.normalize(embedding, p = 2) return embedding diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 0ded689..71a20ea 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -139,7 +139,9 @@ class MotionLoss(nn.Module): def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: input_tensor = (input_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) + + with torch.no_grad(): + pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) rotation = torch.cat([ pitch, yaw, roll ], dim = 1) pose = translation, scale, rotation, motion_points return pose, expression @@ -170,5 +172,7 @@ class GazeLoss(nn.Module): crop_tensor = (crop_tensor + 1) * 0.5 crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic') - pitch, yaw = self.gazer(crop_tensor) + + with torch.no_grad(): + pitch, yaw = self.gazer(crop_tensor) return pitch, yaw