From 2ed558a873b756a1eea56d52acccbbcd119422e2 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 10 Feb 2025 22:43:15 +0530 Subject: [PATCH] cleanup --- face_swapper/__init__.py | 0 face_swapper/src/data_loader.py | 3 +- face_swapper/src/exporting.py | 4 +- face_swapper/src/inferencing.py | 4 +- .../src/{ => models}/discriminator.py | 63 ++++---- face_swapper/src/models/generator.py | 43 +++++ face_swapper/src/models/loss.py | 137 ++++++++++++++++ .../attribute_modulator.py} | 127 ++------------- face_swapper/src/networks/encoder.py | 67 ++++++++ face_swapper/src/training.py | 148 +----------------- 10 files changed, 310 insertions(+), 286 deletions(-) create mode 100644 face_swapper/__init__.py rename face_swapper/src/{ => models}/discriminator.py (75%) create mode 100644 face_swapper/src/models/generator.py create mode 100644 face_swapper/src/models/loss.py rename face_swapper/src/{generator.py => networks/attribute_modulator.py} (54%) create mode 100644 face_swapper/src/networks/encoder.py diff --git a/face_swapper/__init__.py b/face_swapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index c7a5c6c..49f1c21 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -24,8 +24,7 @@ class DataLoaderVGG(TensorDataset): image_path_set = {} for directory_path in self.directory_paths: - image_paths = glob.glob(dataset_image_pattern.format(directory_path)) - image_paths.extend(image_paths) + image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path))) image_path_set[directory_path] = image_paths return image_paths, image_path_set diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 033e26d..311fdc8 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -3,7 +3,7 @@ from os import makedirs import torch -from .generator import AdaptiveEmbeddingIntegrationNetwork +from .models.generator import AdaptiveEmbeddingIntegrationNetwork CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -17,7 +17,7 @@ def export() -> None: makedirs(directory_path, exist_ok = True) state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator') - model = AdaptiveEmbeddingIntegrationNetwork(512, 2) + model = AdaptiveEmbeddingIntegrationNetwork() model.load_state_dict(state_dict) model.eval() source_tensor = torch.randn(1, 512) diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 570c277..3da1573 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -3,8 +3,8 @@ import configparser import cv2 import torch -from .generator import AdaptiveEmbeddingIntegrationNetwork from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image +from .models.generator import AdaptiveEmbeddingIntegrationNetwork from .types import Generator, IdEmbedder, VisionFrame CONFIG = configparser.ConfigParser() @@ -28,7 +28,7 @@ def infer() -> None: output_path = CONFIG.get('inferencing', 'output_path') state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator') - generator = AdaptiveEmbeddingIntegrationNetwork(512, 2) + generator = AdaptiveEmbeddingIntegrationNetwork() generator.load_state_dict(state_dict) generator.eval() id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/models/discriminator.py similarity index 75% rename from face_swapper/src/discriminator.py rename to face_swapper/src/models/discriminator.py index 8d1a2ab..475d66f 100644 --- a/face_swapper/src/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -1,3 +1,4 @@ +import configparser from itertools import chain from typing import List @@ -6,7 +7,41 @@ import torch.nn import torch.nn as nn from torch import Tensor -from .types import DiscriminatorOutputs +from face_swapper.src.types import DiscriminatorOutputs + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +class MultiscaleDiscriminator(nn.Module): + def __init__(self) -> None: + super(MultiscaleDiscriminator, self).__init__() + self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') + self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') + self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') + self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') + self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') + + self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] + self.prepare_discriminators() + + def prepare_discriminators(self) -> None: + for discriminator_index in range(self.num_discriminators): + single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size) + setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) + + def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: + discriminator_outputs = [] + temp_tensor = input_tensor + + for discriminator_index in range(self.num_discriminators): + model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) + discriminator_outputs.append([ model_layers(temp_tensor) ]) + + if discriminator_index < (self.num_discriminators - 1): + temp_tensor = self.downsample(temp_tensor) + + return discriminator_outputs class NLayerDiscriminator(nn.Module): @@ -58,29 +93,3 @@ class NLayerDiscriminator(nn.Module): def forward(self, input_tensor : Tensor) -> Tensor: return self.model(input_tensor) - - -class MultiscaleDiscriminator(nn.Module): - def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int, kernel_size : int): - super(MultiscaleDiscriminator, self).__init__() - self.num_discriminators = num_discriminators - self.num_layers = num_layers - - for discriminator_index in range(num_discriminators): - single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size) - setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) - - self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] - - def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: - discriminator_outputs = [] - temp_tensor = input_tensor - - for discriminator_index in range(self.num_discriminators): - model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) - discriminator_outputs.append([ model_layers(temp_tensor) ]) - - if discriminator_index < (self.num_discriminators - 1): - temp_tensor = self.downsample(temp_tensor) - - return discriminator_outputs diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py new file mode 100644 index 0000000..12a4f76 --- /dev/null +++ b/face_swapper/src/models/generator.py @@ -0,0 +1,43 @@ +import configparser +from typing import Tuple + +import torch.nn as nn + +from face_swapper.src.networks.attribute_modulator import AADGenerator +from face_swapper.src.networks.encoder import UNet +from face_swapper.src.types import SourceEmbedding, TargetAttributes, VisionTensor + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +class AdaptiveEmbeddingIntegrationNetwork(nn.Module): + def __init__(self) -> None: + super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() + id_channels = CONFIG.getint('training.model.generator', 'id_channels') + num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') + + self.encoder = UNet() + self.generator = AADGenerator(id_channels, num_blocks) + self.encoder.apply(init_weight) + self.generator.apply(init_weight) + + def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: + target_attributes = self.get_attributes(target) + swap_tensor = self.generator(target_attributes, source_embedding) + return swap_tensor, target_attributes + + def get_attributes(self, target : VisionTensor) -> TargetAttributes: + return self.encoder(target) + + +def init_weight(module : nn.Module) -> None: + if isinstance(module, nn.Linear): + module.weight.data.normal_(std = 0.001) + module.bias.data.zero_() + + if isinstance(module, nn.Conv2d): + nn.init.xavier_normal_(module.weight.data) + + if isinstance(module, nn.ConvTranspose2d): + nn.init.xavier_normal_(module.weight.data) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py new file mode 100644 index 0000000..2cd4bcb --- /dev/null +++ b/face_swapper/src/models/loss.py @@ -0,0 +1,137 @@ +import configparser +from typing import Tuple + +import torch +from pytorch_msssim import ssim +from torch import Tensor + +from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss +from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +class FaceSwapperLoss: + def __init__(self) -> None: + id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') + landmarker_path = CONFIG.get('training.model', 'landmarker_path') + motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') + self.batch_size = CONFIG.getint('training.loader', 'batch_size') + self.mse_loss = torch.nn.MSELoss() + self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.id_embedder.eval() + self.landmarker.eval() + self.motion_extractor.eval() + + def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: + source_tensor, target_tensor, is_same_person = batch + weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') + weight_id = CONFIG.getfloat('training.losses', 'weight_id') + weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') + weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') + weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') + weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') + generator_loss_set = {} + + generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs) + generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor) + generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes) + generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) + + if weight_pose > 0: + generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + if weight_gaze > 0: + generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze + return generator_loss_set + + def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: + discriminator_loss_set = {} + loss_fake = torch.Tensor(0) + + for fake_discriminator_output in fake_discriminator_outputs: + loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean() + + loss_true = torch.Tensor(0) + + for true_discriminator_output in real_discriminator_outputs: + loss_true += hinge_real_loss(true_discriminator_output[0]).mean() + + discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 + return discriminator_loss_set + + def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: + loss_adversarial = torch.Tensor(0) + + for discriminator_output in discriminator_outputs: + loss_adversarial += hinge_real_loss(discriminator_output[0]) + + loss_adversarial = torch.mean(loss_adversarial) + return loss_adversarial + + def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor: + loss_attribute = torch.Tensor(0) + + for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): + loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() + + loss_attribute *= 0.5 + return loss_attribute + + def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: + loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) + loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 + loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) + loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() + loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 + return loss_reconstruction + + def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: + swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) + loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() + return loss_id + + def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: + swap_motion_features = self.get_pose_features(swap_tensor) + target_motion_features = self.get_pose_features(target_tensor) + loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): + loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature) + + return loss_pose + + def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: + swap_landmark = self.get_face_landmarks(swap_tensor) + target_landmark = self.get_face_landmarks(target_tensor) + left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) + right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) + gaze_loss = left_gaze_loss + right_gaze_loss + return gaze_loss + + def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') + landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) + return landmarks + + def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) + return translation, scale, rotation diff --git a/face_swapper/src/generator.py b/face_swapper/src/networks/attribute_modulator.py similarity index 54% rename from face_swapper/src/generator.py rename to face_swapper/src/networks/attribute_modulator.py index 12f82e6..cc34463 100644 --- a/face_swapper/src/generator.py +++ b/face_swapper/src/networks/attribute_modulator.py @@ -1,25 +1,7 @@ -from typing import Tuple - import torch -import torch.nn as nn -from torch import Tensor +from torch import Tensor, nn as nn -from .types import SourceEmbedding, TargetAttributes, VisionTensor - - -class AdaptiveEmbeddingIntegrationNetwork(nn.Module): - def __init__(self, id_channels : int, num_blocks : int) -> None: - super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() - self.encoder = UNet() - self.generator = AADGenerator(id_channels, num_blocks) - - def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: - target_attributes = self.get_attributes(target) - swap_tensor = self.generator(target_attributes, source_embedding) - return swap_tensor, target_attributes - - def get_attributes(self, target : VisionTensor) -> TargetAttributes: - return self.encoder(target) +from face_swapper.src.types import SourceEmbedding, TargetAttributes class AADGenerator(nn.Module): @@ -34,7 +16,6 @@ class AADGenerator(nn.Module): self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks) self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks) self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks) - self.apply(init_weight) def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor: feature_map = self.upsample(source_embedding) @@ -49,42 +30,6 @@ class AADGenerator(nn.Module): return torch.tanh(output) -class UNet(nn.Module): - def __init__(self) -> None: - super(UNet, self).__init__() - self.downsampler_1 = DownSample(3, 32) - self.downsampler_2 = DownSample(32, 64) - self.downsampler_3 = DownSample(64, 128) - self.downsampler_4 = DownSample(128, 256) - self.downsampler_5 = DownSample(256, 512) - self.downsampler_6 = DownSample(512, 1024) - self.bottleneck = DownSample(1024, 1024) - self.upsampler_1 = Upsample(1024, 1024) - self.upsampler_2 = Upsample(2048, 512) - self.upsampler_3 = Upsample(1024, 256) - self.upsampler_4 = Upsample(512, 128) - self.upsampler_5 = Upsample(256, 64) - self.upsampler_6 = Upsample(128, 32) - self.apply(init_weight) - - def forward(self, target : VisionTensor) -> TargetAttributes: - downsample_feature_1 = self.downsampler_1(target) - downsample_feature_2 = self.downsampler_2(downsample_feature_1) - downsample_feature_3 = self.downsampler_3(downsample_feature_2) - downsample_feature_4 = self.downsampler_4(downsample_feature_3) - downsample_feature_5 = self.downsampler_5(downsample_feature_4) - downsample_feature_6 = self.downsampler_6(downsample_feature_5) - bottleneck_output = self.bottleneck(downsample_feature_6) - upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) - upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) - upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) - upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) - upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) - upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) - output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) - return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output - - class AADLayer(nn.Module): def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None: super(AADLayer, self).__init__() @@ -109,22 +54,18 @@ class AADLayer(nn.Module): return feature_blend -class AddBlocksSequential(nn.Sequential): - #todo: what are inputs? improve the name - def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]: - _, attribute_embedding, id_embedding = inputs - modules = self._modules.values() #todo: what kind of modules? +class AADSequential(nn.Module): + def __init__(self, *args : nn.Module) -> None: + super(AADSequential, self).__init__() + self.layers = nn.ModuleList(args) - for module_index, module in enumerate(modules): - if module_index % 3 == 0 and module_index > 0: - inputs = (inputs, attribute_embedding, id_embedding) # type:ignore[assignment] - - if isinstance(inputs, torch.Tensor): - inputs = module(inputs) + def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: SourceEmbedding) -> Tensor: + for layer in self.layers: + if isinstance(layer, AADLayer): + feature_map = layer(feature_map, attribute_embedding, id_embedding) else: - inputs = module(*inputs) - - return inputs #todo: would be easier to read when you just return xxx_inputs, attribute_embedding, id_embedding ? + feature_map = layer(feature_map) + return feature_map class AADResBlock(nn.Module): @@ -147,11 +88,11 @@ class AADResBlock(nn.Module): nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False) ] ) - self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) + self.primary_add_blocks = AADSequential(*primary_add_blocks) def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int) -> None: if input_channels > output_channels: - auxiliary_add_blocks = AddBlocksSequential( + auxiliary_add_blocks = AADSequential( AADLayer(input_channels, attribute_channels, id_channels), nn.ReLU(inplace = True), nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False) @@ -168,34 +109,6 @@ class AADResBlock(nn.Module): return output_feature -class DownSample(nn.Module): - def __init__(self, input_channels : int, output_channels : int) -> None: - super(DownSample, self).__init__() - self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - - def forward(self, temp : Tensor) -> Tensor: - temp = self.conv(temp) - temp = self.batch_norm(temp) - temp = self.leaky_relu(temp) - return temp - - -class Upsample(nn.Module): - def __init__(self, input_channels : int, output_channels : int) -> None: - super(Upsample, self).__init__() - self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - - def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor: - temp = self.deconv(temp) - temp = self.batch_norm(temp) - temp = self.leaky_relu(temp) - return torch.cat((temp, skip_tensor), dim = 1) - - class PixelShuffleUpsample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super(PixelShuffleUpsample, self).__init__() @@ -206,15 +119,3 @@ class PixelShuffleUpsample(nn.Module): temp = self.conv(temp.view(temp.shape[0], -1, 1, 1)) temp = self.pixel_shuffle(temp) return temp - - -def init_weight(module : nn.Module) -> None: - if isinstance(module, nn.Linear): - module.weight.data.normal_(std = 0.001) - module.bias.data.zero_() - - if isinstance(module, nn.Conv2d): - nn.init.xavier_normal_(module.weight.data) - - if isinstance(module, nn.ConvTranspose2d): - nn.init.xavier_normal_(module.weight.data) diff --git a/face_swapper/src/networks/encoder.py b/face_swapper/src/networks/encoder.py new file mode 100644 index 0000000..e380ab5 --- /dev/null +++ b/face_swapper/src/networks/encoder.py @@ -0,0 +1,67 @@ +import torch +from torch import Tensor, nn as nn + +from face_swapper.src.types import TargetAttributes, VisionTensor + + +class Upsample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super(Upsample, self).__init__() + self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor: + temp = self.deconv(temp) + temp = self.batch_norm(temp) + temp = self.leaky_relu(temp) + return torch.cat((temp, skip_tensor), dim = 1) + + +class DownSample(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super(DownSample, self).__init__() + self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, temp : Tensor) -> Tensor: + temp = self.conv(temp) + temp = self.batch_norm(temp) + temp = self.leaky_relu(temp) + return temp + + +class UNet(nn.Module): + def __init__(self) -> None: + super(UNet, self).__init__() + self.downsampler_1 = DownSample(3, 32) + self.downsampler_2 = DownSample(32, 64) + self.downsampler_3 = DownSample(64, 128) + self.downsampler_4 = DownSample(128, 256) + self.downsampler_5 = DownSample(256, 512) + self.downsampler_6 = DownSample(512, 1024) + self.bottleneck = DownSample(1024, 1024) + self.upsampler_1 = Upsample(1024, 1024) + self.upsampler_2 = Upsample(2048, 512) + self.upsampler_3 = Upsample(1024, 256) + self.upsampler_4 = Upsample(512, 128) + self.upsampler_5 = Upsample(256, 64) + self.upsampler_6 = Upsample(128, 32) + + def forward(self, target : VisionTensor) -> TargetAttributes: + downsample_feature_1 = self.downsampler_1(target) + downsample_feature_2 = self.downsampler_2(downsample_feature_1) + downsample_feature_3 = self.downsampler_3(downsample_feature_2) + downsample_feature_4 = self.downsampler_4(downsample_feature_3) + downsample_feature_5 = self.downsampler_5(downsample_feature_4) + downsample_feature_6 = self.downsampler_6(downsample_feature_5) + bottleneck_output = self.bottleneck(downsample_feature_6) + upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) + upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) + upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) + upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) + upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) + upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) + output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) + return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e6ad370..ef094f0 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -8,157 +8,25 @@ import torchvision from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.types import Optimizer -from pytorch_msssim import ssim from torch import Tensor from torch.utils.data import DataLoader from .data_loader import DataLoaderVGG -from .discriminator import MultiscaleDiscriminator -from .generator import AdaptiveEmbeddingIntegrationNetwork -from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss -from .types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, SwapAttributes, TargetAttributes, VisionTensor +from .helper import calc_id_embedding +from .models.discriminator import MultiscaleDiscriminator +from .models.generator import AdaptiveEmbeddingIntegrationNetwork +from .models.loss import FaceSwapperLoss +from .types import Batch, SourceEmbedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapperLoss: - def __init__(self) -> None: - id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') - landmarker_path = CONFIG.get('training.model', 'landmarker_path') - motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') - self.batch_size = CONFIG.getint('training.loader', 'batch_size') - self.mse_loss = torch.nn.MSELoss() - self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.id_embedder.eval() - self.landmarker.eval() - self.motion_extractor.eval() - - def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - source_tensor, target_tensor, is_same_person = batch - weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') - weight_id = CONFIG.getfloat('training.losses', 'weight_id') - weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') - weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') - weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') - weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') - generator_loss_set = {} - - generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs) - generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor) - generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes) - generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) - - if weight_pose > 0: - generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - if weight_gaze > 0: - generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze - return generator_loss_set - - def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: - discriminator_loss_set = {} - loss_fake = torch.Tensor(0) - - for fake_discriminator_output in fake_discriminator_outputs: - loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean() - - loss_true = torch.Tensor(0) - - for true_discriminator_output in real_discriminator_outputs: - loss_true += hinge_real_loss(true_discriminator_output[0]).mean() - - discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 - return discriminator_loss_set - - def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: - loss_adversarial = torch.Tensor(0) - - for discriminator_output in discriminator_outputs: - loss_adversarial += hinge_real_loss(discriminator_output[0]) - - loss_adversarial = torch.mean(loss_adversarial) - return loss_adversarial - - def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor: - loss_attribute = torch.Tensor(0) - - for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): - loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() - - loss_attribute *= 0.5 - return loss_attribute - - def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: - loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) - loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 - loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) - loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 - return loss_reconstruction - - def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) - loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() - return loss_id - - def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_motion_features = self.get_pose_features(swap_tensor) - target_motion_features = self.get_pose_features(target_tensor) - loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): - loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature) - - return loss_pose - - def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_landmark = self.get_face_landmarks(swap_tensor) - target_landmark = self.get_face_landmarks(target_tensor) - left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) - right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) - gaze_loss = left_gaze_loss + right_gaze_loss - return gaze_loss - - def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') - landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) - return landmarks - - def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - return translation, scale, rotation - - class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() - id_channels = CONFIG.getint('training.model.generator', 'id_channels') - num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') - num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') - num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') - num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') - kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') - self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) - self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators, kernel_size) + self.generator = AdaptiveEmbeddingIntegrationNetwork() + self.discriminator = MultiscaleDiscriminator() self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: @@ -244,8 +112,8 @@ def train() -> None: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') file_path = CONFIG.get('training.output', 'file_path') - dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) + dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swap_model = FaceSwapperTrain() trainer = create_trainer()