Remove map_location

This commit is contained in:
henryruhs
2025-02-22 09:40:47 +01:00
parent 575f215408
commit 83ef075b1d
5 changed files with 10 additions and 10 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path)
model.eval()
model.ir_version = torch.tensor(ir_version)
input_tensor = (torch.randn(1, 512), )
+3 -3
View File
@@ -62,10 +62,10 @@ kernel_size = 4
```
[training.losses]
weight_adversarial = 1
weight_identity = 20
weight_adversarial = 1.5
weight_identity = 15
weight_attribute = 10
weight_reconstruction = 10
weight_reconstruction = 15
weight_pose = 0
weight_gaze = 0
```
+1 -1
View File
@@ -17,7 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
state_dict = torch.load(source_path).get('state_dict').get('generator')
model = Generator()
model.load_state_dict(state_dict)
model.eval()
+2 -2
View File
@@ -27,11 +27,11 @@ def infer() -> None:
target_path = CONFIG.get('inferencing', 'target_path')
output_path = CONFIG.get('inferencing', 'output_path')
state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator')
state_dict = torch.load(generator_path).get('state_dict').get('generator')
generator = Generator()
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = cv2.imread(source_path)
+3 -3
View File
@@ -19,9 +19,9 @@ class FaceSwapperLoss:
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder = torch.jit.load(id_embedder_path) # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path) # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path) # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()