mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Remove map_location
This commit is contained in:
@@ -17,7 +17,7 @@ def export() -> None:
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path)
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
input_tensor = (torch.randn(1, 512), )
|
||||
|
||||
@@ -62,10 +62,10 @@ kernel_size = 4
|
||||
|
||||
```
|
||||
[training.losses]
|
||||
weight_adversarial = 1
|
||||
weight_identity = 20
|
||||
weight_adversarial = 1.5
|
||||
weight_identity = 15
|
||||
weight_attribute = 10
|
||||
weight_reconstruction = 10
|
||||
weight_reconstruction = 15
|
||||
weight_pose = 0
|
||||
weight_gaze = 0
|
||||
```
|
||||
|
||||
@@ -17,7 +17,7 @@ def export() -> None:
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
|
||||
state_dict = torch.load(source_path).get('state_dict').get('generator')
|
||||
model = Generator()
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
@@ -27,11 +27,11 @@ def infer() -> None:
|
||||
target_path = CONFIG.get('inferencing', 'target_path')
|
||||
output_path = CONFIG.get('inferencing', 'output_path')
|
||||
|
||||
state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator')
|
||||
state_dict = torch.load(generator_path).get('state_dict').get('generator')
|
||||
generator = Generator()
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
|
||||
id_embedder.eval()
|
||||
|
||||
source_vision_frame = cv2.imread(source_path)
|
||||
|
||||
@@ -19,9 +19,9 @@ class FaceSwapperLoss:
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path) # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path) # type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
Reference in New Issue
Block a user