diff --git a/face_swapper/README.md b/face_swapper/README.md new file mode 100644 index 0000000..541ca3b --- /dev/null +++ b/face_swapper/README.md @@ -0,0 +1,107 @@ +Face Swapper +================= + +> Swap one face over another face. + +![License](https://img.shields.io/badge/license-MIT-green) + + +Preview +------- + +![Preview]() + + +Installation +------------ + +``` +pip install -r requirements.txt +``` + + +Example +------- + +This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap. + +``` +[preparing.dataset] +dataset_path = datasets/train +folder_pattern = {}/* +image_pattern = {}/*.*g +same_person_probability = 0.2 + +[training.loader] +batch_size = 24 +num_workers = 12 + +[training.model] +id_embedder_path = assets/models/id_embedder.pt +landmarker_path = assets/models/landmarker.pt +motion_extractor_path = assets/models/motion_extractor.pt + +[training.model.generator] +num_blocks = 2 +id_channels = 512 + +[training.model.discriminator] +input_channels = 3 +num_filters = 64 +num_layers = 5 +num_discriminators = 3 +kernel_size = 4 + +[training.losses] +weight_adversarial = 1 +weight_id = 20 +weight_attribute = 10 +weight_reconstruction = 10 +weight_pose = 100 + +[training.trainer] +max_epochs = 50 +learning_rate = 0.0004 +precision = 16-mixed +automatic_optimization = false + +[training.output] +checkpoint_path = outputs/last.ckpt +directory_path = outputs +file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}' +preview_frequency = 250 +validation_frequency = 1000 + +[exporting] +directory_path = export +source_path = outputs/last.ckpt +target_path = export/face_swapper.onnx +opset_version = 15 + +[inference] +generator_path = outputs/last.ckpt +id_embedder_path = assets/models/id_embedder.pt +source_path = assets/images/source.jpg +target_path = assets/models/target.jpg +output_path = outputs/output.jpg +``` + + +Training +-------- + +Train the Face swapper model. + +``` +python train.py +``` + + +Exporting +--------- + +Export the model to ONNX. + +``` +python export.py +``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 8422a1f..66dcdc1 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,12 +1,12 @@ [preparing.dataset] -dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan -folder_pattern = {}/* -image_pattern = {}/*.*g -same_person_probability = 0.2 +dataset_path = +directory_pattern = +image_pattern = +same_person_probability = [training.loader] -batch_size = 24 -num_workers = 12 +batch_size = +num_workers = [training.model] id_embedder_path = @@ -14,32 +14,35 @@ landmarker_path = motion_extractor_path = [training.model.generator] -num_blocks = 2 -id_channels = 512 +num_blocks = +id_channels = [training.model.discriminator] -input_channels = 3 -num_filters = 64 -num_layers = 5 -num_discriminators = 3 +input_channels = +num_filters = +num_layers = +num_discriminators = +kernel_size = [training.losses] -weight_adversarial = 1 -weight_id = 20 -weight_attribute = 10 -weight_reconstruction = 10 -weight_tsr = 100 +weight_adversarial = +weight_id = +weight_attribute = +weight_reconstruction = +weight_pose = [training.trainer] -max_epochs = 50 -learning_rate = 0.0004 +max_epochs = +learning_rate = +precision = +automatic_optimization = [training.output] -checkpoint_path = checkpoints/last.ckpt -directory_path = checkpoints -file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}' -preview_frequency = 250 -validation_frequency = 1000 +checkpoint_path = +directory_path = +file_pattern = +preview_frequency = +validation_frequency = [exporting] directory_path = @@ -47,9 +50,6 @@ source_path = target_path = opset_version = -[execution] -providers = - [inference] generator_path = id_embedder_path = diff --git a/face_swapper/infer.py b/face_swapper/infer.py index 1a2ac35..fc3eb29 100644 --- a/face_swapper/infer.py +++ b/face_swapper/infer.py @@ -1,29 +1,7 @@ -import configparser +#!/usr/bin/env python3 -import cv2 -import torch -from src.generator import AdaptiveEmbeddingIntegrationNetwork -from src.helper import infer, read_image - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +from face_swapper.src.inferencing import infer if __name__ == '__main__': - generator_path = CONFIG.get('inference', 'generator_path') - id_embedder_path = CONFIG.get('inference', 'id_embedder_path') - source_path = CONFIG.get('inference', 'source_path') - target_path = CONFIG.get('inference', 'target_path') - output_path = CONFIG.get('inference', 'output_path') - - state_dict = torch.load(generator_path, map_location = 'cpu')['state_dict']['generator'] - generator = AdaptiveEmbeddingIntegrationNetwork(512, 2) - generator.load_state_dict(state_dict) - generator.eval() - id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') #type:ignore[no-untyped-call] - id_embedder.eval() - - source_vision_frame = read_image(source_path) - target_vision_frame = read_image(target_path) - output_vision_frame = infer(generator, id_embedder, source_vision_frame, target_vision_frame) - cv2.imwrite(output_path, output_vision_frame) + infer() diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index c8b6784..31df9ae 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -1,43 +1,53 @@ import glob import os.path import random +from typing import Tuple import torch import torchvision.transforms as transforms from torch.utils.data import TensorDataset from .helper import read_image -from .typing import Batch +from .typing import Batch, ImagePathList, ImagePathSet class DataLoaderVGG(TensorDataset): def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None: self.same_person_probability = same_person_probability - self.folder_paths = glob.glob(dataset_folder_pattern.format(dataset_path)) - self.image_paths = [] - self.image_path_set = {} - - for folder_path in self.folder_paths: - image_paths = glob.glob(dataset_image_pattern.format(folder_path)) - self.image_paths.extend(image_paths) - self.image_path_set[folder_path] = image_paths + self.directory_paths = glob.glob(dataset_folder_pattern.format(dataset_path)) + self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern) self.dataset_total = len(self.image_paths) - self.transforms = transforms.Compose( + self.transforms = self.compose_transforms() + + def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]: + image_paths = [] + image_path_set = {} + + for directory_path in self.directory_paths: + image_paths = glob.glob(dataset_image_pattern.format(directory_path)) + image_paths.extend(image_paths) + image_path_set[directory_path] = image_paths + return image_paths, image_path_set + + def compose_transforms(self) -> transforms: + transform = transforms.Compose( [ transforms.ToPILImage(), - transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC), - transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), - transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0), + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), + transforms.RandomAffine(4, translate=(0.01, 0.01), scale=(0.98, 1.02), shear=(1, 1), fill=0), transforms.ToTensor(), - transforms.Lambda(lambda img: img[[2, 1, 0], :, :]), + transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) + return transform - def __getitem__(self, item : int) -> Batch: - source_image_path = self.image_paths[item] + def __getitem__(self, index : int) -> Batch: + source_image_path = self.image_paths[index] if random.random() > self.same_person_probability: return self.prepare_same_person(source_image_path) + return self.prepare_different_person(source_image_path) def prepare_different_person(self, source_image_path : str) -> Batch: diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/discriminator.py index 114f377..9e15f37 100644 --- a/face_swapper/src/discriminator.py +++ b/face_swapper/src/discriminator.py @@ -1,4 +1,8 @@ +from itertools import chain +from typing import List + import numpy +import torch.nn import torch.nn as nn from torch import Tensor @@ -6,11 +10,15 @@ from .typing import DiscriminatorOutputs class NLayerDiscriminator(nn.Module): - def __init__(self, input_channels : int, num_filters : int, num_layers : int) -> None: + def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: super(NLayerDiscriminator, self).__init__() self.num_layers = num_layers - kernel_size = 4 + model_layers = self.prepare_model_layers(input_channels, num_filters, num_layers, kernel_size) + self.model = nn.Sequential(*list(chain(*model_layers))) + + def prepare_model_layers(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> List[List[torch.nn.Module]]: padding_size = int(numpy.ceil((kernel_size - 1.0) / 2)) + model_layers =\ [ [ @@ -35,7 +43,7 @@ class NLayerDiscriminator(nn.Module): model_layers +=\ [ [ - nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size), + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding_size), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) ] @@ -43,38 +51,36 @@ class NLayerDiscriminator(nn.Module): model_layers +=\ [ [ - nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size) + nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding_size) ] ] - combined_layers = [] - - for model_layer in model_layers: - combined_layers += model_layer - self.model = nn.Sequential(*combined_layers) + return model_layers def forward(self, input_tensor : Tensor) -> Tensor: return self.model(input_tensor) class MultiscaleDiscriminator(nn.Module): - def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int): + def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int, kernel_size : int): super(MultiscaleDiscriminator, self).__init__() self.num_discriminators = num_discriminators self.num_layers = num_layers for discriminator_index in range(num_discriminators): - single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers) + single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size) setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) + self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: discriminator_outputs = [] - temp_downsampled_input = input_tensor + temp_tensor = input_tensor for discriminator_index in range(self.num_discriminators): model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) - discriminator_outputs.append([ model_layers(temp_downsampled_input) ]) + discriminator_outputs.append([ model_layers(temp_tensor) ]) if discriminator_index < (self.num_discriminators - 1): - temp_downsampled_input = self.downsample(temp_downsampled_input) + temp_tensor = self.downsample(temp_tensor) + return discriminator_outputs diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index ce954de..033e26d 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -16,7 +16,7 @@ def export() -> None: opset_version = CONFIG.getint('exporting', 'opset_version') makedirs(directory_path, exist_ok = True) - state_dict = torch.load(source_path, map_location = 'cpu')['state_dict']['generator'] + state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator') model = AdaptiveEmbeddingIntegrationNetwork(512, 2) model.load_state_dict(state_dict) model.eval() diff --git a/face_swapper/src/generator.py b/face_swapper/src/generator.py index d2aef72..d48b5e1 100644 --- a/face_swapper/src/generator.py +++ b/face_swapper/src/generator.py @@ -34,7 +34,7 @@ class AADGenerator(nn.Module): self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks) self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks) self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks) - self.apply(initialize_weight) + self.apply(init_weight) def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor: feature_map = self.upsample(source_embedding) @@ -65,7 +65,7 @@ class UNet(nn.Module): self.upsampler_4 = Upsample(512, 128) self.upsampler_5 = Upsample(256, 64) self.upsampler_6 = Upsample(128, 32) - self.apply(initialize_weight) + self.apply(init_weight) def forward(self, target : VisionTensor) -> TargetAttributes: downsample_feature_1 = self.downsampler_1(target) @@ -93,7 +93,7 @@ class AADLayer(nn.Module): self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1) self.fc_beta = nn.Linear(id_channels, input_channels) self.fc_gamma = nn.Linear(id_channels, input_channels) - self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False) + self.instance_norm = nn.InstanceNorm2d(input_channels) self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor: @@ -110,9 +110,10 @@ class AADLayer(nn.Module): class AddBlocksSequential(nn.Sequential): + #todo: what are inputs? improve the name def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]: _, attribute_embedding, id_embedding = inputs - modules = self._modules.values() + modules = self._modules.values() #todo: what kind of modules? for module_index, module in enumerate(modules): if module_index % 3 == 0 and module_index > 0: @@ -122,7 +123,8 @@ class AddBlocksSequential(nn.Sequential): inputs = module(inputs) else: inputs = module(*inputs) - return inputs + + return inputs #todo: would be easier to read when you just return xxx_inputs, attribute_embedding, id_embedding ? class AADResBlock(nn.Module): @@ -130,33 +132,38 @@ class AADResBlock(nn.Module): super(AADResBlock, self).__init__() self.input_channels = input_channels self.output_channels = output_channels + self.prepare_primary_add_blocks(input_channels, attribute_channels, id_channels, output_channels, num_blocks) + self.prepare_auxiliary_add_blocks(input_channels, attribute_channels, id_channels, output_channels) + + def prepare_primary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int, num_blocks : int) -> None: primary_add_blocks = [] - for i in range(num_blocks): - intermediate_channels = input_channels if i < (num_blocks - 1) else output_channels + for index in range(num_blocks): + intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels primary_add_blocks.extend( [ AADLayer(input_channels, attribute_channels, id_channels), nn.ReLU(inplace = True), - nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False) ] ) self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) - if input_channels != output_channels: - auxiliary_add_blocks =\ - [ + def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int) -> None: + if input_channels > output_channels: + auxiliary_add_blocks = AddBlocksSequential( AADLayer(input_channels, attribute_channels, id_channels), nn.ReLU(inplace = True), - nn.Conv2d(input_channels, output_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) - ] - self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks) + nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False) + ) + self.auxiliary_add_blocks = auxiliary_add_blocks def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor: primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding) - if self.input_channels != self.output_channels: + if self.input_channels > self.output_channels: feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, id_embedding) + output_feature = primary_feature + feature_map return output_feature @@ -201,7 +208,7 @@ class PixelShuffleUpsample(nn.Module): return temp -def initialize_weight(module : nn.Module) -> None: +def init_weight(module : nn.Module) -> None: if isinstance(module, nn.Linear): module.weight.data.normal_(std = 0.001) module.bias.data.zero_() diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 438e273..ce77c42 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,13 +1,21 @@ +import platform + import cv2 import numpy import torch -from .typing import IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor +from .typing import IdEmbedder, IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor + + +def is_windows() -> bool: + return platform.system().lower() == 'windows' def read_image(image_path : str) -> VisionFrame: - image = cv2.imread(image_path) - return image + if is_windows(): + image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8) + return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR) + return cv2.imread(image_path) def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor: @@ -28,14 +36,18 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame: def hinge_real_loss(tensor : Tensor) -> Tensor: - return torch.relu(1 - tensor) + real_loss = torch.relu(1 - tensor) + real_loss = real_loss.mean(dim = [ 1, 2, 3 ]) + return real_loss def hinge_fake_loss(tensor : Tensor) -> Tensor: - return torch.relu(tensor + 1) + fake_loss = torch.relu(tensor + 1) + fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ]) + return fake_loss -def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding: +def calc_id_embedding(id_embedder : IdEmbedder, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding: crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241] crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') crop_vision_tensor[:, :, :padding[0], :] = 0 @@ -43,14 +55,5 @@ def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTenso crop_vision_tensor[:, :, :, :padding[2]] = 0 crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 source_embedding = id_embedder(crop_vision_tensor) - source_embedding = torch.nn.functional.normalize(source_embedding, p = 2, dim = 1) + source_embedding = torch.nn.functional.normalize(source_embedding, p = 2) return source_embedding - - -def infer(generator : torch.nn.Module, id_embedder : torch.nn.Module, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: - source_vision_tensor = convert_to_vision_tensor(source_vision_frame) - target_vision_tensor = convert_to_vision_tensor(target_vision_frame) - source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0)) - output_vision_tensor = generator(source_embedding, target_vision_tensor)[0] - output_vision_frame = convert_to_vision_frame(output_vision_tensor) - return output_vision_frame diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py new file mode 100644 index 0000000..56ae4bc --- /dev/null +++ b/face_swapper/src/inferencing.py @@ -0,0 +1,40 @@ +import configparser + +import cv2 +import torch + +from .generator import AdaptiveEmbeddingIntegrationNetwork +from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image +from .typing import Generator, IdEmbedder, VisionFrame + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +def run_swap(generator : Generator, id_embedder : IdEmbedder, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + source_vision_tensor = convert_to_vision_tensor(source_vision_frame) + target_vision_tensor = convert_to_vision_tensor(target_vision_frame) + source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0)) + output_vision_tensor = generator(source_embedding, target_vision_tensor)[0] + output_vision_frame = convert_to_vision_frame(output_vision_tensor) + return output_vision_frame + + +def infer() -> None: + generator_path = CONFIG.get('inference', 'generator_path') + id_embedder_path = CONFIG.get('inference', 'id_embedder_path') + source_path = CONFIG.get('inference', 'source_path') + target_path = CONFIG.get('inference', 'target_path') + output_path = CONFIG.get('inference', 'output_path') + + state_dict = torch.load(generator_path, map_location='cpu').get('state_dict').get('generator') + generator = AdaptiveEmbeddingIntegrationNetwork(512, 2) + generator.load_state_dict(state_dict) + generator.eval() + id_embedder = torch.jit.load(id_embedder_path, map_location='cpu') # type:ignore[no-untyped-call] + id_embedder.eval() + + source_vision_frame = read_image(source_path) + target_vision_frame = read_image(target_path) + output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame) + cv2.imwrite(output_path, output_vision_frame) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e248ee3..33534aa 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,13 +16,138 @@ from .data_loader import DataLoaderVGG from .discriminator import MultiscaleDiscriminator from .generator import AdaptiveEmbeddingIntegrationNetwork from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss -from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, TargetAttributes, VisionTensor +from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapper(pytorch_lightning.LightningModule): +class FaceSwapperLoss: + def __init__(self) -> None: + id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') + landmarker_path = CONFIG.get('training.model', 'landmarker_path') + motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') + self.batch_size = CONFIG.getint('training.loader', 'batch_size') + self.mse_loss = torch.nn.MSELoss() + self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.id_embedder.eval() + self.landmarker.eval() + self.motion_extractor.eval() + + def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: + source_tensor, target_tensor, is_same_person = batch + weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') + weight_id = CONFIG.getfloat('training.losses', 'weight_id') + weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') + weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') + weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') + weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') + generator_loss_set = {} + + generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs) + generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor) + generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes) + generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) + + if weight_pose > 0: + generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + if weight_gaze > 0: + generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) + else: + generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze + return generator_loss_set + + def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: + discriminator_loss_set = {} + loss_fake = torch.Tensor(0) + + for fake_discriminator_output in fake_discriminator_outputs: + loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean() + + loss_true = torch.Tensor(0) + + for true_discriminator_output in real_discriminator_outputs: + loss_true += hinge_real_loss(true_discriminator_output[0]).mean() + + discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 + return discriminator_loss_set + + def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: + loss_adversarial = torch.Tensor(0) + + for discriminator_output in discriminator_outputs: + loss_adversarial += hinge_real_loss(discriminator_output[0]) + + loss_adversarial = torch.mean(loss_adversarial) + return loss_adversarial + + def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor: + loss_attribute = torch.Tensor(0) + + for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): + loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() + + loss_attribute *= 0.5 + return loss_attribute + + def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: + loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) + loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 + loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) + loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() + loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 + return loss_reconstruction + + def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: + swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) + loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() + return loss_id + + def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: + swap_motion_features = self.get_pose_features(swap_tensor) + target_motion_features = self.get_pose_features(target_tensor) + loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): + loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature) + + return loss_pose + + def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: + swap_landmark = self.get_face_landmarks(swap_tensor) + target_landmark = self.get_face_landmarks(target_tensor) + left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) + right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) + gaze_loss = left_gaze_loss + right_gaze_loss + return gaze_loss + + def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') + landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) + return landmarks + + def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) + return translation, scale, rotation + + +class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() id_channels = CONFIG.getint('training.model.generator', 'id_channels') @@ -31,21 +156,10 @@ class FaceSwapper(pytorch_lightning.LightningModule): num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') - id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') - landmarker_path = CONFIG.get('training.model', 'landmarker_path') - motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') - + kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) - self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) - self.id_embedder = torch.jit.load(id_embedder_path, map_location ='cpu') #type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') #type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') #type:ignore[no-untyped-call] - self.id_embedder.eval() - self.landmarker.eval() - self.motion_extractor.eval() - self.automatic_optimization = False - self.mse_loss = torch.nn.MSELoss() - self.batch_size = CONFIG.getint('training.loader', 'batch_size') + self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators, kernel_size) + self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]: output = self.generator(target_tensor, source_embedding) @@ -62,135 +176,32 @@ class FaceSwapper(pytorch_lightning.LightningModule): generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) - discriminator_outputs = self.discriminator(swap_tensor) + swap_attributes = self.generator.get_attributes(swap_tensor) + real_discriminator_outputs = self.discriminator(source_tensor.detach()) + fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) - generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch) + generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch) generator_optimizer.zero_grad() self.manual_backward(generator_losses.get('loss_generator')) generator_optimizer.step() - discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor) + discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs) discriminator_optimizer.zero_grad() self.manual_backward(discriminator_losses.get('loss_discriminator')) discriminator_optimizer.step() if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: - self.log_generator_preview(source_tensor, target_tensor, swap_tensor) + self.generate_preview(source_tensor, target_tensor, swap_tensor) self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True) self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True) self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = True) self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True) - self.log('l_ID', generator_losses.get('loss_id'), prog_bar=True) + self.log('l_ID', generator_losses.get('loss_id'), prog_bar = True) self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True) return generator_losses.get('loss_generator') - def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: - loss_adversarial = torch.Tensor(0) - - for discriminator_output in discriminator_outputs: - loss_adversarial += hinge_real_loss(discriminator_output[0]).mean(dim = [ 1, 2, 3 ]) - loss_adversarial = torch.mean(loss_adversarial) - return loss_adversarial - - def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor: - loss_attribute = torch.Tensor(0) - swap_attributes = self.generator.get_attributes(swap_tensor) - - for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): - loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() - loss_attribute *= 0.5 - return loss_attribute - - def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: - loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) - loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 - return loss_reconstruction - - def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) - loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() - return loss_id - - def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_motion_features = self.get_pose_features(swap_tensor) - target_motion_features = self.get_pose_features(target_tensor) - loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): - loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature) - return loss_tsr - - def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_landmark = self.get_face_landmarks(swap_tensor) - target_landmark = self.get_face_landmarks(target_tensor) - left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) - right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) - gaze_loss = left_gaze_loss + right_gaze_loss - return gaze_loss - - def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - source_tensor, target_tensor, is_same_person = batch - weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') - weight_id = CONFIG.getfloat('training.losses', 'weight_id') - weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') - weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') - weight_tsr = CONFIG.getfloat('training.losses', 'weight_tsr') - weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') - - generator_loss_set = {} - generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs) - generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor) - generator_loss_set['loss_attribute'] = self.calc_attribute_loss(swap_tensor, target_attributes) - generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) - - if weight_tsr > 0: - generator_loss_set['loss_tsr'] = self.calc_tsr_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_tsr'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - if weight_gaze > 0: - generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) - else: - generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_tsr') * weight_tsr - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze - return generator_loss_set - - def calc_discriminator_loss(self, swap_tensor : VisionTensor, source_tensor : VisionTensor) -> DiscriminatorLossSet: - discriminator_loss_set = {} - fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) - loss_fake = torch.Tensor(0) - - for fake_discriminator_output in fake_discriminator_outputs: - loss_fake += torch.mean(hinge_fake_loss(fake_discriminator_output[0]).mean(dim = [ 1, 2, 3 ])) - true_discriminator_outputs = self.discriminator(source_tensor) - loss_true = torch.Tensor(0) - - for true_discriminator_output in true_discriminator_outputs: - loss_true += torch.mean(hinge_real_loss(true_discriminator_output[0]).mean(dim = [ 1, 2, 3 ])) - discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 - return discriminator_loss_set - - def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') - landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) - return landmarks - - def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - return translation, scale, rotation - - def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: + def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: max_preview = 8 source_tensors = source_tensor[:max_preview] target_tensors = target_tensor[:max_preview] @@ -204,11 +215,12 @@ def create_trainer() -> Trainer: trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') output_directory_path = CONFIG.get('training.output', 'directory_path') output_file_pattern = CONFIG.get('training.output', 'file_pattern') + trainer_precision = CONFIG.get('training.trainer', 'precision') os.makedirs(output_directory_path, exist_ok = True) return Trainer( max_epochs = trainer_max_epochs, - precision = '16-mixed', + precision = trainer_precision, callbacks = [ ModelCheckpoint( @@ -217,12 +229,10 @@ def create_trainer() -> Trainer: filename = output_file_pattern, every_n_train_steps = 1000, save_top_k = 5, - mode = 'min', save_last = True ) ], - log_every_n_steps = 10, - accumulate_grad_batches = 1, + log_every_n_steps = 10 ) @@ -232,11 +242,11 @@ def train() -> None: checkpoint_path = CONFIG.get('training.output', 'checkpoint_path') dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern') - dataset_folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern') + dataset_directory_pattern = CONFIG.get('preparing.dataset', 'directory_pattern') same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability') - dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_folder_pattern, same_person_probability) + dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - face_swap_model = FaceSwapper() + face_swap_model = FaceSwapperTrain() trainer = create_trainer() trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path) diff --git a/face_swapper/src/typing.py b/face_swapper/src/typing.py index 0154b09..96de434 100644 --- a/face_swapper/src/typing.py +++ b/face_swapper/src/typing.py @@ -1,22 +1,33 @@ from collections import OrderedDict from typing import Any, Dict, List, Tuple +import torch.nn from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader - Batch = Tuple[Any, Any, Any] Loader = DataLoader[Tuple[Tensor, ...]] +ImagePathList = List[str] +ImagePathSet = Dict[str, ImagePathList] + +SwapAttributes = Tuple[Tensor, ...] TargetAttributes = Tuple[Tensor, ...] DiscriminatorOutputs = List[List[Tensor]] + IdEmbedding = Tensor SourceEmbedding = IdEmbedding -StateDict = OrderedDict[str, Any] -Padding = Tuple[int, int, int, int] FaceLandmark203 = Tensor -VisionTensor = Tensor + +StateSet = OrderedDict[str, Any] +Padding = Tuple[int, int, int, int] + LossTensor = Tensor +VisionTensor = Tensor +VisionFrame = NDArray[Any] + GeneratorLossSet = Dict[str, Tensor] DiscriminatorLossSet = Dict[str, Tensor] -VisionFrame = NDArray[Any] + +Generator = torch.nn.Module +IdEmbedder = torch.nn.Module