mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calc_embedding
|
||||
from .models.generator import Generator
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
@@ -17,10 +17,7 @@ def infer() -> None:
|
||||
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
|
||||
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
|
||||
state_dict = torch.load(config_generator_path).get('state_dict').get('generator')
|
||||
generator = Generator(CONFIG_PARSER)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
|
||||
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
|
||||
Reference in New Issue
Block a user