From 99a8527e247bbdd5042fc7ecbe03133ab31a532b Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:48:45 +0530 Subject: [PATCH] changes --- face_swapper/src/exporting.py | 1 - face_swapper/src/inferencing.py | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 31fc078..4b2f4f6 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -1,7 +1,6 @@ import os from configparser import ConfigParser - import torch from .training import FaceSwapperTrainer diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 2643195..6f81b2d 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -4,7 +4,7 @@ import torch from torchvision import io from .helper import calc_embedding -from .models.generator import Generator +from .training import FaceSwapperTrainer CONFIG_PARSER = configparser.ConfigParser() CONFIG_PARSER.read('config.ini') @@ -17,10 +17,7 @@ def infer() -> None: config_target_path = CONFIG_PARSER.get('inferencing', 'target_path') config_output_path = CONFIG_PARSER.get('inferencing', 'output_path') - state_dict = torch.load(config_generator_path).get('state_dict').get('generator') - generator = Generator(CONFIG_PARSER) - generator.load_state_dict(state_dict) - generator.eval() + generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval() source_tensor = io.read_image(config_source_path)