Restore map_location = 'cpu'

This commit is contained in:
henryruhs
2025-02-22 11:21:34 +01:00
parent 206a1411d1
commit ac41bab3a2
2 changed files with 2 additions and 2 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path)
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
input_tensor = (torch.randn(1, 512), )
+1 -1
View File
@@ -17,7 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path).get('state_dict').get('generator')
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = Generator()
model.load_state_dict(state_dict)
model.eval()