This commit is contained in:
harisreedhar
2025-03-28 16:32:44 +05:30
parent f99c73495c
commit 4b851a173d
3 changed files with 24 additions and 0 deletions
+1
View File
@@ -112,6 +112,7 @@ target_path = .exports/face_swapper.onnx
target_size = 256
ir_version = 10
opset_version = 15
precision = fp16
```
```
+1
View File
@@ -66,6 +66,7 @@ target_path =
target_size =
ir_version =
opset_version =
precision =
[inferencing]
generator_path =
+22
View File
@@ -1,14 +1,31 @@
import os
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from .training import FaceSwapperTrainer
from .types import Embedding, Mask, Module
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class HalfPrecisionModel(nn.Module):
def __init__(self, model : Module) -> None:
super().__init__()
self.model = model.half()
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
source_embedding = source_embedding.half()
target_tensor = target_tensor.half()
output_tensor, output_mask = self.model(source_embedding, target_tensor)
output_tensor = output_tensor.float()
output_mask = output_mask.float()
return output_tensor, output_mask
def export() -> None:
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
@@ -16,9 +33,14 @@ def export() -> None:
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
config_precision = CONFIG_PARSER.get('exporting', 'precision')
os.makedirs(config_directory_path, exist_ok = True)
model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
if config_precision == 'fp16':
model = HalfPrecisionModel(model).eval()
model.ir_version = torch.tensor(config_ir_version)
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)