From a6e1405c70d4271e14f3fbeaa5b0b65f665ad6fa Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 22 Feb 2025 11:26:35 +0100 Subject: [PATCH] Restore map_location = 'cpu' --- embedding_converter/src/training.py | 4 ++-- face_swapper/src/inferencing.py | 2 +- face_swapper/src/models/loss.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index b323fe2..632ba1c 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -26,8 +26,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule): target_path = CONFIG.get('training.model', 'target_path') self.embedding_converter = EmbeddingConverter() - self.source_embedder = torch.jit.load(source_path) # type:ignore[no-untyped-call] - self.target_embedder = torch.jit.load(target_path) # type:ignore[no-untyped-call] + self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.source_embedder.eval() self.target_embedder.eval() self.mse_loss = nn.MSELoss() diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index c77af45..658ec69 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -31,7 +31,7 @@ def infer() -> None: generator = Generator() generator.load_state_dict(state_dict) generator.eval() - id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call] + id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] id_embedder.eval() source_vision_frame = cv2.imread(source_path) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 2d0e281..093cc43 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -19,9 +19,9 @@ class FaceSwapperLoss: motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.batch_size = CONFIG.getint('training.loader', 'batch_size') self.mse_loss = nn.MSELoss() - self.id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path) # type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path) # type:ignore[no-untyped-call] + self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.id_embedder.eval() self.landmarker.eval() self.motion_extractor.eval()