From 008a221f5514e494b330f1cbc26cd64aba9e1bff Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 17 Jan 2025 16:09:16 +0530 Subject: [PATCH] cleaning --- face_swapper/config.ini | 7 +- face_swapper/src/data_loader.py | 16 +- face_swapper/src/discriminator.py | 44 ++-- face_swapper/src/generator.py | 97 ++++--- face_swapper/src/helper.py | 28 +- face_swapper/src/training.py | 424 ++++++++++++++---------------- face_swapper/src/typing.py | 13 +- 7 files changed, 288 insertions(+), 341 deletions(-) diff --git a/face_swapper/config.ini b/face_swapper/config.ini index b94f925..c20f445 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -24,7 +24,7 @@ num_discriminators = 3 learning_rate = 0.0004 disable = false -[auxiliary_models.paths] +[auxiliary_models.paths] #just model.trainer ... try to match the config of the arcface converter arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt landmarker_path = /assets/pretrained_models/landmark_203.pt motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth @@ -34,13 +34,10 @@ spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pt [training.losses] weight_adversarial = 1 -weight_identity = 20 +weight_id = 20 weight_attribute = 10 weight_reconstruction = 10 weight_tsr = 100 -weight_eye_gaze = 5 -weight_eye_open = 5 -weight_lip_open = 5 [training.schedulers] step = 5000 diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index 34c7dca..cceae06 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -16,20 +16,20 @@ CONFIG.read('config.ini') def read_image(image_path: str) -> Image.Image: image = cv2.imread(image_path)[:, :, ::-1] - pil_image = Image.fromarray(image) + pil_image = Image.fromarray(image) # @todo like said, use the PIL transformator return pil_image class DataLoaderVGG(TensorDataset): def __init__(self, dataset_path : str) -> None: - self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability')) - self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path)) + self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability')) # @todo use CONFIG.getfloat() - also config block at the top + self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path)) # @todo globs belong to the config self.folder_paths = glob.glob('{}/*'.format(dataset_path)) - self.image_path_dict = {} + self.image_path_dict = {} # @todo we are not using dict as suffix... this image_path_set? self._current_index = 0 for folder_path in tqdm.tqdm(self.folder_paths): - image_paths = glob.glob('{}/*'.format(folder_path)) + image_paths = glob.glob('{}/*'.format(folder_path)) # @todo not sure about alls this globs being used here :-) self.image_path_dict[folder_path] = image_paths self.dataset_total = len(self.image_paths) self.transforms_basic = transforms.Compose( @@ -61,15 +61,15 @@ class DataLoaderVGG(TensorDataset): source_image_path = self.image_paths[item] source = read_image(source_image_path) - if random.random() > self.same_person_probability: + if random.random() > self.same_person_probability: # @todo if -> we_call_a_method_that_explains_what_we_do() is_same_person = 0 target_image_path = random.choice(self.image_paths) target = read_image(target_image_path) source_transform = self.transforms_moderate(source) target_transform = self.transforms_complex(target) - else: + else: # @todo else -> we_do_some_alternative_action() - in other words, move it to speaking methods :-) is_same_person = 1 - source_folder_path = '/'.join(source_image_path.split('/')[:-1]) + source_folder_path = '/'.join(source_image_path.split('/')[:-1]) # @todo use os.path.join() target_image_path = random.choice(self.image_path_dict[source_folder_path]) target = read_image(target_image_path) source_transform = self.transforms_basic(source) diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/discriminator.py index e328b69..114f377 100644 --- a/face_swapper/src/discriminator.py +++ b/face_swapper/src/discriminator.py @@ -1,9 +1,8 @@ -from typing import List - import numpy import torch.nn as nn +from torch import Tensor -from .typing import DiscriminatorOutputs, Tensor +from .typing import DiscriminatorOutputs class NLayerDiscriminator(nn.Module): @@ -12,37 +11,45 @@ class NLayerDiscriminator(nn.Module): self.num_layers = num_layers kernel_size = 4 padding_size = int(numpy.ceil((kernel_size - 1.0) / 2)) - model_layers = [ + model_layers =\ + [ [ nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), nn.LeakyReLU(0.2, True) - ]] + ] + ] current_filters = num_filters for layer_index in range(1, num_layers): previous_filters = current_filters current_filters = min(current_filters * 2, 512) - model_layers += [ + model_layers +=\ + [ [ nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) - ]] + ] + ] previous_filters = current_filters current_filters = min(current_filters * 2, 512) - model_layers += [ + model_layers +=\ + [ [ nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) - ]] - model_layers += [ + ] + ] + model_layers +=\ + [ [ nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size) - ]] + ] + ] combined_layers = [] - for layer in model_layers: - combined_layers += layer + for model_layer in model_layers: + combined_layers += model_layer self.model = nn.Sequential(*combined_layers) def forward(self, input_tensor : Tensor) -> Tensor: @@ -60,17 +67,14 @@ class MultiscaleDiscriminator(nn.Module): setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] - def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]: - return [ model_layers(input_tensor) ] - def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: discriminator_outputs = [] - downsampled_input = input_tensor + temp_downsampled_input = input_tensor for discriminator_index in range(self.num_discriminators): model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) - discriminator_outputs.append(self.single_discriminator_forward(model_layers, downsampled_input)) + discriminator_outputs.append([ model_layers(temp_downsampled_input) ]) - if discriminator_index != (self.num_discriminators - 1): - downsampled_input = self.downsample(downsampled_input) + if discriminator_index < (self.num_discriminators - 1): + temp_downsampled_input = self.downsample(temp_downsampled_input) return discriminator_outputs diff --git a/face_swapper/src/generator.py b/face_swapper/src/generator.py index 2f92bc3..fca082b 100644 --- a/face_swapper/src/generator.py +++ b/face_swapper/src/generator.py @@ -2,8 +2,9 @@ from typing import Tuple import torch import torch.nn as nn +from torch import Tensor -from .typing import IDEmbedding, TargetAttributes, Tensor +from .typing import SourceEmbedding, TargetAttributes, VisionTensor class AdaptiveEmbeddingIntegrationNetwork(nn.Module): @@ -12,12 +13,12 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module): self.encoder = UNet() self.generator = AADGenerator(id_channels, num_blocks) - def forward(self, target : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]: + def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: target_attributes = self.get_attributes(target) - swap = self.generator(target_attributes, source_embedding) - return swap, target_attributes + swap_tensor = self.generator(target_attributes, source_embedding) + return swap_tensor, target_attributes - def get_attributes(self, target : Tensor) -> TargetAttributes: + def get_attributes(self, target : VisionTensor) -> TargetAttributes: return self.encoder(target) @@ -35,7 +36,7 @@ class AADGenerator(nn.Module): self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks) self.apply(initialize_weight) - def forward(self, target_attributes : TargetAttributes, source_embedding : IDEmbedding) -> Tensor: + def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor: feature_map = self.upsample(source_embedding) feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) @@ -66,7 +67,7 @@ class UNet(nn.Module): self.upsampler_6 = Upsample(128, 32) self.apply(initialize_weight) - def forward(self, target : Tensor) -> TargetAttributes: + def forward(self, target : VisionTensor) -> TargetAttributes: downsample_feature_1 = self.downsampler_1(target) downsample_feature_2 = self.downsampler_2(downsample_feature_1) downsample_feature_3 = self.downsampler_3(downsample_feature_2) @@ -88,34 +89,34 @@ class AADLayer(nn.Module): def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None: super(AADLayer, self).__init__() self.input_channels = input_channels - self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) - self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) + self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1) + self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1) self.fc_beta = nn.Linear(id_channels, input_channels) self.fc_gamma = nn.Linear(id_channels, input_channels) self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False) - self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True) + self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) - def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor: + def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor: feature_map = self.instance_norm(feature_map) - attr_gamma = self.conv_gamma(attr_embedding) - attr_beta = self.conv_beta(attr_embedding) - attr_modulation = attr_gamma * feature_map + attr_beta + gamma_attribute = self.conv_gamma(attribute_embedding) + beta_attribute = self.conv_beta(attribute_embedding) + attribute_modulation = gamma_attribute * feature_map + beta_attribute id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) id_modulation = id_gamma * feature_map + id_beta feature_mask = torch.sigmoid(self.conv_mask(feature_map)) - feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation + feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * id_modulation return feature_blend class AddBlocksSequential(nn.Sequential): - def forward(self, *inputs : Tuple[Tensor, Tensor, IDEmbedding]) -> Tuple[Tuple[Tensor, Tensor, IDEmbedding], ...]: - _, attr_embedding, id_embedding = inputs + def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]: + _, attr_embedding, id_embedding = inputs #@todo we are not using shortcuts, it is attribute_embedding - for index, module in enumerate(self._modules.values()): + for index, module in enumerate(self._modules.values()): #@todo refactor this to return values if index % 3 == 0 and index > 0: inputs = (inputs, attr_embedding, id_embedding) # type:ignore[assignment] - if type(inputs) == tuple: + if type(inputs) == tuple: #@todo my IDE complains about the type check inputs = module(*inputs) else: inputs = module(inputs) @@ -123,45 +124,45 @@ class AddBlocksSequential(nn.Sequential): class AADResBlock(nn.Module): - def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None: + def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, id_channels : int, num_blocks : int) -> None: super(AADResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels + self.input_channels = input_channels + self.output_channels = output_channels primary_add_blocks = [] for i in range(num_blocks): - intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels + intermediate_channels = input_channels if i < (num_blocks - 1) else output_channels primary_add_blocks.extend( - [ - AADLayer(in_channels, attr_channels, id_channels), + [ #@todo indent + AADLayer(input_channels, attribute_channels, id_channels), nn.ReLU(inplace = True), - nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) ]) self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) - if in_channels != out_channels: - auxiliary_add_blocks = \ - [ - AADLayer(in_channels, attr_channels, id_channels), + if input_channels != output_channels: + auxiliary_add_blocks =\ + [ #@todo indent + AADLayer(input_channels, attribute_channels, id_channels), nn.ReLU(inplace = True), - nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + nn.Conv2d(input_channels, output_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) ] self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks) - def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor: - primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding) + def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor: + primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding) - if self.in_channels != self.out_channels: - feature_map = self.auxiliary_add_blocks(feature_map, attr_embedding, id_embedding) + if self.input_channels != self.output_channels: + feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, id_embedding) output_feature = primary_feature + feature_map return output_feature class DownSample(nn.Module): - def __init__(self, in_channels : int, out_channels : int) -> None: + def __init__(self, input_channels : int, output_channels : int) -> None: super(DownSample, self).__init__() - self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(out_channels) + self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) def forward(self, temp : Tensor) -> Tensor: @@ -172,10 +173,10 @@ class DownSample(nn.Module): class Upsample(nn.Module): - def __init__(self, in_channels : int, out_channels : int) -> None: + def __init__(self, input_channels : int, output_channels : int) -> None: super(Upsample, self).__init__() - self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(out_channels) + self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor: @@ -186,9 +187,9 @@ class Upsample(nn.Module): class PixelShuffleUpsample(nn.Module): - def __init__(self, in_channels : int, out_channels : int) -> None: + def __init__(self, input_channels : int, output_channels : int) -> None: super(PixelShuffleUpsample, self).__init__() - self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1) + self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1) self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) def forward(self, temp : Tensor) -> Tensor: @@ -199,7 +200,7 @@ class PixelShuffleUpsample(nn.Module): def initialize_weight(module : nn.Module) -> None: if isinstance(module, nn.Linear): - module.weight.data.normal_(0, 0.001) + module.weight.data.normal_(std = 0.001) module.bias.data.zero_() if isinstance(module, nn.Conv2d): @@ -207,11 +208,3 @@ def initialize_weight(module : nn.Module) -> None: if isinstance(module, nn.ConvTranspose2d): nn.init.xavier_normal_(module.weight.data) - - -if __name__ == '__main__': - model = AdaptiveEmbeddingIntegrationNetwork(512, 2) - src = torch.randn(1, 512) - trg = torch.randn(1, 3, 256, 256) - out = model(trg, src) - print(out[0].shape) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 3358c5a..5abdb41 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,31 +1,11 @@ -import configparser -from typing import Tuple - import torch from .typing import Tensor -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') -L2_loss = torch.nn.MSELoss() +def hinge_real_loss(tensor : Tensor) -> Tensor: + return torch.relu(1 - tensor) -def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor) -> Tensor: - points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3) - points_transformed *= scale[..., None] - points_transformed[:, :, 0:2] += translation[:, None, 0:2] - return points_transformed - - -def hinge_loss(tensor : Tensor, is_positive : bool) -> Tensor: - if is_positive: - return torch.relu(1 - tensor) - else: - return torch.relu(tensor + 1) - - -def calc_distance_ratio(landmarks : Tensor, indices : Tuple[int, int, int, int]) -> Tensor: - distance_horizontal = torch.norm(landmarks[:, indices[0]] - landmarks[:, indices[1]], p = 2, dim = 1, keepdim = True) - distance_vertical = torch.norm(landmarks[:, indices[2]] - landmarks[:, indices[3]], p=2, dim = 1, keepdim = True) - return distance_horizontal / (distance_vertical + 1e-4) +def hinge_fake_loss(tensor : Tensor) -> Tensor: + return torch.relu(tensor + 1) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a2c6d65..87e151d 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -2,8 +2,6 @@ import configparser import os from typing import Tuple -import cv2 -import numpy import pytorch_lightning import torch import torchvision @@ -12,18 +10,212 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.types import Optimizer from pytorch_msssim import ssim +from torch import Tensor from torch.utils.data import DataLoader -from .data_loader import DataLoaderVGG, read_image +from .data_loader import DataLoaderVGG from .discriminator import MultiscaleDiscriminator from .generator import AdaptiveEmbeddingIntegrationNetwork -from .helper import L2_loss, hinge_loss -from .typing import Batch, DiscriminatorOutputs, IDEmbedding, LossDict, TargetAttributes, Tensor +from .helper import hinge_fake_loss, hinge_real_loss +from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') +class FaceSwapper(pytorch_lightning.LightningModule): + def __init__(self) -> None: + super().__init__() + id_channels = CONFIG.getint('training.generator', 'id_channels') + num_blocks = CONFIG.getint('training.generator', 'num_blocks') + input_channels = CONFIG.getint('training.discriminator', 'input_channels') + num_filters = CONFIG.getint('training.discriminator', 'num_filters') + num_layers = CONFIG.getint('training.discriminator', 'num_layers') + num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') + arcface_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') + landmarker_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') + motion_extractor_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') + + self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) + self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) + self.arcface = torch.load(arcface_path, map_location = 'cpu', weights_only = False) + self.landmarker = torch.load(landmarker_path, map_location = 'cpu', weights_only = False) + self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') + self.motion_extractor.load_state_dict(torch.load(motion_extractor_path, map_location = 'cpu', weights_only = True)) + self.arcface.eval() + self.landmarker.eval() + self.motion_extractor.eval() + self.automatic_optimization = False + self.mse_loss = torch.nn.MSELoss() + self.batch_size = CONFIG.getint('training.loader', 'batch_size') + + def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: + output = self.generator(target_tensor, source_embedding) + return output + + def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: + generator_learning_rate = CONFIG.getfloat('training.generator', 'learning_rate') + discriminator_learning_rate = CONFIG.getfloat('training.discriminator', 'learning_rate') + generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = generator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = discriminator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + return generator_optimizer, discriminator_optimizer + + def training_step(self, batch : Batch, batch_index : int) -> Tensor: + source_tensor, target_tensor, is_same_person = batch + generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) + swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) + discriminator_outputs = self.discriminator(swap_tensor) + + generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch) + generator_optimizer.zero_grad() + self.manual_backward(generator_losses.get('loss_generator')) + generator_optimizer.step() + + discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor) + discriminator_optimizer.zero_grad() + self.manual_backward(discriminator_losses.get('loss_discriminator')) + discriminator_optimizer.step() + + if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: + self.log_generator_preview(source_tensor, target_tensor, swap_tensor) + + self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True) + self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True) + self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = True) + self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True) + self.log('l_ID', generator_losses.get('loss_id'), prog_bar=True) + self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True) + return generator_losses.get('loss_generator') + + def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss: + loss_adversarial = torch.Tensor(0) + + for discriminator_output in discriminator_outputs: + loss_adversarial += hinge_real_loss(discriminator_output[0]).mean(dim = [ 1, 2, 3 ]) + loss_adversarial = torch.mean(loss_adversarial) + return loss_adversarial + + def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss: + loss_attribute = torch.Tensor(0) + swap_attributes = self.generator.get_attributes(swap_tensor) + + for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): + loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() + loss_attribute *= 0.5 + return loss_attribute + + def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss: + loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) + loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() + loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 + return loss_reconstruction + + def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss: + swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10)) + source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10)) + loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() + return loss_id + + def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss: + swap_motion_features = self.get_pose_features(swap_tensor) + target_motion_features = self.get_pose_features(target_tensor) + loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): + loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature) + return loss_tsr + + def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss: + swap_landmark = self.get_face_landmarks(swap_tensor) + target_landmark = self.get_face_landmarks(target_tensor) + left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) + right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) + gaze_loss = left_gaze_loss + right_gaze_loss + return gaze_loss + + def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: + source_tensor, target_tensor, is_same_person = batch + weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') + weight_id = CONFIG.getfloat('training.losses', 'weight_id') + weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') + weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') + weight_tsr = CONFIG.getfloat('training.losses', 'weight_tsr') + weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') + + generator_loss_set = {} + generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs) + generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor) + generator_loss_set['loss_attribute'] = self.calc_attribute_loss(swap_tensor, target_attributes) + generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) + + if weight_tsr > 0: + generator_loss_set['loss_tsr'] = self.calc_tsr_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_tsr'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + if weight_gaze > 0: + generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_tsr') * weight_tsr + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze + return generator_loss_set + + def calc_discriminator_loss(self, swap_tensor : VisionTensor, source_tensor : VisionTensor) -> DiscriminatorLossSet: + discriminator_loss_set = {} + fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) + loss_fake = torch.Tensor(0) + + for fake_discriminator_output in fake_discriminator_outputs: + loss_fake += torch.mean(hinge_fake_loss(fake_discriminator_output[0]).mean(dim = [ 1, 2, 3 ])) + true_discriminator_outputs = self.discriminator(source_tensor) + loss_true = torch.Tensor(0) + + for true_discriminator_output in true_discriminator_outputs: + loss_true += torch.mean(hinge_real_loss(true_discriminator_output[0]).mean(dim = [ 1, 2, 3 ])) + discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 + return discriminator_loss_set + + def get_id_embedding(self, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding: + crop_vision_tensor = torch.nn.functional.interpolate(vision_tensor, size = (112, 112), mode = 'area') + crop_vision_tensor = crop_vision_tensor[:, :, 0:112, 8:128] + crop_vision_tensor[:, :, :padding[0], :] = 0 + crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 + crop_vision_tensor[:, :, :, :padding[2]] = 0 + crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 + embedding = self.arcface(crop_vision_tensor) + embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) + return embedding + + def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') + landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) + return landmarks + + def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + motion_dict = self.motion_extractor(vision_tensor_norm) + translation = motion_dict.get('t') + scale = motion_dict.get('scale') + rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1) + return translation, scale, rotation + + def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: + max_preview = 8 + source_tensors = source_tensor[:max_preview] + target_tensors = target_tensor[:max_preview] + swap_tensors = swap_tensor[:max_preview] + rows = [ torch.cat([ source_tensor, target_tensor, swap_tensor ], dim = 2) for source_tensor, target_tensor, swap_tensor in zip(source_tensors, target_tensors, swap_tensors) ] + grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True) + self.logger.experiment.add_image("Generator Preview", grid, self.global_step) + + def create_trainer() -> Trainer: trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') output_directory_path = CONFIG.get('training.output', 'directory_path') @@ -56,229 +248,7 @@ def train() -> None: checkpoint_path = CONFIG.get('training.output', 'checkpoint_path') dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path')) - if not (checkpoint_path and os.path.exists(checkpoint_path)): - checkpoint_path = None data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swap_model = FaceSwapper() trainer = create_trainer() trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path) - - -class FaceSwapper(pytorch_lightning.LightningModule): - def __init__(self) -> None: - super().__init__() - self.generator = AdaptiveEmbeddingIntegrationNetwork(CONFIG.getint('training.generator', 'id_channels'), CONFIG.getint('training.generator', 'num_blocks')) - self.discriminator = MultiscaleDiscriminator(CONFIG.getint('training.discriminator', 'input_channels'), CONFIG.getint('training.discriminator', 'num_filters'), CONFIG.getint('training.discriminator', 'num_layers'), CONFIG.getint('training.discriminator', 'num_discriminators')) - self.arcface = torch.load(CONFIG.get('auxiliary_models.paths', 'arcface_path'), map_location = 'cpu', weights_only = False) - self.landmarker = torch.load(CONFIG.get('auxiliary_models.paths', 'landmarker_path'), map_location = 'cpu', weights_only = False) - self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') - self.motion_extractor.load_state_dict(torch.load(CONFIG.get('auxiliary_models.paths', 'motion_extractor_path'), map_location = 'cpu', weights_only = True)) - self.arcface.eval() - self.landmarker.eval() - self.motion_extractor.eval() - self.automatic_optimization = False - self.batch_size = CONFIG.getint('training.loader', 'batch_size') - - def forward(self, target_tensor : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]: - output = self.generator(target_tensor, source_embedding) - return output - - def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: - generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) - return generator_optimizer, discriminator_optimizer - - def training_step(self, batch : Batch, batch_index : int) -> Tensor: - source_tensor, target_tensor, is_same_person = batch - generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] - source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) - swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) - discriminator_outputs = self.discriminator(swap_tensor) - generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch) - generator_optimizer.zero_grad() - self.manual_backward(generator_losses.get('loss_generator')) - generator_optimizer.step() - discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor) - discriminator_optimizer.zero_grad() - self.manual_backward(discriminator_losses.get('loss_discriminator')) - - if not CONFIG.getboolean('training.discriminator', 'disable'): - discriminator_optimizer.step() - - if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: - self.log_generator_preview(source_tensor, target_tensor, swap_tensor) - - if self.global_step % CONFIG.getint('training.output', 'validation_frequency') == 0: - self.log_validation_preview() - self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True) - self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True) - self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False) - self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True) - self.log('l_ID', generator_losses.get('loss_identity'), prog_bar=True) - self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True) - return generator_losses.get('loss_generator') - - def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict: - source_tensor, target_tensor, is_same_person = batch - generator_losses = {} - # adversarial loss - loss_adversarial = torch.Tensor(0) - - for discriminator_output in discriminator_outputs: - loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ]) - loss_adversarial = torch.mean(loss_adversarial) - generator_losses['loss_adversarial'] = loss_adversarial - generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial') - - # identity loss - swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10)) - source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10)) - loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() - generator_losses['loss_identity'] = loss_identity - generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity') - - # attribute loss - loss_attribute = torch.Tensor(0) - swap_attributes = self.generator.get_attributes(swap_tensor) - - for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): - loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() - loss_attribute *= 0.5 - generator_losses['loss_attribute'] = loss_attribute - generator_losses['loss_generator'] += loss_attribute * CONFIG.getfloat('training.losses', 'weight_attribute') - - # reconstruction loss - loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) - loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 - generator_losses['loss_reconstruction'] = loss_reconstruction - generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction') - - if CONFIG.getfloat('training.losses', 'weight_tsr') > 0: - # tsr loss - swap_motion_features = self.get_motion_features(swap_tensor) - target_motion_features = self.get_motion_features(target_tensor) - loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): - loss_tsr += L2_loss(swap_motion_feature, target_motion_feature) - generator_losses['loss_tsr'] = loss_tsr - generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr') - - if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0: - swap_landmark_features = self.get_landmark_features(swap_tensor) - target_landmark_features = self.get_landmark_features(target_tensor) - loss_left_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1]) - loss_right_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1]) - loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze - generator_losses['loss_eye_gaze'] = loss_eye_gaze - generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze') - return generator_losses - - def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict: - discriminator_losses = {} - fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) - loss_fake = torch.Tensor(0) - - for fake_discriminator_output in fake_discriminator_outputs: - loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3])) - true_discriminator_outputs = self.discriminator(source_tensor) - loss_true = torch.Tensor(0) - - for true_discriminator_output in true_discriminator_outputs: - loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3])) - discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 - return discriminator_losses - - def get_id_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor: - _, _, height, width = vision_tensor.shape - crop_height = int(height * 0.0586) - crop_width = int(width * 0.0586) - crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width] - crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear') - crop_vision_tensor[:, :, :padding[0], :] = 0 - crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 - crop_vision_tensor[:, :, :, :padding[2]] = 0 - crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 - embedding = self.arcface(crop_vision_tensor) - embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) - return embedding - - def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor]: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') - landmarks = self.landmarker(vision_tensor_norm)[2] - landmarks = landmarks.view(-1, 203, 2) * 256 - return landmarks[:, 198], landmarks[:, 197] - - def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - motion_dict = self.motion_extractor(vision_tensor_norm) - translation = motion_dict.get('t') - scale = motion_dict.get('scale') - rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1) - return translation, scale, rotation - - def log_generator_preview(self, source_tensor : Tensor, target_tensor : Tensor, swap_tensor : Tensor) -> None: - max_preview = 8 - source_tensor = source_tensor[:max_preview] - target_tensor = target_tensor[:max_preview] - swap_tensor = swap_tensor[:max_preview] - rows = [torch.cat([src, tgt, swp], dim = 2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)] - grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True) - self.logger.experiment.add_image("Generator Preview", grid, self.global_step) - - def log_validation_preview(self) -> None: - read_images = lambda path : [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] - to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5 - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ]) - sources = read_images(CONFIG.get('training.validation', 'sources')) - targets_front = read_images(CONFIG.get('training.validation', 'targets_front')) - targets_side = read_images(CONFIG.get('training.validation', 'targets_side')) - targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup')) - targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion')) - - self.generator.eval() - - results_source = [] - results_front = [] - results_side = [] - results_makeup = [] - results_occlusion = [] - - for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion): - source_tensor = transforms(source).unsqueeze(0).to(self.device).half() - source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) - target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half() - target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half() - target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half() - target_occlusion_tensor = transforms(target_occlusion).unsqueeze(0).to(self.device).half() - - with torch.no_grad(): - output_front, _ = self.generator(target_front_tensor, source_embedding) - output_side, _ = self.generator(target_side_tensor, source_embedding) - output_makeup, _ = self.generator(target_makeup_tensor, source_embedding) - output_occlusion, _ = self.generator(target_occlusion_tensor, source_embedding) - - results_source.append(to_numpy(source_tensor)) - results_front.append(numpy.hstack([to_numpy(target_front_tensor), to_numpy(output_front)])) - results_side.append(numpy.hstack([to_numpy(target_side_tensor), to_numpy(output_side)])) - results_makeup.append(numpy.hstack([to_numpy(target_makeup_tensor), to_numpy(output_makeup)])) - results_occlusion.append(numpy.hstack([to_numpy(target_occlusion_tensor), to_numpy(output_occlusion)])) - - sources_vertical = numpy.vstack(results_source) - results_front_vertical = numpy.vstack(results_front) - results_side_vertical = numpy.vstack(results_side) - results_makeup_vertical = numpy.vstack(results_makeup) - results_occlusion_vertical = numpy.vstack(results_occlusion) - pad = numpy.zeros((sources_vertical.shape[0], 10, 3), dtype = sources_vertical.dtype) - preview = numpy.hstack([sources_vertical, pad, results_front_vertical, pad, results_side_vertical, pad, results_makeup_vertical, pad, results_occlusion_vertical]) - - os.makedirs("validation_previews", exist_ok=True) - cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview) - self.generator.train() diff --git a/face_swapper/src/typing.py b/face_swapper/src/typing.py index 29ab130..f1366b5 100644 --- a/face_swapper/src/typing.py +++ b/face_swapper/src/typing.py @@ -1,7 +1,6 @@ from collections import OrderedDict from typing import Any, Dict, List, Tuple -from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader @@ -9,8 +8,12 @@ Batch = Tuple[Any, Any, Any] Loader = DataLoader[Tuple[Tensor, ...]] TargetAttributes = Tuple[Tensor, ...] DiscriminatorOutputs = List[List[Tensor]] -LossDict = Dict[str, Tensor] -IDEmbedding = Tensor +IdEmbedding = Tensor +SourceEmbedding = IdEmbedding StateDict = OrderedDict[str, Any] -Embedding = NDArray[Any] -VisionFrame = NDArray[Any] +Padding = Tuple[int, int, int, int] +FaceLandmark203 = Tensor +VisionTensor = Tensor +Loss = Tensor +GeneratorLossSet = Dict[str, Loss] +DiscriminatorLossSet = Dict[str, Loss]