mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
cleaning
This commit is contained in:
@@ -24,7 +24,7 @@ num_discriminators = 3
|
||||
learning_rate = 0.0004
|
||||
disable = false
|
||||
|
||||
[auxiliary_models.paths]
|
||||
[auxiliary_models.paths] #just model.trainer ... try to match the config of the arcface converter
|
||||
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
|
||||
landmarker_path = /assets/pretrained_models/landmark_203.pt
|
||||
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
|
||||
@@ -34,13 +34,10 @@ spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pt
|
||||
|
||||
[training.losses]
|
||||
weight_adversarial = 1
|
||||
weight_identity = 20
|
||||
weight_id = 20
|
||||
weight_attribute = 10
|
||||
weight_reconstruction = 10
|
||||
weight_tsr = 100
|
||||
weight_eye_gaze = 5
|
||||
weight_eye_open = 5
|
||||
weight_lip_open = 5
|
||||
|
||||
[training.schedulers]
|
||||
step = 5000
|
||||
|
||||
@@ -16,20 +16,20 @@ CONFIG.read('config.ini')
|
||||
|
||||
def read_image(image_path: str) -> Image.Image:
|
||||
image = cv2.imread(image_path)[:, :, ::-1]
|
||||
pil_image = Image.fromarray(image)
|
||||
pil_image = Image.fromarray(image) # @todo like said, use the PIL transformator
|
||||
return pil_image
|
||||
|
||||
|
||||
class DataLoaderVGG(TensorDataset):
|
||||
def __init__(self, dataset_path : str) -> None:
|
||||
self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability'))
|
||||
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
|
||||
self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability')) # @todo use CONFIG.getfloat() - also config block at the top
|
||||
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path)) # @todo globs belong to the config
|
||||
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
|
||||
self.image_path_dict = {}
|
||||
self.image_path_dict = {} # @todo we are not using dict as suffix... this image_path_set?
|
||||
self._current_index = 0
|
||||
|
||||
for folder_path in tqdm.tqdm(self.folder_paths):
|
||||
image_paths = glob.glob('{}/*'.format(folder_path))
|
||||
image_paths = glob.glob('{}/*'.format(folder_path)) # @todo not sure about alls this globs being used here :-)
|
||||
self.image_path_dict[folder_path] = image_paths
|
||||
self.dataset_total = len(self.image_paths)
|
||||
self.transforms_basic = transforms.Compose(
|
||||
@@ -61,15 +61,15 @@ class DataLoaderVGG(TensorDataset):
|
||||
source_image_path = self.image_paths[item]
|
||||
source = read_image(source_image_path)
|
||||
|
||||
if random.random() > self.same_person_probability:
|
||||
if random.random() > self.same_person_probability: # @todo if -> we_call_a_method_that_explains_what_we_do()
|
||||
is_same_person = 0
|
||||
target_image_path = random.choice(self.image_paths)
|
||||
target = read_image(target_image_path)
|
||||
source_transform = self.transforms_moderate(source)
|
||||
target_transform = self.transforms_complex(target)
|
||||
else:
|
||||
else: # @todo else -> we_do_some_alternative_action() - in other words, move it to speaking methods :-)
|
||||
is_same_person = 1
|
||||
source_folder_path = '/'.join(source_image_path.split('/')[:-1])
|
||||
source_folder_path = '/'.join(source_image_path.split('/')[:-1]) # @todo use os.path.join()
|
||||
target_image_path = random.choice(self.image_path_dict[source_folder_path])
|
||||
target = read_image(target_image_path)
|
||||
source_transform = self.transforms_basic(source)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from .typing import DiscriminatorOutputs, Tensor
|
||||
from .typing import DiscriminatorOutputs
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
@@ -12,37 +11,45 @@ class NLayerDiscriminator(nn.Module):
|
||||
self.num_layers = num_layers
|
||||
kernel_size = 4
|
||||
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
|
||||
model_layers = [
|
||||
model_layers =\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
]
|
||||
]
|
||||
current_filters = num_filters
|
||||
|
||||
for layer_index in range(1, num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
model_layers += [
|
||||
model_layers +=\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
|
||||
nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
]
|
||||
]
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
model_layers += [
|
||||
model_layers +=\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
model_layers += [
|
||||
]
|
||||
]
|
||||
model_layers +=\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size)
|
||||
]]
|
||||
]
|
||||
]
|
||||
combined_layers = []
|
||||
|
||||
for layer in model_layers:
|
||||
combined_layers += layer
|
||||
for model_layer in model_layers:
|
||||
combined_layers += model_layer
|
||||
self.model = nn.Sequential(*combined_layers)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
@@ -60,17 +67,14 @@ class MultiscaleDiscriminator(nn.Module):
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
|
||||
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
|
||||
return [ model_layers(input_tensor) ]
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
downsampled_input = input_tensor
|
||||
temp_downsampled_input = input_tensor
|
||||
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
|
||||
discriminator_outputs.append(self.single_discriminator_forward(model_layers, downsampled_input))
|
||||
discriminator_outputs.append([ model_layers(temp_downsampled_input) ])
|
||||
|
||||
if discriminator_index != (self.num_discriminators - 1):
|
||||
downsampled_input = self.downsample(downsampled_input)
|
||||
if discriminator_index < (self.num_discriminators - 1):
|
||||
temp_downsampled_input = self.downsample(temp_downsampled_input)
|
||||
return discriminator_outputs
|
||||
|
||||
@@ -2,8 +2,9 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from .typing import IDEmbedding, TargetAttributes, Tensor
|
||||
from .typing import SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
|
||||
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
@@ -12,12 +13,12 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
self.encoder = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
|
||||
def forward(self, target : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]:
|
||||
def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
target_attributes = self.get_attributes(target)
|
||||
swap = self.generator(target_attributes, source_embedding)
|
||||
return swap, target_attributes
|
||||
swap_tensor = self.generator(target_attributes, source_embedding)
|
||||
return swap_tensor, target_attributes
|
||||
|
||||
def get_attributes(self, target : Tensor) -> TargetAttributes:
|
||||
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
|
||||
return self.encoder(target)
|
||||
|
||||
|
||||
@@ -35,7 +36,7 @@ class AADGenerator(nn.Module):
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : IDEmbedding) -> Tensor:
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
@@ -66,7 +67,7 @@ class UNet(nn.Module):
|
||||
self.upsampler_6 = Upsample(128, 32)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, target : Tensor) -> TargetAttributes:
|
||||
def forward(self, target : VisionTensor) -> TargetAttributes:
|
||||
downsample_feature_1 = self.downsampler_1(target)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
@@ -88,34 +89,34 @@ class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
|
||||
super(AADLayer, self).__init__()
|
||||
self.input_channels = input_channels
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1)
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1)
|
||||
self.fc_beta = nn.Linear(id_channels, input_channels)
|
||||
self.fc_gamma = nn.Linear(id_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor:
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
attr_gamma = self.conv_gamma(attr_embedding)
|
||||
attr_beta = self.conv_beta(attr_embedding)
|
||||
attr_modulation = attr_gamma * feature_map + attr_beta
|
||||
gamma_attribute = self.conv_gamma(attribute_embedding)
|
||||
beta_attribute = self.conv_beta(attribute_embedding)
|
||||
attribute_modulation = gamma_attribute * feature_map + beta_attribute
|
||||
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
id_modulation = id_gamma * feature_map + id_beta
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation
|
||||
feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * id_modulation
|
||||
return feature_blend
|
||||
|
||||
|
||||
class AddBlocksSequential(nn.Sequential):
|
||||
def forward(self, *inputs : Tuple[Tensor, Tensor, IDEmbedding]) -> Tuple[Tuple[Tensor, Tensor, IDEmbedding], ...]:
|
||||
_, attr_embedding, id_embedding = inputs
|
||||
def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]:
|
||||
_, attr_embedding, id_embedding = inputs #@todo we are not using shortcuts, it is attribute_embedding
|
||||
|
||||
for index, module in enumerate(self._modules.values()):
|
||||
for index, module in enumerate(self._modules.values()): #@todo refactor this to return values
|
||||
if index % 3 == 0 and index > 0:
|
||||
inputs = (inputs, attr_embedding, id_embedding) # type:ignore[assignment]
|
||||
if type(inputs) == tuple:
|
||||
if type(inputs) == tuple: #@todo my IDE complains about the type check
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
@@ -123,45 +124,45 @@ class AddBlocksSequential(nn.Sequential):
|
||||
|
||||
|
||||
class AADResBlock(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None:
|
||||
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, id_channels : int, num_blocks : int) -> None:
|
||||
super(AADResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
primary_add_blocks = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels
|
||||
intermediate_channels = input_channels if i < (num_blocks - 1) else output_channels
|
||||
primary_add_blocks.extend(
|
||||
[
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
[ #@todo indent
|
||||
AADLayer(input_channels, attribute_channels, id_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
])
|
||||
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
|
||||
|
||||
if in_channels != out_channels:
|
||||
auxiliary_add_blocks = \
|
||||
[
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
if input_channels != output_channels:
|
||||
auxiliary_add_blocks =\
|
||||
[ #@todo indent
|
||||
AADLayer(input_channels, attribute_channels, id_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
]
|
||||
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
|
||||
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor:
|
||||
primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding)
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : SourceEmbedding) -> Tensor:
|
||||
primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
feature_map = self.auxiliary_add_blocks(feature_map, attr_embedding, id_embedding)
|
||||
if self.input_channels != self.output_channels:
|
||||
feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, id_embedding)
|
||||
output_feature = primary_feature + feature_map
|
||||
return output_feature
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
@@ -172,10 +173,10 @@ class DownSample(nn.Module):
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Upsample, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
@@ -186,9 +187,9 @@ class Upsample(nn.Module):
|
||||
|
||||
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(PixelShuffleUpsample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
@@ -199,7 +200,7 @@ class PixelShuffleUpsample(nn.Module):
|
||||
|
||||
def initialize_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(0, 0.001)
|
||||
module.weight.data.normal_(std = 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, nn.Conv2d):
|
||||
@@ -207,11 +208,3 @@ def initialize_weight(module : nn.Module) -> None:
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = AdaptiveEmbeddingIntegrationNetwork(512, 2)
|
||||
src = torch.randn(1, 512)
|
||||
trg = torch.randn(1, 3, 256, 256)
|
||||
out = model(trg, src)
|
||||
print(out[0].shape)
|
||||
|
||||
@@ -1,31 +1,11 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .typing import Tensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
L2_loss = torch.nn.MSELoss()
|
||||
def hinge_real_loss(tensor : Tensor) -> Tensor:
|
||||
return torch.relu(1 - tensor)
|
||||
|
||||
|
||||
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor) -> Tensor:
|
||||
points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3)
|
||||
points_transformed *= scale[..., None]
|
||||
points_transformed[:, :, 0:2] += translation[:, None, 0:2]
|
||||
return points_transformed
|
||||
|
||||
|
||||
def hinge_loss(tensor : Tensor, is_positive : bool) -> Tensor:
|
||||
if is_positive:
|
||||
return torch.relu(1 - tensor)
|
||||
else:
|
||||
return torch.relu(tensor + 1)
|
||||
|
||||
|
||||
def calc_distance_ratio(landmarks : Tensor, indices : Tuple[int, int, int, int]) -> Tensor:
|
||||
distance_horizontal = torch.norm(landmarks[:, indices[0]] - landmarks[:, indices[1]], p = 2, dim = 1, keepdim = True)
|
||||
distance_vertical = torch.norm(landmarks[:, indices[2]] - landmarks[:, indices[3]], p=2, dim = 1, keepdim = True)
|
||||
return distance_horizontal / (distance_vertical + 1e-4)
|
||||
def hinge_fake_loss(tensor : Tensor) -> Tensor:
|
||||
return torch.relu(tensor + 1)
|
||||
|
||||
+197
-227
@@ -2,8 +2,6 @@ import configparser
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -12,18 +10,212 @@ from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.types import Optimizer
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data_loader import DataLoaderVGG, read_image
|
||||
from .data_loader import DataLoaderVGG
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .helper import L2_loss, hinge_loss
|
||||
from .typing import Batch, DiscriminatorOutputs, IDEmbedding, LossDict, TargetAttributes, Tensor
|
||||
from .helper import hinge_fake_loss, hinge_real_loss
|
||||
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
id_channels = CONFIG.getint('training.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
|
||||
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
|
||||
arcface_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
|
||||
landmarker_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
|
||||
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
|
||||
self.arcface = torch.load(arcface_path, map_location = 'cpu', weights_only = False)
|
||||
self.landmarker = torch.load(landmarker_path, map_location = 'cpu', weights_only = False)
|
||||
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
self.motion_extractor.load_state_dict(torch.load(motion_extractor_path, map_location = 'cpu', weights_only = True))
|
||||
self.arcface.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
self.automatic_optimization = False
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
output = self.generator(target_tensor, source_embedding)
|
||||
return output
|
||||
|
||||
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
|
||||
generator_learning_rate = CONFIG.getfloat('training.generator', 'learning_rate')
|
||||
discriminator_learning_rate = CONFIG.getfloat('training.discriminator', 'learning_rate')
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = generator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = discriminator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
discriminator_outputs = self.discriminator(swap_tensor)
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
|
||||
self.log_generator_preview(source_tensor, target_tensor, swap_tensor)
|
||||
|
||||
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = True)
|
||||
self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('l_ID', generator_losses.get('loss_id'), prog_bar=True)
|
||||
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_real_loss(discriminator_output[0]).mean(dim = [ 1, 2, 3 ])
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
loss_attribute *= 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss:
|
||||
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss:
|
||||
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
return loss_id
|
||||
|
||||
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
target_motion_features = self.get_pose_features(target_tensor)
|
||||
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature)
|
||||
return loss_tsr
|
||||
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
|
||||
swap_landmark = self.get_face_landmarks(swap_tensor)
|
||||
target_landmark = self.get_face_landmarks(target_tensor)
|
||||
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
return gaze_loss
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
weight_tsr = CONFIG.getfloat('training.losses', 'weight_tsr')
|
||||
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
|
||||
|
||||
generator_loss_set = {}
|
||||
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
|
||||
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
|
||||
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(swap_tensor, target_attributes)
|
||||
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
|
||||
if weight_tsr > 0:
|
||||
generator_loss_set['loss_tsr'] = self.calc_tsr_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_tsr'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
if weight_gaze > 0:
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_tsr') * weight_tsr
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
|
||||
return generator_loss_set
|
||||
|
||||
def calc_discriminator_loss(self, swap_tensor : VisionTensor, source_tensor : VisionTensor) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
loss_fake = torch.Tensor(0)
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += torch.mean(hinge_fake_loss(fake_discriminator_output[0]).mean(dim = [ 1, 2, 3 ]))
|
||||
true_discriminator_outputs = self.discriminator(source_tensor)
|
||||
loss_true = torch.Tensor(0)
|
||||
|
||||
for true_discriminator_output in true_discriminator_outputs:
|
||||
loss_true += torch.mean(hinge_real_loss(true_discriminator_output[0]).mean(dim = [ 1, 2, 3 ]))
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def get_id_embedding(self, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor = crop_vision_tensor[:, :, 0:112, 8:128]
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
embedding = self.arcface(crop_vision_tensor)
|
||||
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
|
||||
return embedding
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
|
||||
return landmarks
|
||||
|
||||
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
motion_dict = self.motion_extractor(vision_tensor_norm)
|
||||
translation = motion_dict.get('t')
|
||||
scale = motion_dict.get('scale')
|
||||
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
|
||||
max_preview = 8
|
||||
source_tensors = source_tensor[:max_preview]
|
||||
target_tensors = target_tensor[:max_preview]
|
||||
swap_tensors = swap_tensor[:max_preview]
|
||||
rows = [ torch.cat([ source_tensor, target_tensor, swap_tensor ], dim = 2) for source_tensor, target_tensor, swap_tensor in zip(source_tensors, target_tensors, swap_tensors) ]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
@@ -56,229 +248,7 @@ def train() -> None:
|
||||
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
|
||||
dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path'))
|
||||
|
||||
if not (checkpoint_path and os.path.exists(checkpoint_path)):
|
||||
checkpoint_path = None
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapper()
|
||||
trainer = create_trainer()
|
||||
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
|
||||
|
||||
|
||||
class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork(CONFIG.getint('training.generator', 'id_channels'), CONFIG.getint('training.generator', 'num_blocks'))
|
||||
self.discriminator = MultiscaleDiscriminator(CONFIG.getint('training.discriminator', 'input_channels'), CONFIG.getint('training.discriminator', 'num_filters'), CONFIG.getint('training.discriminator', 'num_layers'), CONFIG.getint('training.discriminator', 'num_discriminators'))
|
||||
self.arcface = torch.load(CONFIG.get('auxiliary_models.paths', 'arcface_path'), map_location = 'cpu', weights_only = False)
|
||||
self.landmarker = torch.load(CONFIG.get('auxiliary_models.paths', 'landmarker_path'), map_location = 'cpu', weights_only = False)
|
||||
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
self.motion_extractor.load_state_dict(torch.load(CONFIG.get('auxiliary_models.paths', 'motion_extractor_path'), map_location = 'cpu', weights_only = True))
|
||||
self.arcface.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
self.automatic_optimization = False
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
def forward(self, target_tensor : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]:
|
||||
output = self.generator(target_tensor, source_embedding)
|
||||
return output
|
||||
|
||||
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
discriminator_outputs = self.discriminator(swap_tensor)
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
|
||||
if not CONFIG.getboolean('training.discriminator', 'disable'):
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
|
||||
self.log_generator_preview(source_tensor, target_tensor, swap_tensor)
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'validation_frequency') == 0:
|
||||
self.log_validation_preview()
|
||||
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False)
|
||||
self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('l_ID', generator_losses.get('loss_identity'), prog_bar=True)
|
||||
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_losses = {}
|
||||
# adversarial loss
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ])
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
generator_losses['loss_adversarial'] = loss_adversarial
|
||||
generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
|
||||
# identity loss
|
||||
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
generator_losses['loss_identity'] = loss_identity
|
||||
generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
|
||||
# attribute loss
|
||||
loss_attribute = torch.Tensor(0)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
loss_attribute *= 0.5
|
||||
generator_losses['loss_attribute'] = loss_attribute
|
||||
generator_losses['loss_generator'] += loss_attribute * CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
|
||||
# reconstruction loss
|
||||
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
generator_losses['loss_reconstruction'] = loss_reconstruction
|
||||
generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0:
|
||||
# tsr loss
|
||||
swap_motion_features = self.get_motion_features(swap_tensor)
|
||||
target_motion_features = self.get_motion_features(target_tensor)
|
||||
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_tsr += L2_loss(swap_motion_feature, target_motion_feature)
|
||||
generator_losses['loss_tsr'] = loss_tsr
|
||||
generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr')
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0:
|
||||
swap_landmark_features = self.get_landmark_features(swap_tensor)
|
||||
target_landmark_features = self.get_landmark_features(target_tensor)
|
||||
loss_left_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
|
||||
loss_right_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
|
||||
loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze
|
||||
generator_losses['loss_eye_gaze'] = loss_eye_gaze
|
||||
generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze')
|
||||
return generator_losses
|
||||
|
||||
def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict:
|
||||
discriminator_losses = {}
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
loss_fake = torch.Tensor(0)
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3]))
|
||||
true_discriminator_outputs = self.discriminator(source_tensor)
|
||||
loss_true = torch.Tensor(0)
|
||||
|
||||
for true_discriminator_output in true_discriminator_outputs:
|
||||
loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3]))
|
||||
discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_losses
|
||||
|
||||
def get_id_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
|
||||
_, _, height, width = vision_tensor.shape
|
||||
crop_height = int(height * 0.0586)
|
||||
crop_width = int(width * 0.0586)
|
||||
crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width]
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
embedding = self.arcface(crop_vision_tensor)
|
||||
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
|
||||
return embedding
|
||||
|
||||
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2]
|
||||
landmarks = landmarks.view(-1, 203, 2) * 256
|
||||
return landmarks[:, 198], landmarks[:, 197]
|
||||
|
||||
def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
motion_dict = self.motion_extractor(vision_tensor_norm)
|
||||
translation = motion_dict.get('t')
|
||||
scale = motion_dict.get('scale')
|
||||
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
def log_generator_preview(self, source_tensor : Tensor, target_tensor : Tensor, swap_tensor : Tensor) -> None:
|
||||
max_preview = 8
|
||||
source_tensor = source_tensor[:max_preview]
|
||||
target_tensor = target_tensor[:max_preview]
|
||||
swap_tensor = swap_tensor[:max_preview]
|
||||
rows = [torch.cat([src, tgt, swp], dim = 2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
|
||||
|
||||
def log_validation_preview(self) -> None:
|
||||
read_images = lambda path : [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
sources = read_images(CONFIG.get('training.validation', 'sources'))
|
||||
targets_front = read_images(CONFIG.get('training.validation', 'targets_front'))
|
||||
targets_side = read_images(CONFIG.get('training.validation', 'targets_side'))
|
||||
targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup'))
|
||||
targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion'))
|
||||
|
||||
self.generator.eval()
|
||||
|
||||
results_source = []
|
||||
results_front = []
|
||||
results_side = []
|
||||
results_makeup = []
|
||||
results_occlusion = []
|
||||
|
||||
for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion):
|
||||
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half()
|
||||
target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half()
|
||||
target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half()
|
||||
target_occlusion_tensor = transforms(target_occlusion).unsqueeze(0).to(self.device).half()
|
||||
|
||||
with torch.no_grad():
|
||||
output_front, _ = self.generator(target_front_tensor, source_embedding)
|
||||
output_side, _ = self.generator(target_side_tensor, source_embedding)
|
||||
output_makeup, _ = self.generator(target_makeup_tensor, source_embedding)
|
||||
output_occlusion, _ = self.generator(target_occlusion_tensor, source_embedding)
|
||||
|
||||
results_source.append(to_numpy(source_tensor))
|
||||
results_front.append(numpy.hstack([to_numpy(target_front_tensor), to_numpy(output_front)]))
|
||||
results_side.append(numpy.hstack([to_numpy(target_side_tensor), to_numpy(output_side)]))
|
||||
results_makeup.append(numpy.hstack([to_numpy(target_makeup_tensor), to_numpy(output_makeup)]))
|
||||
results_occlusion.append(numpy.hstack([to_numpy(target_occlusion_tensor), to_numpy(output_occlusion)]))
|
||||
|
||||
sources_vertical = numpy.vstack(results_source)
|
||||
results_front_vertical = numpy.vstack(results_front)
|
||||
results_side_vertical = numpy.vstack(results_side)
|
||||
results_makeup_vertical = numpy.vstack(results_makeup)
|
||||
results_occlusion_vertical = numpy.vstack(results_occlusion)
|
||||
pad = numpy.zeros((sources_vertical.shape[0], 10, 3), dtype = sources_vertical.dtype)
|
||||
preview = numpy.hstack([sources_vertical, pad, results_front_vertical, pad, results_side_vertical, pad, results_makeup_vertical, pad, results_occlusion_vertical])
|
||||
|
||||
os.makedirs("validation_previews", exist_ok=True)
|
||||
cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview)
|
||||
self.generator.train()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@@ -9,8 +8,12 @@ Batch = Tuple[Any, Any, Any]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
TargetAttributes = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs = List[List[Tensor]]
|
||||
LossDict = Dict[str, Tensor]
|
||||
IDEmbedding = Tensor
|
||||
IdEmbedding = Tensor
|
||||
SourceEmbedding = IdEmbedding
|
||||
StateDict = OrderedDict[str, Any]
|
||||
Embedding = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
Padding = Tuple[int, int, int, int]
|
||||
FaceLandmark203 = Tensor
|
||||
VisionTensor = Tensor
|
||||
Loss = Tensor
|
||||
GeneratorLossSet = Dict[str, Loss]
|
||||
DiscriminatorLossSet = Dict[str, Loss]
|
||||
|
||||
Reference in New Issue
Block a user