mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix export using Trainer
This commit is contained in:
@@ -20,5 +20,5 @@ def export() -> None:
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
input_tensor = (torch.randn(1, 512), )
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
|
||||
@@ -3,7 +3,7 @@ from os import makedirs
|
||||
|
||||
import torch
|
||||
|
||||
from .models.generator import Generator
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -17,9 +17,7 @@ def export() -> None:
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
|
||||
model = Generator()
|
||||
model.load_state_dict(state_dict)
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
source_tensor = torch.randn(1, 512)
|
||||
|
||||
Reference in New Issue
Block a user