mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Enforce IR version for older onnxruntime
This commit is contained in:
@@ -69,6 +69,7 @@ file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/arcface_converter_simswap.onnx
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
```
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ file_pattern =
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
ir_version =
|
||||
opset_version =
|
||||
|
||||
[execution]
|
||||
|
||||
@@ -13,10 +13,12 @@ def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
ir_version = CONFIG.getint('exporting', 'ir_version')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
embedding_converter_trainer.eval()
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = ir_version
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(embedding_converter_trainer, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
|
||||
@@ -92,6 +92,7 @@ validation_frequency = 1000
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/face_swapper.onnx
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
```
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ validation_frequency =
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
ir_version =
|
||||
opset_version =
|
||||
|
||||
[inferencing]
|
||||
|
||||
@@ -13,6 +13,7 @@ def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
ir_version = CONFIG.getint('exporting', 'ir_version')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
@@ -20,6 +21,7 @@ def export() -> None:
|
||||
model = AdaptiveEmbeddingIntegrationNetwork()
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.ir_version = ir_version
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, 256, 256)
|
||||
torch.onnx.export(model, (target_tensor, source_tensor), target_path, input_names = [ 'target', 'source' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
|
||||
Reference in New Issue
Block a user