diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 2d38f53..e8dd59f 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -30,9 +30,9 @@ class FaceSwapperTrainer(lightning.LightningModule): landmarker_path = CONFIG.get('training.model', 'landmarker_path') motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') - self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] self.generator = Generator() self.discriminator = Discriminator()