Restore map_location = 'cpu'

This commit is contained in:
henryruhs
2025-02-22 11:26:35 +01:00
parent ac41bab3a2
commit a6e1405c70
3 changed files with 6 additions and 6 deletions
+2 -2
View File
@@ -26,8 +26,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
target_path = CONFIG.get('training.model', 'target_path')
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(source_path) # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(target_path) # type:ignore[no-untyped-call]
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.source_embedder.eval()
self.target_embedder.eval()
self.mse_loss = nn.MSELoss()
+1 -1
View File
@@ -31,7 +31,7 @@ def infer() -> None:
generator = Generator()
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = cv2.imread(source_path)
+3 -3
View File
@@ -19,9 +19,9 @@ class FaceSwapperLoss:
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path) # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path) # type:ignore[no-untyped-call]
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()