From bf696be097d13ef3f63507b282bd201215f497bf Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 18 Feb 2025 11:00:30 +0100 Subject: [PATCH] Improve generator namings, Flip args to source then target --- face_swapper/src/exporting.py | 4 ++-- face_swapper/src/inferencing.py | 4 ++-- face_swapper/src/models/generator.py | 28 ++++++++++++++-------------- face_swapper/src/training.py | 9 +++++---- face_swapper/src/types.py | 1 + 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index d4c12ec..3b0892d 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -3,7 +3,7 @@ from os import makedirs import torch -from .models.generator import AdaptiveEmbeddingIntegrationNetwork +from .models.generator import Generator CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -18,7 +18,7 @@ def export() -> None: makedirs(directory_path, exist_ok = True) state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator') - model = AdaptiveEmbeddingIntegrationNetwork() + model = Generator() model.load_state_dict(state_dict) model.eval() model.ir_version = ir_version diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index d915992..939cc37 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -4,7 +4,7 @@ import cv2 import torch from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image -from .models.generator import AdaptiveEmbeddingIntegrationNetwork +from .models.generator import Generator from .types import Embedder, Generator, VisionFrame CONFIG = configparser.ConfigParser() @@ -28,7 +28,7 @@ def infer() -> None: output_path = CONFIG.get('inferencing', 'output_path') state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator') - generator = AdaptiveEmbeddingIntegrationNetwork() + generator = Generator() generator.load_state_dict(state_dict) generator.eval() id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 33d9c80..6308f44 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,34 +1,34 @@ import configparser from typing import Tuple -from torch import nn +from torch import Tensor, nn from ..networks.attribute_modulator import AADGenerator from ..networks.unet import UNet -from ..types import Embedding, TargetAttributes, VisionTensor +from ..types import Embedding, Attributes CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class AdaptiveEmbeddingIntegrationNetwork(nn.Module): +class Generator(nn.Module): def __init__(self) -> None: - super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() + super(Generator, self).__init__() id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - self.encoder = UNet() - self.generator = AADGenerator(id_channels, num_blocks) - self.encoder.apply(init_weight) - self.generator.apply(init_weight) + self.attribute_encoder = UNet() + self.attribute_generator = AADGenerator(id_channels, num_blocks) + self.attribute_encoder.apply(init_weight) + self.attribute_generator.apply(init_weight) - def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: - target_attributes = self.get_attributes(target) - swap_tensor = self.generator(target_attributes, source_embedding) - return swap_tensor, target_attributes + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: + target_attributes = self.get_attributes(target_tensor) + output_tensor = self.attribute_generator(target_attributes, source_embedding) + return output_tensor - def get_attributes(self, target : VisionTensor) -> TargetAttributes: - return self.encoder(target) + def get_attributes(self, input_tensor : Tensor) -> Attributes: + return self.attribute_encoder(input_tensor) def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 1b2852f..d7c14fd 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from .data_loader import DataLoaderVGG from .helper import calc_id_embedding from .models.discriminator import Discriminator -from .models.generator import AdaptiveEmbeddingIntegrationNetwork +from .models.generator import Generator from .models.loss import FaceSwapperLoss from .types import Batch, Embedding, TargetAttributes, VisionTensor @@ -27,12 +27,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() FaceSwapperLoss.__init__(self) - self.generator = AdaptiveEmbeddingIntegrationNetwork() + self.generator = Generator() self.discriminator = Discriminator() self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: - output = self.generator(target_tensor, source_embedding) + output = self.generator(source_embedding, target_tensor) return output def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: @@ -45,7 +45,8 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): source_tensor, target_tensor, is_same_person = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) - swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) + swap_tensor = self.generator(source_embedding, target_tensor) + target_attributes = self.generator.get_attributes(target_tensor) swap_attributes = self.generator.get_attributes(swap_tensor) fake_discriminator_outputs = self.discriminator(swap_tensor) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 3a7bbed..3e78bdd 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -14,6 +14,7 @@ SwapAttributes : TypeAlias = Tuple[Tensor, ...] TargetAttributes : TypeAlias = Tuple[Tensor, ...] DiscriminatorOutputs : TypeAlias = List[List[Tensor]] +Attributes : TypeAlias = Tuple[Tensor, ...] Embedding : TypeAlias = Tensor FaceLandmark203 : TypeAlias = Tensor