From 6bc44ad3d8f15da3b3d9b8fb95a366bf04c7ae9b Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 6 Mar 2025 18:32:20 +0100 Subject: [PATCH] Fix CI --- face_swapper/src/dataset.py | 2 +- face_swapper/src/inferencing.py | 2 +- face_swapper/src/models/discriminator.py | 1 - face_swapper/src/training.py | 6 +++--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 99232be..ff438b8 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import Dataset from torchvision import io, transforms from .helper import warp_tensor -from .types import Batch, WarpTemplate, BatchMode +from .types import Batch, BatchMode, WarpTemplate class DynamicDataset(Dataset[Tensor]): diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 3f718b3..b20535b 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -24,7 +24,7 @@ def infer() -> None: generator = Generator(CONFIG_PARSER) generator.load_state_dict(state_dict) generator.eval() - embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + embedder = torch.jit.load(config.get('embedder_path'), map_location = 'cpu') # type:ignore[no-untyped-call] embedder.eval() source_tensor = io.read_image(config.get('source_path')) diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index 4ccb88f..f60b07f 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -17,7 +17,6 @@ class Discriminator(nn.Module): self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) self.discriminators = self.create_discriminators() - def create_discriminators(self) -> nn.ModuleList: discriminators = nn.ModuleList() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index c9ffde4..f08d65d 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,7 +1,7 @@ import os -from configparser import ConfigParser import warnings -from typing import Tuple, cast +from configparser import ConfigParser +from typing import Tuple import torch import torchvision @@ -17,7 +17,7 @@ from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss -from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate +from .types import Batch, Embedding, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')