again again cleaning

This commit is contained in:
harisreedhar
2025-01-29 20:41:54 +05:30
committed by henryruhs
parent 3b8b6442fc
commit fcb3390796
11 changed files with 426 additions and 254 deletions
+107
View File
@@ -0,0 +1,107 @@
Face Swapper
=================
> Swap one face over another face.
![License](https://img.shields.io/badge/license-MIT-green)
Preview
-------
![Preview]()
Installation
------------
```
pip install -r requirements.txt
```
Example
-------
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap.
```
[preparing.dataset]
dataset_path = datasets/train
folder_pattern = {}/*
image_pattern = {}/*.*g
same_person_probability = 0.2
[training.loader]
batch_size = 24
num_workers = 12
[training.model]
id_embedder_path = assets/models/id_embedder.pt
landmarker_path = assets/models/landmarker.pt
motion_extractor_path = assets/models/motion_extractor.pt
[training.model.generator]
num_blocks = 2
id_channels = 512
[training.model.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
kernel_size = 4
[training.losses]
weight_adversarial = 1
weight_id = 20
weight_attribute = 10
weight_reconstruction = 10
weight_pose = 100
[training.trainer]
max_epochs = 50
learning_rate = 0.0004
precision = 16-mixed
automatic_optimization = false
[training.output]
checkpoint_path = outputs/last.ckpt
directory_path = outputs
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000
[exporting]
directory_path = export
source_path = outputs/last.ckpt
target_path = export/face_swapper.onnx
opset_version = 15
[inference]
generator_path = outputs/last.ckpt
id_embedder_path = assets/models/id_embedder.pt
source_path = assets/images/source.jpg
target_path = assets/models/target.jpg
output_path = outputs/output.jpg
```
Training
--------
Train the Face swapper model.
```
python train.py
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
+27 -27
View File
@@ -1,12 +1,12 @@
[preparing.dataset]
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan
folder_pattern = {}/*
image_pattern = {}/*.*g
same_person_probability = 0.2
dataset_path =
directory_pattern =
image_pattern =
same_person_probability =
[training.loader]
batch_size = 24
num_workers = 12
batch_size =
num_workers =
[training.model]
id_embedder_path =
@@ -14,32 +14,35 @@ landmarker_path =
motion_extractor_path =
[training.model.generator]
num_blocks = 2
id_channels = 512
num_blocks =
id_channels =
[training.model.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
input_channels =
num_filters =
num_layers =
num_discriminators =
kernel_size =
[training.losses]
weight_adversarial = 1
weight_id = 20
weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 100
weight_adversarial =
weight_id =
weight_attribute =
weight_reconstruction =
weight_pose =
[training.trainer]
max_epochs = 50
learning_rate = 0.0004
max_epochs =
learning_rate =
precision =
automatic_optimization =
[training.output]
checkpoint_path = checkpoints/last.ckpt
directory_path = checkpoints
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000
checkpoint_path =
directory_path =
file_pattern =
preview_frequency =
validation_frequency =
[exporting]
directory_path =
@@ -47,9 +50,6 @@ source_path =
target_path =
opset_version =
[execution]
providers =
[inference]
generator_path =
id_embedder_path =
+3 -25
View File
@@ -1,29 +1,7 @@
import configparser
#!/usr/bin/env python3
import cv2
import torch
from src.generator import AdaptiveEmbeddingIntegrationNetwork
from src.helper import infer, read_image
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
from face_swapper.src.inferencing import infer
if __name__ == '__main__':
generator_path = CONFIG.get('inference', 'generator_path')
id_embedder_path = CONFIG.get('inference', 'id_embedder_path')
source_path = CONFIG.get('inference', 'source_path')
target_path = CONFIG.get('inference', 'target_path')
output_path = CONFIG.get('inference', 'output_path')
state_dict = torch.load(generator_path, map_location = 'cpu')['state_dict']['generator']
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') #type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = read_image(source_path)
target_vision_frame = read_image(target_path)
output_vision_frame = infer(generator, id_embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)
infer()
+26 -16
View File
@@ -1,43 +1,53 @@
import glob
import os.path
import random
from typing import Tuple
import torch
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset
from .helper import read_image
from .typing import Batch
from .typing import Batch, ImagePathList, ImagePathSet
class DataLoaderVGG(TensorDataset):
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None:
self.same_person_probability = same_person_probability
self.folder_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
self.image_paths = []
self.image_path_set = {}
for folder_path in self.folder_paths:
image_paths = glob.glob(dataset_image_pattern.format(folder_path))
self.image_paths.extend(image_paths)
self.image_path_set[folder_path] = image_paths
self.directory_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
self.dataset_total = len(self.image_paths)
self.transforms = transforms.Compose(
self.transforms = self.compose_transforms()
def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]:
image_paths = []
image_path_set = {}
for directory_path in self.directory_paths:
image_paths = glob.glob(dataset_image_pattern.format(directory_path))
image_paths.extend(image_paths)
image_path_set[directory_path] = image_paths
return image_paths, image_path_set
def compose_transforms(self) -> transforms:
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(4, translate=(0.01, 0.01), scale=(0.98, 1.02), shear=(1, 1), fill=0),
transforms.ToTensor(),
transforms.Lambda(lambda img: img[[2, 1, 0], :, :]),
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
def __getitem__(self, item : int) -> Batch:
source_image_path = self.image_paths[item]
def __getitem__(self, index : int) -> Batch:
source_image_path = self.image_paths[index]
if random.random() > self.same_person_probability:
return self.prepare_same_person(source_image_path)
return self.prepare_different_person(source_image_path)
def prepare_different_person(self, source_image_path : str) -> Batch:
+20 -14
View File
@@ -1,4 +1,8 @@
from itertools import chain
from typing import List
import numpy
import torch.nn
import torch.nn as nn
from torch import Tensor
@@ -6,11 +10,15 @@ from .typing import DiscriminatorOutputs
class NLayerDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int) -> None:
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
super(NLayerDiscriminator, self).__init__()
self.num_layers = num_layers
kernel_size = 4
model_layers = self.prepare_model_layers(input_channels, num_filters, num_layers, kernel_size)
self.model = nn.Sequential(*list(chain(*model_layers)))
def prepare_model_layers(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> List[List[torch.nn.Module]]:
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
model_layers =\
[
[
@@ -35,7 +43,7 @@ class NLayerDiscriminator(nn.Module):
model_layers +=\
[
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size),
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding_size),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]
@@ -43,38 +51,36 @@ class NLayerDiscriminator(nn.Module):
model_layers +=\
[
[
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size)
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding_size)
]
]
combined_layers = []
for model_layer in model_layers:
combined_layers += model_layer
self.model = nn.Sequential(*combined_layers)
return model_layers
def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int, kernel_size : int):
super(MultiscaleDiscriminator, self).__init__()
self.num_discriminators = num_discriminators
self.num_layers = num_layers
for discriminator_index in range(num_discriminators):
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers)
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
temp_downsampled_input = input_tensor
temp_tensor = input_tensor
for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append([ model_layers(temp_downsampled_input) ])
discriminator_outputs.append([ model_layers(temp_tensor) ])
if discriminator_index < (self.num_discriminators - 1):
temp_downsampled_input = self.downsample(temp_downsampled_input)
temp_tensor = self.downsample(temp_tensor)
return discriminator_outputs
+1 -1
View File
@@ -16,7 +16,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu')['state_dict']['generator']
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = AdaptiveEmbeddingIntegrationNetwork(512, 2)
model.load_state_dict(state_dict)
model.eval()
+23 -16
View File
@@ -34,7 +34,7 @@ class AADGenerator(nn.Module):
self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks)
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
self.apply(initialize_weight)
self.apply(init_weight)
def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor:
feature_map = self.upsample(source_embedding)
@@ -65,7 +65,7 @@ class UNet(nn.Module):
self.upsampler_4 = Upsample(512, 128)
self.upsampler_5 = Upsample(256, 64)
self.upsampler_6 = Upsample(128, 32)
self.apply(initialize_weight)
self.apply(init_weight)
def forward(self, target : VisionTensor) -> TargetAttributes:
downsample_feature_1 = self.downsampler_1(target)
@@ -93,7 +93,7 @@ class AADLayer(nn.Module):
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1)
self.fc_beta = nn.Linear(id_channels, input_channels)
self.fc_gamma = nn.Linear(id_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False)
self.instance_norm = nn.InstanceNorm2d(input_channels)
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor:
@@ -110,9 +110,10 @@ class AADLayer(nn.Module):
class AddBlocksSequential(nn.Sequential):
#todo: what are inputs? improve the name
def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]:
_, attribute_embedding, id_embedding = inputs
modules = self._modules.values()
modules = self._modules.values() #todo: what kind of modules?
for module_index, module in enumerate(modules):
if module_index % 3 == 0 and module_index > 0:
@@ -122,7 +123,8 @@ class AddBlocksSequential(nn.Sequential):
inputs = module(inputs)
else:
inputs = module(*inputs)
return inputs
return inputs #todo: would be easier to read when you just return xxx_inputs, attribute_embedding, id_embedding ?
class AADResBlock(nn.Module):
@@ -130,33 +132,38 @@ class AADResBlock(nn.Module):
super(AADResBlock, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.prepare_primary_add_blocks(input_channels, attribute_channels, id_channels, output_channels, num_blocks)
self.prepare_auxiliary_add_blocks(input_channels, attribute_channels, id_channels, output_channels)
def prepare_primary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int, num_blocks : int) -> None:
primary_add_blocks = []
for i in range(num_blocks):
intermediate_channels = input_channels if i < (num_blocks - 1) else output_channels
for index in range(num_blocks):
intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels
primary_add_blocks.extend(
[
AADLayer(input_channels, attribute_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False)
]
)
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
if input_channels != output_channels:
auxiliary_add_blocks =\
[
def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int) -> None:
if input_channels > output_channels:
auxiliary_add_blocks = AddBlocksSequential(
AADLayer(input_channels, attribute_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, output_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
]
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
)
self.auxiliary_add_blocks = auxiliary_add_blocks
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor:
primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding)
if self.input_channels != self.output_channels:
if self.input_channels > self.output_channels:
feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, id_embedding)
output_feature = primary_feature + feature_map
return output_feature
@@ -201,7 +208,7 @@ class PixelShuffleUpsample(nn.Module):
return temp
def initialize_weight(module : nn.Module) -> None:
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(std = 0.001)
module.bias.data.zero_()
+19 -16
View File
@@ -1,13 +1,21 @@
import platform
import cv2
import numpy
import torch
from .typing import IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor
from .typing import IdEmbedder, IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor
def is_windows() -> bool:
return platform.system().lower() == 'windows'
def read_image(image_path : str) -> VisionFrame:
image = cv2.imread(image_path)
return image
if is_windows():
image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8)
return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)
return cv2.imread(image_path)
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
@@ -28,14 +36,18 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
def hinge_real_loss(tensor : Tensor) -> Tensor:
return torch.relu(1 - tensor)
real_loss = torch.relu(1 - tensor)
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
return real_loss
def hinge_fake_loss(tensor : Tensor) -> Tensor:
return torch.relu(tensor + 1)
fake_loss = torch.relu(tensor + 1)
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
return fake_loss
def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
def calc_id_embedding(id_embedder : IdEmbedder, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor[:, :, :padding[0], :] = 0
@@ -43,14 +55,5 @@ def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTenso
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2, dim = 1)
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2)
return source_embedding
def infer(generator : torch.nn.Module, id_embedder : torch.nn.Module, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
return output_vision_frame
+40
View File
@@ -0,0 +1,40 @@
import configparser
import cv2
import torch
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
from .typing import Generator, IdEmbedder, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def run_swap(generator : Generator, id_embedder : IdEmbedder, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
return output_vision_frame
def infer() -> None:
generator_path = CONFIG.get('inference', 'generator_path')
id_embedder_path = CONFIG.get('inference', 'id_embedder_path')
source_path = CONFIG.get('inference', 'source_path')
target_path = CONFIG.get('inference', 'target_path')
output_path = CONFIG.get('inference', 'output_path')
state_dict = torch.load(generator_path, map_location='cpu').get('state_dict').get('generator')
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = read_image(source_path)
target_vision_frame = read_image(target_path)
output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)
+144 -134
View File
@@ -16,13 +16,138 @@ from .data_loader import DataLoaderVGG
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, TargetAttributes, VisionTensor
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapper(pytorch_lightning.LightningModule):
class FaceSwapperLoss:
def __init__(self) -> None:
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = torch.nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
generator_loss_set = {}
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes)
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
if weight_pose > 0:
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
if weight_gaze > 0:
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fake = torch.Tensor(0)
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
loss_true = torch.Tensor(0)
for true_discriminator_output in real_discriminator_outputs:
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0])
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attribute *= 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_id
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_pose
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
return gaze_loss
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
@@ -31,21 +156,10 @@ class FaceSwapper(pytorch_lightning.LightningModule):
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
self.id_embedder = torch.jit.load(id_embedder_path, map_location ='cpu') #type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') #type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') #type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
self.automatic_optimization = False
self.mse_loss = torch.nn.MSELoss()
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators, kernel_size)
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
output = self.generator(target_tensor, source_embedding)
@@ -62,135 +176,32 @@ class FaceSwapper(pytorch_lightning.LightningModule):
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
discriminator_outputs = self.discriminator(swap_tensor)
swap_attributes = self.generator.get_attributes(swap_tensor)
real_discriminator_outputs = self.discriminator(source_tensor.detach())
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_losses.get('loss_generator'))
generator_optimizer.step()
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_losses.get('loss_discriminator'))
discriminator_optimizer.step()
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
self.log_generator_preview(source_tensor, target_tensor, swap_tensor)
self.generate_preview(source_tensor, target_tensor, swap_tensor)
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = True)
self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True)
self.log('l_ID', generator_losses.get('loss_id'), prog_bar=True)
self.log('l_ID', generator_losses.get('loss_id'), prog_bar = True)
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
return generator_losses.get('loss_generator')
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0]).mean(dim = [ 1, 2, 3 ])
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
swap_attributes = self.generator.get_attributes(swap_tensor)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attribute *= 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
return loss_id
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_tsr
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
return gaze_loss
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_tsr = CONFIG.getfloat('training.losses', 'weight_tsr')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
generator_loss_set = {}
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(swap_tensor, target_attributes)
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
if weight_tsr > 0:
generator_loss_set['loss_tsr'] = self.calc_tsr_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_tsr'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
if weight_gaze > 0:
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_tsr') * weight_tsr
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def calc_discriminator_loss(self, swap_tensor : VisionTensor, source_tensor : VisionTensor) -> DiscriminatorLossSet:
discriminator_loss_set = {}
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
loss_fake = torch.Tensor(0)
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += torch.mean(hinge_fake_loss(fake_discriminator_output[0]).mean(dim = [ 1, 2, 3 ]))
true_discriminator_outputs = self.discriminator(source_tensor)
loss_true = torch.Tensor(0)
for true_discriminator_output in true_discriminator_outputs:
loss_true += torch.mean(hinge_real_loss(true_discriminator_output[0]).mean(dim = [ 1, 2, 3 ]))
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_loss_set
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
max_preview = 8
source_tensors = source_tensor[:max_preview]
target_tensors = target_tensor[:max_preview]
@@ -204,11 +215,12 @@ def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
trainer_precision = CONFIG.get('training.trainer', 'precision')
os.makedirs(output_directory_path, exist_ok = True)
return Trainer(
max_epochs = trainer_max_epochs,
precision = '16-mixed',
precision = trainer_precision,
callbacks =
[
ModelCheckpoint(
@@ -217,12 +229,10 @@ def create_trainer() -> Trainer:
filename = output_file_pattern,
every_n_train_steps = 1000,
save_top_k = 5,
mode = 'min',
save_last = True
)
],
log_every_n_steps = 10,
accumulate_grad_batches = 1,
log_every_n_steps = 10
)
@@ -232,11 +242,11 @@ def train() -> None:
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
dataset_folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern')
dataset_directory_pattern = CONFIG.get('preparing.dataset', 'directory_pattern')
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_folder_pattern, same_person_probability)
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapper()
face_swap_model = FaceSwapperTrain()
trainer = create_trainer()
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
+16 -5
View File
@@ -1,22 +1,33 @@
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
import torch.nn
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Any, Any, Any]
Loader = DataLoader[Tuple[Tensor, ...]]
ImagePathList = List[str]
ImagePathSet = Dict[str, ImagePathList]
SwapAttributes = Tuple[Tensor, ...]
TargetAttributes = Tuple[Tensor, ...]
DiscriminatorOutputs = List[List[Tensor]]
IdEmbedding = Tensor
SourceEmbedding = IdEmbedding
StateDict = OrderedDict[str, Any]
Padding = Tuple[int, int, int, int]
FaceLandmark203 = Tensor
VisionTensor = Tensor
StateSet = OrderedDict[str, Any]
Padding = Tuple[int, int, int, int]
LossTensor = Tensor
VisionTensor = Tensor
VisionFrame = NDArray[Any]
GeneratorLossSet = Dict[str, Tensor]
DiscriminatorLossSet = Dict[str, Tensor]
VisionFrame = NDArray[Any]
Generator = torch.nn.Module
IdEmbedder = torch.nn.Module