ugly training code

This commit is contained in:
harisreedhar
2024-12-10 22:19:03 +05:30
committed by henryruhs
parent 7bef17b551
commit e6c2a64256
10 changed files with 655 additions and 167 deletions
+1 -1
View File
@@ -4,4 +4,4 @@ plugins = flake8-import-order
application_import_names = arcface_converter
import-order-style = pycharm
per-file-ignores = preparing.py:E402
exclude = LivePortrait
exclude = face_swapper
+29 -19
View File
@@ -1,55 +1,65 @@
[preparing.dataset]
dataset_path =
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan
[preparing.dataloader]
same_person_probability = 0.2
[preparing.augmentation]
expression_augmentation = false
expression = false
[training.loader]
batch_size = 6
batch_size = 4
num_workers = 8
[training.generator]
num_blocks = 2
id_channels = 512
learning_rate = 0.0004
[training.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
learning_rate = 0.0004
disable = false
[auxiliary_models.paths]
arcface_path =
landmarker_path =
motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path =
warping_netwrk_path =
spade_generator_path =
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
landmarker_path = /assets/pretrained_models/landmark_203.pt
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path = /assets/pretrained_models/liveportrait_feature_extractor.pth
warping_network_path = /assets/pretrained_models/liveportrait_warping_model.pth
spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pth
[training.losses]
weight_adversarial = 1
weight_identity = 20
weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 0
weight_expression = 0
weight_tsr = 100
weight_eye_gaze = 5
weight_eye_open = 5
weight_lip_open = 5
[training.optimizers]
scheduler_step = 5000
scheduler_gamma = 0.2
generator_learning_rate = 0.0004
discriminator_learning_rate = 0.0004
[training.schedulers]
step = 5000
gamma = 0.2
[training.trainer]
epochs = 50
max_epochs = 50
disable_discriminator = false
[training.output]
directory_path =
file_pattern =
checkpoint_path = checkpoints/last.ckpt
directory_path = checkpoints
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000
[training.validation]
sources = assets/test/front/sources
targets = assets/test/front/targets
[exporting]
directory_path =
+12 -3
View File
@@ -9,7 +9,7 @@ from PIL import Image
from torch.utils.data import TensorDataset
from .augmentations import apply_random_motion_blur
from .sub_typing import Batch
from .typing import Batch
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -27,6 +27,7 @@ class DataLoaderVGG(TensorDataset):
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
self.image_path_dict = {}
self._current_index = 0
for folder_path in tqdm.tqdm(self.folder_paths):
image_paths = glob.glob('{}/*'.format(folder_path))
@@ -50,12 +51,12 @@ class DataLoaderVGG(TensorDataset):
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
])
def __getitem__(self, item : int) -> Batch:
@@ -80,3 +81,11 @@ class DataLoaderVGG(TensorDataset):
def __len__(self) -> int:
return self.dataset_total
def state_dict(self):
return {'current_index': self._current_index}
def load_state_dict(self, state_dict):
self._current_index = state_dict['current_index']
+4 -11
View File
@@ -3,7 +3,7 @@ from typing import List
import numpy
import torch.nn as nn
from .sub_typing import Tensor
from .typing import Tensor, DiscriminatorOutputs
class NLayerDiscriminator(nn.Module):
@@ -49,7 +49,6 @@ class NLayerDiscriminator(nn.Module):
return self.model(input_tensor)
# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
super(MultiscaleDiscriminator, self).__init__()
@@ -61,18 +60,12 @@ class MultiscaleDiscriminator(nn.Module):
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
return [ model_layers(input_tensor) ]
if self.return_intermediate_features:
feature_maps = [ input_tensor ]
for layer in model_layers:
feature_maps.append(layer(feature_maps[-1]))
return feature_maps[1:]
else:
return [ model_layers(input_tensor) ]
def forward(self, input_tensor : Tensor) -> List[Tensor]:
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
downsampled_input = input_tensor
+89 -77
View File
@@ -1,55 +1,52 @@
from typing import Tuple
import torch
import torch.nn as nn
from .sub_typing import Tensor, UNetAttributes
import torch.nn.functional as F
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
def __init__(self, id_channels=512, num_blocks=2):
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
self.encoder = UNet()
self.generator = AdaptiveAttentionalDenorm_Generator(id_channels, num_blocks)
self.generator = AADGenerator(id_channels, num_blocks)
def forward(self, target : Tensor, source_embedding : Tensor) -> Tuple[Tensor, UNetAttributes]:
def forward(self, target, source_embedding):
target_attributes = self.get_attributes(target)
swap = self.generator(target_attributes, source_embedding)
return swap, target_attributes
def get_attributes(self, target : Tensor) -> UNetAttributes:
def get_attributes(self, target):
return self.encoder(target)
class AdaptiveAttentionalDenorm_Generator(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
super(AdaptiveAttentionalDenorm_Generator, self).__init__()
class AADGenerator(nn.Module):
def __init__(self, id_channels=512, num_blocks=2):
super(AADGenerator, self).__init__()
self.upsample = Upsample(id_channels, 1024 * 4)
self.block_1 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
self.block_2 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 2048, id_channels, num_blocks)
self.block_3 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
self.block_4 = AdaptiveAttentionalDenorm_ResBlock(1024, 512, 512, id_channels, num_blocks)
self.block_5 = AdaptiveAttentionalDenorm_ResBlock(512, 256, 256, id_channels, num_blocks)
self.block_6 = AdaptiveAttentionalDenorm_ResBlock(256, 128, 128, id_channels, num_blocks)
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(128, 64, 64, id_channels, num_blocks)
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(64, 3, 64, id_channels, num_blocks)
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, id_channels, num_blocks)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, id_channels, num_blocks)
self.AADBlk5 = AAD_ResBlk(512, 256, 256, id_channels, num_blocks)
self.AADBlk6 = AAD_ResBlk(256, 128, 128, id_channels, num_blocks)
self.AADBlk7 = AAD_ResBlk(128, 64, 64, id_channels, num_blocks)
self.AADBlk8 = AAD_ResBlk(64, 3, 64, id_channels, num_blocks)
self.apply(initialize_weight)
def forward(self, target_attributes : UNetAttributes, source_embedding : Tensor) -> Tensor:
def forward(self, target_attributes, source_embedding):
feature_map = self.upsample(source_embedding)
feature_map_1 = nn.functional.interpolate(self.block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_2 = nn.functional.interpolate(self.block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_3 = nn.functional.interpolate(self.block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_4 = nn.functional.interpolate(self.block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_5 = nn.functional.interpolate(self.block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_6 = nn.functional.interpolate(self.block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_7 = nn.functional.interpolate(self.block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
output = self.block_7(feature_map_7, target_attributes[7], source_embedding)
feature_map_1 = F.interpolate(self.AADBlk1(feature_map, target_attributes[0], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_2 = F.interpolate(self.AADBlk2(feature_map_1, target_attributes[1], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_3 = F.interpolate(self.AADBlk3(feature_map_2, target_attributes[2], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_4 = F.interpolate(self.AADBlk4(feature_map_3, target_attributes[3], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_5 = F.interpolate(self.AADBlk5(feature_map_4, target_attributes[4], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_6 = F.interpolate(self.AADBlk6(feature_map_5, target_attributes[5], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
feature_map_7 = F.interpolate(self.AADBlk7(feature_map_6, target_attributes[6], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
output = self.AADBlk8(feature_map_7, target_attributes[7], source_embedding)
return torch.tanh(output)
class UNet(nn.Module):
def __init__(self) -> None:
def __init__(self):
super(UNet, self).__init__()
self.downsampler_1 = Conv4x4(3, 32)
self.downsampler_2 = Conv4x4(32, 64)
@@ -57,7 +54,9 @@ class UNet(nn.Module):
self.downsampler_4 = Conv4x4(128, 256)
self.downsampler_5 = Conv4x4(256, 512)
self.downsampler_6 = Conv4x4(512, 1024)
self.bottleneck = Conv4x4(1024, 1024)
self.upsampler_1 = DeConv4x4(1024, 1024)
self.upsampler_2 = DeConv4x4(2048, 512)
self.upsampler_3 = DeConv4x4(1024, 256)
@@ -66,53 +65,64 @@ class UNet(nn.Module):
self.upsampler_6 = DeConv4x4(128, 32)
self.apply(initialize_weight)
def forward(self, input_tensor : Tensor) -> UNetAttributes:
def forward(self, input_tensor):
downsample_feature_1 = self.downsampler_1(input_tensor)
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
bottleneck_output = self.bottleneck(downsample_feature_6)
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
output = F.interpolate(upsample_feature_6, scale_factor=2, mode='bilinear', align_corners=False)
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
class AdaptiveAttentionalDenorm_Layer(nn.Module):
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
super(AdaptiveAttentionalDenorm_Layer, self).__init__()
class AADLayer(nn.Module):
def __init__(self, input_channels, attr_channels, id_channels):
super(AADLayer, self).__init__()
self.attr_channels = attr_channels
self.id_channels = id_channels
self.input_channels = input_channels
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
self.fc_gamma = nn.Linear(id_channels, input_channels)
self.fc_beta = nn.Linear(id_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False)
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True)
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, feature_map, attr_embedding, id_embedding):
feature_map = self.instance_norm(feature_map)
attr_gamma = self.conv_gamma(attr_embedding)
attr_beta = self.conv_beta(attr_embedding)
attr_modulation = attr_gamma * feature_map + attr_beta
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
feature_map)
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
feature_map)
id_modulation = id_gamma * feature_map + id_beta
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation
return feature_blend
class AddBlocksSequential(nn.Sequential):
def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor:
feature_map, attr_embedding, id_embedding = inputs
def forward(self, *inputs):
h, attr_embedding, id_embedding = inputs
for index, module in enumerate(self._modules.values()):
if index % 3 == 0 and index > 0:
@@ -124,9 +134,9 @@ class AddBlocksSequential(nn.Sequential):
return inputs
class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None:
super(AdaptiveAttentionalDenorm_ResBlock, self).__init__()
class AAD_ResBlk(nn.Module):
def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_blocks):
super(AAD_ResBlk, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
primary_add_blocks = []
@@ -135,22 +145,22 @@ class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels
primary_add_blocks.extend(
[
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
AADLayer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False)
])
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
if in_channels != out_channels:
auxiliary_add_blocks = \
[
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
AADLayer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
]
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
def forward(self, feature_map, attr_embedding, id_embedding):
primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding)
if self.in_channels != self.out_channels:
@@ -160,47 +170,49 @@ class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
class Conv4x4(nn.Module):
def __init__(self, in_channels : int, out_channels : int) -> None:
def __init__(self, in_channels, out_channels):
super(Conv4x4, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1,
bias=False)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
def forward(self, input : Tensor) -> Tensor:
output = self.conv(input)
output = self.batch_norm(output)
output = self.leaky_relu(output)
return output
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = self.leaky_relu(x)
return x
class DeConv4x4(nn.Module):
def __init__(self, in_channels : int, out_channels : int) -> None:
def __init__(self, in_channels, out_channels):
super(DeConv4x4, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2,
padding=1, bias=False)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
def forward(self, input : Tensor, skip_connection : Tensor) -> Tensor:
output = self.deconv(input)
output = self.batch_norm(output)
output = self.leaky_relu(output)
output = torch.cat((output, skip_connection), dim = 1)
return output
def forward(self, x, skip):
x = self.deconv(x)
x = self.batch_norm(x)
x = self.leaky_relu(x)
return torch.cat((x, skip), dim=1)
class Upsample(nn.Module):
def __init__(self, in_channels : int, out_channels : int):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.initial_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
self.initial_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1,
padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
def forward(self, input : Tensor) -> Tensor:
output = self.initial_conv(input.view(input.shape[0], -1, 1, 1))
output = self.pixel_shuffle(output)
return output
def forward(self, x):
x = self.initial_conv(x.view(x.shape[0], -1, 1, 1))
x = self.pixel_shuffle(x)
return x
def initialize_weight(module : nn.Module) -> None:
def initialize_weight(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.001)
module.bias.data.zero_()
+128
View File
@@ -0,0 +1,128 @@
import configparser
from typing import Tuple
import torch
from .typing import Tensor
import numpy
import torch.nn.functional as F
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
if CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
L2_loss = torch.nn.MSELoss()
EXPRESSION_MIN = numpy.array(
[
[
[-2.88067125e-02, -8.12731311e-02, -1.70541159e-03],
[-4.88598682e-02, -3.32196616e-02, -1.67431499e-04],
[-6.75425082e-02, -4.28681746e-02, -1.98950816e-04],
[-7.23103955e-02, -3.28503326e-02, -7.31324719e-04],
[-3.87073644e-02, -6.01546466e-02, -5.50269964e-04],
[-6.38048723e-02, -2.23840728e-01, -7.13261834e-04],
[-3.02710701e-02, -3.93195450e-02, -8.24086510e-06],
[-2.95799859e-02, -5.39318882e-02, -1.74219604e-04],
[-2.92359516e-02, -1.53050944e-02, -6.30460854e-05],
[-5.56493877e-03, -2.34344602e-02, -1.26858242e-04],
[-4.37593013e-02, -2.77768299e-02, -2.70503685e-02],
[-1.76926646e-02, -1.91676542e-02, -1.15090821e-04],
[-8.34268332e-03, -3.99775570e-03, -3.27481248e-05],
[-3.40162888e-02, -2.81868968e-02, -1.96679524e-04],
[-2.91855410e-02, -3.97511162e-02, -2.81230678e-05],
[-1.50395725e-02, -2.49494594e-02, -9.42573533e-05],
[-1.67938769e-02, -2.00953931e-02, -4.00750607e-04],
[-1.86435618e-02, -2.48535164e-02, -2.74416432e-02],
[-4.61211195e-03, -1.21660791e-02, -2.93173041e-04],
[-4.10017073e-02, -7.43824020e-02, -4.42762971e-02],
[-1.90370996e-02, -3.74363363e-02, -1.34740388e-02]
]
]).astype(numpy.float32)
EXPRESSION_MAX = numpy.array(
[
[
[4.46682945e-02, 7.08772913e-02, 4.08344204e-04],
[2.14308221e-02, 6.15894832e-02, 4.85319615e-05],
[3.02363783e-02, 4.45043296e-02, 1.28298725e-05],
[3.05869691e-02, 3.79812494e-02, 6.57040102e-04],
[4.45670523e-02, 3.97259220e-02, 7.10966764e-04],
[9.43699256e-02, 9.85926315e-02, 2.02551950e-04],
[1.61131397e-02, 2.92906128e-02, 3.44733417e-06],
[5.23825921e-02, 1.07065082e-01, 6.61510974e-04],
[2.85718683e-03, 8.32320191e-03, 2.39314613e-04],
[2.57947259e-02, 1.60935968e-02, 2.41853559e-05],
[4.90833223e-02, 3.43903080e-02, 3.22353356e-02],
[1.44766076e-02, 3.39248963e-02, 1.42291479e-04],
[8.75749043e-04, 6.82212645e-03, 2.76097053e-05],
[1.86958015e-02, 3.84016186e-02, 7.33085908e-05],
[2.01714113e-02, 4.90544215e-02, 2.34028921e-05],
[2.46518422e-02, 3.29151377e-02, 3.48571630e-05],
[2.22457591e-02, 1.21796541e-02, 1.56396593e-04],
[1.72109623e-02, 3.01626958e-02, 1.36556877e-02],
[1.83460284e-02, 1.61141958e-02, 2.87440169e-04],
[3.57594155e-02, 1.80554688e-01, 2.75554154e-02],
[2.17450950e-02, 8.66811201e-02, 3.34241726e-02]
]
]).astype(numpy.float32)
def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator):
with torch.no_grad():
face_tensor_norm = (face_tensor + 1) * 0.5
input_device = face_tensor.device
feature_volume = feature_extractor(face_tensor_norm)
motion_extractor_dict = motion_extractor(face_tensor_norm)
translation = motion_extractor_dict.get('t')
expression = motion_extractor_dict.get('exp')
scale = motion_extractor_dict.get('scale')
points = motion_extractor_dict.get('kp')
pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None]
yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None]
roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None]
rotation_matrix = get_rotation_matrix(pitch, yaw, roll)
random_expression = get_random_expression_blend(expression)
points_transformed = transform_points(points, rotation_matrix, expression, scale, translation)
points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation)
data = warping_network(feature_volume, points_driv, points_transformed).get('out')
output = spade_generator(data)
output = output.to(input_device)
output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False)
output = (output - 0.5) * 2
return output
def get_random_expression_blend(expression : Tensor) -> Tensor:
blend = 0.35
expression = expression.view(-1, 21, 3)
min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array
random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]]
random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9
random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5
return random_batch * 0.8 * blend + expression * (1 - blend)
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor):
points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3)
points_transformed *= scale[..., None]
points_transformed[:, :, 0:2] += translation[:, None, 0:2]
return points_transformed
def hinge_loss(tensor : Tensor, is_positive : bool) -> Tensor:
if is_positive:
return torch.relu(1 - tensor)
else:
return torch.relu(tensor + 1)
def calc_distance_ratio(landmarks : Tensor, indices : Tuple[int, int, int, int]) -> Tensor:
distance_horizontal = torch.norm(landmarks[:, indices[0]] - landmarks[:, indices[1]], p = 2, dim = 1, keepdim = True)
distance_vertical = torch.norm(landmarks[:, indices[2]] - landmarks[:, indices[3]], p=2, dim = 1, keepdim = True)
return distance_horizontal / (distance_vertical + 1e-4)
-50
View File
@@ -1,50 +0,0 @@
import configparser
import torch
import torch.nn as nn
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def load_generator() -> nn.Module:
id_channels = CONFIG.getint('training.generator', 'id_channels')
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
return generator
def load_discriminator() -> nn.Module:
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
return discriminator
def load_arcface() -> nn.Module:
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
arcface.eval()
return arcface
def load_landmarker() -> nn.Module:
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
landmarker.eval()
return landmarker
def load_motion_extractor() -> nn.Module:
from LivePortrait.src.modules.motion_extractor import MotionExtractor
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
motion_extractor.eval()
return motion_extractor
+386 -2
View File
@@ -1,5 +1,389 @@
from .model_loader import load_motion_extractor
import configparser
import random
from sympy.stats.sampling.sample_numpy import numpy
from typing import Tuple
import os
import cv2
import torchvision
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.types import OptimizerLRScheduler
import torch
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .data_loader import DataLoaderVGG, read_image
from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch
from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression
from pytorch_msssim import ssim
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def load_models():
id_channels = CONFIG.getint('training.generator', 'id_channels')
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
arcface.eval()
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
landmarker.eval()
else:
landmarker = None
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.modules.motion_extractor import MotionExtractor
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
motion_extractor.eval()
else:
motion_extractor = None
if CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
from LivePortrait.src.modules.warping_network import WarpingNetwork
from LivePortrait.src.modules.spade_generator import SPADEDecoder
feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path')
feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6)
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True))
feature_extractor.eval()
warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path')
dense_motion_params = {
'block_expansion': 32,
'max_features': 1024,
'num_blocks': 5,
'reshape_depth': 16,
'compress': 4
}
warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params)
warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True))
warping_network.eval()
spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path')
spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2)
spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True))
spade_generator.eval()
else:
feature_extractor = None
warping_network = None
spade_generator = None
return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator
def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
os.makedirs(output_directory_path, exist_ok = True)
return Trainer(
max_epochs = trainer_max_epochs,
precision = '16-mixed',
callbacks =
[
ModelCheckpoint(
monitor = 'l_G',
dirpath = output_directory_path,
filename = output_file_pattern,
# every_n_epochs = 1,
every_n_train_steps = 1000,
save_top_k = 5,
mode = 'min',
save_last = True
)
],
log_every_n_steps = 10,
accumulate_grad_batches = 1,
)
def train():
return print(load_motion_extractor())
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path'))
if not (checkpoint_path and os.path.exists(checkpoint_path)):
checkpoint_path = None
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapper(*load_models())
trainer = create_trainer()
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
class FaceSwapper(pytorch_lightning.LightningModule):
def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None:
super().__init__()
self.generator = generator
self.discriminator = discriminator
self.arcface = arcface
self.landmarker = landmarker
self.motion_extractor = motion_extractor
self.feature_extractor = feature_extractor
self.warping_network = warping_network
self.spade_generator = spade_generator
self.loss_adversarial_accumulated = 20
self.automatic_optimization = False
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor:
output = self.generator(target_tensor, source_embedding)
return output
def state_dict(self, *args, **kwargs):
return {
"generator": self.generator.state_dict(),
"discriminator": self.discriminator.state_dict(),
}
def load_state_dict(self, state_dict, strict: bool = True):
if "generator" in state_dict:
self.generator.load_state_dict(state_dict["generator"], strict = strict)
if "discriminator" in state_dict:
self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict)
def configure_optimizers(self) -> OptimizerLRScheduler:
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
return (
{
"optimizer": generator_optimizer,
"lr_scheduler": generator_scheduler
},
{
"optimizer": discriminator_optimizer,
"lr_scheduler": discriminator_scheduler
})
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, is_same_person = batch
generator_optimizer, discriminator_optimizer = self.optimizers()
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'):
target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator)
swap_tensor, target_attributes = self(target_tensor, source_embedding)
discriminator_outputs = self.discriminator(swap_tensor)
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_losses.get('loss_generator'))
generator_optimizer.step()
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_losses.get('loss_discriminator'))
if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4:
discriminator_optimizer.step()
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
self.log_generator_preview(source_tensor, target_tensor, swap_tensor)
if self.global_step % CONFIG.getint('training.output', 'validation_frequency') == 0:
self.log_validation_preview()
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True)
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False)
self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True)
self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True)
self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True)
return generator_losses.get('loss_generator')
def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict:
source_tensor, target_tensor, is_same_person = batch
generator_losses = {}
# adversarial loss
loss_adversarial = 0
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ])
loss_adversarial = torch.mean(loss_adversarial)
generator_losses['loss_adversarial'] = loss_adversarial
generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial')
self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02
# identity loss
swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10))
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
generator_losses['loss_identity'] = loss_identity
generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity')
# attribute loss
loss_attribute = 0
swap_attributes = self.generator.get_attributes(swap_tensor)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attribute *= 0.5
generator_losses['loss_attribute'] = loss_attribute
generator_losses['loss_generator'] += loss_attribute * CONFIG.getfloat('training.losses', 'weight_attribute')
# reconstruction loss
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
generator_losses['loss_reconstruction'] = loss_reconstruction
generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction')
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0:
# tsr loss
swap_motion_features = self.get_motion_features(swap_tensor)
target_motion_features = self.get_motion_features(target_tensor)
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_tsr += L2_loss(swap_motion_feature, target_motion_feature)
generator_losses['loss_tsr'] = loss_tsr
generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr')
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
swap_landmark_features = self.get_landmark_features(swap_tensor)
target_landmark_features = self.get_landmark_features(target_tensor)
# eye gaze loss
loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3])
loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4])
loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze
generator_losses['loss_eye_gaze'] = loss_eye_gaze
generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze')
# eye open loss
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
loss_eye_open = loss_left_eye_open + loss_right_eye_open
generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
generator_losses['loss_generator'] += loss_eye_open
# lip open loss
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
generator_losses['loss_generator'] += loss_lip_open
return generator_losses
def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict:
discriminator_losses = {}
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
loss_fake = 0
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3]))
true_discriminator_outputs = self.discriminator(source_tensor)
loss_true = 0
for true_discriminator_output in true_discriminator_outputs:
loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3]))
discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_losses
def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
_, _, height, width = vision_tensor.shape
crop_height = int(height * 0.0586)
crop_width = int(width * 0.0586)
crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width]
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear')
crop_vision_tensor[:, :, :padding[0], :] = 0
crop_vision_tensor[:, :, -padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, -padding[3]:] = 0
embedding = self.arcface(crop_vision_tensor)
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
return embedding
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2]
landmarks = landmarks.view(-1, 203, 2) * 256
left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12))
right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36))
lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66))
left_eye_gaze = landmarks[:, 198]
right_eye_gaze = landmarks[:, 197]
return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze
def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
motion_dict = self.motion_extractor(vision_tensor_norm)
translation = motion_dict.get('t')
scale = motion_dict.get('scale')
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
return translation, scale, rotation
def log_generator_preview(self, source_tensor, target_tensor, swap_tensor):
max_preview = 8
source_tensor = source_tensor[:max_preview]
target_tensor = target_tensor[:max_preview]
swap_tensor = swap_tensor[:max_preview]
rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True)
os.makedirs("previews", exist_ok=True)
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
def log_validation_preview(self):
validation_source_path = CONFIG.get('training.validation', 'sources')
validation_target_path = CONFIG.get('training.validation', 'targets')
sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5
self.generator.eval()
results = []
for source, target in zip(sources, targets):
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
target_tensor = transforms(target).unsqueeze(0).to(self.device).half()
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
with torch.no_grad():
output, _ = self.generator(target_tensor, source_embedding)
results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)]))
preview = numpy.vstack(results)
os.makedirs("validation_previews", exist_ok=True)
cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview)
self.generator.train()
@@ -1,12 +1,14 @@
from typing import Any, Tuple
from typing import Any, Tuple, List, Dict, Optional
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor, int]
Batch = Tuple[Any, Any, Any]
Loader = DataLoader[Tuple[Tensor, ...]]
UNetAttributes = Tuple[Tensor, ...]
TargetAttributes = Tuple[Tensor, ...]
DiscriminatorOutputs = List[List[Tensor]]
LossDict = Dict[str, Tensor]
Embedding = NDArray[Any]
VisionFrame = NDArray[Any]
+1 -1
View File
@@ -6,4 +6,4 @@ disallow_untyped_defs = True
ignore_missing_imports = True
strict_optional = False
explicit_package_bases = True
exclude = face_swapper/LivePortrait
exclude = face_swapper