mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Restore map_location = 'cpu'
This commit is contained in:
@@ -26,8 +26,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
target_path = CONFIG.get('training.model', 'target_path')
|
||||
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(source_path) # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(target_path) # type:ignore[no-untyped-call]
|
||||
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.source_embedder.eval()
|
||||
self.target_embedder.eval()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -31,7 +31,7 @@ def infer() -> None:
|
||||
generator = Generator()
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
id_embedder.eval()
|
||||
|
||||
source_vision_frame = cv2.imread(source_path)
|
||||
|
||||
@@ -19,9 +19,9 @@ class FaceSwapperLoss:
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path) # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path) # type:ignore[no-untyped-call]
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
Reference in New Issue
Block a user