mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
clean generator, discriminator and typing
This commit is contained in:
@@ -4,4 +4,4 @@ plugins = flake8-import-order
|
||||
application_import_names = arcface_converter
|
||||
import-order-style = pycharm
|
||||
per-file-ignores = preparing.py:E402
|
||||
exclude = face_swapper
|
||||
exclude = face_swapper/LivePortrait
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch.nn as nn
|
||||
|
||||
from .typing import Tensor, DiscriminatorOutputs
|
||||
from .typing import DiscriminatorOutputs, List, Tensor
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
@@ -60,11 +58,9 @@ class MultiscaleDiscriminator(nn.Module):
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
|
||||
|
||||
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
|
||||
return [ model_layers(input_tensor) ]
|
||||
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
downsampled_input = input_tensor
|
||||
|
||||
+101
-110
@@ -1,132 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .typing import IDEmbedding, TargetAttributes, Tensor, Tuple
|
||||
|
||||
|
||||
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
def __init__(self, id_channels=512, num_blocks=2):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
|
||||
self.encoder = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
|
||||
def forward(self, target, source_embedding):
|
||||
def forward(self, target : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]:
|
||||
target_attributes = self.get_attributes(target)
|
||||
swap = self.generator(target_attributes, source_embedding)
|
||||
return swap, target_attributes
|
||||
|
||||
def get_attributes(self, target):
|
||||
def get_attributes(self, target : Tensor) -> TargetAttributes:
|
||||
return self.encoder(target)
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
def __init__(self, id_channels=512, num_blocks=2):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
super(AADGenerator, self).__init__()
|
||||
self.upsample = Upsample(id_channels, 1024 * 4)
|
||||
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, id_channels, num_blocks)
|
||||
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, id_channels, num_blocks)
|
||||
self.AADBlk5 = AAD_ResBlk(512, 256, 256, id_channels, num_blocks)
|
||||
self.AADBlk6 = AAD_ResBlk(256, 128, 128, id_channels, num_blocks)
|
||||
self.AADBlk7 = AAD_ResBlk(128, 64, 64, id_channels, num_blocks)
|
||||
self.AADBlk8 = AAD_ResBlk(64, 3, 64, id_channels, num_blocks)
|
||||
self.upsample = PixelShuffleUpsample(id_channels, 1024 * 4)
|
||||
self.res_block_1 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.res_block_2 = AADResBlock(1024, 1024, 2048, id_channels, num_blocks)
|
||||
self.res_block_3 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks)
|
||||
self.res_block_4 = AADResBlock(1024, 512, 512, id_channels, num_blocks)
|
||||
self.res_block_5 = AADResBlock(512, 256, 256, id_channels, num_blocks)
|
||||
self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks)
|
||||
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, target_attributes, source_embedding):
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : IDEmbedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map_1 = F.interpolate(self.AADBlk1(feature_map, target_attributes[0], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_2 = F.interpolate(self.AADBlk2(feature_map_1, target_attributes[1], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_3 = F.interpolate(self.AADBlk3(feature_map_2, target_attributes[2], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_4 = F.interpolate(self.AADBlk4(feature_map_3, target_attributes[3], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_5 = F.interpolate(self.AADBlk5(feature_map_4, target_attributes[4], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_6 = F.interpolate(self.AADBlk6(feature_map_5, target_attributes[5], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
feature_map_7 = F.interpolate(self.AADBlk7(feature_map_6, target_attributes[6], source_embedding), scale_factor=2, mode='bilinear', align_corners=False)
|
||||
output = self.AADBlk8(feature_map_7, target_attributes[7], source_embedding)
|
||||
feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_3 = torch.nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_4 = torch.nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_5 = torch.nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_6 = torch.nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_7 = torch.nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding)
|
||||
return torch.tanh(output)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.downsampler_1 = Conv4x4(3, 32)
|
||||
self.downsampler_2 = Conv4x4(32, 64)
|
||||
self.downsampler_3 = Conv4x4(64, 128)
|
||||
self.downsampler_4 = Conv4x4(128, 256)
|
||||
self.downsampler_5 = Conv4x4(256, 512)
|
||||
self.downsampler_6 = Conv4x4(512, 1024)
|
||||
|
||||
self.bottleneck = Conv4x4(1024, 1024)
|
||||
|
||||
self.upsampler_1 = DeConv4x4(1024, 1024)
|
||||
self.upsampler_2 = DeConv4x4(2048, 512)
|
||||
self.upsampler_3 = DeConv4x4(1024, 256)
|
||||
self.upsampler_4 = DeConv4x4(512, 128)
|
||||
self.upsampler_5 = DeConv4x4(256, 64)
|
||||
self.upsampler_6 = DeConv4x4(128, 32)
|
||||
self.downsampler_1 = DownSample(3, 32)
|
||||
self.downsampler_2 = DownSample(32, 64)
|
||||
self.downsampler_3 = DownSample(64, 128)
|
||||
self.downsampler_4 = DownSample(128, 256)
|
||||
self.downsampler_5 = DownSample(256, 512)
|
||||
self.downsampler_6 = DownSample(512, 1024)
|
||||
self.bottleneck = DownSample(1024, 1024)
|
||||
self.upsampler_1 = Upsample(1024, 1024)
|
||||
self.upsampler_2 = Upsample(2048, 512)
|
||||
self.upsampler_3 = Upsample(1024, 256)
|
||||
self.upsampler_4 = Upsample(512, 128)
|
||||
self.upsampler_5 = Upsample(256, 64)
|
||||
self.upsampler_6 = Upsample(128, 32)
|
||||
self.apply(initialize_weight)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
downsample_feature_1 = self.downsampler_1(input_tensor)
|
||||
def forward(self, target : Tensor) -> TargetAttributes:
|
||||
downsample_feature_1 = self.downsampler_1(target)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
|
||||
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
|
||||
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
|
||||
|
||||
bottleneck_output = self.bottleneck(downsample_feature_6)
|
||||
|
||||
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
|
||||
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
|
||||
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
|
||||
output = F.interpolate(upsample_feature_6, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
|
||||
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels, attr_channels, id_channels):
|
||||
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
|
||||
super(AADLayer, self).__init__()
|
||||
self.attr_channels = attr_channels
|
||||
self.id_channels = id_channels
|
||||
self.input_channels = input_channels
|
||||
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.fc_gamma = nn.Linear(id_channels, input_channels)
|
||||
self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
self.fc_beta = nn.Linear(id_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False)
|
||||
self.fc_gamma = nn.Linear(id_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True)
|
||||
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, feature_map, attr_embedding, id_embedding):
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
|
||||
attr_gamma = self.conv_gamma(attr_embedding)
|
||||
attr_beta = self.conv_beta(attr_embedding)
|
||||
attr_modulation = attr_gamma * feature_map + attr_beta
|
||||
|
||||
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
|
||||
feature_map)
|
||||
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(
|
||||
feature_map)
|
||||
id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
id_modulation = id_gamma * feature_map + id_beta
|
||||
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation
|
||||
return feature_blend
|
||||
|
||||
|
||||
class AddBlocksSequential(nn.Sequential):
|
||||
def forward(self, *inputs):
|
||||
h, attr_embedding, id_embedding = inputs
|
||||
def forward(self, *inputs : Tuple[Tensor, Tensor, IDEmbedding]) -> Tuple[Tuple[Tensor, Tensor, IDEmbedding], ...]:
|
||||
_, attr_embedding, id_embedding = inputs
|
||||
|
||||
for index, module in enumerate(self._modules.values()):
|
||||
if index % 3 == 0 and index > 0:
|
||||
inputs = (inputs, attr_embedding, id_embedding)
|
||||
inputs = (inputs, attr_embedding, id_embedding) # type:ignore[assignment]
|
||||
if type(inputs) == tuple:
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
@@ -134,9 +120,9 @@ class AddBlocksSequential(nn.Sequential):
|
||||
return inputs
|
||||
|
||||
|
||||
class AAD_ResBlk(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_blocks):
|
||||
super(AAD_ResBlk, self).__init__()
|
||||
class AADResBlock(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None:
|
||||
super(AADResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
primary_add_blocks = []
|
||||
@@ -146,8 +132,8 @@ class AAD_ResBlk(nn.Module):
|
||||
primary_add_blocks.extend(
|
||||
[
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
])
|
||||
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
|
||||
|
||||
@@ -155,12 +141,12 @@ class AAD_ResBlk(nn.Module):
|
||||
auxiliary_add_blocks = \
|
||||
[
|
||||
AADLayer(in_channels, attr_channels, id_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
|
||||
]
|
||||
self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks)
|
||||
|
||||
def forward(self, feature_map, attr_embedding, id_embedding):
|
||||
def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor:
|
||||
primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
@@ -169,50 +155,47 @@ class AAD_ResBlk(nn.Module):
|
||||
return output_feature
|
||||
|
||||
|
||||
class Conv4x4(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv4x4, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1,
|
||||
bias=False)
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.leaky_relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeConv4x4(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(DeConv4x4, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2,
|
||||
padding=1, bias=False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.deconv(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.leaky_relu(x)
|
||||
return torch.cat((x, skip), dim=1)
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return temp
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
super(Upsample, self).__init__()
|
||||
self.initial_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial_conv(x.view(x.shape[0], -1, 1, 1))
|
||||
x = self.pixel_shuffle(x)
|
||||
return x
|
||||
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp = self.deconv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return torch.cat((temp, skip_tensor), dim = 1)
|
||||
|
||||
|
||||
def initialize_weight(module):
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
def __init__(self, in_channels : int, out_channels : int) -> None:
|
||||
super(PixelShuffleUpsample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp.view(temp.shape[0], -1, 1, 1))
|
||||
temp = self.pixel_shuffle(temp)
|
||||
return temp
|
||||
|
||||
|
||||
def initialize_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(0, 0.001)
|
||||
module.bias.data.zero_()
|
||||
@@ -222,3 +205,11 @@ def initialize_weight(module):
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = AdaptiveEmbeddingIntegrationNetwork(512, 2)
|
||||
src = torch.randn(1, 512)
|
||||
trg = torch.randn(1, 3, 256, 256)
|
||||
out = model(trg, src)
|
||||
print(out[0].shape)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Tuple, List, Dict, Optional
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
@@ -9,6 +9,7 @@ Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
TargetAttributes = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs = List[List[Tensor]]
|
||||
LossDict = Dict[str, Tensor]
|
||||
IDEmbedding = Tensor
|
||||
|
||||
Embedding = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
|
||||
Reference in New Issue
Block a user