From 4b851a173de5939a6bc3d00f1e2fdab6b8ce44c8 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 28 Mar 2025 16:32:44 +0530 Subject: [PATCH] changes --- face_swapper/README.md | 1 + face_swapper/config.ini | 1 + face_swapper/src/exporting.py | 22 ++++++++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/face_swapper/README.md b/face_swapper/README.md index 7fd5f9e..2ef47cb 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -112,6 +112,7 @@ target_path = .exports/face_swapper.onnx target_size = 256 ir_version = 10 opset_version = 15 +precision = fp16 ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 7377124..dade0ca 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -66,6 +66,7 @@ target_path = target_size = ir_version = opset_version = +precision = [inferencing] generator_path = diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 4b2f4f6..cdcbd29 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -1,14 +1,31 @@ import os from configparser import ConfigParser +from typing import Tuple import torch +from torch import Tensor, nn from .training import FaceSwapperTrainer +from .types import Embedding, Mask, Module CONFIG_PARSER = ConfigParser() CONFIG_PARSER.read('config.ini') +class HalfPrecisionModel(nn.Module): + def __init__(self, model : Module) -> None: + super().__init__() + self.model = model.half() + + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: + source_embedding = source_embedding.half() + target_tensor = target_tensor.half() + output_tensor, output_mask = self.model(source_embedding, target_tensor) + output_tensor = output_tensor.float() + output_mask = output_mask.float() + return output_tensor, output_mask + + def export() -> None: config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path') config_source_path = CONFIG_PARSER.get('exporting', 'source_path') @@ -16,9 +33,14 @@ def export() -> None: config_target_size = CONFIG_PARSER.getint('exporting', 'target_size') config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version') config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version') + config_precision = CONFIG_PARSER.get('exporting', 'precision') os.makedirs(config_directory_path, exist_ok = True) model = FaceSwapperTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() + + if config_precision == 'fp16': + model = HalfPrecisionModel(model).eval() + model.ir_version = torch.tensor(config_ir_version) source_tensor = torch.randn(1, 512) target_tensor = torch.randn(1, 3, config_target_size, config_target_size)