mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
ugly training code
This commit is contained in:
@@ -4,4 +4,4 @@ plugins = flake8-import-order
|
||||
application_import_names = arcface_converter
|
||||
import-order-style = pycharm
|
||||
per-file-ignores = preparing.py:E402
|
||||
exclude = LivePortrait
|
||||
exclude = face_swapper
|
||||
|
||||
+29
-19
@@ -1,55 +1,65 @@
|
||||
[preparing.dataset]
|
||||
dataset_path =
|
||||
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan
|
||||
|
||||
[preparing.dataloader]
|
||||
same_person_probability = 0.2
|
||||
|
||||
[preparing.augmentation]
|
||||
expression_augmentation = false
|
||||
expression = false
|
||||
|
||||
[training.loader]
|
||||
batch_size = 6
|
||||
batch_size = 4
|
||||
num_workers = 8
|
||||
|
||||
[training.generator]
|
||||
num_blocks = 2
|
||||
id_channels = 512
|
||||
learning_rate = 0.0004
|
||||
|
||||
[training.discriminator]
|
||||
input_channels = 3
|
||||
num_filters = 64
|
||||
num_layers = 5
|
||||
num_discriminators = 3
|
||||
learning_rate = 0.0004
|
||||
disable = false
|
||||
|
||||
[auxiliary_models.paths]
|
||||
arcface_path =
|
||||
landmarker_path =
|
||||
motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth
|
||||
feature_extractor_path =
|
||||
warping_netwrk_path =
|
||||
spade_generator_path =
|
||||
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
|
||||
landmarker_path = /assets/pretrained_models/landmark_203.pt
|
||||
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
|
||||
feature_extractor_path = /assets/pretrained_models/liveportrait_feature_extractor.pth
|
||||
warping_network_path = /assets/pretrained_models/liveportrait_warping_model.pth
|
||||
spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pth
|
||||
|
||||
[training.losses]
|
||||
weight_adversarial = 1
|
||||
weight_identity = 20
|
||||
weight_attribute = 10
|
||||
weight_reconstruction = 10
|
||||
weight_tsr = 0
|
||||
weight_expression = 0
|
||||
weight_tsr = 100
|
||||
weight_eye_gaze = 5
|
||||
weight_eye_open = 5
|
||||
weight_lip_open = 5
|
||||
|
||||
[training.optimizers]
|
||||
scheduler_step = 5000
|
||||
scheduler_gamma = 0.2
|
||||
generator_learning_rate = 0.0004
|
||||
discriminator_learning_rate = 0.0004
|
||||
[training.schedulers]
|
||||
step = 5000
|
||||
gamma = 0.2
|
||||
|
||||
[training.trainer]
|
||||
epochs = 50
|
||||
max_epochs = 50
|
||||
disable_discriminator = false
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
checkpoint_path = checkpoints/last.ckpt
|
||||
directory_path = checkpoints
|
||||
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
|
||||
preview_frequency = 250
|
||||
validation_frequency = 1000
|
||||
|
||||
[training.validation]
|
||||
sources = assets/test/front/sources
|
||||
targets = assets/test/front/targets
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
|
||||
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
from .augmentations import apply_random_motion_blur
|
||||
from .sub_typing import Batch
|
||||
from .typing import Batch
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -27,6 +27,7 @@ class DataLoaderVGG(TensorDataset):
|
||||
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
|
||||
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
|
||||
self.image_path_dict = {}
|
||||
self._current_index = 0
|
||||
|
||||
for folder_path in tqdm.tqdm(self.folder_paths):
|
||||
image_paths = glob.glob('{}/*'.format(folder_path))
|
||||
@@ -50,12 +51,12 @@ class DataLoaderVGG(TensorDataset):
|
||||
[
|
||||
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(p = 0.5),
|
||||
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
])
|
||||
|
||||
def __getitem__(self, item : int) -> Batch:
|
||||
@@ -80,3 +81,11 @@ class DataLoaderVGG(TensorDataset):
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_total
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
return {'current_index': self._current_index}
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._current_index = state_dict['current_index']
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List
|
||||
import numpy
|
||||
import torch.nn as nn
|
||||
|
||||
from .sub_typing import Tensor
|
||||
from .typing import Tensor, DiscriminatorOutputs
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
@@ -49,7 +49,6 @@ class NLayerDiscriminator(nn.Module):
|
||||
return self.model(input_tensor)
|
||||
|
||||
|
||||
# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
@@ -61,18 +60,12 @@ class MultiscaleDiscriminator(nn.Module):
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
|
||||
|
||||
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
|
||||
return [ model_layers(input_tensor) ]
|
||||
|
||||
if self.return_intermediate_features:
|
||||
feature_maps = [ input_tensor ]
|
||||
|
||||
for layer in model_layers:
|
||||
feature_maps.append(layer(feature_maps[-1]))
|
||||
return feature_maps[1:]
|
||||
else:
|
||||
return [ model_layers(input_tensor) ]
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> List[Tensor]:
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
downsampled_input = input_tensor
|
||||
|
||||
|
||||
@@ -1,55 +1,52 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .sub_typing import Tensor, UNetAttributes
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
def __init__(self, id_channels=512, num_blocks=2):
|
||||
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
|
||||
self.encoder = UNet()
|
||||
self.generator = AdaptiveAttentionalDenorm_Generator(id_channels, num_blocks)
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
|
||||
def forward(self, target : Tensor, source_embedding : Tensor) -> Tuple[Tensor, UNetAttributes]:
|
||||
def forward(self, target, source_embedding):
|
||||
target_attributes = self.get_attributes(target)
|
||||
swap = self.generator(target_attributes, source_embedding)
|
||||
return swap, target_attributes
|
||||
|
||||
def get_attributes(self, target : Tensor) -> UNetAttributes:
|
||||
def get_attributes(self, target):
|
||||
return self.encoder(target)
|
||||
|
||||
|
||||
class AdaptiveAttentionalDenorm_Generator(nn.Module):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
super(AdaptiveAttentionalDenorm_Generator, self).__init__()
|
||||
class AADGenerator(nn.Module):
|
||||
def __init__(self, id_channels=512, num_blocks=2):
|
||||
super(AADGenerator, self).__init__()
|
||||
self.upsample = Upsample(id_channels, 1024 * 4)
|
||||
self.block_1 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.block_2 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 2048, id_channels, num_blocks)
|
||||
self.block_3 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.block_4 = AdaptiveAttentionalDenorm_ResBlock(1024, 512, 512, id_channels, num_blocks)
|
||||
self.block_5 = AdaptiveAttentionalDenorm_ResBlock(512, 256, 256, id_channels, num_blocks)
|
||||
self.block_6 = AdaptiveAttentionalDenorm_ResBlock(256, 128, 128, id_channels, num_blocks)
|
||||
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(128, 64, 64, id_channels, num_blocks)
|
||||
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(64, 3, 64, id_channels, num_blocks)
|
||||
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, id_channels, num_blocks)
|
||||
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, id_channels, num_blocks)
|
||||
self.AADBlk5 = AAD_ResBlk(512, 256, 256, id_channels, num_blocks)
|
||||
self.AADBlk6 = AAD_ResBlk(256, 128, 128, id_channels, num_blocks)
|
||||
self.AADBlk7 = AAD_ResBlk(128, 64, 64, id_channels, num_blocks)
|
||||
self.AADBlk8 = AAD_ResBlk(64, 3, 64, id_channels, num_blocks)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, target_attributes : UNetAttributes, source_embedding : Tensor) -> Tensor:
|
||||
def forward(self, target_attributes, source_embedding):
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map_1 = nn.functional.interpolate(self.block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_2 = nn.functional.interpolate(self.block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_3 = nn.functional.interpolate(self.block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_4 = nn.functional.interpolate(self.block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_5 = nn.functional.interpolate(self.block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_6 = nn.functional.interpolate(self.block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
feature_map_7 = nn.functional.interpolate(self.block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
|
||||
output = self.block_7(feature_map_7, target_attributes[7], source_embedding)
|
||||
feature_map_1 = F.interpolate(self.AADBlk1(feature_map, target_attributes[0], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_2 = F.interpolate(self.AADBlk2(feature_map_1, target_attributes[1], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_3 = F.interpolate(self.AADBlk3(feature_map_2, target_attributes[2], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_4 = F.interpolate(self.AADBlk4(feature_map_3, target_attributes[3], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_5 = F.interpolate(self.AADBlk5(feature_map_4, target_attributes[4], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_6 = F.interpolate(self.AADBlk6(feature_map_5, target_attributes[5], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_7 = F.interpolate(self.AADBlk7(feature_map_6, target_attributes[6], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
output = self.AADBlk8(feature_map_7, target_attributes[7], source_embedding)
|
||||
return torch.tanh(output)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
super(UNet, self).__init__()
|
||||
self.downsampler_1 = Conv4x4(3, 32)
|
||||
self.downsampler_2 = Conv4x4(32, 64)
|
||||
@@ -57,7 +54,9 @@ class UNet(nn.Module):
|
||||
self.downsampler_4 = Conv4x4(128, 256)
|
||||
self.downsampler_5 = Conv4x4(256, 512)
|
||||
self.downsampler_6 = Conv4x4(512, 1024)
|
||||
|
||||
self.bottleneck = Conv4x4(1024, 1024)
|
||||
|
||||
self.upsampler_1 = DeConv4x4(1024, 1024)
|
||||
self.upsampler_2 = DeConv4x4(2048, 512)
|
||||
self.upsampler_3 = DeConv4x4(1024, 256)
|
||||
@@ -66,53 +65,64 @@ class UNet(nn.Module):
|
||||
self.upsampler_6 = DeConv4x4(128, 32)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> UNetAttributes:
|
||||
def forward(self, input_tensor):
|
||||
downsample_feature_1 = self.downsampler_1(input_tensor)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
|
||||
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
|
||||
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
|
||||
|
||||
bottleneck_output = self.bottleneck(downsample_feature_6)
|
||||
|
||||
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
|
||||
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
|
||||
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
output = F.interpolate(upsample_feature_6, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
|
||||
|
||||
class AdaptiveAttentionalDenorm_Layer(nn.Module):
|
||||
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
|
||||
super(AdaptiveAttentionalDenorm_Layer, self).__init__()
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels, attr_channels, id_channels):
|
||||
super(AADLayer, self).__init__()
|
||||
self.attr_channels = attr_channels
|
||||
self.id_channels = id_channels
|
||||
self.input_channels = input_channels
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.fc_gamma = nn.Linear(id_channels, input_channels)
|
||||
self.fc_beta = nn.Linear(id_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, feature_map, attr_embedding, id_embedding):
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
|
||||
attr_gamma = self.conv_gamma(attr_embedding)
|
||||
attr_beta = self.conv_beta(attr_embedding)
|
||||
attr_modulation = attr_gamma * feature_map + attr_beta
|
||||
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
|
||||
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
|
||||
feature_map)
|
||||
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
|
||||
feature_map)
|
||||
id_modulation = id_gamma * feature_map + id_beta
|
||||
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation
|
||||
return feature_blend
|
||||
|
||||
|
||||
class AddBlocksSequential(nn.Sequential):
|
||||
def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor:
|
||||
feature_map, attr_embedding, id_embedding = inputs
|
||||
def forward(self, *inputs):
|
||||
h, attr_embedding, id_embedding = inputs
|
||||
|
||||
for index, module in enumerate(self._modules.values()):
|
||||
if index % 3 == 0 and index > 0:
|
||||
@@ -124,9 +134,9 @@ class AddBlocksSequential(nn.Sequential):
|
||||
return inputs
|
||||
|
||||
|
||||
class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None:
|
||||
super(AdaptiveAttentionalDenorm_ResBlock, self).__init__()
|
||||
class AAD_ResBlk(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_blocks):
|
||||
super(AAD_ResBlk, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
primary_add_blocks = []
|
||||
@@ -135,22 +145,22 @@ class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
|
||||
intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels
|
||||
primary_add_blocks.extend(
|
||||
[
|
||||
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
])
|
||||
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
|
||||
|
||||
if in_channels != out_channels:
|
||||
auxiliary_add_blocks = \
|
||||
[
|
||||
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
]
|
||||
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
|
||||
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
|
||||
def forward(self, feature_map, attr_embedding, id_embedding):
|
||||
primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
@@ -160,47 +170,49 @@ class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
|
||||
|
||||
|
||||
class Conv4x4(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv4x4, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1,
|
||||
bias=False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
|
||||
def forward(self, input : Tensor) -> Tensor:
|
||||
output = self.conv(input)
|
||||
output = self.batch_norm(output)
|
||||
output = self.leaky_relu(output)
|
||||
return output
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.leaky_relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeConv4x4(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(DeConv4x4, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2,
|
||||
padding=1, bias=False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
|
||||
def forward(self, input : Tensor, skip_connection : Tensor) -> Tensor:
|
||||
output = self.deconv(input)
|
||||
output = self.batch_norm(output)
|
||||
output = self.leaky_relu(output)
|
||||
output = torch.cat((output, skip_connection), dim = 1)
|
||||
return output
|
||||
def forward(self, x, skip):
|
||||
x = self.deconv(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.leaky_relu(x)
|
||||
return torch.cat((x, skip), dim=1)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Upsample, self).__init__()
|
||||
self.initial_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
self.initial_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
|
||||
|
||||
def forward(self, input : Tensor) -> Tensor:
|
||||
output = self.initial_conv(input.view(input.shape[0], -1, 1, 1))
|
||||
output = self.pixel_shuffle(output)
|
||||
return output
|
||||
def forward(self, x):
|
||||
x = self.initial_conv(x.view(x.shape[0], -1, 1, 1))
|
||||
x = self.pixel_shuffle(x)
|
||||
return x
|
||||
|
||||
|
||||
def initialize_weight(module : nn.Module) -> None:
|
||||
def initialize_weight(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(0, 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from .typing import Tensor
|
||||
import numpy
|
||||
import torch.nn.functional as F
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
if CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
||||
|
||||
L2_loss = torch.nn.MSELoss()
|
||||
EXPRESSION_MIN = numpy.array(
|
||||
[
|
||||
[
|
||||
[-2.88067125e-02, -8.12731311e-02, -1.70541159e-03],
|
||||
[-4.88598682e-02, -3.32196616e-02, -1.67431499e-04],
|
||||
[-6.75425082e-02, -4.28681746e-02, -1.98950816e-04],
|
||||
[-7.23103955e-02, -3.28503326e-02, -7.31324719e-04],
|
||||
[-3.87073644e-02, -6.01546466e-02, -5.50269964e-04],
|
||||
[-6.38048723e-02, -2.23840728e-01, -7.13261834e-04],
|
||||
[-3.02710701e-02, -3.93195450e-02, -8.24086510e-06],
|
||||
[-2.95799859e-02, -5.39318882e-02, -1.74219604e-04],
|
||||
[-2.92359516e-02, -1.53050944e-02, -6.30460854e-05],
|
||||
[-5.56493877e-03, -2.34344602e-02, -1.26858242e-04],
|
||||
[-4.37593013e-02, -2.77768299e-02, -2.70503685e-02],
|
||||
[-1.76926646e-02, -1.91676542e-02, -1.15090821e-04],
|
||||
[-8.34268332e-03, -3.99775570e-03, -3.27481248e-05],
|
||||
[-3.40162888e-02, -2.81868968e-02, -1.96679524e-04],
|
||||
[-2.91855410e-02, -3.97511162e-02, -2.81230678e-05],
|
||||
[-1.50395725e-02, -2.49494594e-02, -9.42573533e-05],
|
||||
[-1.67938769e-02, -2.00953931e-02, -4.00750607e-04],
|
||||
[-1.86435618e-02, -2.48535164e-02, -2.74416432e-02],
|
||||
[-4.61211195e-03, -1.21660791e-02, -2.93173041e-04],
|
||||
[-4.10017073e-02, -7.43824020e-02, -4.42762971e-02],
|
||||
[-1.90370996e-02, -3.74363363e-02, -1.34740388e-02]
|
||||
]
|
||||
]).astype(numpy.float32)
|
||||
EXPRESSION_MAX = numpy.array(
|
||||
[
|
||||
[
|
||||
[4.46682945e-02, 7.08772913e-02, 4.08344204e-04],
|
||||
[2.14308221e-02, 6.15894832e-02, 4.85319615e-05],
|
||||
[3.02363783e-02, 4.45043296e-02, 1.28298725e-05],
|
||||
[3.05869691e-02, 3.79812494e-02, 6.57040102e-04],
|
||||
[4.45670523e-02, 3.97259220e-02, 7.10966764e-04],
|
||||
[9.43699256e-02, 9.85926315e-02, 2.02551950e-04],
|
||||
[1.61131397e-02, 2.92906128e-02, 3.44733417e-06],
|
||||
[5.23825921e-02, 1.07065082e-01, 6.61510974e-04],
|
||||
[2.85718683e-03, 8.32320191e-03, 2.39314613e-04],
|
||||
[2.57947259e-02, 1.60935968e-02, 2.41853559e-05],
|
||||
[4.90833223e-02, 3.43903080e-02, 3.22353356e-02],
|
||||
[1.44766076e-02, 3.39248963e-02, 1.42291479e-04],
|
||||
[8.75749043e-04, 6.82212645e-03, 2.76097053e-05],
|
||||
[1.86958015e-02, 3.84016186e-02, 7.33085908e-05],
|
||||
[2.01714113e-02, 4.90544215e-02, 2.34028921e-05],
|
||||
[2.46518422e-02, 3.29151377e-02, 3.48571630e-05],
|
||||
[2.22457591e-02, 1.21796541e-02, 1.56396593e-04],
|
||||
[1.72109623e-02, 3.01626958e-02, 1.36556877e-02],
|
||||
[1.83460284e-02, 1.61141958e-02, 2.87440169e-04],
|
||||
[3.57594155e-02, 1.80554688e-01, 2.75554154e-02],
|
||||
[2.17450950e-02, 8.66811201e-02, 3.34241726e-02]
|
||||
]
|
||||
]).astype(numpy.float32)
|
||||
|
||||
|
||||
def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator):
|
||||
with torch.no_grad():
|
||||
face_tensor_norm = (face_tensor + 1) * 0.5
|
||||
input_device = face_tensor.device
|
||||
feature_volume = feature_extractor(face_tensor_norm)
|
||||
motion_extractor_dict = motion_extractor(face_tensor_norm)
|
||||
|
||||
translation = motion_extractor_dict.get('t')
|
||||
expression = motion_extractor_dict.get('exp')
|
||||
scale = motion_extractor_dict.get('scale')
|
||||
points = motion_extractor_dict.get('kp')
|
||||
|
||||
pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None]
|
||||
yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None]
|
||||
roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None]
|
||||
rotation_matrix = get_rotation_matrix(pitch, yaw, roll)
|
||||
random_expression = get_random_expression_blend(expression)
|
||||
|
||||
points_transformed = transform_points(points, rotation_matrix, expression, scale, translation)
|
||||
points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation)
|
||||
|
||||
data = warping_network(feature_volume, points_driv, points_transformed).get('out')
|
||||
output = spade_generator(data)
|
||||
output = output.to(input_device)
|
||||
output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False)
|
||||
output = (output - 0.5) * 2
|
||||
return output
|
||||
|
||||
|
||||
def get_random_expression_blend(expression : Tensor) -> Tensor:
|
||||
blend = 0.35
|
||||
expression = expression.view(-1, 21, 3)
|
||||
min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
|
||||
max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
|
||||
random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array
|
||||
random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]]
|
||||
random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9
|
||||
random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5
|
||||
return random_batch * 0.8 * blend + expression * (1 - blend)
|
||||
|
||||
|
||||
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor):
|
||||
points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3)
|
||||
points_transformed *= scale[..., None]
|
||||
points_transformed[:, :, 0:2] += translation[:, None, 0:2]
|
||||
return points_transformed
|
||||
|
||||
|
||||
def hinge_loss(tensor : Tensor, is_positive : bool) -> Tensor:
|
||||
if is_positive:
|
||||
return torch.relu(1 - tensor)
|
||||
else:
|
||||
return torch.relu(tensor + 1)
|
||||
|
||||
|
||||
def calc_distance_ratio(landmarks : Tensor, indices : Tuple[int, int, int, int]) -> Tensor:
|
||||
distance_horizontal = torch.norm(landmarks[:, indices[0]] - landmarks[:, indices[1]], p = 2, dim = 1, keepdim = True)
|
||||
distance_vertical = torch.norm(landmarks[:, indices[2]] - landmarks[:, indices[3]], p=2, dim = 1, keepdim = True)
|
||||
return distance_horizontal / (distance_vertical + 1e-4)
|
||||
@@ -1,50 +0,0 @@
|
||||
import configparser
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def load_generator() -> nn.Module:
|
||||
id_channels = CONFIG.getint('training.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
return generator
|
||||
|
||||
|
||||
def load_discriminator() -> nn.Module:
|
||||
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
|
||||
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
|
||||
return discriminator
|
||||
|
||||
|
||||
def load_arcface() -> nn.Module:
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
|
||||
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
arcface.eval()
|
||||
return arcface
|
||||
|
||||
|
||||
def load_landmarker() -> nn.Module:
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
|
||||
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
landmarker.eval()
|
||||
return landmarker
|
||||
|
||||
|
||||
def load_motion_extractor() -> nn.Module:
|
||||
from LivePortrait.src.modules.motion_extractor import MotionExtractor
|
||||
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
|
||||
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
|
||||
motion_extractor.eval()
|
||||
return motion_extractor
|
||||
@@ -1,5 +1,389 @@
|
||||
from .model_loader import load_motion_extractor
|
||||
import configparser
|
||||
import random
|
||||
|
||||
from sympy.stats.sampling.sample_numpy import numpy
|
||||
|
||||
from typing import Tuple
|
||||
import os
|
||||
import cv2
|
||||
import torchvision
|
||||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.utilities.types import OptimizerLRScheduler
|
||||
import torch
|
||||
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .data_loader import DataLoaderVGG, read_image
|
||||
|
||||
from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch
|
||||
from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression
|
||||
from pytorch_msssim import ssim
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def load_models():
|
||||
id_channels = CONFIG.getint('training.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
|
||||
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
|
||||
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
|
||||
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
|
||||
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
arcface.eval()
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
|
||||
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
landmarker.eval()
|
||||
else:
|
||||
landmarker = None
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.modules.motion_extractor import MotionExtractor
|
||||
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
|
||||
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
|
||||
motion_extractor.eval()
|
||||
else:
|
||||
motion_extractor = None
|
||||
|
||||
if CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
||||
from LivePortrait.src.modules.warping_network import WarpingNetwork
|
||||
from LivePortrait.src.modules.spade_generator import SPADEDecoder
|
||||
|
||||
feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path')
|
||||
feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6)
|
||||
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True))
|
||||
feature_extractor.eval()
|
||||
|
||||
warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path')
|
||||
dense_motion_params = {
|
||||
'block_expansion': 32,
|
||||
'max_features': 1024,
|
||||
'num_blocks': 5,
|
||||
'reshape_depth': 16,
|
||||
'compress': 4
|
||||
}
|
||||
warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params)
|
||||
warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True))
|
||||
warping_network.eval()
|
||||
|
||||
spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path')
|
||||
spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2)
|
||||
spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True))
|
||||
spade_generator.eval()
|
||||
else:
|
||||
feature_extractor = None
|
||||
warping_network = None
|
||||
spade_generator = None
|
||||
return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
os.makedirs(output_directory_path, exist_ok = True)
|
||||
|
||||
return Trainer(
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = '16-mixed',
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'l_G',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
# every_n_epochs = 1,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 5,
|
||||
mode = 'min',
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
log_every_n_steps = 10,
|
||||
accumulate_grad_batches = 1,
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
return print(load_motion_extractor())
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
|
||||
dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path'))
|
||||
|
||||
if not (checkpoint_path and os.path.exists(checkpoint_path)):
|
||||
checkpoint_path = None
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapper(*load_models())
|
||||
trainer = create_trainer()
|
||||
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
|
||||
|
||||
|
||||
class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None:
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.arcface = arcface
|
||||
self.landmarker = landmarker
|
||||
self.motion_extractor = motion_extractor
|
||||
self.feature_extractor = feature_extractor
|
||||
self.warping_network = warping_network
|
||||
self.spade_generator = spade_generator
|
||||
|
||||
self.loss_adversarial_accumulated = 20
|
||||
self.automatic_optimization = False
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
|
||||
def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor:
|
||||
output = self.generator(target_tensor, source_embedding)
|
||||
return output
|
||||
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return {
|
||||
"generator": self.generator.state_dict(),
|
||||
"discriminator": self.discriminator.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True):
|
||||
if "generator" in state_dict:
|
||||
self.generator.load_state_dict(state_dict["generator"], strict = strict)
|
||||
if "discriminator" in state_dict:
|
||||
self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict)
|
||||
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
return (
|
||||
{
|
||||
"optimizer": generator_optimizer,
|
||||
"lr_scheduler": generator_scheduler
|
||||
},
|
||||
{
|
||||
"optimizer": discriminator_optimizer,
|
||||
"lr_scheduler": discriminator_scheduler
|
||||
})
|
||||
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers()
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
|
||||
|
||||
if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator)
|
||||
|
||||
swap_tensor, target_attributes = self(target_tensor, source_embedding)
|
||||
discriminator_outputs = self.discriminator(swap_tensor)
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
|
||||
if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4:
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
|
||||
self.log_generator_preview(source_tensor, target_tensor, swap_tensor)
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'validation_frequency') == 0:
|
||||
self.log_validation_preview()
|
||||
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True)
|
||||
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False)
|
||||
self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True)
|
||||
self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_losses = {}
|
||||
# adversarial loss
|
||||
loss_adversarial = 0
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ])
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
generator_losses['loss_adversarial'] = loss_adversarial
|
||||
generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02
|
||||
|
||||
# identity loss
|
||||
swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10))
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
generator_losses['loss_identity'] = loss_identity
|
||||
generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
|
||||
# attribute loss
|
||||
loss_attribute = 0
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
loss_attribute *= 0.5
|
||||
generator_losses['loss_attribute'] = loss_attribute
|
||||
generator_losses['loss_generator'] += loss_attribute * CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
|
||||
# reconstruction loss
|
||||
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
|
||||
generator_losses['loss_reconstruction'] = loss_reconstruction
|
||||
generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0:
|
||||
# tsr loss
|
||||
swap_motion_features = self.get_motion_features(swap_tensor)
|
||||
target_motion_features = self.get_motion_features(target_tensor)
|
||||
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_tsr += L2_loss(swap_motion_feature, target_motion_feature)
|
||||
generator_losses['loss_tsr'] = loss_tsr
|
||||
generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr')
|
||||
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
|
||||
swap_landmark_features = self.get_landmark_features(swap_tensor)
|
||||
target_landmark_features = self.get_landmark_features(target_tensor)
|
||||
|
||||
# eye gaze loss
|
||||
loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3])
|
||||
loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4])
|
||||
loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze
|
||||
generator_losses['loss_eye_gaze'] = loss_eye_gaze
|
||||
generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze')
|
||||
|
||||
# eye open loss
|
||||
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
|
||||
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
|
||||
loss_eye_open = loss_left_eye_open + loss_right_eye_open
|
||||
generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
|
||||
generator_losses['loss_generator'] += loss_eye_open
|
||||
|
||||
# lip open loss
|
||||
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
|
||||
generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
|
||||
generator_losses['loss_generator'] += loss_lip_open
|
||||
return generator_losses
|
||||
|
||||
|
||||
def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict:
|
||||
discriminator_losses = {}
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
loss_fake = 0
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3]))
|
||||
true_discriminator_outputs = self.discriminator(source_tensor)
|
||||
loss_true = 0
|
||||
|
||||
for true_discriminator_output in true_discriminator_outputs:
|
||||
loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3]))
|
||||
discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_losses
|
||||
|
||||
|
||||
def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
|
||||
_, _, height, width = vision_tensor.shape
|
||||
crop_height = int(height * 0.0586)
|
||||
crop_width = int(width * 0.0586)
|
||||
crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width]
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, -padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, -padding[3]:] = 0
|
||||
embedding = self.arcface(crop_vision_tensor)
|
||||
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2]
|
||||
landmarks = landmarks.view(-1, 203, 2) * 256
|
||||
left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12))
|
||||
right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36))
|
||||
lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66))
|
||||
left_eye_gaze = landmarks[:, 198]
|
||||
right_eye_gaze = landmarks[:, 197]
|
||||
return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze
|
||||
|
||||
|
||||
def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
motion_dict = self.motion_extractor(vision_tensor_norm)
|
||||
translation = motion_dict.get('t')
|
||||
scale = motion_dict.get('scale')
|
||||
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
def log_generator_preview(self, source_tensor, target_tensor, swap_tensor):
|
||||
max_preview = 8
|
||||
source_tensor = source_tensor[:max_preview]
|
||||
target_tensor = target_tensor[:max_preview]
|
||||
swap_tensor = swap_tensor[:max_preview]
|
||||
rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True)
|
||||
os.makedirs("previews", exist_ok=True)
|
||||
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
|
||||
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
|
||||
|
||||
def log_validation_preview(self):
|
||||
validation_source_path = CONFIG.get('training.validation', 'sources')
|
||||
validation_target_path = CONFIG.get('training.validation', 'targets')
|
||||
sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5
|
||||
self.generator.eval()
|
||||
results = []
|
||||
|
||||
for source, target in zip(sources, targets):
|
||||
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
|
||||
target_tensor = transforms(target).unsqueeze(0).to(self.device).half()
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
|
||||
|
||||
with torch.no_grad():
|
||||
output, _ = self.generator(target_tensor, source_embedding)
|
||||
results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)]))
|
||||
preview = numpy.vstack(results)
|
||||
os.makedirs("validation_previews", exist_ok=True)
|
||||
cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview)
|
||||
self.generator.train()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Tuple, List, Dict, Optional
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
Batch = Tuple[Tensor, Tensor, int]
|
||||
Batch = Tuple[Any, Any, Any]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
UNetAttributes = Tuple[Tensor, ...]
|
||||
TargetAttributes = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs = List[List[Tensor]]
|
||||
LossDict = Dict[str, Tensor]
|
||||
|
||||
Embedding = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
Reference in New Issue
Block a user