From 39818a16df492eb21c2d2ddec0e4039a4dfc056f Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 18 Feb 2025 11:05:58 +0100 Subject: [PATCH] Fix CI --- face_swapper/src/helper.py | 4 ++-- face_swapper/src/inferencing.py | 4 ++-- face_swapper/src/types.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 00faa66..6f14dae 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -5,7 +5,7 @@ import numpy import torch from torch import Tensor, nn -from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor +from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor def is_windows() -> bool: @@ -48,7 +48,7 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor: return fake_loss -def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding: +def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding: crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241] crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') crop_vision_tensor[:, :, :padding[0], :] = 0 diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 939cc37..c6e7f5f 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -5,13 +5,13 @@ import torch from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image from .models.generator import Generator -from .types import Embedder, Generator, VisionFrame +from .types import EmbedderModule, GeneratorModule, VisionFrame CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -def run_swap(generator : Generator, id_embedder : Embedder, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: +def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: source_vision_tensor = convert_to_vision_tensor(source_vision_frame) target_vision_tensor = convert_to_vision_tensor(target_vision_frame) source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0)) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 3e78bdd..bf2fbff 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -28,5 +28,5 @@ VisionTensor : TypeAlias = Tensor GeneratorLossSet : TypeAlias = Dict[str, Tensor] DiscriminatorLossSet : TypeAlias = Dict[str, Tensor] -Generator : TypeAlias = Module -Embedder : TypeAlias = Module +GeneratorModule : TypeAlias = Module +EmbedderModule : TypeAlias = Module