Fix export using Trainer

This commit is contained in:
henryruhs
2025-02-26 16:50:40 +01:00
parent cadbe9cf76
commit 84b4451366
2 changed files with 3 additions and 5 deletions
+1 -1
View File
@@ -20,5 +20,5 @@ def export() -> None:
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
input_tensor = (torch.randn(1, 512), )
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
+2 -4
View File
@@ -3,7 +3,7 @@ from os import makedirs
import torch
from .models.generator import Generator
from .training import FaceSwapperTrainer
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -17,9 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = Generator()
model.load_state_dict(state_dict)
model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
source_tensor = torch.randn(1, 512)