diff --git a/embedding_converter/src/exporting.py b/embedding_converter/src/exporting.py index 4ede026..9102c71 100644 --- a/embedding_converter/src/exporting.py +++ b/embedding_converter/src/exporting.py @@ -20,5 +20,5 @@ def export() -> None: model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu') model.eval() model.ir_version = torch.tensor(ir_version) - input_tensor = (torch.randn(1, 512), ) + input_tensor = torch.randn(1, 512) torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 6631fe1..c0c74e1 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -3,7 +3,7 @@ from os import makedirs import torch -from .models.generator import Generator +from .training import FaceSwapperTrainer CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -17,9 +17,7 @@ def export() -> None: opset_version = CONFIG.getint('exporting', 'opset_version') makedirs(directory_path, exist_ok = True) - state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator') - model = Generator() - model.load_state_dict(state_dict) + model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu') model.eval() model.ir_version = torch.tensor(ir_version) source_tensor = torch.randn(1, 512)