new swapper

This commit is contained in:
harisreedhar
2024-12-09 21:42:46 +05:30
committed by henryruhs
parent a461d9c389
commit 8e53c6bc9f
14 changed files with 550 additions and 1 deletions
+1 -1
View File
@@ -4,4 +4,4 @@ plugins = flake8-import-order
application_import_names = arcface_converter
import-order-style = pycharm
per-file-ignores = preparing.py:E402
exclude = LivePortrait
+3
View File
@@ -0,0 +1,3 @@
[submodule "face_swapper/LivePortrait"]
path = face_swapper/LivePortrait
url = https://github.com/KwaiVGI/LivePortrait
+3
View File
@@ -0,0 +1,3 @@
Non-Commercial license
Copyright (c) 2024 Henry Ruhs
+61
View File
@@ -0,0 +1,61 @@
[preparing.dataset]
dataset_path =
[preparing.dataloader]
same_person_probability = 0.2
[preparing.augmentation]
expression_augmentation = false
[training.loader]
batch_size = 6
num_workers = 8
[training.generator]
num_blocks = 2
id_channels = 512
[training.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
[auxiliary_models.paths]
arcface_path =
landmarker_path =
motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path =
warping_netwrk_path =
spade_generator_path =
[training.losses]
weight_adversarial = 1
weight_identity = 20
weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 0
weight_expression = 0
[training.optimizers]
scheduler_step = 5000
scheduler_gamma = 0.2
generator_learning_rate = 0.0004
discriminator_learning_rate = 0.0004
[training.trainer]
epochs = 50
disable_discriminator = false
[training.output]
directory_path =
file_pattern =
[exporting]
directory_path =
source_path =
target_path =
opset_version =
[execution]
providers =
+27
View File
@@ -0,0 +1,27 @@
import torch
from torch import Tensor
def apply_random_motion_blur(tensor_image : Tensor) -> Tensor:
kernel_size = 9
kernel = torch.zeros((kernel_size, kernel_size), dtype=torch.float32)
random_angle = torch.empty(1).uniform_(-2 * torch.pi, 2 * torch.pi)
dx = torch.cos(random_angle)
dy = torch.sin(random_angle)
center = kernel_size // 2
for i in range(kernel_size):
x = int(center + (i - center) * dx)
y = int(center + (i - center) * dy)
if 0 <= x < kernel_size and 0 <= y < kernel_size:
kernel[y, x] = 1
kernel /= kernel.sum()
kernel = kernel.unsqueeze(0).unsqueeze(0)
blurred_channels = []
for channel in tensor_image:
channel = channel.unsqueeze(0).unsqueeze(0)
channel = torch.nn.functional.conv2d(channel, kernel, padding=kernel_size // 2)
channel = channel.squeeze(0).squeeze(0)
blurred_channels.append(channel)
return torch.stack(blurred_channels)
+82
View File
@@ -0,0 +1,82 @@
import configparser
import glob
import random
import cv2
import torchvision.transforms as transforms
import tqdm
from PIL import Image
from torch.utils.data import TensorDataset
from .augmentations import apply_random_motion_blur
from .sub_typing import Batch
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def read_image(image_path: str) -> Image.Image:
image = cv2.imread(image_path)[:, :, ::-1]
pil_image = Image.fromarray(image)
return pil_image
class DataLoaderVGG(TensorDataset):
def __init__(self, dataset_path : str) -> None:
self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability'))
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
self.image_path_dict = {}
for folder_path in tqdm.tqdm(self.folder_paths):
image_paths = glob.glob('{}/*'.format(folder_path))
self.image_path_dict[folder_path] = image_paths
self.dataset_total = len(self.image_paths)
self.transforms_basic = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_moderate = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_complex = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __getitem__(self, item : int) -> Batch:
source_image_path = self.image_paths[item]
source = read_image(source_image_path)
if random.random() > self.same_person_probability:
is_same_person = 0
target_image_path = random.choice(self.image_paths)
target = read_image(target_image_path)
source_transform = self.transforms_moderate(source)
target_transform = self.transforms_complex(target)
else:
is_same_person = 1
source_folder_path = '/'.join(source_image_path.split('/')[:-1])
target_image_path = random.choice(self.image_path_dict[source_folder_path])
target = read_image(target_image_path)
source_transform = self.transforms_basic(source)
target_transform = self.transforms_basic(target)
return source_transform, target_transform, is_same_person
def __len__(self) -> int:
return self.dataset_total
+85
View File
@@ -0,0 +1,85 @@
from typing import List
import numpy
import torch.nn as nn
from .sub_typing import Tensor
class NLayerDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int) -> None:
super(NLayerDiscriminator, self).__init__()
self.num_layers = num_layers
kernel_size = 4
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
model_layers = [
[
nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.LeakyReLU(0.2, True)
]]
current_filters = num_filters
for layer_index in range(1, num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers += [
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True)
]]
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers += [
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]]
model_layers += [
[
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size)
]]
combined_layers = []
for layer in model_layers:
combined_layers += layer
self.model = nn.Sequential(*combined_layers)
def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)
# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
super(MultiscaleDiscriminator, self).__init__()
self.num_discriminators = num_discriminators
self.num_layers = num_layers
for discriminator_index in range(num_discriminators):
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
if self.return_intermediate_features:
feature_maps = [ input_tensor ]
for layer in model_layers:
feature_maps.append(layer(feature_maps[-1]))
return feature_maps[1:]
else:
return [ model_layers(input_tensor) ]
def forward(self, input_tensor : Tensor) -> List[Tensor]:
discriminator_outputs = []
downsampled_input = input_tensor
for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append(self.single_discriminator_forward(model_layers, downsampled_input))
if discriminator_index != (self.num_discriminators - 1):
downsampled_input = self.downsample(downsampled_input)
return discriminator_outputs
+212
View File
@@ -0,0 +1,212 @@
from typing import Tuple
import torch
import torch.nn as nn
from .sub_typing import Tensor, UNetAttributes
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
self.encoder = UNet()
self.generator = AdaptiveAttentionalDenorm_Generator(id_channels, num_blocks)
def forward(self, target : Tensor, source_embedding : Tensor) -> Tuple[Tensor, UNetAttributes]:
target_attributes = self.get_attributes(target)
swap = self.generator(target_attributes, source_embedding)
return swap, target_attributes
def get_attributes(self, target : Tensor) -> UNetAttributes:
return self.encoder(target)
class AdaptiveAttentionalDenorm_Generator(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
super(AdaptiveAttentionalDenorm_Generator, self).__init__()
self.upsample = Upsample(id_channels, 1024 * 4)
self.block_1 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
self.block_2 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 2048, id_channels, num_blocks)
self.block_3 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks)
self.block_4 = AdaptiveAttentionalDenorm_ResBlock(1024, 512, 512, id_channels, num_blocks)
self.block_5 = AdaptiveAttentionalDenorm_ResBlock(512, 256, 256, id_channels, num_blocks)
self.block_6 = AdaptiveAttentionalDenorm_ResBlock(256, 128, 128, id_channels, num_blocks)
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(128, 64, 64, id_channels, num_blocks)
self.block_7 = AdaptiveAttentionalDenorm_ResBlock(64, 3, 64, id_channels, num_blocks)
self.apply(initialize_weight)
def forward(self, target_attributes : UNetAttributes, source_embedding : Tensor) -> Tensor:
feature_map = self.upsample(source_embedding)
feature_map_1 = nn.functional.interpolate(self.block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_2 = nn.functional.interpolate(self.block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_3 = nn.functional.interpolate(self.block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_4 = nn.functional.interpolate(self.block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_5 = nn.functional.interpolate(self.block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_6 = nn.functional.interpolate(self.block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
feature_map_7 = nn.functional.interpolate(self.block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False)
output = self.block_7(feature_map_7, target_attributes[7], source_embedding)
return torch.tanh(output)
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.downsampler_1 = Conv4x4(3, 32)
self.downsampler_2 = Conv4x4(32, 64)
self.downsampler_3 = Conv4x4(64, 128)
self.downsampler_4 = Conv4x4(128, 256)
self.downsampler_5 = Conv4x4(256, 512)
self.downsampler_6 = Conv4x4(512, 1024)
self.bottleneck = Conv4x4(1024, 1024)
self.upsampler_1 = DeConv4x4(1024, 1024)
self.upsampler_2 = DeConv4x4(2048, 512)
self.upsampler_3 = DeConv4x4(1024, 256)
self.upsampler_4 = DeConv4x4(512, 128)
self.upsampler_5 = DeConv4x4(256, 64)
self.upsampler_6 = DeConv4x4(128, 32)
self.apply(initialize_weight)
def forward(self, input_tensor : Tensor) -> UNetAttributes:
downsample_feature_1 = self.downsampler_1(input_tensor)
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
bottleneck_output = self.bottleneck(downsample_feature_6)
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
class AdaptiveAttentionalDenorm_Layer(nn.Module):
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
super(AdaptiveAttentionalDenorm_Layer, self).__init__()
self.attr_channels = attr_channels
self.id_channels = id_channels
self.input_channels = input_channels
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
self.fc_gamma = nn.Linear(id_channels, input_channels)
self.fc_beta = nn.Linear(id_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False)
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True)
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
feature_map = self.instance_norm(feature_map)
attr_gamma = self.conv_gamma(attr_embedding)
attr_beta = self.conv_beta(attr_embedding)
attr_modulation = attr_gamma * feature_map + attr_beta
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
id_modulation = id_gamma * feature_map + id_beta
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation
return feature_blend
class AddBlocksSequential(nn.Sequential):
def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor:
feature_map, attr_embedding, id_embedding = inputs
for index, module in enumerate(self._modules.values()):
if index % 3 == 0 and index > 0:
inputs = (inputs, attr_embedding, id_embedding)
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class AdaptiveAttentionalDenorm_ResBlock(nn.Module):
def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None:
super(AdaptiveAttentionalDenorm_ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
primary_add_blocks = []
for i in range(num_blocks):
intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels
primary_add_blocks.extend(
[
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
])
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
if in_channels != out_channels:
auxiliary_add_blocks = \
[
AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
]
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor:
primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding)
if self.in_channels != self.out_channels:
feature_map = self.auxiliary_add_blocks(feature_map, attr_embedding, id_embedding)
output_feature = primary_feature + feature_map
return output_feature
class Conv4x4(nn.Module):
def __init__(self, in_channels : int, out_channels : int) -> None:
super(Conv4x4, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, input : Tensor) -> Tensor:
output = self.conv(input)
output = self.batch_norm(output)
output = self.leaky_relu(output)
return output
class DeConv4x4(nn.Module):
def __init__(self, in_channels : int, out_channels : int) -> None:
super(DeConv4x4, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, input : Tensor, skip_connection : Tensor) -> Tensor:
output = self.deconv(input)
output = self.batch_norm(output)
output = self.leaky_relu(output)
output = torch.cat((output, skip_connection), dim = 1)
return output
class Upsample(nn.Module):
def __init__(self, in_channels : int, out_channels : int):
super(Upsample, self).__init__()
self.initial_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
def forward(self, input : Tensor) -> Tensor:
output = self.initial_conv(input.view(input.shape[0], -1, 1, 1))
output = self.pixel_shuffle(output)
return output
def initialize_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.001)
module.bias.data.zero_()
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_normal_(module.weight.data)
+50
View File
@@ -0,0 +1,50 @@
import configparser
import torch
import torch.nn as nn
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def load_generator() -> nn.Module:
id_channels = CONFIG.getint('training.generator', 'id_channels')
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
return generator
def load_discriminator() -> nn.Module:
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
return discriminator
def load_arcface() -> nn.Module:
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
arcface.eval()
return arcface
def load_landmarker() -> nn.Module:
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
landmarker.eval()
return landmarker
def load_motion_extractor() -> nn.Module:
from LivePortrait.src.modules.motion_extractor import MotionExtractor
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
motion_extractor.eval()
return motion_extractor
+12
View File
@@ -0,0 +1,12 @@
from typing import Any, Tuple
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor, int]
Loader = DataLoader[Tuple[Tensor, ...]]
UNetAttributes = Tuple[Tensor, ...]
Embedding = NDArray[Any]
VisionFrame = NDArray[Any]
+5
View File
@@ -0,0 +1,5 @@
from .model_loader import load_motion_extractor
def train():
return print(load_motion_extractor())
+7
View File
@@ -0,0 +1,7 @@
#!/usr/bin/env python3
from src.training import train
if __name__ == '__main__':
train()
+1
View File
@@ -5,3 +5,4 @@ disallow_untyped_calls = True
disallow_untyped_defs = True
ignore_missing_imports = True
strict_optional = False
exclude = ^LivePortrait